/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sentiment;

import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.sentiment.SentimentModel;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class SentimentCostAndGradient
extends AbstractCachingDiffFunction {
    private static final Redwood.RedwoodChannels log = Redwood.channels(SentimentCostAndGradient.class);
    private final SentimentModel model;
    private final List<Tree> trainingBatch;

    public SentimentCostAndGradient(SentimentModel model, List<Tree> trainingBatch) {
        this.model = model;
        this.trainingBatch = trainingBatch;
    }

    @Override
    public int domainDimension() {
        return this.model.totalParamSize();
    }

    private static double sumError(Tree tree) {
        if (tree.isLeaf()) {
            return 0.0;
        }
        if (tree.isPreTerminal()) {
            return RNNCoreAnnotations.getPredictionError(tree);
        }
        double error = 0.0;
        for (Tree child : tree.children()) {
            error += SentimentCostAndGradient.sumError(child);
        }
        return RNNCoreAnnotations.getPredictionError(tree) + error;
    }

    private static int getPredictedClass(SimpleMatrix predictions) {
        int argmax = 0;
        for (int i = 1; i < predictions.getNumElements(); ++i) {
            if (!(predictions.get(i) > predictions.get(argmax))) continue;
            argmax = i;
        }
        return argmax;
    }

    private ModelDerivatives scoreDerivatives(List<Tree> trainingBatch) {
        ModelDerivatives derivatives = new ModelDerivatives(this.model);
        ArrayList<Tree> forwardPropTrees = Generics.newArrayList();
        for (Tree tree : trainingBatch) {
            Tree trainingTree = tree.deepCopy();
            this.forwardPropagateTree(trainingTree);
            forwardPropTrees.add(trainingTree);
        }
        for (Tree tree : forwardPropTrees) {
            this.backpropDerivativesAndError(tree, derivatives.binaryTD, derivatives.binaryCD, derivatives.binaryTensorTD, derivatives.unaryCD, derivatives.wordVectorD);
            derivatives.error += SentimentCostAndGradient.sumError(tree);
        }
        return derivatives;
    }

    @Override
    public void calculate(double[] theta) {
        ModelDerivatives derivatives;
        this.model.vectorToParams(theta);
        if (this.model.op.trainOptions.nThreads == 1) {
            derivatives = this.scoreDerivatives(this.trainingBatch);
        } else {
            MulticoreWrapper<List<Tree>, ModelDerivatives> wrapper = new MulticoreWrapper<List<Tree>, ModelDerivatives>(this.model.op.trainOptions.nThreads, new ScoringProcessor());
            for (List<Tree> chunk : CollectionUtils.partitionIntoFolds(this.trainingBatch, wrapper.nThreads())) {
                wrapper.put(chunk);
            }
            wrapper.join();
            derivatives = new ModelDerivatives(this.model);
            while (wrapper.peek()) {
                ModelDerivatives batchDerivatives = wrapper.poll();
                derivatives.add(batchDerivatives);
            }
        }
        double scale = 1.0 / (double)this.trainingBatch.size();
        this.value = derivatives.error * scale;
        this.value += SentimentCostAndGradient.scaleAndRegularize(derivatives.binaryTD, this.model.binaryTransform, scale, this.model.op.trainOptions.regTransformMatrix, false);
        this.value += SentimentCostAndGradient.scaleAndRegularize(derivatives.binaryCD, this.model.binaryClassification, scale, this.model.op.trainOptions.regClassification, true);
        this.value += SentimentCostAndGradient.scaleAndRegularizeTensor(derivatives.binaryTensorTD, this.model.binaryTensors, scale, this.model.op.trainOptions.regTransformTensor);
        this.value += SentimentCostAndGradient.scaleAndRegularize(derivatives.unaryCD, this.model.unaryClassification, scale, this.model.op.trainOptions.regClassification, false, true);
        this.value += SentimentCostAndGradient.scaleAndRegularize(derivatives.wordVectorD, this.model.wordVectors, scale, this.model.op.trainOptions.regWordVector, true, false);
        this.derivative = NeuralUtils.paramsToVector(theta.length, derivatives.binaryTD.valueIterator(), derivatives.binaryCD.valueIterator(), SimpleTensor.iteratorSimpleMatrix(derivatives.binaryTensorTD.valueIterator()), derivatives.unaryCD.values().iterator(), derivatives.wordVectorD.values().iterator());
    }

    private static double scaleAndRegularize(TwoDimensionalMap<String, String, SimpleMatrix> derivatives, TwoDimensionalMap<String, String, SimpleMatrix> currentMatrices, double scale, double regCost, boolean dropBiasColumn) {
        double cost = 0.0;
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : currentMatrices) {
            SimpleMatrix D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
            SimpleMatrix regMatrix = entry.getValue();
            if (dropBiasColumn) {
                regMatrix = new SimpleMatrix(regMatrix);
                regMatrix.insertIntoThis(0, regMatrix.numCols() - 1, (SimpleBase)new SimpleMatrix(regMatrix.numRows(), 1));
            }
            D = (SimpleMatrix)((SimpleMatrix)D.scale(scale)).plus(regMatrix.scale(regCost));
            derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
            cost += ((SimpleMatrix)regMatrix.elementMult((SimpleBase)regMatrix)).elementSum() * regCost / 2.0;
        }
        return cost;
    }

    private static double scaleAndRegularize(Map<String, SimpleMatrix> derivatives, Map<String, SimpleMatrix> currentMatrices, double scale, double regCost, boolean activeMatricesOnly, boolean dropBiasColumn) {
        double cost = 0.0;
        for (Map.Entry<String, SimpleMatrix> entry : currentMatrices.entrySet()) {
            SimpleMatrix D = derivatives.get(entry.getKey());
            if (activeMatricesOnly && D == null) {
                derivatives.put(entry.getKey(), new SimpleMatrix(entry.getValue().numRows(), entry.getValue().numCols()));
                continue;
            }
            SimpleMatrix regMatrix = entry.getValue();
            if (dropBiasColumn) {
                regMatrix = new SimpleMatrix(regMatrix);
                regMatrix.insertIntoThis(0, regMatrix.numCols() - 1, (SimpleBase)new SimpleMatrix(regMatrix.numRows(), 1));
            }
            D = (SimpleMatrix)((SimpleMatrix)D.scale(scale)).plus(regMatrix.scale(regCost));
            derivatives.put(entry.getKey(), D);
            cost += ((SimpleMatrix)regMatrix.elementMult((SimpleBase)regMatrix)).elementSum() * regCost / 2.0;
        }
        return cost;
    }

    private static double scaleAndRegularizeTensor(TwoDimensionalMap<String, String, SimpleTensor> derivatives, TwoDimensionalMap<String, String, SimpleTensor> currentMatrices, double scale, double regCost) {
        double cost = 0.0;
        for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : currentMatrices) {
            SimpleTensor D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
            D = D.scale(scale).plus(entry.getValue().scale(regCost));
            derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
            cost += entry.getValue().elementMult(entry.getValue()).elementSum() * regCost / 2.0;
        }
        return cost;
    }

    private void backpropDerivativesAndError(Tree tree, TwoDimensionalMap<String, String, SimpleMatrix> binaryTD, TwoDimensionalMap<String, String, SimpleMatrix> binaryCD, TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD, Map<String, SimpleMatrix> unaryCD, Map<String, SimpleMatrix> wordVectorD) {
        SimpleMatrix delta = new SimpleMatrix(this.model.op.numHid, 1);
        this.backpropDerivativesAndError(tree, binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, delta);
    }

    private void backpropDerivativesAndError(Tree tree, TwoDimensionalMap<String, String, SimpleMatrix> binaryTD, TwoDimensionalMap<String, String, SimpleMatrix> binaryCD, TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD, Map<String, SimpleMatrix> unaryCD, Map<String, SimpleMatrix> wordVectorD, SimpleMatrix deltaUp) {
        if (tree.isLeaf()) {
            return;
        }
        SimpleMatrix currentVector = RNNCoreAnnotations.getNodeVector(tree);
        String category = tree.label().value();
        category = this.model.basicCategory(category);
        SimpleMatrix goldLabel = new SimpleMatrix(this.model.numClasses, 1);
        int goldClass = RNNCoreAnnotations.getGoldClass(tree);
        if (goldClass >= 0) {
            goldLabel.set(goldClass, 1.0);
        }
        double nodeWeight = this.model.op.trainOptions.getClassWeight(goldClass);
        SimpleMatrix predictions = RNNCoreAnnotations.getPredictions(tree);
        SimpleMatrix deltaClass = goldClass >= 0 ? (SimpleMatrix)((SimpleMatrix)predictions.minus((SimpleBase)goldLabel)).scale(nodeWeight) : new SimpleMatrix(predictions.numRows(), predictions.numCols());
        SimpleMatrix localCD = (SimpleMatrix)deltaClass.mult(NeuralUtils.concatenateWithBias(currentVector).transpose());
        double error = -((SimpleMatrix)NeuralUtils.elementwiseApplyLog(predictions).elementMult((SimpleBase)goldLabel)).elementSum();
        RNNCoreAnnotations.setPredictionError(tree, error *= nodeWeight);
        if (tree.isPreTerminal()) {
            unaryCD.put(category, (SimpleMatrix)unaryCD.get(category).plus((SimpleBase)localCD));
            String word = tree.children()[0].label().value();
            word = this.model.getVocabWord(word);
            SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
            SimpleMatrix deltaFromClass = (SimpleMatrix)((SimpleMatrix)this.model.getUnaryClassification(category).transpose()).mult((SimpleBase)deltaClass);
            deltaFromClass = (SimpleMatrix)((SimpleMatrix)deltaFromClass.extractMatrix(0, this.model.op.numHid, 0, 1)).elementMult((SimpleBase)currentVectorDerivative);
            SimpleMatrix deltaFull = (SimpleMatrix)deltaFromClass.plus((SimpleBase)deltaUp);
            SimpleMatrix oldWordVectorD = wordVectorD.get(word);
            if (oldWordVectorD == null) {
                wordVectorD.put(word, deltaFull);
            } else {
                wordVectorD.put(word, (SimpleMatrix)oldWordVectorD.plus((SimpleBase)deltaFull));
            }
        } else {
            SimpleMatrix deltaDown;
            String leftCategory = this.model.basicCategory(tree.children()[0].label().value());
            String rightCategory = this.model.basicCategory(tree.children()[1].label().value());
            if (this.model.op.combineClassification) {
                unaryCD.put("", (SimpleMatrix)unaryCD.get("").plus((SimpleBase)localCD));
            } else {
                binaryCD.put(leftCategory, rightCategory, (SimpleMatrix)binaryCD.get(leftCategory, rightCategory).plus((SimpleBase)localCD));
            }
            SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
            SimpleMatrix deltaFromClass = (SimpleMatrix)((SimpleMatrix)this.model.getBinaryClassification(leftCategory, rightCategory).transpose()).mult((SimpleBase)deltaClass);
            deltaFromClass = (SimpleMatrix)((SimpleMatrix)deltaFromClass.extractMatrix(0, this.model.op.numHid, 0, 1)).elementMult((SimpleBase)currentVectorDerivative);
            SimpleMatrix deltaFull = (SimpleMatrix)deltaFromClass.plus((SimpleBase)deltaUp);
            SimpleMatrix leftVector = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
            SimpleMatrix rightVector = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
            SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
            SimpleMatrix W_df = (SimpleMatrix)deltaFull.mult(childrenVector.transpose());
            binaryTD.put(leftCategory, rightCategory, (SimpleMatrix)binaryTD.get(leftCategory, rightCategory).plus((SimpleBase)W_df));
            if (this.model.op.useTensors) {
                SimpleTensor Wt_df = SentimentCostAndGradient.getTensorGradient(deltaFull, leftVector, rightVector);
                binaryTensorTD.put(leftCategory, rightCategory, binaryTensorTD.get(leftCategory, rightCategory).plus(Wt_df));
                deltaDown = SentimentCostAndGradient.computeTensorDeltaDown(deltaFull, leftVector, rightVector, this.model.getBinaryTransform(leftCategory, rightCategory), this.model.getBinaryTensor(leftCategory, rightCategory));
            } else {
                deltaDown = (SimpleMatrix)((SimpleMatrix)this.model.getBinaryTransform(leftCategory, rightCategory).transpose()).mult((SimpleBase)deltaFull);
            }
            SimpleMatrix leftDerivative = NeuralUtils.elementwiseApplyTanhDerivative(leftVector);
            SimpleMatrix rightDerivative = NeuralUtils.elementwiseApplyTanhDerivative(rightVector);
            SimpleMatrix leftDeltaDown = (SimpleMatrix)deltaDown.extractMatrix(0, deltaFull.numRows(), 0, 1);
            SimpleMatrix rightDeltaDown = (SimpleMatrix)deltaDown.extractMatrix(deltaFull.numRows(), deltaFull.numRows() * 2, 0, 1);
            this.backpropDerivativesAndError(tree.children()[0], binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, (SimpleMatrix)leftDerivative.elementMult((SimpleBase)leftDeltaDown));
            this.backpropDerivativesAndError(tree.children()[1], binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, (SimpleMatrix)rightDerivative.elementMult((SimpleBase)rightDeltaDown));
        }
    }

    private static SimpleMatrix computeTensorDeltaDown(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector, SimpleMatrix W, SimpleTensor Wt) {
        SimpleMatrix WTDelta = (SimpleMatrix)((SimpleMatrix)W.transpose()).mult((SimpleBase)deltaFull);
        SimpleMatrix WTDeltaNoBias = (SimpleMatrix)WTDelta.extractMatrix(0, deltaFull.numRows() * 2, 0, 1);
        int size = deltaFull.getNumElements();
        SimpleMatrix deltaTensor = new SimpleMatrix(size * 2, 1);
        SimpleMatrix fullVector = NeuralUtils.concatenate(leftVector, rightVector);
        for (int slice = 0; slice < size; ++slice) {
            SimpleMatrix scaledFullVector = (SimpleMatrix)fullVector.scale(deltaFull.get(slice));
            deltaTensor = (SimpleMatrix)deltaTensor.plus(((SimpleMatrix)Wt.getSlice(slice).plus(Wt.getSlice(slice).transpose())).mult((SimpleBase)scaledFullVector));
        }
        return (SimpleMatrix)deltaTensor.plus((SimpleBase)WTDeltaNoBias);
    }

    private static SimpleTensor getTensorGradient(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector) {
        int size = deltaFull.getNumElements();
        SimpleTensor Wt_df = new SimpleTensor(size * 2, size * 2, size);
        SimpleMatrix fullVector = NeuralUtils.concatenate(leftVector, rightVector);
        for (int slice = 0; slice < size; ++slice) {
            Wt_df.setSlice(slice, (SimpleMatrix)((SimpleMatrix)fullVector.scale(deltaFull.get(slice))).mult(fullVector.transpose()));
        }
        return Wt_df;
    }

    public void forwardPropagateTree(Tree tree) {
        SimpleMatrix nodeVector;
        SimpleMatrix classification;
        if (tree.isLeaf()) {
            log.info("SentimentCostAndGradient: warning: We reached leaves in forwardPropagate: " + tree);
            throw new AssertionError((Object)"We should not have reached leaves in forwardPropagate");
        }
        if (tree.isPreTerminal()) {
            classification = this.model.getUnaryClassification(tree.label().value());
            String word = tree.children()[0].label().value();
            SimpleMatrix wordVector = this.model.getWordVector(word);
            nodeVector = NeuralUtils.elementwiseApplyTanh(wordVector);
        } else {
            if (tree.children().length == 1) {
                log.info("SentimentCostAndGradient: warning: Non-preterminal nodes of size 1: " + tree);
                throw new AssertionError((Object)"Non-preterminal nodes of size 1 should have already been collapsed");
            }
            if (tree.children().length == 2) {
                this.forwardPropagateTree(tree.children()[0]);
                this.forwardPropagateTree(tree.children()[1]);
                String leftCategory = tree.children()[0].label().value();
                String rightCategory = tree.children()[1].label().value();
                SimpleMatrix W = this.model.getBinaryTransform(leftCategory, rightCategory);
                classification = this.model.getBinaryClassification(leftCategory, rightCategory);
                SimpleMatrix leftVector = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
                SimpleMatrix rightVector = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
                SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
                if (this.model.op.useTensors) {
                    SimpleTensor tensor = this.model.getBinaryTensor(leftCategory, rightCategory);
                    SimpleMatrix tensorIn = NeuralUtils.concatenate(leftVector, rightVector);
                    SimpleMatrix tensorOut = tensor.bilinearProducts(tensorIn);
                    nodeVector = NeuralUtils.elementwiseApplyTanh((SimpleMatrix)((SimpleMatrix)W.mult((SimpleBase)childrenVector)).plus((SimpleBase)tensorOut));
                } else {
                    nodeVector = NeuralUtils.elementwiseApplyTanh((SimpleMatrix)W.mult((SimpleBase)childrenVector));
                }
            } else {
                log.info("SentimentCostAndGradient: warning: Tree not correctly binarized: " + tree);
                throw new AssertionError((Object)"Tree not correctly binarized");
            }
        }
        SimpleMatrix predictions = NeuralUtils.softmax((SimpleMatrix)classification.mult((SimpleBase)NeuralUtils.concatenateWithBias(nodeVector)));
        int index = SentimentCostAndGradient.getPredictedClass(predictions);
        if (!(tree.label() instanceof CoreLabel)) {
            log.info("SentimentCostAndGradient: warning: No CoreLabels in nodes: " + tree);
            throw new AssertionError((Object)"Expected CoreLabels in the nodes");
        }
        CoreLabel label = (CoreLabel)tree.label();
        label.set(RNNCoreAnnotations.Predictions.class, predictions);
        label.set(RNNCoreAnnotations.PredictedClass.class, index);
        label.set(RNNCoreAnnotations.NodeVector.class, nodeVector);
    }

    class ScoringProcessor
    implements ThreadsafeProcessor<List<Tree>, ModelDerivatives> {
        ScoringProcessor() {
        }

        @Override
        public ModelDerivatives process(List<Tree> trainingBatch) {
            return SentimentCostAndGradient.this.scoreDerivatives(trainingBatch);
        }

        @Override
        public ThreadsafeProcessor<List<Tree>, ModelDerivatives> newInstance() {
            return this;
        }
    }

    private static class ModelDerivatives {
        public final TwoDimensionalMap<String, String, SimpleMatrix> binaryTD;
        public final TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD;
        public final TwoDimensionalMap<String, String, SimpleMatrix> binaryCD;
        public final Map<String, SimpleMatrix> unaryCD;
        public final Map<String, SimpleMatrix> wordVectorD;
        public double error = 0.0;

        public ModelDerivatives(SentimentModel model) {
            this.binaryTD = ModelDerivatives.initDerivatives(model.binaryTransform);
            this.binaryTensorTD = model.op.useTensors ? ModelDerivatives.initTensorDerivatives(model.binaryTensors) : TwoDimensionalMap.treeMap();
            this.binaryCD = !model.op.combineClassification ? ModelDerivatives.initDerivatives(model.binaryClassification) : TwoDimensionalMap.treeMap();
            this.unaryCD = ModelDerivatives.initDerivatives(model.unaryClassification);
            this.wordVectorD = Generics.newTreeMap();
        }

        public void add(ModelDerivatives other) {
            ModelDerivatives.addMatrices(this.binaryTD, other.binaryTD);
            ModelDerivatives.addTensors(this.binaryTensorTD, other.binaryTensorTD);
            ModelDerivatives.addMatrices(this.binaryCD, other.binaryCD);
            ModelDerivatives.addMatrices(this.unaryCD, other.unaryCD);
            ModelDerivatives.addMatrices(this.wordVectorD, other.wordVectorD);
            this.error += other.error;
        }

        public static void addMatrices(TwoDimensionalMap<String, String, SimpleMatrix> first, TwoDimensionalMap<String, String, SimpleMatrix> second) {
            for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : first) {
                if (!second.contains(entry.getFirstKey(), entry.getSecondKey())) continue;
                first.put(entry.getFirstKey(), entry.getSecondKey(), (SimpleMatrix)entry.getValue().plus((SimpleBase)second.get(entry.getFirstKey(), entry.getSecondKey())));
            }
            for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : second) {
                if (first.contains(entry.getFirstKey(), entry.getSecondKey())) continue;
                first.put(entry.getFirstKey(), entry.getSecondKey(), entry.getValue());
            }
        }

        public static void addTensors(TwoDimensionalMap<String, String, SimpleTensor> first, TwoDimensionalMap<String, String, SimpleTensor> second) {
            for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : first) {
                if (!second.contains(entry.getFirstKey(), entry.getSecondKey())) continue;
                first.put(entry.getFirstKey(), entry.getSecondKey(), entry.getValue().plus(second.get(entry.getFirstKey(), entry.getSecondKey())));
            }
            for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : second) {
                if (first.contains(entry.getFirstKey(), entry.getSecondKey())) continue;
                first.put(entry.getFirstKey(), entry.getSecondKey(), entry.getValue());
            }
        }

        public static void addMatrices(Map<String, SimpleMatrix> first, Map<String, SimpleMatrix> second) {
            for (Map.Entry<String, SimpleMatrix> entry : first.entrySet()) {
                if (!second.containsKey(entry.getKey())) continue;
                first.put(entry.getKey(), (SimpleMatrix)entry.getValue().plus((SimpleBase)second.get(entry.getKey())));
            }
            for (Map.Entry<String, SimpleMatrix> entry : second.entrySet()) {
                if (first.containsKey(entry.getKey())) continue;
                first.put(entry.getKey(), entry.getValue());
            }
        }

        private static TwoDimensionalMap<String, String, SimpleMatrix> initDerivatives(TwoDimensionalMap<String, String, SimpleMatrix> map) {
            TwoDimensionalMap<String, String, SimpleMatrix> derivatives = TwoDimensionalMap.treeMap();
            for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : map) {
                int numRows = entry.getValue().numRows();
                int numCols = entry.getValue().numCols();
                derivatives.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
            }
            return derivatives;
        }

        private static TwoDimensionalMap<String, String, SimpleTensor> initTensorDerivatives(TwoDimensionalMap<String, String, SimpleTensor> map) {
            TwoDimensionalMap<String, String, SimpleTensor> derivatives = TwoDimensionalMap.treeMap();
            for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : map) {
                int numRows = entry.getValue().numRows();
                int numCols = entry.getValue().numCols();
                int numSlices = entry.getValue().numSlices();
                derivatives.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleTensor(numRows, numCols, numSlices));
            }
            return derivatives;
        }

        private static Map<String, SimpleMatrix> initDerivatives(Map<String, SimpleMatrix> map) {
            TreeMap<String, SimpleMatrix> derivatives = Generics.newTreeMap();
            for (Map.Entry<String, SimpleMatrix> entry : map.entrySet()) {
                int numRows = entry.getValue().numRows();
                int numCols = entry.getValue().numCols();
                derivatives.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
            }
            return derivatives;
        }
    }
}

