import unittest
from secml.array import CArray
from secml_malware.attack.whitebox.c_format_exploit_evasion import CFormatExploitEvasion

from secml_malware.attack.whitebox.tests.malware_test_base import End2EndBaseTests
from secml_malware.attack.whitebox import CHeaderEvasion, CPaddingEvasion, CKreukEvasion, CSuciuEvasion


class EvasionEnd2EndTestSuite(End2EndBaseTests):
	def setUp(self):
		super(EvasionEnd2EndTestSuite, self).setUp()
		self.classifier.load_pretrained_model(self.ember_path)
		self.surrogate_classifier.load_pretrained_model(self.surrogate_path)
		self.Y = self.classifier.predict(CArray(self.X), return_decision_function=False)

	def test_whitebox_ember_model_attack_random_init(self):
		attack = CHeaderEvasion(
			self.classifier,
			is_debug=True,
			random_init=True,
		)
		self.assert_evasion_result(attack)

	def test_whitebox_ember_model_attack_no_random_init(self):
		attack = CHeaderEvasion(
			self.classifier,
			is_debug=True,
			random_init=False,
		)
		y_pred, scores, _, _ = attack.run(self.X, self.Y)
		n_old_y_malw = sum(self.Y == 1)
		n_false_negative = sum(self.Y == 0)
		n_new_detected_malw = sum(y_pred == 1) - n_false_negative
		self.assertNotEqual(
			n_old_y_malw,
			n_new_detected_malw,
			msg="Evasion achieved: {}/{}".format(
				self.Y.shape[0] - n_new_detected_malw, self.Y.shape[0]
			),
		)

	def test_padding_whitebox_ember_model_attack(self):
		padding_attack = CPaddingEvasion(
			self.classifier,
			1000,
			random_init=True,
			is_debug=True
		)
		self.assert_evasion_result(padding_attack)

	def test_kreuk_whitebox_attack(self):
		kreuk_attack = CKreukEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			iterations=50,
			is_debug=True
		)
		self.assert_evasion_result(kreuk_attack)

	def test_kreuk_whitebox_no_slack_attack_p2(self):
		kreuk_attack = CKreukEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			iterations=50,
			is_debug=True,
			compute_slack=False,
			p_norm=2
		)
		self.assert_evasion_result(kreuk_attack)

	def test_kreuk_whitebox_attack_p2(self):
		kreuk_attack = CKreukEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			iterations=50,
			is_debug=True,
			compute_slack=True,
			p_norm=2
		)
		self.assert_evasion_result(kreuk_attack)

	def test_kreuk_whitebox_no_slack_attack(self):
		kreuk_attack = CKreukEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			iterations=50,
			is_debug=True,
			compute_slack=False
		)
		self.assert_evasion_result(kreuk_attack)

	def test_suciu_appending_whitebox_attack(self):
		kreuk_attack = CSuciuEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			is_debug=True
		)
		self.assert_evasion_result(kreuk_attack)

	def test_suciu_appending_whitebox_no_slack_attack(self):
		kreuk_attack = CSuciuEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			is_debug=True,
			compute_slack=False
		)
		self.assert_evasion_result(kreuk_attack)

	def test_section_shift_attack(self):
		shift_attack = CFormatExploitEvasion(
			self.classifier,
			preferable_extension_amount=0x200,
			pe_header_extension=0,
			iterations=20,
			is_debug=True
		)
		self.assert_evasion_result(shift_attack)

	def test_pe_shift_attack(self):
		shift_attack = CFormatExploitEvasion(
			self.classifier,
			preferable_extension_amount=0,
			pe_header_extension=0x200,
			iterations=20,
			is_debug=True
		)
		self.assert_evasion_result(shift_attack)

	def assert_evasion_result(self, attack):
		y_pred, _, _, _ = attack.run(self.X, self.Y)
		n_old_y_malw = sum(self.Y == 1)
		n_false_negative = sum(self.Y == 0)
		n_new_detected_malw = sum(y_pred == 1) - n_false_negative
		self.assertNotEqual(
			n_old_y_malw,
			n_new_detected_malw,
			msg="Evasion achieved: {}/{}".format(
				self.Y.shape[0] - n_new_detected_malw, self.Y.shape[0]
			),
		)


if __name__ == "__main__":
	unittest.main()
