#  Copyright (c) maiot GmbH 2020. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at:
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
#  or implied. See the License for the specific language governing
#  permissions and limitations under the License.

import os
from typing import Text, List

import apache_beam as beam
import tensorflow as tf

from zenml.core.components.bulk_inferrer.constants import PREDICTIONS
from zenml.core.steps.base_step import BaseStep


class BaseInferrer(BaseStep):
    """
    Base inferrer class. This step is responsible for inference (batch).
    """

    def __init__(self,
                 labels: List[Text],
                 **kwargs):
        """
        Base Inferrer constructor.

        Args:
            model_uri: URI for a model, usually generated by
            TrainingPipeline and retrieved by
            `training_pipeline.get_model_uri()`.
        """
        self.labels = labels
        self.output_uri = None
        super(BaseInferrer, self).__init__(
            labels=labels,
            **kwargs,
        )

    def get_labels(self):
        return self.labels

    def set_output_uri(self, output_uri):
        self.output_uri = output_uri

    def write_inference_results(self):
        return beam.io.WriteToTFRecord(
            os.path.join(self.output_uri, PREDICTIONS),
            file_name_suffix='.gz',
            coder=beam.coders.ProtoCoder(tf.train.Example))
