from Crypto.Cipher import AES from Crypto.Util import Counter from Crypto.Util.number import long_to_bytes, bytes_to_long from Utils import * # GF(2^128) defined by 1 + a + a^2 + a^7 + a^128 # Please note the MSB is x0 and LSB is x127 def gf_2_128_mul(x, y): assert x < (1 << 128) assert y < (1 << 128) res = 0 for i in range(127, -1, -1): res ^= x * ((y >> i) & 1) # branchless x = (x >> 1) ^ ((x & 1) * 0xE1000000000000000000000000000000) assert res < 1 << 128 return res class InvalidInputException(Exception): def __init__(self, msg): self.msg = msg def __str__(self): return str(self.msg) class InvalidTagException(Exception): def __str__(self): return 'The authenticaiton tag is invalid.' # Galois/Counter Mode with AES-128 and 96-bit IV class AES_GCM: def __init__(self, master_key, init_value,seq=0): self.change_key(master_key) self.seq = seq self.init_value = init_value self.b = b'' def change_key(self, master_key): if master_key >= (1 << 128): raise InvalidInputException('Master key should be 128-bit') self.__master_key = long_to_bytes(master_key, 16) self.__aes_ecb = AES.new(self.__master_key, AES.MODE_ECB) self.__auth_key = bytes_to_long(self.__aes_ecb.encrypt(b'\x00' * 16)) print("\nEK0", self.__auth_key) table = [] # for 8-bit for i in range(16): row = [] for j in range(256): row.append(gf_2_128_mul(self.__auth_key, j << (8 * i))) table.append(tuple(row)) self.__pre_table = tuple(table) self.prev_init_value = None # reset def __times_auth_key(self, val): res = 0 for i in range(16): res ^= self.__pre_table[i][val & 0xFF] val >>= 8 return res def __ghash(self, aad, txt): len_aad = len(aad) len_txt = len(txt) if 0 == len_aad % 16: data = aad else: data = aad + b'\x00' * (16 - len_aad % 16) if 0 == len_txt % 16: data += txt else: data += txt + b'\x00' * (16 - len_txt % 16) tag = 0 assert len(data) % 16 == 0 for i in range(len(data) // 16): tag ^= bytes_to_long(data[i * 16: (i + 1) * 16]) tag = self.__times_auth_key(tag) tag ^= ((8 * len_aad) << 64) | (8 * len_txt) tag = self.__times_auth_key(tag) return tag def getNounce(self): print(self.init_value) n = self.init_value ^ self.seq print(n) self.seq += 1 return n def decrypt(self, c): t = [] while len(c) > 0: assert(c[:3] == b'\x17\x03\x03') a = bytes_to_long(c[3:5]) auth_data = c[:5] cipher = c[5:5+a-16] auth_tag = c[5+a-16:5+a] bPrint("auth_data", auth_data.hex()) bPrint("cipher", cipher.hex()) bPrint("auth_tag", auth_tag.hex()) t.append(self._decrypt(cipher, auth_tag, auth_data)) c = c[5+a:] return t def _decrypt(self, ciphertext, auth_tag, auth_data=b''): init_value = self.getNounce() # if auth_tag != self.__ghash(auth_data, ciphertext) ^\ # bytes_to_long(self.__aes_ecb.encrypt( # long_to_bytes((init_value << 32) | 1, 16))): # raise InvalidTagException # TODO CHECK TAG len_ciphertext = len(ciphertext) if len_ciphertext > 0: counter = Counter.new( nbits=32, prefix=long_to_bytes(init_value, 12), initial_value=2, allow_wraparound=True) aes_ctr = AES.new(self.__master_key, AES.MODE_CTR, counter=counter) if 0 != len_ciphertext % 16: padded_ciphertext = ciphertext + \ b'\x00' * (16 - len_ciphertext % 16) else: padded_ciphertext = ciphertext plaintext = aes_ctr.decrypt(padded_ciphertext)[:len_ciphertext] else: plaintext = b'' self.b += plaintext return plaintext