#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
from os.path import join

import torch
from torch import nn
import torch.utils.model_zoo as model_zoo

import numpy as np

BatchNorm = nn.BatchNorm2d


def get_model_url(data="imagenet", name="dla34", hash="ba72cf86"):
    return join("http://dl.yf.io/dla/models", data, "{}-{}.pth".format(name, hash))


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
    )


class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, dilation=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            inplanes,
            planes,
            kernel_size=3,
            stride=stride,
            padding=dilation,
            bias=False,
            dilation=dilation,
        )
        self.bn1 = BatchNorm(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
            stride=1,
            padding=dilation,
            bias=False,
            dilation=dilation,
        )
        self.bn2 = BatchNorm(planes)
        self.stride = stride

    def forward(self, x, residual=None):
        if residual is None:
            residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, stride=1, dilation=1):
        super(Bottleneck, self).__init__()
        expansion = Bottleneck.expansion
        bottle_planes = planes // expansion
        self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm(bottle_planes)
        self.conv2 = nn.Conv2d(
            bottle_planes,
            bottle_planes,
            kernel_size=3,
            stride=stride,
            padding=dilation,
            bias=False,
            dilation=dilation,
        )
        self.bn2 = BatchNorm(bottle_planes)
        self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False)
        self.bn3 = BatchNorm(planes)
        self.relu = nn.ReLU(inplace=True)
        self.stride = stride

    def forward(self, x, residual=None):
        if residual is None:
            residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += residual
        out = self.relu(out)

        return out


class BottleneckX(nn.Module):
    expansion = 2
    cardinality = 32

    def __init__(self, inplanes, planes, stride=1, dilation=1):
        super(BottleneckX, self).__init__()
        cardinality = BottleneckX.cardinality
        bottle_planes = planes * cardinality // 32
        self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm(bottle_planes)
        self.conv2 = nn.Conv2d(
            bottle_planes,
            bottle_planes,
            kernel_size=3,
            stride=stride,
            padding=dilation,
            bias=False,
            dilation=dilation,
            groups=cardinality,
        )
        self.bn2 = BatchNorm(bottle_planes)
        self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False)
        self.bn3 = BatchNorm(planes)
        self.relu = nn.ReLU(inplace=True)
        self.stride = stride

    def forward(self, x, residual=None):
        if residual is None:
            residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += residual
        out = self.relu(out)

        return out


class Root(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, residual):
        super(Root, self).__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            1,
            stride=1,
            bias=False,
            padding=(kernel_size - 1) // 2,
        )
        self.bn = BatchNorm(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.residual = residual

    def forward(self, *x):
        children = x
        x = self.conv(torch.cat(x, 1))
        x = self.bn(x)
        if self.residual:
            x += children[0]
        x = self.relu(x)

        return x


class Tree(nn.Module):
    def __init__(
        self,
        levels,
        block,
        in_channels,
        out_channels,
        stride=1,
        level_root=False,
        root_dim=0,
        root_kernel_size=1,
        dilation=1,
        root_residual=False,
    ):
        super(Tree, self).__init__()
        if root_dim == 0:
            root_dim = 2 * out_channels
        if level_root:
            root_dim += in_channels
        if levels == 1:
            self.tree1 = block(in_channels, out_channels, stride, dilation=dilation)
            self.tree2 = block(out_channels, out_channels, 1, dilation=dilation)
        else:
            self.tree1 = Tree(
                levels - 1,
                block,
                in_channels,
                out_channels,
                stride,
                root_dim=0,
                root_kernel_size=root_kernel_size,
                dilation=dilation,
                root_residual=root_residual,
            )
            self.tree2 = Tree(
                levels - 1,
                block,
                out_channels,
                out_channels,
                root_dim=root_dim + out_channels,
                root_kernel_size=root_kernel_size,
                dilation=dilation,
                root_residual=root_residual,
            )
        if levels == 1:
            self.root = Root(root_dim, out_channels, root_kernel_size, root_residual)
        self.level_root = level_root
        self.root_dim = root_dim
        self.downsample = None
        self.project = None
        self.levels = levels
        if stride > 1:
            self.downsample = nn.MaxPool2d(stride, stride=stride)
        if in_channels != out_channels:
            self.project = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=1, bias=False
                ),
                BatchNorm(out_channels),
            )

    def forward(self, x, residual=None, children=None):
        children = [] if children is None else children
        bottom = self.downsample(x) if self.downsample else x
        residual = self.project(bottom) if self.project else bottom
        if self.level_root:
            children.append(bottom)
        x1 = self.tree1(x, residual)
        if self.levels == 1:
            x2 = self.tree2(x1)
            x = self.root(x2, x1, *children)
        else:
            children.append(x1)
            x = self.tree2(x1, children=children)
        return x


class DLA(nn.Module):
    def __init__(
        self,
        levels,
        channels,
        num_classes=1000,
        block=BasicBlock,
        residual_root=False,
        return_levels=False,
        pool_size=7,
        linear_root=False,
    ):
        super(DLA, self).__init__()
        self.channels = channels
        self.return_levels = return_levels
        self.num_classes = num_classes
        self.base_layer = nn.Sequential(
            nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
            BatchNorm(channels[0]),
            nn.ReLU(inplace=True),
        )
        self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
        self.level1 = self._make_conv_level(
            channels[0], channels[1], levels[1], stride=2
        )
        self.level2 = Tree(
            levels[2],
            block,
            channels[1],
            channels[2],
            2,
            level_root=False,
            root_residual=residual_root,
        )
        self.level3 = Tree(
            levels[3],
            block,
            channels[2],
            channels[3],
            2,
            level_root=True,
            root_residual=residual_root,
        )
        self.level4 = Tree(
            levels[4],
            block,
            channels[3],
            channels[4],
            2,
            level_root=True,
            root_residual=residual_root,
        )
        self.level5 = Tree(
            levels[5],
            block,
            channels[4],
            channels[5],
            2,
            level_root=True,
            root_residual=residual_root,
        )

        self.avgpool = nn.AvgPool2d(pool_size)
        self.fc = nn.Conv2d(
            channels[-1], num_classes, kernel_size=1, stride=1, padding=0, bias=True
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n))
            elif isinstance(m, BatchNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_level(self, block, inplanes, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes:
            downsample = nn.Sequential(
                nn.MaxPool2d(stride, stride=stride),
                nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False),
                BatchNorm(planes),
            )

        layers = []
        layers.append(block(inplanes, planes, stride, downsample=downsample))
        for i in range(1, blocks):
            layers.append(block(inplanes, planes))

        return nn.Sequential(*layers)

    def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
        modules = []
        for i in range(convs):
            modules.extend(
                [
                    nn.Conv2d(
                        inplanes,
                        planes,
                        kernel_size=3,
                        stride=stride if i == 0 else 1,
                        padding=dilation,
                        bias=False,
                        dilation=dilation,
                    ),
                    BatchNorm(planes),
                    nn.ReLU(inplace=True),
                ]
            )
            inplanes = planes
        return nn.Sequential(*modules)

    def forward(self, x):
        y = []
        x = self.base_layer(x)
        for i in range(6):
            x = getattr(self, "level{}".format(i))(x)
            y.append(x)
        if self.return_levels:
            return y
        else:
            x = self.avgpool(x)
            x = self.fc(x)
            x = x.view(x.size(0), -1)

            return x

    def load_model(self, pretrained=True, data="imagenet", name="dla34", hash="ba72cf86", num_classes=None):
        fc = self.fc

        if pretrained:
            if name.endswith(".pth"):
                model_weights = torch.load(data + name)
            else:
                model_url = get_model_url(data, name, hash)
                model_weights = model_zoo.load_url(model_url)
            num_classes = len(model_weights[list(model_weights.keys())[-1]])

        self.fc = nn.Conv2d(
            self.channels[-1],
            num_classes,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )

        if pretrained:
            self.load_state_dict(model_weights)
        self.fc = fc


def dla34(pretrained, num_classes=None, **kwargs):  # DLA-34
    model = DLA(
        [1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock, **kwargs
    )
    model.load_model(data="imagenet", name="dla34", hash="ba72cf86", num_classes=num_classes, pretrained=pretrained)
    return model


def dla46_c(pretrained=False, num_classes=None, **kwargs):  # DLA-46-C
    Bottleneck.expansion = 2
    model = DLA(
        [1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=Bottleneck, **kwargs
    )
    model.load_model(pretrained, "dla46_c", num_classes=num_classes)
    return model


def dla46x_c(pretrained=False, num_classes=None, **kwargs):  # DLA-X-46-C
    BottleneckX.expansion = 2
    model = DLA(
        [1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs
    )
    model.load_model(pretrained, "dla46x_c", num_classes=num_classes)
    return model


def dla60x_c(pretrained, num_classes=None, **kwargs):  # DLA-X-60-C
    BottleneckX.expansion = 2
    model = DLA(
        [1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs
    )
    model.load_model(data="imagenet", name="dla60x_c", hash="b870c45c", num_classes=num_classes, pretrained=pretrained)
    return model


def dla60(pretrained=False, num_classes=None, **kwargs):  # DLA-60
    Bottleneck.expansion = 2
    model = DLA(
        [1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=Bottleneck, **kwargs
    )
    model.load_model(pretrained, "dla60", num_classes=num_classes)
    return model


def dla60x(pretrained=False, num_classes=None, **kwargs):  # DLA-X-60
    BottleneckX.expansion = 2
    model = DLA(
        [1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=BottleneckX, **kwargs
    )
    model.load_model(pretrained, "dla60x", num_classes=num_classes)
    return model


def dla102(pretrained=False, num_classes=None, **kwargs):  # DLA-102
    Bottleneck.expansion = 2
    model = DLA(
        [1, 1, 1, 3, 4, 1],
        [16, 32, 128, 256, 512, 1024],
        block=Bottleneck,
        residual_root=True,
        **kwargs
    )
    model.load_model(pretrained, "dla102", num_classes=num_classes)
    return model


def dla102x(pretrained=False, num_classes=None, **kwargs):  # DLA-X-102
    BottleneckX.expansion = 2
    model = DLA(
        [1, 1, 1, 3, 4, 1],
        [16, 32, 128, 256, 512, 1024],
        block=BottleneckX,
        residual_root=True,
        **kwargs
    )
    model.load_model(pretrained, "dla102x", num_classes=num_classes)
    return model


def dla102x2(pretrained=False, num_classes=None, **kwargs):  # DLA-X-102 64
    BottleneckX.cardinality = 64
    model = DLA(
        [1, 1, 1, 3, 4, 1],
        [16, 32, 128, 256, 512, 1024],
        block=BottleneckX,
        residual_root=True,
        **kwargs
    )
    model.load_model(pretrained, "dla102x2", num_classes=num_classes)
    return model


def dla169(pretrained=False, num_classes=None, **kwargs):  # DLA-169
    Bottleneck.expansion = 2
    model = DLA(
        [1, 1, 2, 3, 5, 1],
        [16, 32, 128, 256, 512, 1024],
        block=Bottleneck,
        residual_root=True,
        **kwargs
    )
    model.load_model(pretrained, "dla169", num_classes=num_classes)
    return model


def set_bn(bn):
    global BatchNorm
    BatchNorm = bn
    dla.BatchNorm = bn


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


def fill_up_weights(up):
    w = up.weight.data
    f = math.ceil(w.size(2) / 2)
    c = (2 * f - 1 - f % 2) / (2.0 * f)
    for i in range(w.size(2)):
        for j in range(w.size(3)):
            w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
    for c in range(1, w.size(0)):
        w[c, 0, :, :] = w[0, 0, :, :]


class IDAUp(nn.Module):
    def __init__(self, node_kernel, out_dim, channels, up_factors):
        super(IDAUp, self).__init__()
        self.channels = channels
        self.out_dim = out_dim
        for i, c in enumerate(channels):
            if c == out_dim:
                proj = Identity()
            else:
                proj = nn.Sequential(
                    nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False),
                    BatchNorm(out_dim),
                    nn.ReLU(inplace=True),
                )
            f = int(up_factors[i])
            if f == 1:
                up = Identity()
            else:
                up = nn.ConvTranspose2d(
                    out_dim,
                    out_dim,
                    f * 2,
                    stride=f,
                    padding=f // 2,
                    output_padding=0,
                    groups=out_dim,
                    bias=False,
                )
                fill_up_weights(up)
            setattr(self, "proj_" + str(i), proj)
            setattr(self, "up_" + str(i), up)

        for i in range(1, len(channels)):
            node = nn.Sequential(
                nn.Conv2d(
                    out_dim * 2,
                    out_dim,
                    kernel_size=node_kernel,
                    stride=1,
                    padding=node_kernel // 2,
                    bias=False,
                ),
                BatchNorm(out_dim),
                nn.ReLU(inplace=True),
            )
            setattr(self, "node_" + str(i), node)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n))
            elif isinstance(m, BatchNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, layers):
        assert len(self.channels) == len(layers), "{} vs {} layers".format(
            len(self.channels), len(layers)
        )
        layers = list(layers)
        for i, l in enumerate(layers):
            upsample = getattr(self, "up_" + str(i))
            project = getattr(self, "proj_" + str(i))
            layers[i] = upsample(project(l))
        x = layers[0]
        y = []
        for i in range(1, len(layers)):
            node = getattr(self, "node_" + str(i))
            x = node(torch.cat([x, layers[i]], 1))
            y.append(x)
        return x, y


class DLAUp(nn.Module):
    def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None):
        super(DLAUp, self).__init__()
        if in_channels is None:
            in_channels = channels
        self.channels = channels
        channels = list(channels)
        scales = np.array(scales, dtype=int)
        for i in range(len(channels) - 1):
            j = -i - 2
            setattr(
                self,
                "ida_{}".format(i),
                IDAUp(3, channels[j], in_channels[j:], scales[j:] // scales[j]),
            )
            scales[j + 1 :] = scales[j]
            in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]]

    def forward(self, layers):
        layers = list(layers)
        assert len(layers) > 1
        for i in range(len(layers) - 1):
            ida = getattr(self, "ida_{}".format(i))
            x, y = ida(layers[-i - 2 :])
            layers[-i - 1 :] = y
        return x


def fill_fc_weights(layers):
    for m in layers.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, std=0.001)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)


class DLASeg(nn.Module):
    def __init__(self, base_name, heads, pretrained=True, down_ratio=4, head_conv=256, num_classes=None):
        super(DLASeg, self).__init__()
        assert down_ratio in [2, 4, 8, 16]
        self.heads = heads
        self.first_level = int(np.log2(down_ratio))
        self.base = globals()[base_name](pretrained=pretrained, return_levels=True, num_classes=num_classes)
        channels = self.base.channels
        scales = [2 ** i for i in range(len(channels[self.first_level :]))]
        self.dla_up = DLAUp(channels[self.first_level :], scales=scales)
        """
        self.fc = nn.Sequential(
            nn.Conv2d(channels[self.first_level], classes, kernel_size=1,
                      stride=1, padding=0, bias=True)
        )
        """

        for head in self.heads:
            classes = self.heads[head]
            if head_conv > 0:
                fc = nn.Sequential(
                    nn.Conv2d(
                        channels[self.first_level],
                        head_conv,
                        kernel_size=3,
                        padding=1,
                        bias=True,
                    ),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(
                        head_conv,
                        classes,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        bias=True,
                    ),
                )
                if "hm" in head:
                    fc[-1].bias.data.fill_(-2.19)
                else:
                    fill_fc_weights(fc)
            else:
                fc = nn.Conv2d(
                    channels[self.first_level],
                    classes,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    bias=True,
                )
                if "hm" in head:
                    fc.bias.data.fill_(-2.19)
                else:
                    fill_fc_weights(fc)
            self.__setattr__(head, fc)

        """
        up_factor = 2 ** self.first_level
        if up_factor > 1:
            up = nn.ConvTranspose2d(classes, classes, up_factor * 2,
                                    stride=up_factor, padding=up_factor // 2,
                                    output_padding=0, groups=classes,
                                    bias=False)
            fill_up_weights(up)
            up.weight.requires_grad = False
        else:
            up = Identity()
        self.up = up
        self.softmax = nn.LogSoftmax(dim=1)
        

        for m in self.fc.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, BatchNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        """

    def forward(self, x):
        x = self.base(x)
        x = self.dla_up(x[self.first_level:])
        ret = {}
        for head in self.heads:
            ret[head] = self.__getattr__(head)(x)
        return [ret]

    """
    def optim_parameters(self, memo=None):
        for param in self.base.parameters():
            yield param
        for param in self.dla_up.parameters():
            yield param
        for param in self.fc.parameters():
            yield param
    """


"""
def dla34up(classes, pretrained_base=None, **kwargs):
    model = DLASeg('dla34', classes, pretrained_base=pretrained_base, **kwargs)
    return model


def dla60up(classes, pretrained_base=None, **kwargs):
    model = DLASeg('dla60', classes, pretrained_base=pretrained_base, **kwargs)
    return model


def dla102up(classes, pretrained_base=None, **kwargs):
    model = DLASeg('dla102', classes,
                   pretrained_base=pretrained_base, **kwargs)
    return model


def dla169up(classes, pretrained_base=None, **kwargs):
    model = DLASeg('dla169', classes,
                   pretrained_base=pretrained_base, **kwargs)
    return model
"""


def get_pose_net(pretrained, num_layers, heads, head_conv=256, down_ratio=4, num_classes=1000):
    model = DLASeg(
        "dla{}".format(num_layers),
        heads,
        pretrained=pretrained,
        down_ratio=down_ratio,
        head_conv=head_conv,
        num_classes=num_classes,
    )
    return model
