123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- import abc
- import nacl.bindings
- from .key_schedule import tls_sha256, tls_sha384
- from Crypto.Cipher import AES
- from Crypto.Util import Counter
- from Crypto.Util.number import long_to_bytes, bytes_to_long
- from Utils import *
- # GF(2^128) defined by 1 + a + a^2 + a^7 + a^128
- # Please note the MSB is x0 and LSB is x127
- def gf_2_128_mul(x, y):
- assert x < (1 << 128)
- assert y < (1 << 128)
- res = 0
- for i in range(127, -1, -1):
- res ^= x * ((y >> i) & 1) # branchless
- x = (x >> 1) ^ ((x & 1) * 0xE1000000000000000000000000000000)
- assert res < 1 << 128
- return res
- class InvalidInputException(Exception):
- def __init__(self, msg):
- self.msg = msg
- def __str__(self):
- return str(self.msg)
- class InvalidTagException(Exception):
- def __str__(self):
- return 'The authenticaiton tag is invalid.'
- class TLS_AEAD_Cipher(abc.ABC):
- NONCE_LEN = 12
- @property
- @abc.abstractmethod
- def KEY_LEN(self):
- ""
- @property
- @abc.abstractmethod
- def MAC_LEN(self):
- ""
- @property
- @abc.abstractmethod
- def tls_hash(self):
- ""
- @abc.abstractmethod
- def cipher(self):
- ""
- def __init__(self, secret):
- self.reset(secret)
- def reset(self, secret):
- self.secret = secret
- self.key = self.tls_hash.derive_key(self.secret, self.KEY_LEN)
- print("\n\tkey\t\t", self.key.hex())
- iv = self.tls_hash.derive_iv(self.secret, self.NONCE_LEN)
- self.ivhex = iv
- self.iv = int.from_bytes(iv, "big")
- print("\n\tiv\t\t", iv.hex())
- self.sequence_number = 0
- def next_application_traffic_secret(self):
- return self.tls_hash.hkdf_expand_label(
- self.secret, b"traffic upd", b"", self.tls_hash.hash_len
- )
- def update_traffic_secret(self):
- self.reset(self.next_application_traffic_secret())
- def verify_data(self, msg):
- return self.tls_hash.verify_data(self.secret, msg)
- def get_nonce(self):
- nonce = self.sequence_number ^ self.iv
- nonce = nonce.to_bytes(self.NONCE_LEN, "big")
- simplePrint("IV", self.iv)
- simplePrint("Key", self.key.hex())
- simplePrint("nonce", nonce.hex())
- simplePrint("sequence_number", self.sequence_number)
- self.sequence_number += 1
- return nonce
- def decrypt(self, ciphertext, associated_data):
- cipher = self.cipher()
- simplePrint("ciphertext",
- ciphertext[: -self.MAC_LEN].tobytes().hex())
- simplePrint("Tag",
- ciphertext[-self.MAC_LEN:].tobytes().hex())
- simplePrint("associated_data",
- associated_data.tobytes().hex())
- cipher.update(associated_data)
- return cipher.decrypt_and_verify(
- ciphertext[: -self.MAC_LEN], ciphertext[-self.MAC_LEN:]
- )
- def encrypt(self, plaintext, associated_data):
- cipher = self.cipher()
- self.update2(associated_data, cipher)
- ciphertext, tag = cipher.encrypt_and_digest(plaintext)
- return ciphertext + tag
- def update2(self, associated_data, cipher):
- cipher.update(associated_data)
- def tls_ciphertext(self, plaintext):
- head = b"\x17\x03\x03" + \
- (len(plaintext) + self.MAC_LEN).to_bytes(2, "big")
- return head + self.encrypt(plaintext, head)
- class TLS_AES_128_GCM_SHA256(TLS_AEAD_Cipher):
- KEY_LEN = 16
- MAC_LEN = 16
- tls_hash = tls_sha256
- def cipher(self):
- return AES.new(
- self.key, AES.MODE_GCM, nonce=self.get_nonce(), mac_len=self.MAC_LEN
- )
|