#!/usr/bin/env python3

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Standard library imports
from pathlib import Path
import sys

# Third-party imports
from textwrap import dedent
from typing import Iterable

import click

license = dedent(
    """
    # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License").
    # You may not use this file except in compliance with the License.
    # A copy of the License is located at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # or in the "license" file accompanying this file. This file is distributed
    # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
    # express or implied. See the License for the specific language governing
    # permissions and limitations under the License.
    """
).strip()


@click.group()
def cli() -> None:
    pass


@cli.command()
@click.argument("roots", type=click.Path(), nargs=-1, required=True)
def fix(roots: Iterable[str]) -> None:
    for root in [Path(r) for r in roots]:
        if root.is_file():
            file_paths = [root]
        else:
            file_paths = root.glob("**/*.py")

        for file_path in file_paths:
            with file_path.open() as file:
                file_content = file.read()

            if file_content.startswith(license):
                print(f"Skipping {file_path} with correct header")
                continue

            with file_path.open("w") as file:
                print(f"Prepending missing header to {file_path}")
                file.write(license)
                file.write("\n\n")
                file.write(file_content)


@cli.command()
@click.argument("roots", type=click.Path(), nargs=-1, required=True)
def check(roots: Iterable[str]) -> None:
    exit_code = 0

    for root in [Path(r) for r in roots]:
        if root.is_file():
            file_paths = [root]
        else:
            file_paths = root.glob("**/*.py")

        for file_path in file_paths:
            with file_path.open() as file:
                file_content = file.read()

            if not file_content.startswith(license):
                print(f"File {file_path} does not have correct header")
                exit_code = 1

    sys.exit(exit_code)


if __name__ == "__main__":
    cli()
