#!/usr/bin/env python
# -*- coding:UTF-8 -*-
'''

@Description: the user could directly input the dataframe format file and plot the polarity without re-do from the bam files

		1) the input file must be python DataFrame format, and has N columns (N represents the number of samples),
		index is the transcript id and each column represents the polarity of a specific sample.The input file was generated by Polarity_calculation.py
		2) the output file could be pdf/png/jpg format
'''

import numpy as np
import pandas as pd
import sys
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from optparse import OptionParser
from collections import defaultdict
from .__init__ import __version__

def create_parser():
	'''argument parser'''
	usage="usage: python %prog [options]" +"\n" + __doc__+"\n"
	parser=OptionParser(usage=usage,version=__version__)
	parser.add_option("-i","--input",action="store",type="string",dest="polarity",help="Input file in txt format.Generated by Polarity_calculation.py")
	parser.add_option("-o","--output",action="store",type="string",dest="output_prefix",help="Prefix of output files.[required]")
	parser.add_option("-f","--format",action="store",type="string",dest="output_format",default='pdf',help="Output file format,'pdf','png' or 'jpg'. default=%default")
	parser.add_option('-g','--group',action="store",type="string",dest="group_name",help="Group name of each group separated by comma. e.g. 'si-control,si-eIF3e'")
	parser.add_option('-r','--replicate',action="store",type="string",dest="replicate_name",help="Replicate name of each group separated by comma. e.g. 'si_3e_1_80S,si_3e_2_80S__si_cttl_1_80S,si_ctrl_2_80S'")
	parser.add_option("-y","--ymax",action="store",type="float",dest="ymax",default=5,help="The max of ylim. default=%default")
	parser.add_option("--mode",action="store",type="string",dest="mode",default='all',help="plot all samples or just mean samples [all or mean].If choose 'all',output all samples as well as mean samples, else just mean samples.default=%default")
	return parser

def Draw_polarity_for_all_samples(data,samples,inOutPrefix,inOutFomat,ymax,text_font={"size":20,"family":"Arial","weight":"bold"},legend_font={"size":20,"family":"Arial","weight":"bold"}):
	"""plot polarity scores"""
	plt.rc('font',weight='bold')
	# colors="bgrcmykwbgrcmykw"
	if len(samples) <=8:
		colors=["b","orangered","green","c","m","y","k","w"]
	else:
		colors=sns.color_palette('husl',len(samples))
	fig=plt.figure(figsize=(5,4))
	ax=fig.add_subplot(111)
	## try to use for loop to re-write the plot function
	for i in np.arange(len(samples)):
		lst=data.iloc[:,i].values
		lst=lst[~np.isnan(lst)]
		sns.distplot(lst,hist=False,rug=False,label=samples[i],color=colors[i])
	ax.set_ylim(-0.1,ymax)
	ax.set_xlabel("Polarity score",fontdict=text_font)
	ax.set_ylabel("Relative gene numbers",fontdict=text_font)
	ax.spines["top"].set_visible(False)
	ax.spines["right"].set_visible(False)
	ax.spines["bottom"].set_linewidth(2)
	ax.spines["left"].set_linewidth(2)
	ax.tick_params(which="both",width=2,length=2)
	plt.legend(loc="best",prop=legend_font)
	plt.tight_layout()
	plt.savefig(inOutPrefix+"_polarity."+inOutFomat,format=inOutFomat)
	plt.close()

def calculate_mean_polarity(data,groups,replicates,output_prefix):
	'''calculate the mean polarity of each group based on different replicates'''
	transcriptList=data.index.values
	labels_dict={}
	data_dict={}
	data_mean_dict=defaultdict(dict)
	for g,r in zip(groups,replicates):
		labels_dict[g]=r.strip().split(',')
	## separate different groups
	for g in groups:
		# data_dict[g]=data.loc[:,labels_dict[g]]
		data_dict[g]=data.reindex(columns=labels_dict[g])

	## calculate the mean polarity between different replicates
	for g in groups:
		for trans in transcriptList:
			index=np.where(~np.isnan(data_dict[g].loc[trans,:]))[0]
			if len(index) == len(labels_dict[g]):
				polarity=np.mean(data_dict[g].loc[trans,data_dict[g].columns[index]])
				data_mean_dict[g][trans]=polarity
			elif len(index) < len(labels_dict[g]) and len(index) > 1:
				polarity=np.mean(data_dict[g].loc[trans,data_dict[g].columns[index]])
				data_mean_dict[g][trans]=polarity
			elif len(index) == 1:
				polarity=np.mean(data_dict[g].loc[trans,data_dict[g].columns[index]])
				data_mean_dict[g][trans]=polarity
			elif len(index) == 0:
				continue
			else:
				raise KeyError("Key error, please check your data input!")

	## transform the dict to a python dataframe
	for g in groups:
		data_mean_dict[g]=pd.DataFrame(data_mean_dict[g],index=[g]).T

	## concatenate different data frame
	data_mean=pd.concat([v for v in data_mean_dict.values()],axis=1,sort=True)
	## write the mean density file
	data_mean.to_csv(output_prefix+"_mean_polarity_dataframe.txt",sep="\t")
	return data_mean


def main():
	'''main function'''
	parser=create_parser()
	(options,args)=parser.parse_args()
	(input,output_prefix,output_format,ymax,groups,replicates,mode)=(options.polarity,options.output_prefix,options.output_format,options.ymax,options.group_name.strip().split(','),options.replicate_name.strip().split('__'),options.mode)
	data=pd.read_csv(input,sep="\t",index_col=0)
	samples=data.columns
	## calculate the mean polarity
	data_mean=calculate_mean_polarity(data,groups,replicates,output_prefix)
	samples_mean=data_mean.columns
	text_font={"size":20,"family":"Arial","weight":"bold"}
	legend_font={"size":10,"family":"Arial","weight":"bold"}
	print("Start plot the polarity...",file=sys.stderr)
	if mode in ['all','All','a','A']:
		Draw_polarity_for_all_samples(data,samples,output_prefix,output_format,ymax,text_font=text_font,legend_font=legend_font)
		Draw_polarity_for_all_samples(data_mean,samples_mean,output_prefix+"_mean",output_format,ymax,text_font=text_font,legend_font=legend_font)
	elif mode in ['mean','Mean','m','M']:
		Draw_polarity_for_all_samples(data_mean,samples_mean,output_prefix+"_mean",output_format,ymax,text_font=text_font,legend_font=legend_font)
	else:
		raise IOError("Please enter a correct --mode parameter [all or mean]")
	print("finished plot the polarity!",file=sys.stderr)

if __name__ == "__main__":
    main()