from os import path
import numpy as np
import random
import warnings
from pydub import AudioSegment
from g2p_en import G2p
from typing import List, Tuple

g2p = G2p()

morshu_wav_fp = path.join(path.dirname(__file__), 'morshu.wav')
morshu_wav = AudioSegment.from_wav(morshu_wav_fp)

# Record that contains each recognizable phoneme in the morshu audio file,
# along with the time that phoneme ends in milliseconds, and the priority (how the phoneme sounds compared to others).
morshu_rec = np.rec.array([
    # typos in comments are intentional
    ('', 160, 0), ('L', 250, 2), ('AE', 348, 2), ('M', 420, 2), ('P', 510, 1),  # lamp
    ('OY', 700, 2), ('L', 835, 1), ('', 1090, 0),  # oil
    ('R', 1180, 2), ('OW', 1300, 2), ('', 1390, 0), ('P', 1490, 2), ('', 1850, 0),  # rope
    ('B', 1895, 2), ('AA', 2090, 2), ('M', 2235, 2), ('Z', 2390, 2),  # bombs
    ('', 2780, 0), ('Y', 2840, 2), ('UW', 2960, 2),  # you
    ('W', 3030, 2), ('AA', 3110, 2), ('N', 3150, 1), ('IH', 3240, 2), ('T', 3370, 2), ('', 3810, 0),  # won it
    ('IH', 3960, 2), ('T', 4070, 2), ('Y', 4260, 2), ('UH', 4400, 2), ('R', 4510, 2), ('Z', 4600, 2),  # it yours
    ('M', 4675, 2), ('AY', 4810, 2), ('', 4885, 0),  # my
    ('F', 4930, 2), ('R', 4980, 2), ('EH', 5100, 2), ('N', 5240, 2), ('D', 5300, 2), ('', 5520, 0),  # friend
    ('AE', 5630, 2), ('Z', 5740, 2), ('L', 5870, 2), ('AO', 6000, 2), ('NG', 6140, 2),  # as long
    ('AE', 6170, 1), ('Z', 6265, 2), ('Y', 6300, 2), ('UW', 6380, 2),  # as you
    ('HH', 6450, 2), ('AE', 6510, 1), ('V', 6580, 2),  # have
    ('IH', 6640, 2), ('N', 6670, 2), ('AH', 6747, 2), ('F', 6855, 2),  # enough
    ('R', 6960, 2), ('UW', 7060, 2), ('B', 7170, 1), ('IY', 7340, 2), ('Z', 7520, 2), ('', 8236, 0),  # rubies

    ('S', 8407, 2), ('AA', 8495, 2), ('R', 8570, 2), ('IY', 8630, 1),  # sorry
    ('L', 8740, 2), ('IH', 8811, 2), ('NG', 8942, 2), ('K', 9014, 2), ('', 9251, 0),  # link
    ('AY', 9384, 2), ('', 9467, 0), ('K', 9512, 2), ('AE', 9640, 2), ('N', 9716, 2), ('', 9844, 0),  # i can
    ('G', 9894, 2), ('IH', 9985, 2), ('V', 10060, 2), ('', 10149, 0),  # give
    ('K', 10256, 2), ('R', 10297, 2), ('EH', 10383, 2), ('IH', 10482, 1), ('', 10564, 0), ('T', 10617, 2),  # cre-it
    ('', 10962, 0), ('K', 11019, 2), ('AH', 11100, 2), ('M', 11229, 2), ('B', 11246, 2), ('AE', 11369, 2),  # come ba-
    ('', 11511, 0), ('W', 11590, 2), ('EH', 11622, 1), ('N', 11705, 2),  # when
    ('Y', 11755, 2), ('UH', 11808, 2), ('R', 11864, 2), ('AH', 11959, 2),  # you're a
    ('L', 12095, 2), ('IH', 12202, 2), ('L', 12386, 2),  # lil
    ('', 12596, 0), ('M', 12748, 2), ('M', 12888, 2), ('M', 13037, 2), ('M', 13196, 2), ('', 13426, 0),  # MMMM
    ('R', 13494, 2), ('IH', 13589, 2), ('', 13632, 0), ('CH', 13773, 2), ('ER', 13991, 2), ('', 13992, 0)  # richer
], names=('phoneme', 'timing', 'priority'))

# substitutes to phonemes that morshu doesn't say (some of these are tentative
similar_phonemes = {
    'AW': ['AE', 'UW'],
    'DH': ['D'],
    'EY': ['EH', 'IY'],
    'JH': ['CH'],
    'SH': ['CH'],
    'TH': ['D'],
    'ZH': ['CH'],
}


class Morshu:
    def __init__(self):
        self.input_str = ""
        self.input_phonemes = []

        self.stop_chars = '.,?!:;()\n'
        """
        Characters that represent a stop in the audio. If any of these characters appear in the text, silence of length
        stop_length will be added.
        """

        self.space_length = 20
        self.stop_length = 100

        self.use_phoneme_priority = True

        self.out_audio = AudioSegment.empty()
        """The audio segment generated by load_audio()"""

        self.audio_segment_timings = np.rec.array((0, 0), names=('output', 'morshu'))
        """
        Record of segment timings in the output audio. Each entry represents the time that a new morshu segment begins.
        
        The first index in each entry (named 'output') is the time in milliseconds when a new morshu segment begins.
        The second index (named 'morshu') is the time when the segment starts in the morshu audio.
        
        Example: If this record contains the entry (2000, 1895), that means at 2 seconds into the output audio, the
        morshu segment that begins at 1.895 seconds will be played (when morshu says 'B' in 'bombs').
        """

    def load_text(self, text: str = None) -> AudioSegment:
        """
        Generate audio from the given text. The input_str, input_phonemes, and audio_segment_timings variables are also
        updated.

        :param text: The text to use. If omitted, the input_str variable is used instead.

        :return: The generated audio. It's also stored in the out_audio variable.
        """
        if text is None:
            text = self.input_str
        self.input_str = text
        text = text.replace('\n', ',,,')

        phonemes = g2p(text)

        # output audio
        output = AudioSegment.empty().set_frame_rate(morshu_wav.frame_rate)

        # milliseconds marking each time a new morshu audio segment is used
        audio_out_millis = []

        # milliseconds marking the beginning of each segment in from the morshu audio
        audio_morshu_millis = []

        # segment of multiple phonemes in one word (phonemes between pauses)
        phoneme_segment = []
        while len(phonemes) > 0:
            p = phonemes.pop(0)
            if p in g2p.phonemes:
                phoneme_segment.append(p)
            if p not in g2p.phonemes or len(phonemes) == 0:
                output = self.append_best_morshu_phoneme_segment(output, phoneme_segment, audio_out_millis,
                                                                 audio_morshu_millis)
                phoneme_segment = []
            if p == ' ':
                output = self.append_audio_segment(output, AudioSegment.silent(self.space_length), -1, audio_out_millis,
                                                   audio_morshu_millis)
            elif p in self.stop_chars:
                output = self.append_audio_segment(output, AudioSegment.silent(self.stop_length), -1, audio_out_millis,
                                                   audio_morshu_millis)

        if len(output) == 0:
            warnings.warn('returned audio segment is empty', UserWarning)
            self.audio_segment_timings = np.rec.array((0, 0), names=('output', 'morshu'))
        else:
            self.audio_segment_timings = np.rec.array(tuple(zip(audio_out_millis, audio_morshu_millis)),
                                                      names=('output', 'morshu'))
        self.out_audio = output
        return output

    def get_frame_idx_from_millis(self, millis: int) -> int:
        """
        Get the morshu frame from the given time in milliseconds
        :param millis: Time in the output audio in milliseconds.
        :return: The morshu frame index that occurs at that time in the generated audio. The morshu video is 10 fps.
        """
        millis = int(millis)
        idx = np.argmin(self.audio_segment_timings['output'] <= millis) - 1

        output_segment_start, morshu_segment_start = self.audio_segment_timings[idx]
        if morshu_segment_start == -1:
            return -1

        morshu_frame = (morshu_segment_start + (millis - output_segment_start)) // 100  # 10 fps, 1 frame per 100 millis
        return morshu_frame

    @staticmethod
    def substitute_similar_phonemes(phonemes: List[str]):
        """
        Parse through a list of phonemes and replace them if necessary.

        The replacement phonemes are stored in the global similar_phonemes dictionary. These are phonemes that Morshu
        doesn't say in his two lines of dialog. These phonemes may sound slightly different than expected, and may be
        updated to be more accurate later.

        The emphasis number at the end of some vowel phonemes are removed to simplify things.

        :param phonemes: A list of phonemes to parse through.

        :return: A new list of phonemes.
        """
        i = 0
        while i < len(phonemes):
            # remove emphasis number
            if phonemes[i].endswith('0') or phonemes[i].endswith('1') or phonemes[i].endswith('2'):
                phonemes[i] = phonemes[i][:len(phonemes[i]) - 1]

            if phonemes[i] in similar_phonemes.keys():
                phonemes = phonemes[0:i] + similar_phonemes[phonemes[i]] + phonemes[i + 1:]
            i += 1
        return phonemes

    @staticmethod
    def append_audio_segment(audio_out: AudioSegment, audio_segment: AudioSegment, morshu_millis_start: int,
                             audio_out_millis: List[int], audio_morshu_millis: List[int]) -> AudioSegment:
        """
        Helper function to append one audio segment to another and update several variables at the same time.

        :param audio_out: The full audio to append to.

        :param audio_segment: The audio segment that will be appended.

        :param morshu_millis_start: The time in milliseconds that the audio segment begins in the morshu audio. Use -1
        if this audio doesn't appear in the Morshu audio (like if it's silence).

        :param audio_out_millis: A list of milliseconds representing when new segments begin in the output audio.

        :param audio_morshu_millis: A list of milliseconds representing when the segment begins in the morshu audio.

        :return: audio_out with audio_segment appended.
        """
        audio_out_millis.append(len(audio_out))
        audio_morshu_millis.append(morshu_millis_start)
        audio_out += audio_segment
        return audio_out

    @staticmethod
    def get_phoneme_sequence_occurrences(phonemes: List[str]) -> List[Tuple[int, int]]:
        """
        Get all occurrences of a given phoneme segment in the morshu audio.
        :return: List of tuples containing (start_millis, end_millis)
        """
        occurrences = []
        for i in range(len(morshu_rec) - len(phonemes)):
            if (morshu_rec['phoneme'][i:i + len(phonemes)] == phonemes).all():
                start = morshu_rec['timing'][i - 1]
                end = morshu_rec['timing'][i + len(phonemes) - 1]
                occurrences.append((start, end))
        return occurrences

    def get_best_morshu_single_phoneme(self, phoneme: str, preceding: str = "", succeeding: str = "") \
            -> Tuple[AudioSegment, int]:
        """
        Find the best morshu audio segment of the given phoneme.

        This compares the given surrounding phonemes with the phonemes in the morshu audio to determine the best one.
        Segments that match the same preceding or succeeding phoneme will be given the highest priority, and moderate
        priority is given if the phonemes both contain vowels. If two segments have the same priority, a random one is
        chosen.

        For example, if we're looking for the phoneme 'K' with nothing before it and 'IH' after it, the 'K' in either
        "can't" or "come" will be chosen instead of "credit", because the preceding phoneme matches (nothing) and the
        succeeding phoneme contains a vowel, so it's close enough.

        :param phoneme: The phoneme to search for.

        :param preceding: The phoneme that comes before the searching phoneme.

        :param succeeding: The phoneme that comes after the searching phoneme.

        :return: A tuple containing the audio segment and the time that the segment starts in the morshu audio
        """
        # list of phoneme indices of the highest priority
        best_indices = []
        phoneme_indices = np.where(morshu_rec['phoneme'] == phoneme)[0]
        if len(phoneme_indices) == 0:
            return AudioSegment.empty(), 0

        highest_priority = 0
        for i in phoneme_indices:
            # priorities for preceding and succeeding phonemes:
            # exact match: 10
            # compared phonemes both contain vowels: 5
            # no match: 0
            # starting priority is obtained from morshu_rec
            morshu_preceding = morshu_rec['phoneme'][i - 1]

            priority = morshu_rec['priority'][i] if self.use_phoneme_priority else 0
            if morshu_preceding == preceding:
                priority += 10
            # check both phonemes for any vowel
            elif any(c in morshu_preceding for c in "AEIOU") and any(c in preceding for c in "AEIOU"):
                priority += 5

            # check succeeding phonemes
            morshu_succeeding = morshu_rec['phoneme'][i + 1]
            if morshu_succeeding == succeeding:
                priority += 10
            # check both phonemes for any vowel
            elif any(c in morshu_succeeding for c in "AEIOU") and any(c in succeeding for c in "AEIOU"):
                priority += 1

            if priority < highest_priority:
                continue
            if priority > highest_priority:
                highest_priority = priority
                best_indices = []
            best_indices.append(i)

        index = random.choice(best_indices)
        segment = morshu_wav[morshu_rec['timing'][index - 1]: morshu_rec['timing'][index]]
        return segment, morshu_rec['timing'][index - 1]

    def append_best_morshu_phoneme_segment(self, output: AudioSegment, phonemes: List[str],
                                           audio_out_millis: List[int] = None,
                                           audio_morshu_millis: List[int] = None) -> AudioSegment:
        """
        Search for a phoneme segment that appears in the morshu audio, and append it to the given audio output. If a
        segment more than 1 length can't be found, get_best_morshu_single_phoneme will be used to find the best one.

        :param output: The audio to append the best segment to.

        :param phonemes: The phoneme segment to search for.

        :param audio_out_millis: A list of milliseconds representing when new segments begin in the output audio.

        :param audio_morshu_millis: A list of milliseconds representing when the segment begins in the morshu audio.

        :return: The audio segment with the new segment appended to it.
        """
        phonemes = Morshu.substitute_similar_phonemes(phonemes)
        if len(phonemes) == 1:
            segment, start = self.get_best_morshu_single_phoneme(phonemes[0])
            return Morshu.append_audio_segment(output, segment, start, audio_out_millis, audio_morshu_millis)

        # preceding and succeeding phonemes are used if we need to search for a single phoneme
        preceding = ""

        # full_segment = AudioSegment.empty()
        while len(phonemes) > 0:
            sequence_length = 1
            segment = AudioSegment.empty()

            start = 0
            while sequence_length <= len(phonemes):
                occurrences = Morshu.get_phoneme_sequence_occurrences(phonemes[:sequence_length])
                if len(occurrences) == 0:
                    break
                start, end = random.choice(occurrences)
                segment = morshu_wav[start:end]
                sequence_length += 1
            sequence_length -= 1

            # find the best single phoneme if a longer segment wasn't found
            if sequence_length == 1:
                if sequence_length + 1 < len(phonemes):
                    succeeding = phonemes[sequence_length + 1]
                else:
                    succeeding = ""
                segment, start = self.get_best_morshu_single_phoneme(phonemes[0], preceding, succeeding)

            output = Morshu.append_audio_segment(output, segment, start, audio_out_millis, audio_morshu_millis)

            preceding = phonemes[sequence_length - 1]
            del phonemes[:sequence_length]

        return output
