#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tunnel.py
#
# Copyright 2021 Vincent Schouten
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to
#  deal in the Software without restriction, including without limitation the
#  rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
#  sell copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
#  all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
#  FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
#  DEALINGS IN THE SOFTWARE.
#

"""
Main code for tunnel.

.. _Google Python Style Guide:
   http://google.github.io/styleguide/pyguide.html

NOTE: The Tunnel classes are responsible to purge the stream (ie. index in stream is at COMMAND_PROMPT)

"""

# from abc import ABC, abstractmethod
import threading
from time import sleep
import pexpect
from .logging import LoggerMixin

__author__ = '''Vincent Schouten <powermole@protonmail.com>'''
__docformat__ = '''google'''
__date__ = '''10-05-2019'''
__copyright__ = '''Copyright 2021, Vincent Schouten'''
__credits__ = ["Vincent Schouten"]
__license__ = '''MIT'''
__maintainer__ = '''Vincent Schouten'''
__email__ = '''<powermole@protonmail.com>'''
__status__ = '''Development'''  # "Prototype", "Development", "Production".

# Constant for Pexpect. This prompt is default for Fedora and CentOS.
COMMAND_PROMPT = '[#$] '


class Tunnel(LoggerMixin):
    """Establishes a connection to the target destination host via one or more intermediaries.

    Be aware, the child's buffer needs to be purged periodically. This can be done by invoking
    periodically_purge_buffer(). As verbose mode is enabled for SSH (the child process), it
    will slowly fill up the buffer, so this has to be taken care of. But don't invoke this
    method before having start()'ed BootstrapAgent.
    """

    def __init__(self, path_ssh_cfg, mode, all_host_addr, group_ports,  # pylint: disable=too-many-arguments
                 forward_connections=None):
        """Initializes the Tunnel object.

        Args:
            path_ssh_cfg (str): Path to the SSH config file that is generated by write_ssh_config_file().
            mode (str): Contains any of these values: TOR|FOR|PLAIN.
            all_host_addr (list): IP addresses of all hosts (eg. gateway/intermediary and destination hosts).
            group_ports (dict): Port names with port numbers.
            forward_connections (str): Formatted as "-Lport:host:hostport".

        """
        super().__init__()
        self.path_ssh_cfg = path_ssh_cfg
        self.mode = mode
        self.all_host_addr = all_host_addr
        self.last_host_addr = all_host_addr[-1]
        self.group_ports = group_ports
        self.forward_connections = forward_connections
        self.inbound_address_socks = 'localhost'
        self.child = None
        self.thread = None
        self.terminate = False
        self.authenticated_hosts = []

    def __str__(self):
        return 'Tunnel'

    def _generate_ssh_runtime_param(self):
        var_param = None
        # block below composes _mode specific_ forwarding strings
        if self.mode == 'FOR':
            var_param = f'{self.forward_connections} '
        elif self.mode == 'TOR':
            var_param = f'-L{self.group_ports["local_port_proxy"]}:{self.inbound_address_socks}:' \
                        f'{self.group_ports["remote_port_proxy"]} '
        elif self.mode == 'PLAIN':
            var_param = ''  # no *additional* ports will be forwarded

        if len(self.all_host_addr) == 2:
            # the result will be in this format: 'host1 host2' --> '10.10.1.72 10.10.2.92'
            order_of_hosts = f'{self.all_host_addr[0]} {self.all_host_addr[1]}'
        else:
            # the result will be something in this format: 'host1,host2 host3' --> '10.10.1.72,10.10.2.92 10.10.3.52'
            order_of_hosts = ''
            for i, host in enumerate(self.all_host_addr):
                if i == 0:
                    order_of_hosts += f'{host}'
                elif i < len(self.all_host_addr) - 1:
                    order_of_hosts += f',{host}'
                else:
                    order_of_hosts += f' {host}'  # is this branch necessary?

        # block below composes _base_ forwarding strings
        runtime_param = f'ssh -v -F {self.path_ssh_cfg} ' \
                        f'-L{self.group_ports["local_port_agent"]}:localhost:' \
                        f'{self.group_ports["remote_port_agent"]} ' \
                        f'-L{self.group_ports["local_port_heartbeat"]}:localhost:' \
                        f'{self.group_ports["remote_port_heartbeat"]} ' \
                        f'-L{self.group_ports["local_port_command"]}:localhost:' \
                        f'{self.group_ports["remote_port_command"]} ' \
                        f'-L{self.group_ports["local_port_transfer"]}:localhost:' \
                        f'{self.group_ports["remote_port_transfer"]} '
        runtime_param += var_param
        runtime_param += f'-J {order_of_hosts}'

        self._logger.debug(runtime_param)
        return runtime_param

    def start(self, debug=None):
        """Establishes an SSH tunnel.

        It determines along the way if the authentication process is successful.

        In addition, this method and mines for 'Authenticated' keywords, so
        we can keep track which hosts have been connected through.

        SSH is here a 'child application'.

        Args:
            debug(basestring): if True, TIMEOUT will not be raised and may block indefinitely. Use only for debugging
                                purposes to capture the output of the child, which is essentially, hidden 'under the
                                hood', and write it to a file.

        """
        result = False
        try:
            arguments = {"command": self._generate_ssh_runtime_param(), "env": {"TERM": "dumb"}, "encoding": 'utf-8'}
            if debug:
                arguments.update({"timeout": 10})
            self.child = pexpect.spawn(**arguments)
            # setecho() doesn't seem to have effect.
            #    doc says: Not supported on platforms where isatty() returns False.
            #    perhaps related to the recursive shells (SSH spawns a new shell in the current shell)
            self.child.setecho(False)
            self._logger.debug('going through the stream to match patterns: %s', self.all_host_addr)
            for hostname in self.all_host_addr:
                # according to the documentation, "If you wish to read up to the end of the child's output
                #    without generating an EOF exception then use the expect(pexpect.EOF) method."
                #    but apparently this doesn't work in a shell within a shell (SSH spawns a new shell)
                index = self.child.expect(
                    [f'Authenticated to {hostname}', 'Last failed login:', 'Last login:', 'socket error',
                     'not accessible', 'fingerprint', 'open failed: connect failed:', pexpect.TIMEOUT])
                result = False  # reset var as this var could be set True in a previous iteration, we want fresh start
                if index == 0:
                    self._logger.info('authenticated to %s', hostname)  # logger level is "info" to inform user
                    self.authenticated_hosts.append(hostname)
                    result = True
                elif index == 1:
                    self._logger.debug('there were failed login attempts')
                    result = True
                elif index == 2:
                    self._logger.debug('there were no failed login attempts')
                    result = True
                elif index == 3:
                    self._logger.error('socket error. probable cause: SSH service on proxy or target machine disabled')
                    break
                elif index == 4:
                    self._logger.error('the identity file is not accessible')
                    break
                elif index == 5:
                    self._logger.warning('warning: hostname automatically added to list of known hosts')
                    self.child.sendline('yes')  # security issue
                elif index == 6:
                    self._logger.error('SSH could not connect to %s', hostname)
                    break
                elif index == 7:
                    self._logger.error('TIMEOUT exception was thrown. SSH could probably not connect to %s', hostname)
                    break
                else:
                    self._logger.error('unknown state reached')
            self.child.expect(COMMAND_PROMPT)
        except pexpect.exceptions.ExceptionPexpect:
            self._logger.error('EOF is read; SSH has exited abnormally')
            self.child.terminate()
        if not result:
            self._logger.error('debug information: %s', str(self.child))
            self.child.terminate()
        return result

    def stop(self):
        """Closes the SSH connection essentially by terminating the program SSH."""
        self.terminate = True
        if self.child.isalive():
            self._logger.debug('SSH is alive, terminating')
            self.child.terminate()
        self._logger.debug('SSH terminated')
        return True

    def debug(self):
        """Captures the output of the child (warning: BLOCKING)."""
        with open('~/mylog.txt', 'a', encoding='utf-8') as file:
            fout = file
        self.child.logfile = fout
        try:
            self.child.readlines()
        except pexpect.ExceptionPexpect:
            pass

    def periodically_purge_buffer(self):
        """Purges the child's (SSH) output buffer due to buffer limitations."""
        self.thread = threading.Thread(target=self._run_purger)
        self.thread.start()

    def _run_purger(self):
        while not self.terminate:
            try:
                self.child.expect([pexpect.TIMEOUT], timeout=0.2)
                sleep(2)
            except pexpect.exceptions.EOF:
                pass
