#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Project      : AppZoo.
# @File         : nh_bert
# @Time         : 2020/11/19 3:34 下午
# @Author       : yuanjie
# @Email        : yuanjie@xiaomi.com
# @Software     : PyCharm
# @Description  : tf2 bert4keras


from meutils.pipe import *

os.environ['TF_KERAS'] = '1'

from meutils.zk_utils import get_zk_config
from meutils.bert_utils.bert4keras_utils import text2seq
from meutils.datamodels.ArticleInfo import ArticleInfo

from bert4keras.backend import keras, search_layer, K
from bert4keras.tokenizers import Tokenizer
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open
from tensorflow.keras.layers import Lambda, Dense
from tensorflow.keras.utils import to_categorical

# cfg
cfg = get_zk_config('/mipush/nh_model')
vocab_url = cfg['vocab_url']
nh_bert_model_url_strict = cfg['nh_bert_model_url']['strict']
nh_bert_model_url_nostrict = cfg['nh_bert_model_url']['nostrict']  # loose
nh_lgb_model_url_strict = cfg['nh_lgb_model_url']['strict']
nh_lgb_model_url_nostrict = cfg['nh_lgb_model_url']['nostrict']  # loose

if not Path('vocab.txt').exists():
    download(vocab_url, 'vocab.txt')
tokenizer = Tokenizer('vocab.txt', do_lower_case=True)

download(nh_bert_model_url_strict, 'nh_bert_strict')
download(nh_bert_model_url_nostrict, 'nh_bert_nostrict')
download(nh_lgb_model_url_strict, 'nh_lgb_strict')
download(nh_lgb_model_url_nostrict, 'nh_lgb_nostrict')

nh_bert_strict = keras.models.load_model('nh_bert_strict', compile=False)
nh_bert_nostrict = keras.models.load_model('nh_bert_nostrict', compile=False)

nh_lgb_strict = joblib.load('nh_lgb_strict')
nh_lgb_nostrict = joblib.load('nh_lgb_nostrict')

logger.info("初始化KerasModel")
logger.info(nh_bert_strict.predict(text2seq("文本")))
logger.info(nh_bert_nostrict.predict(text2seq("文本")))


# 打分融合
def merge_score(X, text="", mode_type='strict'):
    if mode_type == 'strict':
        pred1 = nh_lgb_strict.predict_proba(X)[:, 1]
        pred2 = nh_bert_strict.predict(text2seq(text))[:, 1]
        pred = pred1 * 0.8 + pred2 * 0.2

        return pred.tolist()

    else:
        pred1 = nh_lgb_nostrict.predict_proba(X)[:, 1]
        pred2 = nh_bert_nostrict.predict(text2seq(text))[:, 1]
        pred = pred1 * 0.8 + pred2 * 0.2

        return pred.tolist()


# Api
def get_feats(ac):
    logger.info(ac)

    _ = ac.pop('category')
    articleInfo = ArticleInfo(**ac)
    d = articleInfo.dict()
    _ = d.pop('id'), d.pop('title'), d.pop('nCategory1'), d.pop('nSubCategory1')

    dt_feat = d.pop('createTime') + d.pop('publishTime')
    r = list(d.values()) + dt_feat
    return [r]


def predict_strict(**ac):
    X = get_feats(ac)
    text = ac.get("title", "请输入一个文本")

    return merge_score(X, text, 'strict')


def predict_nostrict(**ac):
    X = get_feats(ac)
    text = ac.get("title", "请输入一个文本")

    return merge_score(X, text, 'nostrict')


if __name__ == '__main__':
    from appzoo import App

    app = App()
    app.add_route('/nh_bert/strict', predict_strict, method="POST")
    app.add_route('/nh_bert/nostrict', predict_nostrict, method="POST")

    app.run(port=8000, access_log=False)
