import lief
import numpy as np

from secml_malware.attack.whitebox.c_fast_gradient_sign_evasion import CFastGradientSignMethodEvasion
from secml_malware.models import CClassifierEnd2EndMalware
from secml_malware.utils.extend_pe import create_int_list_from_x_adv


class CKreukEvasion(CFastGradientSignMethodEvasion):
	def __init__(
			self,
			end2end_model: CClassifierEnd2EndMalware,
			how_many_padding_bytes: int,
			epsilon: float,
			iterations: int = 100,
			is_debug: bool = False,
			threshold: float = 0.5,
			p_norm: float = np.infty,
			compute_slack: bool = True,
			store_checkpoints : int = None
	):
		"""
		Create the padding attack by Kreuk et al. https://arxiv.org/abs/1802.04528

		Parameters
		----------
		end2end_model : CClassifierEnd2EndMalware
			the target end-to-end model
		how_many_padding_bytes: int, optional, default 512
			how many padding bytes
		epsilon : float
			the distortion amount
		iterations : int, optional, default 100
			the number of iterations of the optimizer
		is_debug : bool, optional, default False
			if True, prints debug information during the optimization
		threshold : float, optional, default 0
			the detection threshold to bypass. Default is 0
		p_norm : float, optional, default np.infty:
			the norm to use for compute the attack
		compute_slack : bool, optional, default True
			if True, uses also the slack space between sections. Default True
		store_checkpoints: int, optional, default None
			if set, it reconstruct the samples after the number of iteration specified. Default None
		"""
		super(CKreukEvasion, self).__init__(
			end2end_model=end2end_model,
			indexes_to_perturb=[],
			epsilon=epsilon,
			iterations=iterations,
			is_debug=is_debug,
			threshold=threshold,
			penalty_regularizer=0,
			p_norm=p_norm,
			store_checkpoints=store_checkpoints
		)
		self.how_many_padding_bytes = how_many_padding_bytes
		self.compute_slack = compute_slack

	def _run(self, x0, y0, x_init=None):
		padding = self._create_pading_indexes(x0)
		self.indexes_to_perturb = self._create_slack_indexes(x0) + padding if self.compute_slack else padding
		return super(CFastGradientSignMethodEvasion, self)._run(x0, y0, x_init=x_init)

	def _create_slack_indexes(self, x0):
		x_bytes = create_int_list_from_x_adv(x0, self.classifier.get_embedding_value(),
											 self.classifier.get_is_shifting_values())
		try:
			liefpe = lief.PE.parse(x_bytes)
		except:
			return []
		window_input_length = self.classifier.get_input_max_length()
		all_slack_space = []
		for s in liefpe.sections:
			if s.size > s.virtual_size:
				all_slack_space.extend(list(range(min(window_input_length, s.offset + s.virtual_size),
												  min(window_input_length, s.offset + s.size))))
		return all_slack_space

	def _create_pading_indexes(self, x0):
		invalid_value = 256 if self.invalid_pos == -1 else self.invalid_pos
		padding_positions = x0.find(x0 == invalid_value)
		if not padding_positions:
			return []
		else:
			return list(
				range(
					padding_positions[0],
					min(x0.size, padding_positions[0] + self.how_many_padding_bytes),
				)
			)
