#!/usr/bin/env python
# Copyright (c) 2019 - The Procedural Generation for Gazebo authors
# For information on the respective copyright owner see the NOTICE file
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import logging
import numpy as np
from pcg_gazebo.parsers import parse_sdf, parse_xacro
from pcg_gazebo.visualization import plot_occupancy_grid
from pcg_gazebo.simulation import World
from pcg_gazebo.utils import is_string


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        'Generate occupancy grid map from a SDF world'
        ' file or the current scenario in Gazebo')
    parser.add_argument(
        '--world-file',
        '-w',
        type=str,
        help='SDF world filename')
    parser.add_argument(
        '--input-topic',
        type=str,
        help='Receive world XML file per ROS topic')
    parser.add_argument(
        '--xml',
        type=str,
        help='Receive world XML as string')
    parser.add_argument(
        '--from-simulation',
        '-s',
        action='store_true',
        help='Retrieve world description from the current Gazebo simulation')
    parser.add_argument(
        '--z-levels',
        '-l',
        nargs='+',
        type=float,
        help='Z levels to compute the grid map from')
    parser.add_argument(
        '--min-z',
        default=0.0,
        type=float,
        help='Minimum height for the Z rays in the ray tracing grid')
    parser.add_argument(
        '--max-z',
        default=1.0,
        type=float,
        help='Maximum height for the Z rays in the ray tracing grid')
    parser.add_argument(
        '--without-ground-plane',
        action='store_true',
        help='Ignore ground plane meshes from the map')
    parser.add_argument(
        '--occupied-color',
        default=0,
        type=int,
        help='Gray-scale color of the occupied cells')
    parser.add_argument(
        '--free-color',
        default=1,
        type=float,
        help='Gray-scale color of the free cells')
    parser.add_argument(
        '--unavailable-color',
        default=0.5,
        type=float,
        help='Gray-scale color of the unavailable cells')
    parser.add_argument(
        '--output-dir',
        default='/tmp',
        type=str,
        help='Output directory to store the map')
    parser.add_argument(
        '--output-filename',
        type=str,
        help='Name of the output map file')
    parser.add_argument(
        '--static-models-only',
        action='store_true',
        help='Uses only static models for the map construction')
    parser.add_argument(
        '--dpi',
        default=200,
        type=int,
        help='Figure DPI')
    parser.add_argument(
        '--figure-width',
        default=15,
        type=float,
        help='Width of the figure')
    parser.add_argument(
        '--figure-height',
        default=15,
        type=float,
        help='Height of the figure')
    parser.add_argument(
        '--figure-size-unit',
        default='cm',
        type=str,
        help='Figure size unit [cm, m or inch]')
    parser.add_argument(
        '--exclude-contains',
        type=str,
        nargs='+',
        help='List of keywords for model names to be excluded from the map')
    parser.add_argument(
        '--ground-plane-models',
        type=str,
        nargs='+',
        help='List of models that will be considered ground plane')
    parser.add_argument(
        '--map-x-limits',
        type=float,
        nargs='+',
        help='X limits of the output map in meters')
    parser.add_argument(
        '--map-y-limits',
        type=float,
        nargs='+',
        help='Y limits of the output map in meters')
    parser.add_argument(
        '--use-visual',
        action='store_true',
        help='Use visual meshes instead of collision')

    args = parser.parse_args()

    # Configure logger
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s | %(levelname)s | %(module)s | %(message)s',
        datefmt='%m-%d %H:%M',)

    console = logging.StreamHandler()
    console.setLevel(logging.ERROR)
    # add the handler to the root logger
    logging.getLogger('').addHandler(console)

    assert 0 <= args.occupied_color <= 1, \
        'Occupied color must be between 0 and 1'
    assert 0 <= args.unavailable_color <= 1, \
        'Unavailable color must be between 0 and 1'
    assert 0 <= args.free_color <= 1, \
        'Free color must be between 0 and 1'

    if args.z_levels is None:
        assert args.min_z is not None and args.max_z is not None, \
            'If the Z levels to compute the grid map are not provided,' \
            'the min_z and max_z must be available'
        assert args.min_z < args.max_z, \
            'Invalid min_z and max_y interval limits'
        z_levels = np.linspace(args.min_z, args.max_z, 5)
    else:
        z_levels = args.z_levels

    logging.info(
        'Grid map to be computed on the following Z levels={}'.format(
            z_levels))
    if args.map_x_limits is None:
        x_limits = None
    else:
        x_limits = args.map_x_limits
        assert len(x_limits) == 2, 'X limits must be provided as two elements'
        assert x_limits[0] < x_limits[1], 'X limits are invalid'
        logging.info('Map X limits: {}'.format(x_limits))

    if args.map_y_limits is None:
        y_limits = None
    else:
        y_limits = args.map_y_limits
        assert len(y_limits) == 2, 'Y limits must be provided as two elements'
        assert y_limits[0] < y_limits[1], 'Y limits are invalid'
        logging.info('Map Y limits: {}'.format(y_limits))

    if args.world_file is not None:
        assert os.path.isfile(args.world_file), \
            'World file is invalid, filename={}'.format(
                args.world_file)
        logging.info(
            'Reading the world description from SDF file, '
            'filename={}'.format(args.world_file))

        if args.world_file.endswith('.xacro'):
            logging.info('Parsing xacro file')
            sdf = parse_xacro(args.world_file, output_type='sdf')
        else:
            logging.info('Parsing SDF file')
            sdf = parse_sdf(args.world_file)

        world = World.from_sdf(sdf)

        logging.info('SDF world file was parsed, world name={}'.format(
            world.name))

        if args.output_filename is None:
            output_filename = os.path.basename(
                args.world_file).split('.')[0] + '.pgm'
        else:
            output_filename = args.output_filename
    elif args.input_topic is not None:
        try:
            import rospy
            from std_msgs.msg import String
        except ImportError:
            logging.error('rospy is not available, cannot start node')
            exit(1)
        if rospy.is_shutdown():
            raise rospy.ROSInitException('ROS master is not running!')

        logging.info(
            'Waiting to receive the world XML file through topic {}'.format(
                args.input_topic))
        rospy.init_node('generate_occupancy_map', anonymous=True)
        topic_xml = rospy.wait_for_message(args.input_topic, String)

        logging.info('World XML received through topic ' + args.input_topic)
        sdf = parse_sdf(topic_xml.data)

        world = World.from_sdf(sdf)

        logging.info('SDF world file was parsed, world name={}'.format(
            world.name))

        if args.output_filename is None:
            output_filename = '{}.pgm'.format(world.name)
        else:
            output_filename = args.output_filename
    elif args.xml is not None:
        logging.info('Parsing XML to SDF world')
        assert is_string(args.xml), 'Input XML is not a string'
        sdf = parse_sdf(args.xml)
        world = World.from_sdf(sdf)
    else:
        logging.error('No world input provided')
        exit(1)

    # Checking for ground-plane models
    if args.ground_plane_models is None:
        gp_models = ['ground_plane']
    else:
        gp_models = args.ground_plane_models

    logging.info('Flag ground plane models with names={}'.format(gp_models))
    for model_name in gp_models:
        if model_name in world.models:
            logging.info('Flagging model {} as ground plane'.format(
                model_name))
            world.set_as_ground_plane(model_name)

    logging.info('Generating occupancy grid')

    if args.exclude_contains:
        logging.info(
            'Excluding models that include the following keywords: {}'.format(
                args.exclude_contains))
        exclude_contains = args.exclude_contains
    else:
        exclude_contains = ['ground_plane']

    logging.info('Excluded models from the grid map computation')

    logging.info('Plotting occupancy grid map')
    plot_occupancy_grid(
        world.models,
        z_levels=z_levels,
        with_ground_plane=not args.without_ground_plane,
        static_models_only=args.static_models_only,
        dpi=args.dpi,
        fig_size=(args.figure_width, args.figure_height),
        fig_size_unit=args.figure_size_unit,
        occupied_color=[args.occupied_color for _ in range(3)],
        free_color=[args.free_color for _ in range(3)],
        unavailable_color=[args.unavailable_color for _ in range(3)],
        output_folder=args.output_dir,
        output_filename=output_filename,
        exclude_contains=exclude_contains,
        axis_x_limits=x_limits,
        axis_y_limits=y_limits,
        mesh_type='visual' if args.use_visual else 'collision',
        ground_plane_models=gp_models)

    logging.info('Output map filename: {}'.format(
        os.path.join(args.output_dir, output_filename)))
    logging.info('Map stored at: {}'.format(
        os.path.join(args.output_dir, output_filename)))
