/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import com.google.common.collect.HashBasedTable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import net.librec.common.LibrecException;
import net.librec.data.convertor.appender.AuxiliaryDataAppender;
import net.librec.data.model.ArffInstance;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.MatrixFactorizationRecommender;

public class ReMFRecommender
extends MatrixFactorizationRecommender {
    private String hierarchy_side;
    private double rate = 5.0E-5;
    private double alpha;
    private int continentNum = 0;
    private int countryNum = 0;
    private int cityNum = 0;
    private int total_node = 0;
    private int non_leaf = 0;
    private Map<Integer, ArrayList<ArrayList>> node = new HashMap<Integer, ArrayList<ArrayList>>();
    private ArrayList<String> continentList = new ArrayList();
    private ArrayList<String> countryList = new ArrayList();
    private ArrayList<String> cityList = new ArrayList();
    private Map<String, Integer> ContinentMap = new HashMap<String, Integer>();
    private Map<String, Integer> CountryMap = new HashMap<String, Integer>();
    private Map<String, Integer> CityMap = new HashMap<String, Integer>();
    private Map<Integer, String> ConverseMap = new HashMap<Integer, String>();
    protected static Map<String, ArrayList<String>> hierarchy = new HashMap<String, ArrayList<String>>();
    private Map<Integer, String> userIdxToUserId;
    private Map<Integer, String> itemIdxToItemId;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.alpha = this.conf.getDouble("rec.alpha");
        this.hierarchy_side = this.conf.get("rec.side");
        this.userFactors.init(0.6);
        this.itemFactors.init(0.6);
        this.userIdxToUserId = this.context.getDataModel().getUserMappingData().inverse();
        this.itemIdxToItemId = this.context.getDataModel().getItemMappingData().inverse();
        try {
            hierarchy = this.readHierarchy();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        this.getLayers();
        this.getIDs();
        this.createHierarchy();
        double[][] coef = new double[this.non_leaf][2];
        for (int i = 0; i < coef.length; ++i) {
            coef[i][0] = 0.5;
            coef[i][1] = 0.5;
        }
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            int i;
            int y;
            int x;
            double L2g;
            int i2;
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix QS = new DenseMatrix(this.numItems, this.numFactors);
            double[][] transfer = new double[this.non_leaf][2];
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                if (ruj <= 0.0) continue;
                double pred = this.predict(u, j);
                double euj = pred - ruj;
                this.loss += euj * euj;
                for (int f = 0; f < this.numFactors; ++f) {
                    double puf = this.userFactors.get(u, f);
                    double qjf = this.itemFactors.get(j, f);
                    PS.plus(u, f, euj * qjf + (double)this.regUser * puf);
                    QS.plus(j, f, euj * puf + (double)this.regItem * qjf);
                    this.loss += (double)this.regUser * puf * puf + (double)this.regItem * qjf * qjf;
                }
            }
            HashBasedTable<Integer, Integer, Double> L2node = HashBasedTable.create();
            for (int nodeid = this.non_leaf; nodeid <= this.total_node; ++nodeid) {
                Object[] id;
                if (!this.node.containsKey(nodeid) || (id = this.node.get(nodeid).get(1).toArray()).length < 2) continue;
                double value = 0.0;
                for (int j = 0; j < id.length; ++j) {
                    for (int g = j + 1; g < id.length; ++g) {
                        int idj = (Integer)id[j];
                        int idg = (Integer)id[g];
                        for (int f = 0; f < this.numFactors; ++f) {
                            double ejg = 0.0;
                            if (this.hierarchy_side.equals("user")) {
                                ejg = this.userFactors.get(idj, f) - this.userFactors.get(idg, f);
                                PS.plus(idj, f, this.alpha * ejg);
                            }
                            if (this.hierarchy_side.equals("item")) {
                                ejg = this.itemFactors.get(idj, f) - this.itemFactors.get(idg, f);
                                QS.plus(idj, f, this.alpha * ejg);
                            }
                            this.loss += this.alpha * ejg * ejg;
                            value += ejg * ejg;
                        }
                    }
                }
                L2node.put(nodeid, nodeid, value);
            }
            for (i2 = this.non_leaf; i2 <= this.total_node; ++i2) {
                if (!this.node.containsKey(i2)) continue;
                for (int j = i2 + 1; j <= this.total_node; ++j) {
                    if (!this.node.containsKey(j)) continue;
                    ArrayList<Integer> list = new ArrayList<Integer>();
                    for (Object pi : this.node.get(i2).get(0)) {
                        for (Object pj : this.node.get(j).get(0)) {
                            if (pi != pj) continue;
                            list.add((Integer)pi);
                        }
                    }
                    double value = 0.0;
                    for (Object idi : this.node.get(i2).get(1)) {
                        for (Object idj : this.node.get(j).get(1)) {
                            int lastid = list.size() - 1;
                            double reg = coef[(Integer)list.get(lastid)][0];
                            for (int num = lastid - 1; num >= 0; --num) {
                                reg += coef[(Integer)list.get(num)][0] + reg * coef[(Integer)list.get(num)][1];
                            }
                            int idi_1 = (Integer)idi;
                            int idj_1 = (Integer)idj;
                            for (int f = 0; f < this.numFactors; ++f) {
                                double eij = 0.0;
                                if (this.hierarchy_side.equals("user")) {
                                    eij = this.userFactors.get(idi_1, f) - this.userFactors.get(idj_1, f);
                                    PS.plus(idi_1, f, this.alpha * reg * eij);
                                }
                                if (this.hierarchy_side.equals("item")) {
                                    eij = this.itemFactors.get(idi_1, f) - this.itemFactors.get(idj_1, f);
                                    QS.plus(idi_1, f, this.alpha * reg * eij);
                                }
                                this.loss += this.alpha * reg * eij * eij;
                                value += eij * eij;
                            }
                        }
                    }
                    L2node.put(i2, j, value);
                    L2node.put(j, i2, value);
                }
            }
            for (i2 = this.continentNum + 1; i2 < this.non_leaf; ++i2) {
                if (!this.node.containsKey(i2)) continue;
                double value = this.getValueG(i2, coef);
                L2g = 0.0;
                Object[] city = this.node.get(i2).get(2).toArray();
                for (x = 0; x < city.length; ++x) {
                    int cityx = (Integer)city[x];
                    if (L2node.contains(cityx, cityx)) {
                        L2g += ((Double)L2node.get(cityx, cityx)).doubleValue();
                    }
                    for (y = x + 1; y < city.length; ++y) {
                        int cityy = (Integer)city[y];
                        if (!L2node.contains(cityy, cityy)) continue;
                        L2g += ((Double)L2node.get(cityx, cityy)).doubleValue();
                    }
                }
                L2node.put(i2, i2, L2g);
                transfer[i2][0] = L2g * value;
            }
            for (i2 = 1; i2 <= this.continentNum; ++i2) {
                if (!this.node.containsKey(i2)) continue;
                double value = this.getValueG(i2, coef);
                L2g = 0.0;
                Object[] country = this.node.get(i2).get(2).toArray();
                for (x = 0; x < country.length; ++x) {
                    int countryx = (Integer)country[x];
                    if (L2node.contains(countryx, countryx)) {
                        L2g += ((Double)L2node.get(countryx, countryx)).doubleValue();
                    }
                    for (y = x + 1; y < country.length; ++y) {
                        int countryy = (Integer)country[y];
                        for (Object cityx : this.node.get(countryx).get(2)) {
                            for (Object cityy : this.node.get(countryy).get(2)) {
                                if (!L2node.contains(cityx, cityy)) continue;
                                L2g += ((Double)L2node.get(cityx, cityy)).doubleValue();
                            }
                        }
                    }
                }
                L2node.put(i2, i2, L2g);
                transfer[i2][0] = L2g * value;
            }
            double L2g2 = 0.0;
            for (i = 1; i <= this.continentNum; ++i) {
                if (!this.node.containsKey(i) || !L2node.contains(i, i)) continue;
                L2g2 += ((Double)L2node.get(i, i)).doubleValue();
            }
            transfer[0][0] = L2g2;
            for (i = 0; i < coef.length; ++i) {
                coef[i][0] = coef[i][0] - this.rate * Math.sqrt(Math.sqrt(transfer[i][0]));
                if (coef[i][0] < 0.0) {
                    coef[i][0] = 0.0;
                } else if (coef[i][0] > 1.0) {
                    coef[i][0] = 1.0;
                }
                coef[i][1] = 1.0 - coef[i][0];
                int number = 0;
                if (!this.node.containsKey(i)) continue;
                number = this.node.get(i).get(1).size();
            }
            this.userFactors = this.userFactors.plus(PS.times(-this.learnRate));
            this.itemFactors = this.itemFactors.plus(QS.times(-this.learnRate));
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
            this.LOG.info("learnRate:" + this.learnRate);
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double pred = this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx));
        return pred;
    }

    protected void getLayers() {
        for (String user : hierarchy.keySet()) {
            String continent = hierarchy.get(user).get(0);
            String country = hierarchy.get(user).get(1);
            String city = hierarchy.get(user).get(2);
            if (!this.continentList.contains(continent)) {
                this.continentList.add(continent);
            }
            if (!this.countryList.contains(country)) {
                this.countryList.add(country);
            }
            if (this.cityList.contains(city)) continue;
            this.cityList.add(city);
        }
        this.continentNum = this.continentList.size();
        this.countryNum = this.countryList.size();
        this.cityNum = this.cityList.size();
        this.total_node = this.continentNum + this.countryNum + this.cityNum;
        this.non_leaf = this.continentNum + this.countryNum + 1;
        this.LOG.info("continentNum: " + this.continentNum + "; countryNum: " + this.countryNum + "; cityNum: " + this.cityNum);
        this.LOG.info("total number of nodes: " + this.total_node + "; the number of non_leaf nodes: " + this.non_leaf);
    }

    protected void getIDs() {
        int count = 1;
        for (String continent : this.continentList) {
            this.ContinentMap.put(continent, count);
            this.ConverseMap.put(count, continent);
            ++count;
        }
        for (String country : this.countryList) {
            this.CountryMap.put(country, count);
            this.ConverseMap.put(count, country);
            ++count;
        }
        for (String city : this.cityList) {
            this.CityMap.put(city, count);
            this.ConverseMap.put(count, city);
            ++count;
        }
    }

    protected void divideUI(int nodeid, int parentid, int element) {
        if (!this.node.containsKey(nodeid)) {
            ArrayList lists = new ArrayList();
            ArrayList<Integer> parent = new ArrayList<Integer>();
            ArrayList child = new ArrayList();
            ArrayList<Integer> temp = new ArrayList<Integer>();
            parent.addAll(this.node.get(parentid).get(0));
            parent.add(parentid);
            temp.add(element);
            lists.add(parent);
            lists.add(temp);
            lists.add(child);
            this.node.put(nodeid, lists);
        } else {
            this.node.get(nodeid).get(1).add(element);
        }
        if (!this.node.get(parentid).get(2).contains(nodeid)) {
            this.node.get(parentid).get(2).add(nodeid);
        }
    }

    protected void buildLayer(int begin, int end, int layer) {
        for (int num = begin; num < end; ++num) {
            if (!this.node.containsKey(num) || this.node.get(num).get(1).size() < 2) continue;
            for (Object innerid : this.node.get(num).get(1).toArray()) {
                String rawid = null;
                if (this.hierarchy_side.equals("user")) {
                    rawid = this.userIdxToUserId.get(innerid);
                }
                if (this.hierarchy_side.equals("item")) {
                    rawid = this.itemIdxToItemId.get(innerid);
                }
                String feature_rawid = hierarchy.get(rawid).get(layer - 1);
                int feature_innerid = 0;
                if (layer == 2) {
                    feature_innerid = this.CountryMap.get(feature_rawid);
                }
                if (layer == 3) {
                    feature_innerid = this.CityMap.get(feature_rawid);
                }
                this.divideUI(feature_innerid, num, (Integer)innerid);
            }
        }
    }

    protected void createHierarchy() {
        int continentid;
        String continent;
        ArrayList Lists2 = new ArrayList();
        ArrayList parent = new ArrayList();
        ArrayList element = new ArrayList();
        ArrayList child = new ArrayList();
        Lists2.add(parent);
        Lists2.add(element);
        Lists2.add(child);
        this.node.put(0, Lists2);
        if (this.hierarchy_side.equals("user")) {
            for (int uid = 0; uid < this.numUsers; ++uid) {
                String user = this.userIdxToUserId.get(uid);
                if (!hierarchy.containsKey(user)) continue;
                continent = hierarchy.get(user).get(0);
                continentid = this.ContinentMap.get(continent);
                this.divideUI(continentid, 0, uid);
            }
        }
        if (this.hierarchy_side.equals("item")) {
            for (int iid = 0; iid < this.numItems; ++iid) {
                String item = this.itemIdxToItemId.get(iid);
                if (!hierarchy.containsKey(item)) continue;
                continent = hierarchy.get(item).get(0);
                continentid = this.ContinentMap.get(continent);
                this.divideUI(continentid, 0, iid);
            }
        }
        this.buildLayer(0, this.continentNum, 2);
        this.buildLayer(this.continentNum + 1, this.continentNum + this.countryNum, 3);
    }

    protected double getValueG(int id, double[][] coef) {
        double valueg = 1.0;
        for (Object parentid : this.node.get(id).get(0)) {
            int pid = (Integer)parentid;
            valueg *= coef[pid][1];
        }
        return valueg;
    }

    protected Map<String, ArrayList<String>> readHierarchy() {
        HashMap<String, ArrayList<String>> hierarchy = new HashMap<String, ArrayList<String>>();
        ArrayList<ArffInstance> auxiliaryData = ((AuxiliaryDataAppender)this.getDataModel().getDataAppender()).getAuxiliaryData();
        for (ArffInstance instance : auxiliaryData) {
            String userId = (String)instance.getValueByIndex(0);
            String continent = (String)instance.getValueByIndex(1);
            String country = (String)instance.getValueByIndex(2);
            String city = (String)instance.getValueByIndex(3);
            ArrayList<String> temp = new ArrayList<String>();
            temp.add(continent);
            temp.add(country);
            temp.add(city);
            hierarchy.put(userId, temp);
        }
        return hierarchy;
    }
}

