# Copyright (c) 2018, Palo Alto Networks
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

# Authors: Nathan Embery

import json
import logging
import xml.etree.ElementTree as elementTree
from abc import ABC
from abc import abstractmethod
from base64 import urlsafe_b64encode
from typing import Tuple
from xml.etree.ElementTree import ParseError

import xmltodict
from jinja2 import BaseLoader
from jinja2 import Environment
from jinja2.exceptions import TemplateAssertionError
from jinja2.exceptions import UndefinedError
from jsonpath_ng import parse
from lxml import etree
from passlib.hash import md5_crypt

from skilletlib.exceptions import SkilletLoaderException
from skilletlib.exceptions import SkilletValidationException

logger = logging.getLogger(__name__)


class Snippet(ABC):
    """
    BaseSnippet implements a basic template object snippet
    """
    # set of required metadata, each snippet will define what attributes are required in the snippet definition
    # by default, we only require a 'name' attribute, but sub-classes will require more
    required_metadata = {'name'}

    # dict of optional metadata  and their default values. These values will be set on the snippet class but
    # will not throw an exeption is they are not present
    optional_metadata = dict()

    # set a default output type. this can be overridden for each SnippetType. This is used to determine the default
    # output handler to use for each snippet class. This can be set on a per snippet basis, but this allows a
    # short-cut on each
    output_type = 'xml'

    def __init__(self, metadata):

        # first validate all the required fields are present in the metadata (snippet definition)
        self.metadata = self.sanitize_metadata(metadata)
        # always have a default name, subclasses will set additional fields on the class
        self.name = self.metadata['name']
        # set up jinja environment and add any custom filters. Snippet sub-classes can override __add_filters
        # to append additional filters. See the PanosSnippet class for an example
        self.__init_env()

        # set all the required fields with their values from the snippet definition
        for k in self.required_metadata:
            setattr(self, k, self.metadata[k])

        # iterate the optional_metadata dict and set the default values
        # if they have not been set in the snippet metadata directly
        for k, v in self.optional_metadata.items():
            if v is None:
                continue
            if k in self.metadata:
                setattr(self, k, self.metadata[k])
            else:
                setattr(self, k, v)
                self.metadata[k] = v

        self.context = dict()

    def update_context(self, context: dict) -> None:
        """
        This will update the snippet context with the passed in dict.
        This gets called before render_metadata
        :param context:
        :return:
        """
        self.context.update(context)

    @abstractmethod
    def execute(self, context: dict) -> Tuple[dict, str]:
        """
        Execute this Snippet and return a tuple consisting on the updated context and a string representing
        success, failure, or running.

        Each snippet sub class must override this method!

        :param context: context to use for variable interpolation
        :return: Tuple containng updated context dictionary and string indicated success or failure
        """
        return dict(), 'success'

    def should_execute(self, context: dict) -> bool:
        """
        Evaluate 'when' conditionals and return a bool if this snippet should be executed
        :param context: jinja context containing previous outputs and user supplied variables
        :return: boolean
        """

        logger.debug(f'Checking snippet: {self.name}')

        if 'when' not in self.metadata:
            # always execute when no when conditional is present
            logger.debug(f'No conditional present, proceeding with skillet: {self.name}')
            return True

        results = self.execute_conditional(self.metadata['when'], context)
        logger.debug(f'  Conditional Evaluation results: {results} ')
        return results

    def execute_conditional(self, test: str, context: dict) -> bool:
        """
        Evaluate 'test' conditionals and return a bool
        :param test: string of the conditional to execute
        :param context: jinja context containing previous outputs and user supplied variables
        :return: boolean
        """
        try:
            test_str = '{{%- if {0} -%}} True {{%- else -%}} False {{%- endif -%}}'.format(test)
            test_template = self._env.from_string(test_str)
            results = test_template.render(context)
            if str(results).strip() == 'True':
                return True
            else:
                return False
        except UndefinedError as ude:
            logger.error(ude)
            # always return false on error condition
            return False
        except TypeError as te:
            logger.error(te)
            return False
        except TemplateAssertionError as tae:
            logger.error(tae)
            raise SkilletValidationException('Malformed Jinja expression!')
        except Exception as e:
            # catch-all - always return false on other error conditions
            logger.error(e)
            logger.error(type(e))
            return False

    def get_output(self) -> Tuple[str, str]:
        """
        get_output can be used when a snippet executes async and cannot or will not return output right away
        snippets that operate async must override this method
        :return:
        """

        return '', 'success'

    def get_default_output(self, results: str, status: str) -> dict:
        """
        each snippet type can override this method to provide it's own default output. This is used
        when there are no variables defined to be captured
        :param results: raw output from snippet execution
        :param status: status of the snippet.execute method
        :return: dict of default outputs
        """

        r = {
            self.name: {
                'results': status,
                'raw': results
            }
        }
        return r

    def capture_outputs(self, results: str, status: str) -> dict:
        """
        All snippet output or portions of snippet output can be captured and saved on the context as a new variable
        :param results: the raw output from the snippet execution
        :param status: status of the snippet.execute method
        :return: a dictionary containing all captured variables
        """

        # always capture the default output
        # captured_outputs = self.get_default_output(results, status)
        captured_outputs = dict()
        output_type = self.metadata.get('output_type', self.output_type)

        # check if this snippet type wants to handle it's own outputs
        if hasattr(self, f'handle_output_type_{output_type}'):
            func = getattr(self, f'handle_output_type_{output_type}')
            return func(results)

        # otherwise, check all the normal types here
        if 'outputs' not in self.metadata:
            return captured_outputs

        for output in self.metadata['outputs']:

            outputs = dict()

            if 'name' not in output:
                continue

            if 'capture_variable' in output:
                outputs[output['name']] = self.render(output['capture_variable'], self.context)
            else:
                # allow jinja syntax in capture_pattern, capture_value, capture_object etc
                output = self.__render_output_metadata(output, self.context)

                if output_type == 'xml':
                    outputs = self.__handle_xml_outputs(output, results)
                elif output_type == 'manual':
                    outputs = self.__handle_manual_outputs(output, results)
                elif output_type == 'text':
                    outputs = self.__handle_text_outputs(output, results)
                elif output_type == 'json':
                    outputs = self.__handle_json_outputs(output, results)
            # elif output_type == 'base64':
            #     outputs = self._handle_base64_outputs(results)

            # elif output_type == 'manual':
            #     outputs = self._handle_manual_outputs(results)
            # elif output_type == 'text':
            #     outputs = self.__handle_text_outputs(results)
            # # sub classes can handle their own output types
            # # see panos/__handle_validation for example
            # elif hasattr(self, f'handle_output_type_{output_type}'):
            #     func = getattr(self, f'handle_output_type_{output_type}')
            #     outputs = func(results)
            captured_outputs.update(outputs)
            self.context.update(outputs)

        return captured_outputs

    def __render_output_metadata(self, output: dict, context: dict) -> dict:
        keys = ('capture_value', 'capture_pattern', 'capture_object', 'capture_list')
        for k in keys:
            if k in output:
                output[k] = self.render(output[k], context)

        return output

    def __filter_outputs(self, output_definition: dict, output: (str, dict, list), local_context: dict) -> (list, None):
        """
        Filter OUT items that do not pass the test
        :param output_definition: the output definition from the skillet
        :param output: the captured object to test
        :param local_context: local context for the jinja expression based tests
        :return: a list of the items that passed the test, or all items if there is not test defined
        """
        if 'filter_items' not in output_definition:
            return output

        # grab the test string to evaluate
        test_str = output_definition['filter_items']

        # keep a new list of all the items that have matched the test
        filtered_items = list()

        if isinstance(output, list):
            for item in output:
                local_context['item'] = item
                results = self.execute_conditional(test_str, local_context)
                if results:
                    filtered_items.append(item)

            # if len(filtered_items) == 0:
            #     output = None
            # elif len(filtered_items) == 1:
            #     output = filtered_items[0]
            # else:
            #     output = filtered_items
            return filtered_items

        elif isinstance(output, str) or isinstance(output, dict):
            local_context['item'] = output
            results = self.execute_conditional(test_str, local_context)
            if results:
                filtered_items.append(output)

        return filtered_items

    def render(self, template_str: str, context: (dict, None)) -> str:
        if context is None:
            context = self.context
        t = self._env.from_string(template_str)
        return t.render(context)

    def sanitize_metadata(self, metadata: dict) -> dict:
        """
        Ensure the configured metadata is valid for this snippet type
        :param metadata: dict
        :return: validated metadata dict
        """
        name = metadata.get('name', '')
        if not self.required_metadata.issubset(metadata):
            for attr_name in metadata:
                if attr_name not in self.required_metadata:
                    raise SkilletLoaderException(f'Invalid snippet metadata configuration: attribute: {attr_name} '
                                                 f'is required for snippet: {name}')

        return metadata

    def render_metadata(self, context: dict) -> dict:
        """
        Each snippet sub class can override this method to perform jinja variable interpolation on various items
        in it's snippet definition. For example, the PanosSnippet will check the 'xpath' attribute and perform
        the required interpolation
        :param context: context from environment
        :return: metadata with jinja rendered variables
        """
        self.context.update(context)

        return self.metadata

    # define functions for custom jinja filters
    @staticmethod
    def __md5_hash(txt: str) -> str:
        """
        Returns the MD5 Hashed secret for use as a password hash in the PAN-OS configuration
        :param txt: text to be hashed
        :return: password hash of the string with salt and configuration information. Suitable to place in the phash field
        in the configurations
        """

        return md5_crypt.hash(txt)

    def __init_env(self) -> None:
        """
        init the jinja2 environment and add any required filters
        :return: Jinja2 environment object
        """
        self._env = Environment(loader=BaseLoader)
        self._env.filters["md5_hash"] = self.__md5_hash
        self.add_filters()

    def add_filters(self) -> None:
        """
        Each snippet sub-class can add additional filters. See the PanosSnippet for examples
        :return:
        """
        pass

    def __handle_text_outputs(self, output_definition: dict, results: str) -> dict:
        """
        Parse the results string as a text blob into a single variable.

        - name: system_info
          path: /api/?type=op&cmd=<show><system><info></info></system></show>&key={{ api_key }}
          output_type: text
          outputs:
            - name: system_info_as_xml

        :param results: results string from the action
        :return: dict of outputs, in this case a single entry
        """
        outputs = dict()
        output_name = output_definition.get('name', self.name)
        outputs[output_name] = results
        return outputs

    def __handle_xml_outputs(self, output_definition: dict, results: str) -> dict:
        """
        Parse the results string as an XML document
        Example .meta-cnc snippets section:
        snippets:

          - name: system_info
            path: /api/?type=op&cmd=<show><system><info></info></system></show>&key={{ api_key }}
            output_type: xml
            outputs:
              - name: hostname
                capture_value: result/system/hostname
              - name: uptime
                capture_value: result/system/uptime
              - name: sw_version
                capture_value: result/system/sw-version

        :param results: string as returned from some action, to be parsed as XML document
        :return: dict containing all outputs found from the capture pattern in each output
        """

        captured_output = dict()

        def unique_tag_list(elements: list) -> bool:
            tag_list = list()
            for el in elements:
                # some xpath queries can return a list of str
                if isinstance(el, str):
                    return False

                if el.tag not in tag_list:
                    tag_list.append(el.tag)

            if len(tag_list) == 1:
                # all tags in this list are the same
                return False
            else:
                # there are unique tags in this list
                return True

        try:
            xml_doc = etree.XML(results)

            # xml_doc = elementTree.fromstring(results)
            # allow jinja syntax in capture_pattern, capture_value, capture_object etc

            local_context = self.context.copy()
            output = self.__render_output_metadata(output_definition, local_context)

            var_name = output['name']
            if 'capture_pattern' in output or 'capture_value' in output:

                if 'capture_value' in output:
                    capture_pattern = output['capture_value']
                else:
                    capture_pattern = output['capture_pattern']

                # by default we will attempt to return the text of the found element
                return_type = 'text'
                entries = xml_doc.xpath(capture_pattern)
                logger.debug(f'found entries: {entries}')
                if len(entries) == 0:
                    captured_output[var_name] = ''
                elif len(entries) == 1:
                    entry = entries.pop()
                    if isinstance(entry, str):
                        captured_output[var_name] = entry
                    else:
                        if len(entry) == 0:
                            # this tag has no children, so try to grab the text
                            if return_type == 'text':
                                captured_output[var_name] = str(entry.text).strip()
                            else:
                                captured_output[var_name] = entry.tag
                        else:
                            # we have 1 Element returned, so the user has a fairly specific xpath
                            # however, this element has children itself, so we can't return a text value
                            # just return the tag name of this element only
                            captured_output[var_name] = entry.tag
                else:
                    # we have a list of elements returned from the users xpath query
                    capture_list = list()
                    # are there unique tags in this list? or is this a list of the same tag names?
                    if unique_tag_list(entries):
                        return_type = 'tag'
                    for entry in entries:
                        if isinstance(entry, str):
                            capture_list.append(entry)
                        else:
                            if len(entry) == 0:
                                if return_type == 'text':
                                    if entry.text is not None:
                                        capture_list.append(entry.text.strip())
                                    else:
                                        # If there is no text, then try to grab a sensible attribute
                                        # if you need more control than this, then you should first
                                        # capture_object to convert to a python object then use a jinja filter
                                        # to get what you need
                                        if 'value' in entry.attrib:
                                            capture_list.append(entry.attrib.get('value', ''))
                                        elif 'name' in entry.attrib:
                                            capture_list.append(entry.attrib.get('name', ''))
                                        else:
                                            capture_list.append(json.dumps(dict(entry.attrib)))
                                else:
                                    capture_list.append(entry.tag)
                            else:
                                capture_list.append(entry.tag)

                    captured_output[var_name] = capture_list

            elif 'capture_object' in output:
                capture_pattern = output['capture_object']
                entries = xml_doc.xpath(capture_pattern)

                if len(entries) == 0:
                    captured_output[var_name] = None
                elif len(entries) == 1:
                    captured_output[var_name] = xmltodict.parse(elementTree.tostring(entries.pop()))
                else:
                    capture_list = list()
                    for entry in entries:
                        capture_list.append(xmltodict.parse(elementTree.tostring(entry)))
                    captured_output[var_name] = capture_list

            elif 'capture_list' in output:
                capture_pattern = output['capture_list']
                entries = xml_doc.xpath(capture_pattern)

                capture_list = list()
                for entry in entries:
                    if isinstance(entry, str):
                        capture_list.append(entry)
                    else:
                        capture_list.append(xmltodict.parse(elementTree.tostring(entry)))

                captured_output[var_name] = capture_list

            # filter selected items here
            captured_output[var_name] = self.__filter_outputs(output, captured_output[var_name], local_context)

        except ParseError:
            logger.error('Could not parse XML document in output_utils')
            # just return blank captured_outputs here
            raise SkilletLoaderException(f'Could not parse output as XML in {self.name}')

        return captured_output

    def _handle_base64_outputs(self, results: str) -> dict:
        """
         Parses results and returns a dict containing base64 encoded values
         :param results: string as returned from some action, to be encoded as base64
         :return: dict containing all outputs found from the capture pattern in each output
         """

        outputs = dict()

        snippet_name = 'unknown'
        if 'name' in self.metadata:
            snippet_name = self.metadata['name']

        try:
            if 'outputs' not in self.metadata:
                logger.info(f'No output defined in this snippet {snippet_name}')
                return outputs

            for output in self.metadata['outputs']:
                if 'name' not in output:
                    continue

                results_as_bytes = bytes(results, 'utf-8')
                encoded_results = urlsafe_b64encode(results_as_bytes)
                var_name = output['name']
                outputs[var_name] = encoded_results.decode('utf-8')

        except TypeError:
            raise SkilletLoaderException(f'Could not base64 encode results {snippet_name}')

        return outputs

    def __handle_json_outputs(self, output_definition: dict, results: str) -> dict:
        """
        Parses results and returns a dict containing base64 encoded values

        output_type: json
        outputs:
          - name: salt_auth_token
            capture_object: '$.return[0].token'

        See here for more jsonpath examples: https://github.com/h2non/jsonpath-ng

        :param results: string as returned from some action, to be parsed as JSON
        :return: dict containing all outputs found from the capture pattern in each output
        """
        captured_output = dict()

        local_context = self.context.copy()
        output = self.__render_output_metadata(output_definition, local_context)

        try:
            for i in ('capture_pattern', 'capture_value', 'capture_object'):
                if i in output:
                    capture_pattern = output[i]
                else:
                    continue

                if 'name' not in output:
                    continue

                # some Skillet types may return us json already, check if results are actually a str like object
                # before trying to convert
                if type(results) is not str or type(results) is not bytes or type(results) is not bytearray:
                    json_object = results
                else:
                    json_object = json.loads(results)

                var_name = output['name']

                # short cut for just getting all the results
                if capture_pattern == '$' or capture_pattern == '.':
                    captured_output[var_name] = json_object
                    continue

                jsonpath_expr = parse(capture_pattern)
                result = jsonpath_expr.find(json_object)
                if len(result) == 1:
                    captured_output[var_name] = str(result[0].value)
                else:
                    # FR #81 - add ability to capture from a list
                    capture_list = list()
                    for r in result:
                        capture_list.append(r.value)

                    captured_output[var_name] = capture_list

        except ValueError as ve:
            logger.error('Caught error converting results to json')
            captured_output['fail_message'] = str(ve)
        except Exception as e:
            logger.error('Unknown exception here!')
            logger.error(e)
            captured_output['fail_message'] = str(e)

        return captured_output

    def __handle_manual_outputs(self, output_definition: dict, results: str) -> dict:
        """
        Manually set a value in the context, this could be useful with 'when' conditionals
        :param results: results from snippet execution, ignored in this method
        :return: dict containing manually defined name / value pair
        """
        outputs = dict()

        if 'name' not in output_definition or 'capture_value' not in output_definition:
            return outputs

        var_name = output_definition['name']
        value = str(self.render(output_definition['capture_value'], self.context))

        outputs[var_name] = value

        return outputs

    def _handle_manual_outputs(self, results: str) -> dict:
        """
        Manually set a value in the context, this could be useful with 'when' conditionals
        :param results: results from snippet execution, ignored in this method
        :return: dict containing manually defined name / value pair
        """
        outputs = dict()

        try:
            if 'outputs' not in self.metadata:
                logger.info('No outputs defined in this snippet')
                return outputs

            for output in self.metadata['outputs']:

                if 'name' not in output:
                    continue

                var_name = output['name']
                value = output['value']

                outputs[var_name] = value

        except KeyError as ke:
            logger.error(f'Could not locate required attributes for manual output: {ke} in snippet: {self.name}')

        return outputs
