#!/usr/bin/env python
# -*- coding:UTF-8 -*-
'''
@Description: This script is used for plot metagene results for the whole transcripts
			usage: python PlotMetageneAnalysisForTheWholeRegions.py -i test_scaled_dataframe.txt -o test -g si-Ctrl,si-eIF3e -r si_ctrl_1_80S,si_ctrl_2_80S,si_ctrl_3_80S__si_3e_1_80S,si_3e_2_80S,si_3e_3_80S -f pdf -b 15,90,60 --mode all
			input:
			1) metagene results generated by MetageneAnalysisForTheWholeRegions.py
			output (selective):
			1) plots for all samples
			2) plots for mean samples
'''


import numpy as np
import pandas as pd
import sys
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from optparse import OptionParser
from functools import reduce
from itertools import chain
from collections import defaultdict
from .__init__ import __version__



def create_parser_for_the_whole_metagene_plot():
	'''argument parser'''
	usage="usage: python %prog [options]" +"\n" + __doc__+"\n"
	parser=OptionParser(usage=usage,version=__version__)
	parser.add_option("-i","--input",action="store",type="string",dest="density_file",help="Input file in txt format.And the files have N columns, meaning N samples and have total bins rows. [FiveUTR+CDS+ThreeUTR]")
	parser.add_option("-o","--output",action="store",type="string",dest="output_prefix",help="Prefix of output files.[required]")
	parser.add_option('-g','--group',action="store",type="string",dest="group_name",help="Group name of each group separated by comma. e.g. 'si-control,si-eIF3e'")
	parser.add_option('-r','--replicate',action="store",type="string",dest="replicate_name",help="Replicate name of each group separated by comma. e.g. 'si_3e_1_80S,si_3e_2_80S__si_cttl_1_80S,si_ctrl_2_80S'")
	parser.add_option("-f","--format",action="store",type="string",dest="output_format",default='pdf',help="Output file format,'pdf','png' or 'jpg'. default=%default")
	parser.add_option("-b","--bins",action='store',type='string',dest="bins",default='15,90,60',help="Bins to scale the transcript length.e.g.'15,90,60'. bins must be separated by comma, namely '5UTRBins,CDSBins,3UTRBins'. default=%default")
	parser.add_option("--ymax",action="store",type="float",dest="ymax",default=None,help="The max of ylim. default=%default")
	parser.add_option("--ymin",action="store",type="float",dest="ymin",default=None,help="The min of ylim. default=%default")
	parser.add_option("--mode",action="store",type="string",dest="mode",default='all',help="plot all samples or just mean samples [all or mean].If choose 'all',output all samples as well as mean samples, else just mean samples.default=%default")
	parser.add_option("--xlabel-loc",action="store",type="float",dest="xlabelLoc",default=None,help="Location of xlabel. Used to control the yaxis location of xlabel. default=%default")

	return parser

def lengths_offsets_split(value):
		''' Split the given comma separated values to multiple integer values'''
		values=[]
		for item in value.split(','):
				item=int(item)
				values.append(item)
		return values

def plot_read_coverage_distribution(data,samples,bins,inOutPrefix,inOutFomat,ymin,ymax,xlabelLoc,text_font={"size":20,"family":"Arial","weight":"bold"},legend_font={"size":20,"family":"Arial","weight":"bold"}):
	'''plot the density dsitribution'''
	plt.rc('font',weight='bold')
	fig=plt.figure(figsize=(10,6))
	ax=fig.add_subplot(111)
	bins_vector=lengths_offsets_split(bins)
	winLen=np.sum(bins_vector)
	if len(samples) <=8:
		colors=["b","orangered","green","c","m","y","k","w"]
	else:
		colors=sns.color_palette('husl',len(samples))
	for i in np.arange(len(samples)):
		plt.plot(np.arange(0,winLen),data.loc[:,samples[i]],color=colors[i],label=samples[i],linewidth=2)
	ax.set_xticks([])
	ax.axvline(int(bins_vector[0]),color='gray',dashes=(2,3),clip_on=False,linewidth=2)
	ax.axvline(int(bins_vector[0])+int(bins_vector[1]),color='gray',dashes=(2,3),clip_on=False,linewidth=2)
	ax.set_ylabel("Relative footprint density (AU)",fontdict=text_font)
	ax.set_xlabel("Normalized transcript length",fontdict=text_font,labelpad=30)
	ax.spines["top"].set_visible(False)
	ax.spines["right"].set_visible(False)
	ax.spines["bottom"].set_linewidth(2)
	ax.spines["left"].set_linewidth(2)
	ax.tick_params(which="both",width=2)
	if xlabelLoc:
		ax.text(0,float(xlabelLoc),"5'UTR",fontdict=text_font)
		ax.text((int(bins_vector[0])+int(bins_vector[1])/3),float(xlabelLoc),"Coding region",fontdict=text_font)
		ax.text((int(bins_vector[0])+int(bins_vector[1])+int(bins_vector[2])/2),float(xlabelLoc),"3'UTR",fontdict=text_font)
	else:
		pass
	if not ymin and not ymax:
		pass
	elif not ymin and ymax:
		ax.set_ylim(0,ymax)
	elif ymin and not ymax:
		raise IOError("Please offer the ymax parameter as well!")
	elif ymin and ymax:
		ax.set_ylim(ymin,ymax)
	else:
		raise IOError("Please enter correct ymin and ymax parameters!")
	plt.legend(loc="best",prop=legend_font)
	# plt.legend(loc="best",prop={'size':10,'weight':'bold'})
	plt.tight_layout()
	plt.savefig(inOutPrefix+"_metaplot"+"."+inOutFomat,format=inOutFomat)
	plt.close()

def calculate_mean_data(data,samples,group_names,replicate_names,output_prefix):
	labels_dict={}
	for g,rep in zip(group_names,replicate_names):
		labels_dict[g]=rep.strip().split(',')
	if len(samples) < 1:
		raise IOError("There is no samples in the file, please check your input!")
	if len(samples) == 1:
		return data
	if len(samples) > 1:
		data_dict=defaultdict(list)
		data_mean_dict=defaultdict(list)
		for g in group_names:
			for r in labels_dict[g]:
				data_dict[g].append(data.loc[:,r])
		for g in group_names:
			if len(labels_dict[g]) <1:
				raise IOError("Please reset your -g -r parameters because nothing present here.")
			elif len(labels_dict[g]) ==1:
				data_mean_dict[g].append(list(reduce(zip,[data_dict[g][i].values for i in np.arange(len(data_dict[g]))])))
			elif len(labels_dict[g])==2:
				data_mean_dict[g].append(np.array([sum(list(chain(i))) for i in list(reduce(zip,[data_dict[g][i].values for i in np.arange(len(data_dict[g]))]))])/len(labels_dict[g]))
			else:
				data_mean_dict[g].append(np.array([reduce(sum,list(chain(i))) for i in list(reduce(zip,[data_dict[g][i].values for i in np.arange(len(data_dict[g]))]))])/len(labels_dict[g]))

		for k,v in data_mean_dict.items():
			data_mean_dict[k]=pd.DataFrame(v,index=[k]).T

		data_mean=pd.concat([v for v in data_mean_dict.values()],axis=1)
		## write the mean density file
		data_mean.to_csv(output_prefix+"_mean_scaled_dataframe.txt",sep="\t",index=None)
	return data_mean


def main():
	parsed=create_parser_for_the_whole_metagene_plot()
	(options,args)=parsed.parse_args()
	(data,output_prefix,group_names,replicate_names,output_format,ymin,ymax,bins,mode,xlabelLoc)=(options.density_file,
	options.output_prefix,options.group_name.strip().split(','),options.replicate_name.strip().split('__'),
	options.output_format,options.ymin,options.ymax,options.bins.strip(),options.mode,options.xlabelLoc)
	if len(bins.strip().split(',')) != 3:
		raise IOError("Please check your -b parameters!")
	print("your input file is: "+str(data),file=sys.stderr)
	data=pd.read_csv(data,sep="\t")
	samples=np.unique(data.columns)
	text_font={"size":20,"family":"Arial","weight":"bold"}
	legend_font={"size":20,"family":"Arial","weight":"bold"}
	text_font_mean={"size":20,"family":"Arial","weight":"bold"}
	legend_font_mean={"size":20,"family":"Arial","weight":"bold"}
	## calculate the mean density
	data_mean=calculate_mean_data(data,samples,group_names,replicate_names,output_prefix)
	samples_new=np.unique(data_mean.columns)
	if mode in ['all','All','a','A']:
		plot_read_coverage_distribution(data,samples,bins,output_prefix,output_format,ymin,ymax,xlabelLoc,text_font=text_font,legend_font=legend_font)
		plot_read_coverage_distribution(data_mean,samples_new,bins,output_prefix+"_mean",output_format,ymin,ymax,xlabelLoc,text_font=text_font,legend_font=legend_font)
		## plot density
		print("finished plot the ribosome footprint density",file=sys.stderr)
	elif mode in ['mean','Mean','m','M']:
		plot_read_coverage_distribution(data_mean,samples_new,bins,output_prefix+"_mean",output_format,ymin,ymax,xlabelLoc,text_font=text_font_mean,legend_font=legend_font_mean)
		print("finished plot the ribosome footprint density",file=sys.stderr)
	else:
		plot_read_coverage_distribution(data,samples,bins,output_prefix,output_format,ymin,ymax,xlabelLoc,text_font=text_font,legend_font=legend_font)
		plot_read_coverage_distribution(data_mean,samples_new,bins,output_prefix+"_mean",output_format,ymin,ymax,xlabelLoc,text_font=text_font_mean,legend_font=legend_font_mean)
		print("finished plot the ribosome footprint density",file=sys.stderr)


if __name__=="__main__":
	main()
