/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.nn.ranking;

import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class CDAEParamInitializer
extends DefaultParamInitializer {
    private static final CDAEParamInitializer INSTANCE = new CDAEParamInitializer();
    public static final String USER_WEIGHT_KEY = "uw";
    public static int numUsers = 0;

    public static CDAEParamInitializer getInstance() {
        return INSTANCE;
    }

    public int numParams(NeuralNetConfiguration conf) {
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        return super.numParams(conf) + numUsers * layerConf.getNOut();
    }

    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        Map params = super.init(conf, paramsView, initializeParams);
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        int nIn = layerConf.getNIn();
        int nOut = layerConf.getNOut();
        int nWeightParams = nIn * nOut;
        int nUserWeightParams = numUsers * nOut;
        INDArray userWeightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(nWeightParams + nOut), (int)(nWeightParams + nOut + nUserWeightParams))});
        params.put(USER_WEIGHT_KEY, this.createUserWeightMatrix(conf, userWeightView, initializeParams));
        conf.addVariable(USER_WEIGHT_KEY);
        return params;
    }

    protected INDArray createUserWeightMatrix(NeuralNetConfiguration conf, INDArray weightParamView, boolean initializeParameters) {
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        if (initializeParameters) {
            Distribution dist = Distributions.createDistribution((org.deeplearning4j.nn.conf.distribution.Distribution)layerConf.getDist());
            return this.createWeightMatrix(numUsers, layerConf.getNOut(), layerConf.getWeightInit(), dist, weightParamView, true);
        }
        return this.createWeightMatrix(numUsers, layerConf.getNOut(), null, null, weightParamView, false);
    }

    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        Map out = super.getGradientsFromFlattened(conf, gradientView);
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        int nIn = layerConf.getNIn();
        int nOut = layerConf.getNOut();
        int nWeightParams = nIn * nOut;
        int nUserWeightParams = numUsers * nOut;
        INDArray userWeightGradientView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(nWeightParams + nOut), (int)(nWeightParams + nOut + nUserWeightParams))}).reshape('f', numUsers, nOut);
        out.put(USER_WEIGHT_KEY, userWeightGradientView);
        return out;
    }
}

