#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#--------------------------------------------------------------------------
import sys
import os
import pytest
import logging
import uuid
import warnings
import datetime
from logging.handlers import RotatingFileHandler

from azure.identity import EnvironmentCredential
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.eventhub import EventHubManagementClient
from azure.eventhub import EventHubProducerClient
from azure.eventhub._pyamqp import ReceiveClient
from azure.eventhub._pyamqp.authentication import SASTokenAuth
from azure.eventhub.extensions.checkpointstoreblob import BlobCheckpointStore
from azure.eventhub.extensions.checkpointstoreblobaio import BlobCheckpointStore as BlobCheckpointStoreAsync
try:
    import uamqp
    uamqp_transport_params = [True, False]
    uamqp_transport_ids = ["uamqp", "pyamqp"]
except (ModuleNotFoundError, ImportError):
    uamqp_transport_params = [False]
    uamqp_transport_ids = ["pyamqp"]

from devtools_testutils import get_region_override
from tracing_common import FakeSpan

collect_ignore = []
PARTITION_COUNT = 2
CONN_STR = "Endpoint=sb://{}/;SharedAccessKeyName={};SharedAccessKey={};EntityPath={}"
RES_GROUP_PREFIX = "eh-res-group"
NAMESPACE_PREFIX = "eh-ns"
EVENTHUB_PREFIX = "eh"
EVENTHUB_DEFAULT_AUTH_RULE_NAME = 'RootManageSharedAccessKey'
LOCATION = get_region_override("westus")


def pytest_addoption(parser):
    parser.addoption(
        "--sleep", action="store", default="True", help="sleep on reconnect test: True or False"
    )


@pytest.fixture
def sleep(request):
    sleep = request.config.getoption("--sleep")
    return sleep.lower() in ('true', 'yes', '1', 'y')

@pytest.fixture(scope="session", params=uamqp_transport_params, ids=uamqp_transport_ids)
def uamqp_transport(request):
    return request.param

@pytest.fixture(scope="session")    
def storage_connection_str():
    try:
        return os.environ['AZURE_STORAGE_CONN_STR']
    except KeyError:
        pytest.skip('AZURE_STORAGE_CONN_STR undefined')
        return

@pytest.fixture()    
def checkpoint_store(storage_connection_str):
    checkpoint_store = BlobCheckpointStore.from_connection_string(storage_connection_str, "blobcontainer" + str(uuid.uuid4()))
    return checkpoint_store

@pytest.fixture()    
def checkpoint_store_aio(storage_connection_str):
    checkpoint_store = BlobCheckpointStoreAsync.from_connection_string(storage_connection_str, "blobcontainer" + str(uuid.uuid4()))
    return checkpoint_store

def get_logger(filename, level=logging.INFO):
    azure_logger = logging.getLogger("azure.eventhub")
    azure_logger.setLevel(level)
    uamqp_logger = logging.getLogger("uamqp")
    uamqp_logger.setLevel(logging.INFO)

    formatter = logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
    console_handler = logging.StreamHandler(stream=sys.stdout)
    console_handler.setFormatter(formatter)
    if not azure_logger.handlers:
        azure_logger.addHandler(console_handler)
    if not uamqp_logger.handlers:
        uamqp_logger.addHandler(console_handler)

    if filename:
        file_handler = RotatingFileHandler(filename, maxBytes=5*1024*1024, backupCount=2)
        file_handler.setFormatter(formatter)
        azure_logger.addHandler(file_handler)

    return azure_logger


log = get_logger(None, logging.DEBUG)


@pytest.fixture(scope="session")
def timeout_factor(uamqp_transport):
    if uamqp_transport:
        return 1000
    else:
        return 1

@pytest.fixture(scope="session")
def fake_span():
    return FakeSpan


@pytest.fixture(scope="session")
def get_credential():
    use_pwsh = os.environ.get("AZURE_TEST_USE_PWSH_AUTH", "false")
    use_cli = os.environ.get("AZURE_TEST_USE_CLI_AUTH", "false")

    # User-based authentication through Azure PowerShell, if requested
    if use_pwsh.lower() == "true":
        log.info(
            "Environment variable AZURE_TEST_USE_PWSH_AUTH set to 'true'. Using AzurePowerShellCredential."
        )
        from azure.identity import AzurePowerShellCredential
        return AzurePowerShellCredential()
    # User-based authentication through Azure CLI, if requested
    elif use_cli.lower() == "true":
        log.info("Environment variable AZURE_TEST_USE_CLI_AUTH set to 'true'. Using AzureCliCredential.")
        from azure.identity import AzureCliCredential
        return AzureCliCredential()
    return EnvironmentCredential()


@pytest.fixture(scope="session")
def resource_group(get_credential):
    try:
        SUBSCRIPTION_ID = os.environ["AZURE_SUBSCRIPTION_ID"]
    except KeyError:
        pytest.skip('AZURE_SUBSCRIPTION_ID undefined')
        return
    base_url = os.environ.get("EVENTHUB_RESOURCE_MANAGER_URL", "https://management.azure.com/")
    credential_scopes = ["{}.default".format(base_url)]
    resource_client = ResourceManagementClient(get_credential, SUBSCRIPTION_ID, base_url=base_url, credential_scopes=credential_scopes)
    resource_group_name = RES_GROUP_PREFIX + str(uuid.uuid4())
    parameters = {"location": LOCATION}
    expiry = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)
    parameters['tags'] = {'DeleteAfter': expiry.replace(microsecond=0).isoformat()}
    try:
        rg = resource_client.resource_groups.create_or_update(
            resource_group_name,
            parameters
        )
        yield rg
    finally:
        try:
            resource_client.resource_groups.begin_delete(resource_group_name)
        except:
            warnings.warn(UserWarning("resource group teardown failed"))


@pytest.fixture(scope="session")
def eventhub_namespace(resource_group, get_credential):
    try:
        SUBSCRIPTION_ID = os.environ["AZURE_SUBSCRIPTION_ID"]
    except KeyError:
        pytest.skip('AZURE_SUBSCRIPTION_ID defined')
        return
    base_url = os.environ.get("EVENTHUB_RESOURCE_MANAGER_URL", "https://management.azure.com/")
    credential_scopes = ["{}.default".format(base_url)]
    resource_client = EventHubManagementClient(get_credential, SUBSCRIPTION_ID, base_url=base_url, credential_scopes=credential_scopes)
    namespace_name = NAMESPACE_PREFIX + str(uuid.uuid4())
    try:
        namespace = resource_client.namespaces.begin_create_or_update(
            resource_group.name, namespace_name, {"location": LOCATION}
        ).result()
        key = resource_client.namespaces.list_keys(resource_group.name, namespace_name, EVENTHUB_DEFAULT_AUTH_RULE_NAME)
        connection_string = key.primary_connection_string
        key_name = key.key_name
        primary_key = key.primary_key
        yield namespace.name, connection_string, key_name, primary_key
    finally:
        try:
            resource_client.namespaces.begin_delete(resource_group.name, namespace_name).wait()
        except:
            warnings.warn(UserWarning("eventhub namespace teardown failed"))


@pytest.fixture()
def live_eventhub(resource_group, eventhub_namespace, get_credential):  # pylint: disable=redefined-outer-name
    try:
        SUBSCRIPTION_ID = os.environ["AZURE_SUBSCRIPTION_ID"]
    except KeyError:
        pytest.skip('AZURE_SUBSCRIPTION_ID defined')
        return
    base_url = os.environ.get("EVENTHUB_RESOURCE_MANAGER_URL", "https://management.azure.com/")
    credential_scopes = ["{}.default".format(base_url)]
    resource_client = EventHubManagementClient(get_credential, SUBSCRIPTION_ID, base_url=base_url, credential_scopes=credential_scopes)
    eventhub_name = EVENTHUB_PREFIX + str(uuid.uuid4())
    eventhub_ns_name, connection_string, key_name, primary_key = eventhub_namespace
    eventhub_endpoint_suffix = os.environ.get("EVENT_HUB_ENDPOINT_SUFFIX", ".servicebus.windows.net")
    try:
        eventhub = resource_client.event_hubs.create_or_update(
            resource_group.name, eventhub_ns_name, eventhub_name, {"partition_count": PARTITION_COUNT}
        )
        live_eventhub_config = {
            'resource_group': resource_group.name,
            'hostname': "{}{}".format(eventhub_ns_name, eventhub_endpoint_suffix),
            'key_name': key_name,
            'access_key': primary_key,
            'namespace': eventhub_ns_name,
            'event_hub': eventhub.name,
            'consumer_group': '$Default',
            'partition': '0',
            'connection_str': connection_string + ";EntityPath="+eventhub.name
        }
        yield live_eventhub_config
    finally:
        try:
            resource_client.event_hubs.delete(resource_group.name, eventhub_ns_name, eventhub_name)
        except:
            warnings.warn(UserWarning("eventhub teardown failed"))

@pytest.fixture()
def resource_mgmt_client(get_credential):
    try:
        SUBSCRIPTION_ID = os.environ["AZURE_SUBSCRIPTION_ID"]
    except KeyError:
        pytest.skip('AZURE_SUBSCRIPTION_ID defined')
        return
    base_url = os.environ.get("EVENTHUB_RESOURCE_MANAGER_URL", "https://management.azure.com/")
    credential_scopes = ["{}.default".format(base_url)]
    resource_client = EventHubManagementClient(get_credential, SUBSCRIPTION_ID, base_url=base_url, credential_scopes=credential_scopes)
    yield resource_client

@pytest.fixture()
def connection_str(live_eventhub):
    return CONN_STR.format(
        live_eventhub['hostname'],
        live_eventhub['key_name'],
        live_eventhub['access_key'],
        live_eventhub['event_hub'])


@pytest.fixture()
def invalid_hostname(live_eventhub):
    return CONN_STR.format(
        "invalid123.servicebus.windows.net",
        live_eventhub['key_name'],
        live_eventhub['access_key'],
        live_eventhub['event_hub'])


@pytest.fixture()
def invalid_key(live_eventhub):
    return CONN_STR.format(
        live_eventhub['hostname'],
        live_eventhub['key_name'],
        "invalid",
        live_eventhub['event_hub'])


@pytest.fixture()
def invalid_policy(live_eventhub):
    return CONN_STR.format(
        live_eventhub['hostname'],
        "invalid",
        live_eventhub['access_key'],
        live_eventhub['event_hub'])


@pytest.fixture()
def connstr_receivers(live_eventhub, uamqp_transport):
    connection_str = live_eventhub["connection_str"]
    partitions = [str(i) for i in range(PARTITION_COUNT)]
    receivers = []
    for p in partitions:
        uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub'])
        source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format(
            live_eventhub['hostname'],
            live_eventhub['event_hub'],
            live_eventhub['consumer_group'],
            p)
        if uamqp_transport:
            sas_auth = uamqp.authentication.SASTokenAuth.from_shared_access_key(
                uri, live_eventhub['key_name'], live_eventhub['access_key'])
            receiver = uamqp.ReceiveClient(source, auth=sas_auth, debug=False, timeout=0, prefetch=500)
        else:
            sas_auth = SASTokenAuth(
                uri, uri, live_eventhub['key_name'], live_eventhub['access_key']
            )
            receiver = ReceiveClient(live_eventhub['hostname'], source, auth=sas_auth, network_trace=False, timeout=0, link_credit=500)
        receiver.open()
        receivers.append(receiver)
    yield connection_str, receivers
    for r in receivers:
        r.close()


@pytest.fixture()
def connstr_senders(live_eventhub, uamqp_transport):
    connection_str = live_eventhub["connection_str"]
    client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport)
    partitions = client.get_partition_ids()

    senders = []
    for p in partitions:
        sender = client._create_producer(partition_id=p)
        senders.append(sender)
    yield connection_str, senders
    for s in senders:
        s.close()
    client.close()

# Note: This is duplicated between here and the basic conftest, so that it does not throw warnings if you're
# running locally to this SDK. (Everything works properly, pytest just makes a bit of noise.)
def pytest_configure(config):
    # register an additional marker
    config.addinivalue_line(
        "markers", "liveTest: mark test to be a live test only"
    )
