#!usr/bin/python

import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter

parser = argparse.ArgumentParser(description='Plot data points, and create graphs of linear regressions')
parser.add_argument('input'     , type=str, help='File path to some input.csv')
parser.add_argument('--cols'    , type=list, default=[], nargs=2, help='Specify the names the columns to plot')
parser.add_argument('--labels'  , type=list, default=[], nargs=2, help='Specify labels for the x-axis, and the y-axis')
parser.add_argument('--title'   , type=str, help='Specify labels for the x-axis, and the y-axis')
parser.add_argument('--saveas'  , type=str, help='Exports the figure with the given filename')
args = parser.parse_args()

# Process user input
input_csv   = args.input
cols        = args.cols
labels      = args.labels
title       = args.title
saveas      = args.saveas

labels = labels if labels else ['X-Axis', 'Y-Axis']
title = title if title else 'Linear Regression'

def signif(x, p=3):
    x = np.asarray(x)
    x_positive = np.where(np.isfinite(x) & (x != 0), np.abs(x), 10**(p-1))
    mags = 10 ** (p - 1 - np.floor(np.log10(x_positive)))
    return np.round(x * mags) / mags


df = pd.read_csv(input_csv)
x = df[cols[0]].values.flatten()
y = df[cols[1]].values.flatten()

# Create Linear Regression
coef = np.polyfit(x, y, 1)
m, b = coef
poly1d_fn = np.poly1d(coef)
plot_label = f'y={signif(m)}x + {signif(b)}'

# Set up figure
fig, ax = plt.subplots()                    # Create a figure and an axes.
ax.plot(x, poly1d_fn(x), label=plot_label)  # Plot best fit line
ax.plot(x, y, 'o', label='Original Data')   # Plot the points

# Format Graph
ax.set_xlabel(labels[0])    # Add an x-label to the axes.
ax.set_ylabel(labels[1])    # Add a y-label to the axes.
ax.set_title(title)         # Add a title to the axes.
ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f')) # Set x-axis to 3 sigfig
ax.legend()                 # Add a legend.

if saveas is None:
    plt.show()
else:
    plt.savefig(saveas)
