"""Test C function dealing with Leech lattice vectors mod 3 of type 3

In this script we test functions dealing with vectors of type 3 in the 
Leech lattice modulo 3. These functions are implemented in file 
gen_leech3.c and available in the extension mmgroup.generators.

We use the terminology defined in
the document *The C interface of the mmgroup project*, 
section *Description of the mmgroup.generators extension*.
"""

from __future__ import absolute_import, division, print_function
from __future__ import  unicode_literals



from random import randint, choices #, shuffle, sample
from functools import reduce
from operator import __or__
from numbers import Integral
from collections import defaultdict
from multiprocessing import Pool, TimeoutError

import numpy as np
import scipy
import scipy.special
from scipy.stats import chisquare
import pytest

from mmgroup import MM0
from mmgroup.mat24 import MAT24_ORDER, ploop_theta
from mmgroup.mat24 import bw24 as mat24_bw24
from mmgroup.generators import gen_leech3to2_type3
from mmgroup.generators import gen_leech3_op_vector_word
from mmgroup.generators import gen_leech3_op_vector_atom
from mmgroup.generators import gen_leech2_op_word
from mmgroup.generators import gen_leech2_op_atom
from mmgroup.generators import gen_leech2_type_selftest

from mmgroup.tests.test_gen_xi.test_gen_type4 import rand_xsp2co1_elem
from mmgroup.tests.test_gen_xi.test_gen_type4 import create_test_elements
from mmgroup.tests.test_gen_xi.test_gen_type4 import mul_v3
from mmgroup.tests.test_gen_xi.test_gen_type4 import mul_v2
from mmgroup.tests.test_gen_xi.test_gen_type4 import str_v3
from mmgroup.tests.test_gen_xi.test_gen_type4 import weight_v3
from mmgroup.tests.test_gen_xi.test_gen_type4 import chisquare_v3



#####################################################################
# Creating test vectors
#####################################################################

# Define STD_TYPE3 = -(5,1,1,...,1) to be the standard type-3 
# vector in theLeech lattice.

# Here is STD_TYPE3 (mod 3) in Leech lattice mod 3 encoding
STD_TYPE3_MOD3 = 0xfffffe000001
# Here is STD_TYPE3 (mod 2) in Leech lattice encoding
STD_TYPE3_MOD2 = 0x800800







#####################################################################
# Test mapping of type 3 Leech vectors modulo 3 to vectors modulo 2
#####################################################################


def v3_to_v2(v3):
    """Map type-3 vector from Leech lattice mod 3 to Leech lattice mod 2

    Here parameter ``v3`` is a type-3 vector in the Leech lattice mod 3 
    in Leech lattice mod 3 encoding. 

    The function returns a type-3 vector in the Leech lattice mod 2
    corresponding to ``v3`` in Leech lattice encoding.

    The result is unique. The function returns 0 if ``v3`` is not of
    type 3 in the Leech lattice mod 3
    
    Tis function is a wrapper for the C function ``gen_leech3to2_type3`` 
    in file ``gen_leech3.c``. 
    """
    result = gen_leech3to2_type3(v3)
    assert result != 0, (str_v3(v3), weight_v3(v3), hex(result))
    return result




@pytest.mark.gen_xi
def test_type3(verbose = 0):
    r"""Test conversion of type-3 vectors 

    Let STD_TYPE3 be the type-3 vector in the Leech lattice defined
    above. Let STD_TYPE3_MOD2 and STD_TYPE3_MOD3 be the images of
    STD_TYPE3 in the Leech lattice mod 2 and mod 3, respectively.

    For a set of elements g of the group ``G_x_0`` we convert 
    STD_TYPE3_MOD3  * g to  a vector v2 in the Leech lattice mod 2  
    with function ``v3_to_v2`` and we check that the result is equal 
    to  STD_TYPE3_MOD2 * g. We use function ``create_test_elements`` 
    for generating the elements g. 
    """
    weights = defaultdict(int)
    for ntest, data in enumerate(create_test_elements()):
        g = MM0(data) 
        v3_st = STD_TYPE3_MOD3 
        v2_st = STD_TYPE3_MOD2       
        if verbose:
            print("\nTEST %s" % (ntest+1))
            print("v3_start = " , str_v3(v3_st))
            print("g =", g)
        v3 = mul_v3(v3_st, g)
        w = weight_v3(v3)  
        weights[w] += 1        
        v2_ref = mul_v2(v2_st, g) & 0xffffff
        v2 = v3_to_v2(v3) 
        ok = v2 == v2_ref 
        #if  weights[w] <= 20:
        #     assert  v2 == py_gen_leech3to2_type3(v3)        
        if verbose or not ok:
            if not verbose:
                print("\nTEST %s" % (ntest+1))
                print("v3_start = " , str_v3(v3st))
                print("g =", g)
            print("v3 = v3_st*g =",str_v3(v3))
            print("weight =",w)
            print("v2 obtained= ", hex(v2))
            print("v2 expected= ", hex(v2_ref))
            if not ok:
                ERR = "Error in opation mod 3"
                raise ValueError(ERR)
    print("weights =", dict(weights))
    assert set(weights.keys()) == set([9, 12, 21, 24])
    
    
    
#####################################################################
# Chisquare test of random type 3 Leech vectors modulo 3
#####################################################################



"""Return binommial coeffient n choose k"""
def binom(n, k):
    return int(scipy.special.binom(n, k) + 0.1)

# The dictionary contains the number DATA_GEOMETRY[w] of type-3 
# vectors in the Leech lattice modulo 3 of weight w. 
# This table is obtained from :cite:`Iva99`, Lemma 4.4.1
DATA_GEOMETRY = {
  24: 24 * 2**12, 
   9: 759 * 16 * 2**8,
  21: binom(24,3) * 2**12,
  12: 2576 * 2**11,
}  
 
# Number of type-3 vectors in the Leech lattice
NUM_LEECH_TYPE3 = 2**24 - 2**12
# .. must be equal to the sum of the values in DATA_GEOMETRY
assert sum(DATA_GEOMETRY.values()) ==  NUM_LEECH_TYPE3
# Inverse of the number of vectors in Leech lattice mod 3
I_NUMV3 = 3.0**(-24)

# Assemble a dictionary DICT_P for a chisquare test.
# That dictionary maps a bit weight w to itself, if that bit weight
# occurs as a weight of a type-3 vector in the Leech lattice mod 3 
# with sufficiently high probability. It maps w to 1 if w occurs as
# such a weight with low probability.
# Let P be a dictionary that maps y the probability that a random 
# vector v in the Leech lattice mod 3 has type 3 and that 
# DICT_P[weight(v)] is equal to y. Let P[0] be the probablity that
# such a random vector y is not of type 3.
BLOCKSIZE = 1000000  # Minimum of type-3 vectors needed for test 
DICT_P = defaultdict(int)
MIN_P = 1.0 / 40000
P = defaultdict(float)
for w, num in DATA_GEOMETRY.items():
    p = num * I_NUMV3
    assert 0 <= p < 1
    DICT_P[w] = 1 if p < MIN_P else w
    P[DICT_P[w]] += p
P[0] = 1.0 - NUM_LEECH_TYPE3 * I_NUMV3
DICT_P[0] = 0

P_MIN = min([x for x in P.values() if x > 0])

RANDMOD3 = [0, 1, 0x1000000]
def rand_v3():
    """Return a random vector in the space GF(3)^{24}.

    This random vector is encoded in **Leech lattice mod 3 encoding**.
    """
    randomlist = choices(RANDMOD3, k=24)
    return sum((x << i for i, x in enumerate(randomlist)))
    
def rand_v3_dict(n = BLOCKSIZE):
    """Create n random vectors and group the according to their weight

    We create n random vectors in GF(3)^{24}. We return a dictionary
    ``d`` with the following entries:

    d[0] counts the vectors not of type 3 in the Leech lattice mod 3.
    d[w] counts the vectors v of type 3 such that  DICT_P[weight(v)]
    is equal to w.

    Then the value d[i] should be about P[i] / n.    
    """
    d = defaultdict(int)
    for i in range(n):
        v3 = rand_v3()
        v2 = gen_leech3to2_type3(v3) 
        if (v2 == 0):
            d[0] += 1
        else: 
            w = mat24_bw24((v3 | (v3 >> 24)) & 0xffffff)
            d[DICT_P[w]] += 1
    return d






@pytest.mark.slow
@pytest.mark.very_slow
@pytest.mark.gen_xi
def test_chisq_type3(verbose = 0):
    """Test distribution of weights of type-t vectors

    The function creates a list of n random vectors in GF(3)^{24}
    and groups them with respect to the property of having type 3
    in the Leech lattice and with respect to the weight as 
    described in function ``rand_v3_dict``.

    A chisquare test fails if the p-value is less than 0.01.
    We perform at most 4 chisquare test and raise ValueError
    if all of them fail.
    """
    p_min = 0.01
    print("Check distribution of type-3 vectors mod 3") 
    for i in range(4):
        d = rand_v3_dict()  
        chisq, p =  chisquare_v3(d, P)
        if verbose or i or p < p_min:
            print("Chisq = %.3f, p = %.4f" % (chisq, p))
        if p >= p_min: return
    raise ValueError("Chisquare test failed") 



