"""
Provide quantilized form of Adder2d, https://arxiv.org/pdf/1912.13200.pdf
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

from . import extra as ex
from .number import directquant


class Adder2d(ex.Adder2d):
    def __init__(self,
                 input_channel,
                 output_channel,
                 kernel_size,
                 stride=1,
                 padding=0,
                 bias=False,
                 weight_bit_width=8,
                 bias_bit_width=16):
        super().__init__(input_channel,
                         output_channel,
                         kernel_size,
                         stride=stride,
                         padding=padding,
                         bias=bias)
        self.weight_bit_width = weight_bit_width
        self.bias_bit_width = bias_bit_width

    def adder_forward(self, input):
        if self.bias is None:
            bias = None
        else:
            bias = directquant(self.bias, self.bias_bit_width)
        weight = directquant(self.weight, self.weight_bit_width)
        return ex.adder2d_function(input,
                                   weight,
                                   bias,
                                   stride=self.stride,
                                   padding=self.padding)

    def forward(self, input):
        return self.adder_forward(input)


if __name__ == '__main__':
    add = Adder2d(3, 4, 3, bias=True)
    x = torch.rand(10, 3, 10, 10)
    print(add(x).shape)
