import abc import nacl.bindings from .key_schedule import tls_sha256, tls_sha384 from Crypto.Cipher import AES from Crypto.Util import Counter from Crypto.Util.number import long_to_bytes, bytes_to_long from Utils import * # GF(2^128) defined by 1 + a + a^2 + a^7 + a^128 # Please note the MSB is x0 and LSB is x127 def gf_2_128_mul(x, y): assert x < (1 << 128) assert y < (1 << 128) res = 0 for i in range(127, -1, -1): res ^= x * ((y >> i) & 1) # branchless x = (x >> 1) ^ ((x & 1) * 0xE1000000000000000000000000000000) assert res < 1 << 128 return res class InvalidInputException(Exception): def __init__(self, msg): self.msg = msg def __str__(self): return str(self.msg) class InvalidTagException(Exception): def __str__(self): return 'The authenticaiton tag is invalid.' class TLS_AEAD_Cipher(abc.ABC): NONCE_LEN = 12 @property @abc.abstractmethod def KEY_LEN(self): "" @property @abc.abstractmethod def MAC_LEN(self): "" @property @abc.abstractmethod def tls_hash(self): "" @abc.abstractmethod def cipher(self): "" def __init__(self, secret): self.reset(secret) def reset(self, secret): self.secret = secret self.key = self.tls_hash.derive_key(self.secret, self.KEY_LEN) print("\n\tkey\t\t", self.key.hex()) iv = self.tls_hash.derive_iv(self.secret, self.NONCE_LEN) self.ivhex = iv self.iv = int.from_bytes(iv, "big") print("\n\tiv\t\t", iv.hex()) self.sequence_number = 0 def next_application_traffic_secret(self): return self.tls_hash.hkdf_expand_label( self.secret, b"traffic upd", b"", self.tls_hash.hash_len ) def update_traffic_secret(self): self.reset(self.next_application_traffic_secret()) def verify_data(self, msg): return self.tls_hash.verify_data(self.secret, msg) def get_nonce(self): nonce = self.sequence_number ^ self.iv nonce = nonce.to_bytes(self.NONCE_LEN, "big") simplePrint("IV", self.iv) simplePrint("Key", self.key.hex()) simplePrint("nonce", nonce.hex()) simplePrint("sequence_number", self.sequence_number) self.sequence_number += 1 return nonce def decrypt(self, ciphertext, associated_data): cipher = self.cipher() simplePrint("ciphertext", ciphertext[: -self.MAC_LEN].tobytes().hex()) simplePrint("Tag", ciphertext[-self.MAC_LEN:].tobytes().hex()) simplePrint("associated_data", associated_data.tobytes().hex()) cipher.update(associated_data) return cipher.decrypt_and_verify( ciphertext[: -self.MAC_LEN], ciphertext[-self.MAC_LEN:] ) def encrypt(self, plaintext, associated_data): cipher = self.cipher() self.update2(associated_data, cipher) ciphertext, tag = cipher.encrypt_and_digest(plaintext) return ciphertext + tag def update2(self, associated_data, cipher): cipher.update(associated_data) def tls_ciphertext(self, plaintext): head = b"\x17\x03\x03" + \ (len(plaintext) + self.MAC_LEN).to_bytes(2, "big") return head + self.encrypt(plaintext, head) class TLS_AES_128_GCM_SHA256(TLS_AEAD_Cipher): KEY_LEN = 16 MAC_LEN = 16 tls_hash = tls_sha256 def cipher(self): return AES.new( self.key, AES.MODE_GCM, nonce=self.get_nonce(), mac_len=self.MAC_LEN )