#!/usr/bin/env python
# Copyright (c) 2020 - The Procedural Generation for Gazebo authors
# For information on the respective copyright owner see the NOTICE file
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific
import os
import argparse
import logging
import numpy as np
import trimesh
import matplotlib.pyplot as plt
from pcg_gazebo.utils import is_array
from pcg_gazebo.parsers import parse_xacro, parse_sdf
from pcg_gazebo.generators import WorldGenerator
from pcg_gazebo.generators.engines import get_engine_tags
from pcg_gazebo.simulation import is_gazebo_model, \
    add_custom_gazebo_resource_path
from pcg_gazebo.visualization import plot_workspace, plot_shapely_geometry
from pcg_gazebo.generators.occupancy import generate_occupancy_grid


WORLD_GEN = WorldGenerator()


def export_world_to_file(input_filename):
    folder = os.path.dirname(input_filename)
    if len(folder) == 0:
        folder = '.'
    filename = os.path.basename(input_filename)
    if not filename.endswith('.world'):
        filename += '.world'

    i = 1
    while os.path.isfile(os.path.join(folder, filename)):
        filename = \
            os.path.basename(input_filename).split('.')[0] + \
            '_{}'.format(i)
        filename += '.world'
        i += 1

    WORLD_GEN.world.export_to_file(
        output_dir=folder,
        filename=filename,
        with_default_ground_plane=False,
        with_default_sun=False)
    return os.path.join(folder, filename)


def add_factory_asset(model, is_static):
    mass_inline = "__import__('pcg_gazebo').random.rand()"
    # A Gazebo model is already part of the assets manager

    if model == 'random_box':
        assert WORLD_GEN.add_asset(
            tag=model,
            description=dict(
                type='box',
                args=dict(
                    size="2 * __import__('pcg_gazebo').random.rand(3)",
                    name='box',
                    color='xkcd',
                    mass=0 if s else mass_inline
                )
            )
        )
        return model
    elif model == 'random_cylinder':
        assert WORLD_GEN.add_asset(
            tag=model,
            description=dict(
                type='cylinder',
                args=dict(
                    radius="2 * __import__('pcg_gazebo').random.rand()",
                    length="2 * __import__('pcg_gazebo').random.rand()",
                    name='cylinder',
                    color='xkcd',
                    mass=0 if s else mass_inline
                )
            )
        )
        return model
    elif model == 'random_sphere':
        assert WORLD_GEN.add_asset(
            tag=model,
            description=dict(
                type='sphere',
                args=dict(
                    radius="2 * __import__('pcg_gazebo').random.rand()",
                    name='sphere',
                    color='xkcd',
                    mass=0 if s else mass_inline
                )
            )
        )
        return model
    elif os.path.isfile(model) and \
            (model.endswith('.stl') or model.endswith('.dae') or
                model.endswith('.obj')):
        model_name = os.path.basename(model).split('.')[0]
        logging.info('Add mesh asset from <{}> with tag {}'.format(
            model, model_name))
        assert WORLD_GEN.add_asset(
            tag=model_name,
            description=dict(
                type='mesh',
                args=dict(
                    visual_mesh=model,
                    name=model_name,
                    color='xkcd',
                    mass=0 if s else mass_inline
                )
            )
        )
        return model_name
    else:
        raise NotImplementedError('Invalid factory description input')


if __name__ == '__main__':
    description = \
        'This script will either take a Gazebo world (or ' \
        'use an empty one if none is provided) and populate it ' \
        'with models placed within a workspace.' \
        'The workspace can be either provided as a set of 2D or ' \
        '3D points, the convex hull of the models in the world, or ' \
        'the free space found in the world grid map.'
    usage_text = '''
    pcg-populate-world --models MODEL_1 MODEL_2 --num 3 3 --min-distance 0.2 --static 1  # Models 1 and 2 must be Gazebo models
    pcg-populate-world --models random_box --num 2  # Add two random boxes to an empty world
    pcg-populate-world --models random_box random_cylinder random_sphere --num -1 -1 -1 --tangent-to-ground  # Add as many random boxes, cylinders and spheres as possible in the default workspace and place them tangent to the ground plane
    pcg-populate-world --models MESH_FILENAME --num 1  # Add mesh as model in the world
    pcg-populate-world --world WORLD_FILE --models MODEL_1 --num 5  # Add MODEL_1 to world loaded from file
    '''
    parser = argparse.ArgumentParser(
        description=description,
        epilog=usage_text,
        formatter_class=argparse.RawDescriptionHelpFormatter
    )

    parser.add_argument(
        '--world', '-w',
        type=str,
        help='Gazebo world as file or XML input to be populated')
    parser.add_argument(
        '--config', '-c',
        type=str,
        help='YAML file with the configuration for engines '
             'and models to be added to the world')
    parser.add_argument(
        '--models', '-m',
        nargs='+', type=str,
        help='Name of Gazebo model to include in the world')
    parser.add_argument(
        '--num', '-n',
        nargs='+', type=int,
        help='Number of elements to add to world (use '
             '-1 if no maximum limit is necessary)')
    parser.add_argument(
        '--min-distance', '-d',
        nargs='+',
        type=float,
        help='Minimum distance between objects (either '
             'for all models or per model')
    parser.add_argument(
        '--static', '-s',
        nargs='+',
        type=int,
        help='Whether the model should be static or not')
    parser.add_argument(
        '--export-filename', '-f',
        type=str,
        help='Name of the filename to export the generated '
             'world. If none is provided, the same filename'
             ' with a datetime suffix will be used.')
    parser.add_argument(
        '--custom-models-path', '-p',
        nargs='+',
        type=str,
        help='Custom folder containing Gazebo models')
    parser.add_argument(
        '--random-roll', action='store_true',
        help='For random engines, the roll angle will also be randomized')
    parser.add_argument(
        '--random-pitch', action='store_true',
        help='For random engines, the pitch angle will also be randomized')
    parser.add_argument(
        '--random-yaw', action='store_true',
        help='For random engines, the yaw angle will also be randomized')
    parser.add_argument(
        '--workspace', '-k',
        type=str,
        help='Name of model or set of coordinates provided as '
             '(1) [(x, y, z), (x, y, z), ...] for a 3D workpace, '
             '(2) [(x, y), (x, y), ...] for a 2D workspace, '
             '(3) name of models in the provided world that '
             'delimit the workspace (e.g. walls). '
             '(4) "gridmap" to compute a grid map of the input'
             ' world and place objects in the free space found'
             'If no workspace is defined and there is a world'
             ' with models provided, the workspace will be the'
             ' convex hull around all existing models. If no '
             'world is provided, an area of 20 x 20 m on the '
             'ground plane will be set as the workspace')
    parser.add_argument(
        '--tangent-to-ground', '-t',
        action='store_true',
        help='Place all models tangent to the ground plane')
    parser.add_argument(
        '--preview',
        action='store_true',
        help='Preview the workspace and generated world')

    args = parser.parse_args()

    # Configure logging
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s | %(levelname)s | %(module)s | %(message)s',
        datefmt='%m-%d %H:%M',)

    console = logging.StreamHandler()
    console.setLevel(logging.ERROR)
    # add the handler to the root logging
    logging.getLogger('').addHandler(console)

    if args.custom_models_path is not None:
        for models_path in args.custom_models_path:
            if not add_custom_gazebo_resource_path(models_path):
                logging.error(
                    'Could not add <{}> as a custom Gazebo'
                    ' model path'.format(models_path))
            else:
                logging.info(
                    'Added <{}> as a custom Gazebo model '
                    'path'.format(models_path))

    if args.world is None:
        logging.info('No world was provided, using empty world instead')
    else:
        logging.info('Initializing world generator')
        WORLD_GEN.init_from_sdf(args.world)
        if args.preview:
            WORLD_GEN.world.create_scene().show()

    if args.config:
        logging.info(
            'Setting configuration for world generation from '
            'input configuration file <{}>'.format(args.config))
        WORLD_GEN.from_yaml(args.config)

    # Defining a workspace
    if args.workspace is not None:
        try:
            points = np.array(eval(args.workspace))
        except BaseException:
            points = None
        if points is not None:
            assert len(points.shape) == 2, \
                'Workspace points must have either' \
                ' 2 or 3 components'
            position_dofs = ['x', 'y', 'z']

            if points.shape[1] == 2:
                WORLD_GEN.add_constraint(
                    name='workspace',
                    type='workspace',
                    frame='world',
                    geometry_type='area',
                    points=points)
            elif points.shape[1] == 3:
                WORLD_GEN.add_constraint(
                    name='workspace',
                    type='workspace',
                    frame='world',
                    geometry_type='mesh',
                    points=points)
            else:
                raise ValueError(
                    'Workspace points have the wrong '
                    'dimension = {}'.format(points.shape))
        elif args.workspace == 'gridmap':
            position_dofs = ['x', 'y']
            model_tags =  \
                [tag for tag in WORLD_GEN.world.models
                 if tag != 'ground_plane']
            free_space_polygon = WORLD_GEN.world.get_free_space_polygon(
                ground_plane_models=model_tags,
                ignore_models=['ground_plane']
            )
            WORLD_GEN.add_constraint(
                name='workspace',
                type='workspace',
                frame='world',
                geometry_type='polygon',
                polygon=free_space_polygon)
        else:
            logging.info(
                'Computing the workspace limited by the models'
                ' provided in the workspace list = {}'.format(
                    args.workspace))
            position_dofs = ['x', 'y', 'z']

            if is_array(args.workspace):
                ws = args.workspace
            else:
                ws = [args.workspace]
            group = WORLD_GEN.world.to_model_group(
                include_models=ws,
                ignore_models=['ground_plane'])
            WORLD_GEN.add_constraint(
                name='workspace',
                type='workspace',
                frame='world',
                geometry_type='mesh',
                entity=group)
    else:
        if len(WORLD_GEN.world.models):
            # Trying to guess the free space
            position_dofs = ['x', 'y', 'z']
            WORLD_GEN.add_constraint(
                name='workspace',
                type='workspace',
                frame='world',
                geometry_type='mesh',
                entity=WORLD_GEN.world)
        else:
            position_dofs = ['x', 'y']
            xx, yy = np.meshgrid([-10, 10], [-10, 10])
            WORLD_GEN.add_constraint(
                name='workspace',
                type='workspace',
                frame='world',
                geometry_type='area',
                points=[[x, y] for x, y in
                        zip(xx.flatten(), yy.flatten())])

    if args.tangent_to_ground:
        WORLD_GEN.add_constraint(
            name='tangent_to_ground_plane',
            type='tangent',
            frame='world',
            reference=dict(
                type='plane',
                args=dict(
                    origin=[0, 0, 0],
                    normal=[0, 0, 1]
                )
            )
        )

    if args.preview:
        fig, ax = plot_workspace(WORLD_GEN.constraints.get('workspace'))
        ax.set_xlabel('X [m]')
        ax.set_ylabel('Y [m]')
        try:
            ax.set_zlabel('Z [m]')
        except BaseException:
            pass
        ax.set_title('Workspace')
        plt.show()

    if args.models is not None:
        if args.min_distance is not None:
            if len(args.min_distance) == 1:
                min_distance = [args.min_distance[0]
                                for _ in range(len(args.models))]
            else:
                assert len(args.min_distance) == len(args.models), \
                    'Number of min. distance inputs must be equal ' \
                    'to number of models'
                min_distance = args.min_distance
        else:
            min_distance = [0 for _ in range(len(args.models))]

        if args.static is not None:
            if len(args.static) == 1:
                static = [args.static[0] for _ in range(len(args.models))]
            else:
                assert len(args.static) == len(args.models), \
                    'Number of static flags must be equal ' \
                    'to number of models'
                static = args.static
        else:
            static = [True for _ in range(len(args.models))]

        if args.num is not None:
            if len(args.num) == 1:
                num = [args.num[0] for _ in range(len(args.models))]
            else:
                assert len(args.models) == len(args.num), \
                    'List of models and number of elements' \
                    ' must have the same length'
                num = args.num
        else:
            num = [None for _ in range(len(args.models))]

        max_num = dict()
        constraints = list()

        use_random_picker = False
        model_names = list()
        for model, n, md, s in zip(
                args.models, num, min_distance, static):
            if not is_gazebo_model(model):
                # Add the model as a factory model description
                model_names.append(add_factory_asset(model, is_static=s))
                use_random_picker = True
            else:
                model_names.append(model)

            logging.info(
                'Added model={}, # models={}, min. '
                'distance={}, static={}'.format(
                    model_names[-1], n, md, s))
            if n is None or n < 1:
                max_num[model_names[-1]] = None
                use_random_picker = True
            else:
                max_num[model_names[-1]] = n

            if args.tangent_to_ground:
                constraints.append(dict(
                    model=model_names[-1],
                    constraint='tangent_to_ground_plane'))

        policy = dict(
            models=model_names,
            config=list())

        policy['config'].append(dict())
        policy['config'][0]['dofs'] = position_dofs
        policy['config'][0]['tag'] = 'workspace'
        policy['config'][0]['workspace'] = 'workspace'

        a_dofs = list()
        for a_dof, a_tag in zip(
                [args.random_roll, args.random_pitch, args.random_yaw],
                ['roll', 'pitch', 'yaw']):
            if a_dof:
                a_dofs.append(a_tag)

        if len(a_dofs) > 0:
            policy['config'].append(dict())
            policy['config'][-1]['dofs'] = a_dofs
            policy['config'][-1]['tag'] = 'uniform'
            policy['config'][-1]['min'] = -np.pi
            policy['config'][-1]['max'] = np.pi

        WORLD_GEN.add_engine(
            tag='engine',
            engine_name='random_pose',
            models=model_names,
            max_num=max_num,
            model_picker='random' if use_random_picker else 'size',
            no_collision=True,
            min_distance=md,
            policies=[policy],
            constraints=constraints
        )

    WORLD_GEN.run_engines(attach_models=True)

    if args.preview:
        WORLD_GEN.world.create_scene().show()

    if args.export_filename:
        filename = export_world_to_file(args.export_filename)
    elif args.world is not None and not args.export_filename:
        if os.path.isfile(args.world):
            filename = export_world_to_file(args.world)
        else:
            filename = export_world_to_file(
                os.path.join(
                    os.path.expanduser('~'),
                    '.gazebo',
                    'worlds',
                    WORLD_GEN.world.name + '.world'))
    else:
        filename = export_world_to_file(
            os.path.join(
                os.path.expanduser('~'),
                '.gazebo',
                'worlds',
                WORLD_GEN.world.name + '.world'))

    logging.info('World exported at: {}'.format(filename))
