from abc import abstractmethod

from secml.array import CArray
from secml.ml.classifiers import CClassifier

from secml_malware.models import CClassifierEnd2EndMalware, CClassifierEmber


class CWrapperPhi:
	"""
	Abstract class that encapsulates a model for being used in a black-box way.
	"""

	def __init__(self, model: CClassifier):
		"""
		Creates the wrapper.

		Parameters
		----------
		model : CClassifier
		The model to wrap
		"""
		self.classifier = model

	@abstractmethod
	def extract_features(self, x: CArray):
		"""
		It maps the input sample inside the feature space of the wrapped model.

		Parameters
		----------
		x : CArray
			The sample in the input space.
		Returns
		-------
		CArray
			The feature space representation of the input sample.
		"""
		raise NotImplementedError("This method is abstract, you should implement it somewhere else!")

	def predict(self, x: CArray, return_decision_function: bool = True):
		"""
		Returns the prediction of the sample (in input space).

		Parameters
		----------
		x : CArray
			The input sample in input space.
		return_decision_function : bool, default True
			If True, it also returns the decision function value, rather than only the label.
			Default is True.
		Returns
		-------
		CArray, (CArray)
			Returns the label of the sample.
			If return_decision_function is True, it also returns the output of the decision function.
		"""
		x = x.atleast_2d()
		# feature_vectors = []
		# for i in range(x.shape[0]):
		# 	x_i = x[i, :]
		# 	padding_position = x_i.find(x_i == 256)
		# 	if padding_position:
		# 		x_i = x_i[0, :padding_position[0]]
		# 	feature_vectors.append(self.extract_features(x_i))
		# feature_vectors = CArray(feature_vectors)
		feature_vectors = self.extract_features(x)
		return self.classifier.predict(feature_vectors, return_decision_function=return_decision_function)


class CEmberWrapperPhi(CWrapperPhi):
	"""
	Class that wraps a GBDT classifier with EMBER feature set.
	"""

	def __init__(self, model: CClassifierEmber):
		"""
		Creates the wrapper of a CClassifierEmber.

		Parameters
		----------
		model : CClassifierEmber
		The GBDT model to wrap
		"""
		if not isinstance(model, CClassifierEmber):
			raise ValueError(f"Input model is {type(model)} and not CClassifierEmber")
		super().__init__(model)

	def extract_features(self, x):
		"""
		It extracts the EMBER hand-crafted features

		Parameters
		----------
		x : CArray
			The sample in the input space.
		Returns
		-------
		CArray
			The feature space representation of the input sample.
		"""
		x = x.atleast_2d()
		clf: CClassifierEmber = self.classifier
		feature_vectors = CArray.zeros((x.shape[0], 2381))
		for i in range(x.shape[0]):
			x_i = x[i, :]
			padding_positions = x_i.find(x_i == 256)
			if padding_positions:
				feature_vectors[i, :] = clf.extract_features(x_i[0, :padding_positions[0]])
			else:
				feature_vectors[i, :] = clf.extract_features(x_i)
		return feature_vectors


class CEnd2EndWrapperPhi(CWrapperPhi):
	"""
	Class that wraps an end-to-end model
	"""

	def __init__(self, model: CClassifierEnd2EndMalware):
		"""
		Creates the wrapper of a CClassifierEnd2EndMalware.

		Parameters
		----------
		model : CClassifierEnd2EndMalware
		The end to end model to wrap
		"""
		if not isinstance(model, CClassifierEnd2EndMalware):
			raise ValueError(f"Input model is {type(model)} and not CClassifierEnd2EndMalware")
		super().__init__(model)

	def extract_features(self, x):
		"""
		Crops and pads the input sample for being passed to the network.

		Parameters
		----------
		x : CArray
			The sample in the input space.
		Returns
		-------
		CArray
			The feature space representation of the input sample.
		"""
		clf: CClassifierEnd2EndMalware = self.classifier
		x = x.atleast_2d()
		padded_x = CArray.zeros((x.shape[0], clf.get_input_max_length())) + clf.get_embedding_value()
		for i in range(x.shape[0]):
			x_i = x[i, :]
			length = min(x_i.shape[-1], clf.get_input_max_length())
			padded_x[i, :length] = x_i[0, :length] + clf.get_is_shifting_values()
		return padded_x
