#!/usr/bin/env python3
import sys
import textwrap
import argparse
import logging
from privex.helpers import ErrHelpParser, empty
from privex.pyrewall import conf, VERSION
from privex.pyrewall.conf import FILE_SUFFIX, CONF_DIRS, SEARCH_DIRS
from privex.pyrewall.core import find_file
from privex.pyrewall.PyreParser import PyreParser
from privex.pyrewall.repl import repl_main
from typing import Union, Tuple, Dict, List
from io import TextIOWrapper
from datetime import datetime

log = logging.getLogger('privex.pyrewall.repl')


CMD_DESC = {
    'parse': f'Parse a {FILE_SUFFIX} file and output rules compatible with iptables-restore',
    'reload': f'Reload {FILE_SUFFIX}, .v4 and .v6 files from the first valid folder in CONF_DIRS'
}

CONF_DIR_LIST = "\n".join("   - " +c for c in CONF_DIRS)
SEARCH_DIR_LIST = "\n".join("   - " +c for c in SEARCH_DIRS)

HELP_TEXT = textwrap.dedent(f'''\

PyreWall Version v{VERSION}
(C) 2019 Privex Inc. ( https://wwww.privex.io )
Official Repo: https://github.com/Privex/pyrewall


Sub-commands:

    parse  (-i 4|6) [filename]      - {CMD_DESC['parse']}
    reload                          - {CMD_DESC['reload']}

CONF_DIRS: 
{CONF_DIR_LIST}

SEARCH_DIRS: 
{SEARCH_DIR_LIST}

''')

parser = ErrHelpParser(
    description='PyreWall - Python firewall management using iptables',
    formatter_class=argparse.RawDescriptionHelpFormatter,
    epilog=HELP_TEXT
)


def parse_stdin(ipver='both'):
    lines = []
    for l in sys.stdin:
        lines.append(l.strip())
    p = PyreParser()
    ip4, ip6 = p.parse_lines(lines=lines)
    print_rules(ip4=ip4, ip6=ip6, ipver=ipver)

def err(*msgs: str, file=sys.stderr, **kwargs):
    print(*msgs, **kwargs, file=file)

class RuleOutput:
    VER_TUPLE = Tuple[List[str], List[str]]

    def __init__(self, opt: argparse.Namespace):
        super().__init__()
        self.ip_ver = opt.ipver
        self.input_file = opt.file
        self.output_file = opt.output
        self.output_file4 = opt.output4
        self.output_file6 = opt.output6

        self.input_stream = None
        self.output_stream = None
        self.output_stream4 = None
        self.output_stream6 = None

        self.rules_v4 = []
        self.rules_v6 = []

    @property
    def using_v4(self):
        return self.ip_ver in ['4', 'v4', 'ipv4', 'both']
    
    @property
    def using_v6(self):
        return self.ip_ver in ['6', 'v6', 'ipv6', 'both'] 

    @staticmethod
    def _get_stream(direction: str, dest: str, overwrite=False) -> TextIOWrapper:
        modes = 'r'

        if direction == 'out':
            if dest == '-': 
                return sys.stdout
            
            modes = 'w' if overwrite else 'x'
        elif direction == 'in':
            if dest == '-': 
                return sys.stdin
        else:
            raise AttributeError('direction must be "in" or "out".')

        return open(dest, modes)
    
    def parse_stream(self, stream: TextIOWrapper = None) -> VER_TUPLE:
        stream = self.input_stream if stream is None else stream
        stream = sys.stdin if stream is None else stream

        lines = []
        for l in stream.readlines():
            lines.append(l.strip())
        return PyreParser().parse_lines(lines=lines)
        # print_rules(ip4=ip4, ip6=ip6, ipver=ipver)
    
    def parse_file(self, file=None) -> VER_TUPLE:
        f = self.input_file if file is None else file
        try:
            path = find_file(f, SEARCH_DIRS, extensions=conf.SEARCH_EXTENSIONS)
        except FileNotFoundError:
            err(f'ERROR: The file "{f}" could not be found in any of your search directories.')
            return sys.exit(1)
        err(f'Parsing file: {path}')
        p = PyreParser()
        return p.parse_file(path=path)

    def parse(self, file=None, output=None, overwrite=False):
        f = self.input_file if file is None else file
        custom_out = output is not None
        output = self.output_file if output is None else output

        self.output_stream = self._get_stream(direction='out', dest=output, overwrite=overwrite)
        self.output_stream4, self.output_stream6 = self.output_stream, self.output_stream

        if not custom_out:
            if self.output_file4 is not None and self.output_file4 != self.output_file:
                self.output_stream4 = self._get_stream(direction='out', dest=self.output_file4, overwrite=overwrite)
            
            if self.output_file6 is not None and self.output_file6 != self.output_file:
                self.output_stream6 = self._get_stream(direction='out', dest=self.output_file6, overwrite=overwrite)

        if f == '-' or (empty(f) and not sys.stdin.isatty()):
            self.input_stream = sys.stdin
            ip4, ip6 = self.parse_stream(stream=self.input_stream)
        elif empty(f):
            return parser.error('Error! The following arguments are required: file')
        else:
            ip4, ip6 = self.parse_file(file=f)
        
        self.rules_v4, self.rules_v6 = ip4, ip6

        now = datetime.utcnow().replace(microsecond=0)
        start_line = f'### Generated by PyreWall from file: "{f}" at date/time: {now.isoformat(" ")} UTC-0'

        if self.using_v4:
            w = lambda r: self.output_rule(r, dest=self.output_stream4)
            w(start_line)
            w('# --- IPv4 Rules --- #')

            for line in ip4:
                w(line)

            w('# --- End IPv4 Rules --- #')
        
        if self.output_file4 == self.output_file6:
            self.output_rule("\n#############################\n", dest=self.output_stream)

        if self.using_v6:
            w = lambda r: self.output_rule(r, dest=self.output_stream6)
            w(start_line)
            w('# --- IPv6 Rules --- #')

            for line in ip6:
                w(line)

            w('# --- End IPv6 Rules --- #')
        

    @staticmethod
    def output_rule(rule: str, dest: TextIOWrapper = sys.stdout):
        if dest == '-':
            return print(rule)
        
        dest.write(rule + "\n")
    
    def print_rules(ip4: list = None, ip6: list = None, ipver='both'):
        pass

    def _cleanup(self):
        cls_name = self.__class__.__name__
        if self.input_stream is not None:
            try:
                if self.input_stream != sys.stdout:
                    self.input_stream.close()
            except Exception:
                log.exception(f"Error while closing {cls_name}.input_stream ...")
            self.input_stream = None
        
        if self.output_stream is not None:
            try:
                if self.output_stream != sys.stdout:
                    self.output_stream.close()
            except Exception:
                log.exception(f"Error while closing {cls_name}.output_stream ...")
            self.output_stream = None
        
        if self.output_stream4 is not None:
            try:
                if self.output_stream4 != sys.stdout:
                    self.output_stream4.close()
            except Exception:
                log.exception(f"Error while closing {cls_name}.output_stream4 ...")
            self.output_stream4 = None
        
        if self.output_stream6 is not None:
            try:
                if self.output_stream6 != sys.stdout:
                    self.output_stream6.close()
            except Exception:
                log.exception(f"Error while closing {cls_name}.output_stream6 ...")
            self.output_stream6 = None
            

    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self._cleanup()
    
    def __del__(self):
        self._cleanup()
        del self.output_stream
        del self.output_stream4
        del self.output_stream6
        del self.input_stream


def print_rules(ip4: list = None, ip6: list = None, ipver='both'):
    ip4, ip6 = [] if not ip4 else ip4, [] if not ip6 else ip6

    if ipver.lower() in ['4', 'v4', 'ipv4', 'both'] and len(ip4) > 0:
        print('# --- IPv4 Rules --- #')
        for l in ip4:
            print(l)
        print('# --- End IPv4 Rules --- #')
    print()
    if ipver.lower() in ['6', 'v6', 'ipv6', 'both'] and len(ip6) > 0:
        print('# --- IPv6 Rules --- #')
        for l in ip6:
            print(l)
        print('# --- End IPv6 Rules --- #')


def ap_parse(opt):
    k = RuleOutput(opt)
    k.parse()

    # if empty(f):
    #     return parser.error('error: the following arguments are required: file')
    # f = opt.file
    # if f == '-' or (empty(f) and not sys.stdin.isatty()):
    #     return parse_stdin(opt.ipver)
    # if empty(f):
    #     return parser.error('error: the following arguments are required: file')
    # try:
    #     path = find_file(f, SEARCH_DIRS, extensions=conf.SEARCH_EXTENSIONS)
    # except FileNotFoundError:
    #     print(f'ERROR: The file "{f}" could not be found in any of your search directories.', file=sys.stderr)
    #     return sys.exit(1)
    # print(f'Parsing file: {path}', file=sys.stderr)
    # p = PyreParser()
    # ip4, ip6 = p.parse_file(path=path)
    # print_rules(ip4=ip4, ip6=ip6, ipver=opt.ipver)


def ap_reload(opt):
    f = opt.files
    print(f'reloading files: {f}')


def ap_repl(opt):
    repl_main(files=opt.files)


sp = parser.add_subparsers()

parse_sp = sp.add_parser('parse', description=CMD_DESC['parse'])
parse_sp.add_argument('file', default=None, help='Pyrewall file to parse', nargs='?')
parse_sp.add_argument(
    '-i', type=str, default='both', dest='ipver',
    help='4 = Output only IPv4 config, 6 = Output only IPv6 config, both = Output both configurations (default)'
)

parse_sp.add_argument(
    '--output', '-o', type=str, default='-', dest='output',
    help='Output the IPTables rules lines to this file (default "-" (stdout))'
)

parse_sp.add_argument(
    '--output6', '-o6', type=str, default=None, dest='output6',
    help='Output only the IPv6 IPTables rules lines to this file (defaults to value of shared "--output")'
)

parse_sp.add_argument(
    '--output4', '-o4', type=str, default=None, dest='output4',
    help='Output only the IPv4 IPTables rules lines to this file (defaults to value of shared "--output")'
)


parse_sp.set_defaults(func=ap_parse)

reload_sp = sp.add_parser('reload', description=CMD_DESC['reload'])
reload_sp.add_argument('files', help='Pyrewall file(s) to reload', nargs='+')

reload_sp.set_defaults(func=ap_parse)

parse_repl = sp.add_parser('repl', description=CMD_DESC['parse'])
parse_repl.add_argument('files', help='Optionally read these Pyrewall file(s) into the REPL in order', nargs='*')
parse_repl.set_defaults(func=ap_repl)

args = parser.parse_args()

# Resolves the error "'Namespace' object has no attribute 'func'
# Taken from https://stackoverflow.com/a/54161510/2648583
try:
    func = args.func
    func(args)
except AttributeError:
    if not sys.stdin.isatty():
        parse_stdin()
        sys.exit(0)
    parser.error('Too few arguments')
    sys.exit(1)

