from itertools import product
import numpy as np

class BaseEncoding:
    """ Basic ACII byte encoding """

    @classmethod
    def from_string(cls, sequence):
        return np.array([ord(c) for c in sequence], dtype=np.uint8)

    @classmethod
    def from_bytes(cls, sequence):
        """Identity"""
        return sequence

    @classmethod
    def to_bytes(cls, sequence):
        """Identity"""
        return sequence

    @classmethod
    def to_string(cls, byte_sequence):
        return "".join(chr(b) for b in byte_sequence)

class ACTGTwoBitEncoding:
    letters = ["A", "C", "T", "G"]
    bitcodes = ["00", "01", 
                "10", "11"]
    reverse = np.array([1, 3, 20, 7], dtype=np.uint8)
    _lookup_2bytes_to_4bits = np.zeros(256*256, dtype=np.uint8)
    _lookup_2bytes_to_4bits[256*reverse[np.arange(4)[:, None]]+reverse[np.arange(4)]] = np.arange(4)[:, None]*4+np.arange(4)
    _shift_4bits = (4*np.arange(2, dtype=np.uint8))
    _shift_2bits = 2*np.arange(4, dtype=np.uint8)

    @classmethod
    def convert_2bytes_to_4bits(cls, two_bytes):
        assert two_bytes.dtype == np.uint16, two_bytes.dtype
        return cls._lookup_2bytes_to_4bits[two_bytes]

    @classmethod
    def join_4bits_to_byte(cls, four_bits):
        return np.bitwise_or.reduce(four_bits << cls._shift_4bits, axis=1)

    @classmethod
    def complement(cls, char):
        complements = np.packbits([1, 0, 1, 0, 1, 0, 1, 0])
        dtype = char.dtype
        return (char.view(np.uint8) ^ complements).view(dtype)

    @classmethod
    def from_bytes(cls, sequence):
        assert sequence.dtype==np.uint8
        assert sequence.size % 4 == 0, sequence.size
        sequence = sequence & 31
        four_bits = cls.convert_2bytes_to_4bits(sequence.view(np.uint16))
        codes = cls.join_4bits_to_byte(four_bits.reshape(-1, 2))
        assert codes.dtype == np.uint8, codes.dtype
        return codes.flatten().view(np.uint8)

    @classmethod
    def from_string(cls, string):
        byte_repr = np.array([ord(c) for c in string], dtype=np.uint8)
        return cls.from_bytes(byte_repr)

    @classmethod
    def to_string(cls, bits):
        byte_repr = cls.to_bytes(bits)
        return "".join(chr(b) for b in byte_repr)

    @classmethod
    def to_bytes(cls, sequence):
        assert sequence.dtype==np.uint8
        bit_mask = np.uint8(3) # last two bits
        all_bytes = (sequence[:, None]>>cls._shift_2bits) & bit_mask
        return cls.reverse[all_bytes.flatten()]+96


class SimpleEncoding(ACTGTwoBitEncoding):
    _lookup_byte_to_2bits = np.zeros(256, dtype=np.uint8)
    _lookup_byte_to_2bits[[97, 65]] = 0
    _lookup_byte_to_2bits[[99, 67]] = 1
    _lookup_byte_to_2bits[[116, 84]] = 2
    _lookup_byte_to_2bits[[103, 71]] = 3

    _shift_2bits = 2*np.arange(4, dtype=np.uint8)

    @classmethod
    def convert_byte_to_2bits(cls, one_byte):
        assert one_byte.dtype == np.uint8, one_byte.dtype
        return cls._lookup_byte_to_2bits[one_byte]

    @classmethod
    def join_2bits_to_byte(cls, two_bits_vector):
        return np.bitwise_or.reduce(two_bits_vector << cls._shift_2bits, axis=-1)

    @classmethod
    def from_bytes(cls, sequence):
        assert sequence.dtype==np.uint8
        assert sequence.size % 4 == 0, sequence.size
        two_bits = cls.convert_byte_to_2bits(sequence)
        codes = cls.join_2bits_to_byte(two_bits.reshape(-1, 4))
        return codes.flatten()

def twobit_swap(number):
    dtype = number.dtype
    byte_lookup = np.zeros(256, dtype=np.uint8)
    power_array = 4**np.arange(4)
    rev_power_array = power_array[::-1]
    for two_bit_string in product([0, 1, 2, 3], repeat=4):
        byte_lookup[np.sum(power_array*two_bit_string)] = np.sum(rev_power_array*two_bit_string)
    new_bytes = byte_lookup[number.view(np.uint8)]
    return new_bytes.view(dtype).byteswap()
