import os
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes


def generate_random_key():
    """
    Returns a new random cryptographic key
    """
    return Fernet.generate_key()

def generate_password_key(password, salt):
    """
    Returns a password-based cryptographic key generated by scrypt
    when given a password string
    """
    password_bytes = str.encode(password)

    # Scrypt is a key derivation function (KDF) that generates a
    # cryptographic key. It's main benefit is that is bottlenecked
    # by a machine's memory access speed. This means that it would
    # be difficult for an adversary to brute force attack scrypt

    kdf = Scrypt(
        salt=salt,  # Random salt to prevent brute-force
        length=32,  # Output length of bytes
        n=2**14,    # Computational cost, reccomended value in https://www.tarsnap.com/scrypt/scrypt.pdf
        r=8,        # Block size
        p=1,        # Parallelization parameter
    )

    key = base64.urlsafe_b64encode(kdf.derive(password_bytes))
    return key


def encrypt(key, message):
    """
    Returns an AES-CBC encryption of the message under the key
    """

    # Fernet uses AES in CBC mode with a 128-bit key for encryption
    # and decryption. It uses the standard PKCS7 padding

    f = Fernet(key)
    ciphertext = None
    if type(message) == bytes:
        ciphertext = f.encrypt(message)
    else:
        message_bytes = str.encode(message)
        ciphertext = f.encrypt(message_bytes)
    return ciphertext


def decrypt(key, ciphertext):
    """
    Returns an AES-CBC decryption of the ciphertext under the key
    """

    if type(ciphertext) is str:
        ciphertext = str.encode(ciphertext)

    # Fernet uses AES in CBC mode with a 128-bit key for encryption
    # and decryption. It uses the standard PKCS7 padding

    f = Fernet(key)
    message = f.decrypt(ciphertext)
    return message


if __name__ == "__main__":
    # (1) Encrypt a file (random cryptographic key)
    file_key = generate_random_key()
    enc_file = encrypt(file_key, "contents of the file")
    print("Encrypted File:", enc_file)

    # (2) Encrypt a row entry (password-based key)
    salt = os.urandom(16)
    pb_key = generate_password_key("password123", salt)
    enc_entry = encrypt(pb_key, "entry: " + file_key.decode())
    print("Encrypted Entry:", enc_entry)
    print()

    # (3) Decrypt a row entry (password-based key)
    pb_key = generate_password_key("password123", salt) # Note that salt can be stored in plaintext
    dec_entry = decrypt(pb_key, enc_entry)
    print("Decrypted Entry:", dec_entry)

    # (4) Decrypt a file (random cryptographic key)
    file_key = dec_entry[7:] # get rid of text
    dec_file = decrypt(file_key, enc_file)
    print("Decrypted File:", dec_file)
