/*
 * Decompiled with CFR 0.152.
 */
package net.librec.math.structure;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DataSet;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.RandomAccessSparseVector;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.TensorEntry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class SparseTensor
implements DataSet,
Iterable<TensorEntry>,
Serializable {
    private static final Log LOG = LogFactory.getLog(SparseTensor.class);
    private static final long serialVersionUID = 2487513413901432943L;
    public int numDimensions;
    public int[] dimensions;
    public List<Integer>[] ndKeys;
    public List<Double> values;
    private Multimap<Integer, Integer>[] keyIndices;
    private List<Integer> indexedDimensions;
    private Set<Integer> indexedDimensionsSet;
    private int userDimension;
    private int itemDimension;

    public SparseTensor(int ... dims) {
        this(dims, null, null);
    }

    public SparseTensor(int[] dims, List<Integer>[] nds, List<Double> vals) {
        if (dims.length < 3) {
            throw new Error("The dimension of a tensor cannot be smaller than 3!");
        }
        this.numDimensions = dims.length;
        this.dimensions = new int[this.numDimensions];
        this.ndKeys = new List[this.numDimensions];
        this.keyIndices = new Multimap[this.numDimensions];
        for (int d = 0; d < this.numDimensions; ++d) {
            this.dimensions[d] = dims[d];
            this.ndKeys[d] = nds == null ? new ArrayList<Integer>() : new ArrayList<Integer>(nds[d]);
            this.keyIndices[d] = HashMultimap.create();
        }
        this.values = vals == null ? new ArrayList<Double>() : new ArrayList<Double>(vals);
        this.indexedDimensions = new ArrayList<Integer>(this.numDimensions);
        this.indexedDimensionsSet = new HashSet<Integer>((int)((double)this.numDimensions / 0.7));
    }

    public SparseTensor clone() {
        SparseTensor res = new SparseTensor(this.dimensions);
        for (int d = 0; d < this.numDimensions; ++d) {
            res.ndKeys[d].addAll(this.ndKeys[d]);
            res.keyIndices[d].putAll(this.keyIndices[d]);
        }
        res.values.addAll(this.values);
        res.indexedDimensions.addAll(this.indexedDimensions);
        res.indexedDimensionsSet.addAll(this.indexedDimensionsSet);
        res.userDimension = this.userDimension;
        res.itemDimension = this.itemDimension;
        return res;
    }

    public void add(double val, int ... keys) throws Exception {
        int index = this.findIndex(keys);
        if (index >= 0) {
            this.values.set(index, this.values.get(index) + val);
        } else {
            this.set(val, keys);
        }
    }

    public void set(double val, int ... keys) throws Exception {
        int index = this.findIndex(keys);
        if (index >= 0) {
            this.values.set(index, val);
            return;
        }
        for (int d = 0; d < this.numDimensions; ++d) {
            this.ndKeys[d].add(keys[d]);
            if (!this.isIndexed(d)) continue;
            this.keyIndices[d].put(keys[d], this.ndKeys[d].size() - 1);
        }
        this.values.add(val);
    }

    public boolean remove(int ... keys) throws Exception {
        int index = this.findIndex(keys);
        if (index < 0) {
            return false;
        }
        for (int d = 0; d < this.numDimensions; ++d) {
            this.ndKeys[d].remove(index);
            if (!this.isIndexed(d)) continue;
            this.buildIndex(d);
        }
        this.values.remove(index);
        return true;
    }

    public List<Integer> getIndices(int user, int item) {
        ArrayList<Integer> res = new ArrayList<Integer>();
        Collection<Integer> indices = this.getIndex(this.userDimension, user);
        for (int index : indices) {
            if (this.key(this.itemDimension, index) != item) continue;
            res.add(index);
        }
        return res;
    }

    public List<Integer> getTargetKeyFromSubKey(Integer[] subKey) throws Exception {
        int d;
        Collection<Integer> indices;
        ArrayList<Integer> res = new ArrayList<Integer>();
        if (subKey.length != this.numDimensions - 1) {
            throw new Exception("The given input does not match with the subKey dimension!");
        }
        if (this.values.size() == 0) {
            return null;
        }
        if (this.indexedDimensions.size() == 0) {
            this.buildIndex(0);
        }
        if ((indices = this.keyIndices[d = this.indexedDimensions.get(0).intValue()].get(subKey[d])) == null || indices.size() == 0) {
            return null;
        }
        for (int index : indices) {
            boolean found = true;
            for (int dd = 0; dd < this.numDimensions - 1; ++dd) {
                if (subKey[dd].intValue() == this.key(dd, index)) continue;
                found = false;
                break;
            }
            if (!found) continue;
            res.add(this.ndKeys[this.numDimensions - 1].get(index));
        }
        return res;
    }

    private int findIndex(int ... keys) throws Exception {
        int d;
        Collection<Integer> indices;
        if (keys.length != this.numDimensions) {
            throw new Exception("The given input does not match with the tensor dimension!");
        }
        if (this.values.size() == 0) {
            return -1;
        }
        if (this.indexedDimensions.size() == 0) {
            this.buildIndex(0);
        }
        if ((indices = this.keyIndices[d = this.indexedDimensions.get(0).intValue()].get(keys[d])) == null || indices.size() == 0) {
            return -1;
        }
        for (int index : indices) {
            boolean found = true;
            for (int dd = 0; dd < this.numDimensions; ++dd) {
                if (keys[dd] == this.key(dd, index)) continue;
                found = false;
                break;
            }
            if (!found) continue;
            return index;
        }
        return -1;
    }

    public RandomAccessSparseVector fiber(int dim, int ... keys) {
        if (keys.length != this.numDimensions - 1 || this.size() < 1) {
            throw new Error("The input indices do not match the fiber specification!");
        }
        int d = -1;
        if (this.indexedDimensions.size() == 0 || this.indexedDimensionsSet.contains(dim) && this.indexedDimensionsSet.size() == 1) {
            d = dim != 0 ? 0 : 1;
            this.buildIndex(d);
        } else {
            for (int dd : this.indexedDimensions) {
                if (dd == dim) continue;
                d = dd;
                break;
            }
        }
        RandomAccessSparseVector res = new RandomAccessSparseVector(this.dimensions[dim]);
        Collection<Integer> indices = this.keyIndices[d].get(keys[d < dim ? d : d - 1]);
        if (indices == null || indices.size() == 0) {
            return res;
        }
        for (int index : indices) {
            boolean found = true;
            int ndi = 0;
            for (int dd = 0; dd < this.numDimensions; ++dd) {
                if (dd == dim || keys[ndi++] == this.key(dd, index)) continue;
                found = false;
                break;
            }
            if (!found) continue;
            res.set(this.key(dim, index), this.value(index));
        }
        return res;
    }

    public boolean contains(int ... keys) throws Exception {
        return this.findIndex(keys) >= 0;
    }

    public boolean isIndexed(int d) {
        return this.indexedDimensionsSet.contains(d);
    }

    public boolean isCubical() {
        int dim = this.dimensions[0];
        for (int d = 1; d < this.numDimensions; ++d) {
            if (dim == this.dimensions[d]) continue;
            return false;
        }
        return true;
    }

    public boolean isDiagonal() {
        for (TensorEntry te : this) {
            double val = te.get();
            if (val == 0.0) continue;
            int i = te.key(0);
            for (int d = 0; d < this.numDimensions; ++d) {
                int j = te.key(d);
                if (i == j) continue;
                return false;
            }
        }
        return true;
    }

    public double get(int ... keys) throws Exception {
        assert (keys.length == this.numDimensions);
        int index = this.findIndex(keys);
        return index < 0 ? 0.0 : this.values.get(index);
    }

    public void shuffle() {
        int len = this.size();
        for (int i = 0; i < len; ++i) {
            int j = i + Randoms.uniform(len - i);
            double temp = this.values.get(i);
            this.values.set(i, this.values.get(j));
            this.values.set(j, temp);
            for (int d = 0; d < this.numDimensions; ++d) {
                int ikey = this.key(d, i);
                int jkey = this.key(d, j);
                this.ndKeys[d].set(i, jkey);
                this.ndKeys[d].set(j, ikey);
                if (!this.isIndexed(d)) continue;
                this.keyIndices[d].remove(jkey, j);
                this.keyIndices[d].put(jkey, i);
                this.keyIndices[d].remove(ikey, i);
                this.keyIndices[d].put(ikey, j);
            }
        }
    }

    public void buildIndex(int ... dims) {
        for (int d : dims) {
            this.keyIndices[d].clear();
            for (int index = 0; index < this.ndKeys[d].size(); ++index) {
                this.keyIndices[d].put(this.key(d, index), index);
            }
            if (this.indexedDimensionsSet.contains(d)) continue;
            this.indexedDimensions.add(d);
            this.indexedDimensionsSet.add(d);
        }
    }

    public void buildIndices() {
        int d = 0;
        while (d < this.numDimensions) {
            this.buildIndex(d++);
        }
    }

    public Collection<Integer> getIndex(int d, int key) {
        if (!this.isIndexed(d)) {
            this.buildIndex(d);
        }
        return this.keyIndices[d].get(key);
    }

    public int[] keys(int index) {
        int[] res = new int[this.numDimensions];
        for (int d = 0; d < this.numDimensions; ++d) {
            res[d] = this.key(d, index);
        }
        return res;
    }

    public int key(int d, int index) {
        return this.ndKeys[d].get(index);
    }

    public double value(int index) {
        return this.values.get(index);
    }

    public List<Integer> getRelevantKeys(int sd, int key, int td) {
        Collection<Integer> indices = this.getIndex(sd, key);
        ArrayList<Integer> res = null;
        if (indices != null) {
            res = new ArrayList<Integer>();
            for (int index : indices) {
                res.add(this.key(td, index));
            }
        }
        return res;
    }

    @Override
    public int size() {
        return this.values.size();
    }

    public SequentialAccessSparseMatrix slice(int rowDim, int colDim, int ... otherKeys) {
        Collection<Integer> indices;
        boolean cond3;
        if (otherKeys.length != this.numDimensions - 2) {
            throw new Error("The input dimensions do not match the tensor specification!");
        }
        int d = -1;
        boolean cond1 = this.indexedDimensions.size() == 0;
        boolean cond2 = (this.indexedDimensionsSet.contains(rowDim) || this.indexedDimensionsSet.contains(colDim)) && this.indexedDimensions.size() == 1;
        boolean bl = cond3 = this.indexedDimensionsSet.contains(rowDim) && this.indexedDimensionsSet.contains(colDim) && this.indexedDimensions.size() == 2;
        if (cond1 || cond2 || cond3) {
            for (d = 0; d < this.numDimensions && (d == rowDim || d == colDim); ++d) {
            }
            this.buildIndex(d);
        } else {
            for (int dd : this.indexedDimensions) {
                if (dd == rowDim || dd == colDim) continue;
                d = dd;
                break;
            }
        }
        int key = -1;
        int i = 0;
        for (int dim = 0; dim < this.numDimensions; ++dim) {
            if (dim == rowDim || dim == colDim) continue;
            if (dim == d) {
                key = otherKeys[i];
                break;
            }
            ++i;
        }
        if ((indices = this.keyIndices[d].get(key)) == null || indices.size() == 0) {
            return null;
        }
        HashBasedTable<Integer, Integer, Double> dataTable = HashBasedTable.create();
        for (int index : indices) {
            boolean found = true;
            int j = 0;
            for (int dd = 0; dd < this.numDimensions; ++dd) {
                if (dd == rowDim || dd == colDim || otherKeys[j++] == this.key(dd, index)) continue;
                found = false;
                break;
            }
            if (!found) continue;
            int row = this.ndKeys[rowDim].get(index);
            int col = this.ndKeys[colDim].get(index);
            double val = this.values.get(index);
            dataTable.put(row, col, val);
        }
        return new SequentialAccessSparseMatrix(this.dimensions[rowDim], this.dimensions[colDim], dataTable);
    }

    public SequentialAccessSparseMatrix matricization(int n) {
        int numRows2 = this.dimensions[n];
        int numCols = 1;
        for (int d = 0; d < this.numDimensions; ++d) {
            if (d == n) continue;
            numCols *= this.dimensions[d];
        }
        HashBasedTable<Integer, Integer, Double> dataTable = HashBasedTable.create();
        for (TensorEntry te : this) {
            int[] keys = te.keys();
            int i = keys[n];
            int j = 0;
            for (int k = 0; k < this.numDimensions; ++k) {
                if (k == n) continue;
                int ik = keys[k];
                int jk = 1;
                for (int m = 0; m < k; ++m) {
                    if (m == n) continue;
                    jk *= this.dimensions[m];
                }
                j += ik * jk;
            }
            dataTable.put(i, j, te.get());
        }
        return new SequentialAccessSparseMatrix(numRows2, numCols, dataTable);
    }

    public SparseTensor modeProduct(DenseMatrix mat, int dim) throws Exception {
        if (this.dimensions[dim] != mat.columnSize()) {
            throw new Exception("Dimensions of a tensor and a matrix do not match for n-mode product!");
        }
        int[] dims = new int[this.numDimensions];
        for (int i = 0; i < dims.length; ++i) {
            dims[i] = i == dim ? mat.rowSize() : this.dimensions[i];
        }
        SparseTensor res = new SparseTensor(dims);
        for (TensorEntry te : this) {
            double val = te.get();
            int[] keys = te.keys();
            int i = keys[dim];
            for (int j = 0; j < mat.rowSize(); ++j) {
                int[] ks = new int[this.numDimensions];
                for (int k = 0; k < ks.length; ++k) {
                    ks[k] = k == dim ? j : keys[k];
                }
                res.add(val * mat.get(j, i), ks);
            }
        }
        return res;
    }

    public SparseTensor modeProduct(DenseVector vec, int dim) throws Exception {
        if (this.dimensions[dim] != vec.cardinality()) {
            throw new Exception("Dimensions of a tensor and a vector do not match for n-mode product!");
        }
        int[] dims = new int[this.numDimensions];
        for (int i = 0; i < dims.length; ++i) {
            dims[i] = i == dim ? 1 : this.dimensions[i];
        }
        SparseTensor res = new SparseTensor(dims);
        for (TensorEntry te : this) {
            double val = te.get();
            int[] keys = te.keys();
            int i = keys[dim];
            int[] ks = new int[this.numDimensions];
            for (int k = 0; k < ks.length; ++k) {
                ks[k] = k == dim ? 1 : keys[k];
            }
            res.add(val * vec.get(i), ks);
        }
        return res;
    }

    public SequentialAccessSparseMatrix rateMatrix() {
        HashBasedTable<Integer, Integer, Double> dataTable = HashBasedTable.create();
        for (TensorEntry te : this) {
            int u = te.key(this.userDimension);
            int i = te.key(this.itemDimension);
            dataTable.put(u, i, te.get());
        }
        return new SequentialAccessSparseMatrix(this.dimensions[this.userDimension], this.dimensions[this.itemDimension], dataTable);
    }

    @Override
    public Iterator<TensorEntry> iterator() {
        return new TensorIterator();
    }

    public double norm() {
        double res = 0.0;
        for (double val : this.values) {
            res += val * val;
        }
        return Math.sqrt(res);
    }

    public double mean() {
        double res = 0.0;
        for (double val : this.values) {
            res += val;
        }
        return res / (double)this.size();
    }

    public double innerProduct(SparseTensor st) throws Exception {
        if (!this.isDimMatch(st)) {
            throw new Exception("The dimensions of two sparse tensors do not match!");
        }
        double res = 0.0;
        for (TensorEntry te : this) {
            double v1 = te.get();
            double v2 = st.get(te.keys());
            res += v1 * v2;
        }
        return res;
    }

    public boolean isDimMatch(SparseTensor st) {
        if (this.numDimensions != st.numDimensions) {
            return false;
        }
        boolean match = true;
        for (int d = 0; d < this.numDimensions; ++d) {
            if (this.dimensions[d] == st.dimensions[d]) continue;
            match = false;
            break;
        }
        return match;
    }

    public int getUserDimension() {
        return this.userDimension;
    }

    public int getIndexDimension(int index) {
        assert (index < this.numDimensions);
        return this.dimensions[index];
    }

    public void setUserDimension(int userDimension) {
        this.userDimension = userDimension;
    }

    public int getItemDimension() {
        return this.itemDimension;
    }

    public void setItemDimension(int itemDimension) {
        this.itemDimension = itemDimension;
    }

    public int[] dimensions() {
        return this.dimensions;
    }

    public int numDimensions() {
        return this.numDimensions;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("N-Dimension: ").append(this.numDimensions).append(", Size: ").append(this.size()).append("\n");
        for (int index = 0; index < this.values.size(); ++index) {
            for (int d = 0; d < this.numDimensions; ++d) {
                sb.append(this.key(d, index)).append("\t");
            }
            sb.append(this.value(index)).append("\n");
        }
        return sb.toString();
    }

    public static void main(String[] args) throws Exception {
        int[] a = new int[]{2, 2, 2};
        List[] ndLists = null;
        ndLists = new List[3];
        for (int d = 0; d < 3; ++d) {
            ndLists[d] = new ArrayList();
        }
        ArrayList<Double> vals = new ArrayList<Double>();
        ndLists[0].add(1);
        ndLists[0].add(2);
        ndLists[0].add(3);
        ndLists[1].add(1);
        ndLists[1].add(2);
        ndLists[1].add(4);
        ndLists[2].add(1);
        ndLists[2].add(2);
        ndLists[2].add(5);
        vals.add(1.0);
        vals.add(2.0);
        vals.add(3.0);
        SparseTensor st1 = new SparseTensor(a, ndLists, vals);
        LOG.debug(st1);
        LOG.debug(String.format("Index of keys (1, 2) = {}", st1.getTargetKeyFromSubKey(new Integer[]{3, 4})));
        SparseTensor st = new SparseTensor(4, 4, 6);
        LOG.debug(st);
        st.set(1.0, 1, 0, 0);
        st.set(1.5, 1, 0, 0);
        st.set(2.0, 1, 1, 0);
        st.set(3.0, 2, 0, 0);
        st.set(4.0, 1, 3, 0);
        st.set(5.0, 1, 0, 5);
        st.set(6.0, 3, 1, 4);
        LOG.debug(st);
        LOG.debug(String.format("Keys (1, 0, 0) = {}", st.get(1, 0, 0)));
        LOG.debug(String.format("Keys (1, 1, 0) = {}", st.get(1, 1, 0)));
        LOG.debug(String.format("Keys (1, 2, 0) = {}", st.get(1, 2, 0)));
        LOG.debug(String.format("Keys (2, 0, 0) = {}", st.get(2, 0, 0)));
        LOG.debug(String.format("Keys (1, 0, 6) = {}", st.get(1, 0, 6)));
        LOG.debug(String.format("Keys (3, 1, 4) = {}", st.get(3, 1, 4)));
        LOG.debug(String.format("Index of dimension 0 key 1 = {}", st.getIndex(0, 1)));
        LOG.debug(String.format("Index of dimension 1 key 3 = {}", st.getIndex(1, 3)));
        LOG.debug(String.format("Index of dimension 2 key 1 = {}", st.getIndex(2, 1)));
        LOG.debug(String.format("Index of dimension 2 key 6 = {}", st.getIndex(2, 6)));
        st.set(4.5, 2, 1, 1);
        LOG.debug(st);
        LOG.debug(String.format("Index of dimension 2 key 1 = {}", st.getIndex(2, 1)));
        st.remove(2, 1, 1);
        LOG.debug(String.format("Index of dimension 2 key 1 = {}", st.getIndex(2, 1)));
        LOG.debug(String.format("Index of keys (1, 2, 0) = {}, value = {}", st.findIndex(1, 2, 0), st.get(1, 2, 0)));
        LOG.debug(String.format("Index of keys (3, 1, 4) = {}, value = {}", st.findIndex(3, 1, 4), st.get(3, 1, 4)));
        LOG.debug(String.format("Keys in dimension 2 associated with dimension 0 key 1 = {}", st.getRelevantKeys(0, 1, 2)));
        LOG.debug(String.format("norm = {}", st.norm()));
        SparseTensor st2 = st.clone();
        LOG.debug(String.format("make a clone = {}", st2));
        LOG.debug(String.format("inner with the clone = {}", st.innerProduct(st2)));
        st.set(2.5, 1, 0, 0);
        st2.remove(1, 0, 0);
        LOG.debug(String.format("st1 = {}", st));
        LOG.debug(String.format("st2 = {}", st2));
        LOG.debug(st);
        int[] aa = new int[]{0, 0};
        LOG.debug(String.format("fiber (1, 1, 0) = {}", st.fiber(1, 1, 0)));
        LOG.debug(String.format("fiber (2, 1, 0) = {}", st.fiber(2, 1, 0)));
        LOG.debug(String.format("slice (0, 1, 0) = {}", st.slice(0, 1, 0)));
        LOG.debug(String.format("slice (0, 2, 1) = {}", st.slice(0, 2, 1)));
        LOG.debug(String.format("slice (1, 2, 1) = {}", st.slice(1, 2, 1)));
        for (TensorEntry te : st) {
            te.set(te.get() + 0.588);
        }
        LOG.debug(String.format("Before shuffle: {}", st));
        st.shuffle();
        LOG.debug(String.format("After shuffle: {}", st));
        st = new SparseTensor(3, 4, 2);
        st.set(1.0, 0, 0, 0);
        st.set(4.0, 0, 1, 0);
        st.set(7.0, 0, 2, 0);
        st.set(10.0, 0, 3, 0);
        st.set(2.0, 1, 0, 0);
        st.set(5.0, 1, 1, 0);
        st.set(8.0, 1, 2, 0);
        st.set(11.0, 1, 3, 0);
        st.set(3.0, 2, 0, 0);
        st.set(6.0, 2, 1, 0);
        st.set(9.0, 2, 2, 0);
        st.set(12.0, 2, 3, 0);
        st.set(13.0, 0, 0, 1);
        st.set(16.0, 0, 1, 1);
        st.set(19.0, 0, 2, 1);
        st.set(22.0, 0, 3, 1);
        st.set(14.0, 1, 0, 1);
        st.set(17.0, 1, 1, 1);
        st.set(20.0, 1, 2, 1);
        st.set(23.0, 1, 3, 1);
        st.set(15.0, 2, 0, 1);
        st.set(18.0, 2, 1, 1);
        st.set(21.0, 2, 2, 1);
        st.set(24.0, 2, 3, 1);
        LOG.debug(String.format("A new tensor = {}", st));
        LOG.debug(String.format("Mode X0 unfoldings = {}", st.matricization(0)));
        LOG.debug(String.format("Mode X1 unfoldings = {}", st.matricization(1)));
        LOG.debug(String.format("Mode X2 unfoldings = {}", st.matricization(2)));
    }

    private class SparseTensorEntry
    implements TensorEntry {
        private int index = -1;

        private SparseTensorEntry() {
        }

        public SparseTensorEntry update(int index) {
            this.index = index;
            return this;
        }

        @Override
        public int key(int d) {
            return SparseTensor.this.ndKeys[d].get(this.index);
        }

        @Override
        public double get() {
            return SparseTensor.this.values.get(this.index);
        }

        @Override
        public void set(double value) {
            SparseTensor.this.values.set(this.index, value);
        }

        @Override
        public void remove() {
            for (int d = 0; d < SparseTensor.this.numDimensions; ++d) {
                if (SparseTensor.this.isIndexed(d)) {
                    SparseTensor.this.keyIndices[d].remove(this.key(d), this.index);
                }
                SparseTensor.this.ndKeys[d].remove(this.index);
            }
            SparseTensor.this.values.remove(this.index);
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            for (int d = 0; d < SparseTensor.this.numDimensions; ++d) {
                sb.append(this.key(d)).append("\t");
            }
            sb.append(this.get());
            return sb.toString();
        }

        @Override
        public int[] keys() {
            int[] res = new int[SparseTensor.this.numDimensions];
            for (int d = 0; d < SparseTensor.this.numDimensions; ++d) {
                res[d] = this.key(d);
            }
            return res;
        }
    }

    private class TensorIterator
    implements Iterator<TensorEntry> {
        private int index = 0;
        private SparseTensorEntry entry = new SparseTensorEntry();

        private TensorIterator() {
        }

        @Override
        public boolean hasNext() {
            return this.index < SparseTensor.this.values.size();
        }

        @Override
        public TensorEntry next() {
            return this.entry.update(this.index++);
        }

        @Override
        public void remove() {
            this.entry.remove();
        }
    }
}

