123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- from Crypto.Cipher import AES
- from Crypto.Util import Counter
- from Crypto.Util.number import long_to_bytes, bytes_to_long
- from Utils import *
- # GF(2^128) defined by 1 + a + a^2 + a^7 + a^128
- # Please note the MSB is x0 and LSB is x127
- def gf_2_128_mul(x, y):
- assert x < (1 << 128)
- assert y < (1 << 128)
- res = 0
- for i in range(127, -1, -1):
- res ^= x * ((y >> i) & 1) # branchless
- x = (x >> 1) ^ ((x & 1) * 0xE1000000000000000000000000000000)
- assert res < 1 << 128
- return res
- class InvalidInputException(Exception):
- def __init__(self, msg):
- self.msg = msg
- def __str__(self):
- return str(self.msg)
- class InvalidTagException(Exception):
- def __str__(self):
- return 'The authenticaiton tag is invalid.'
- # Galois/Counter Mode with AES-128 and 96-bit IV
- class AES_GCM:
- def __init__(self, master_key, init_value,seq=0):
- self.change_key(master_key)
- self.seq = seq
- self.init_value = init_value
- self.b = b''
- def change_key(self, master_key):
- if master_key >= (1 << 128):
- raise InvalidInputException('Master key should be 128-bit')
- self.__master_key = long_to_bytes(master_key, 16)
- self.__aes_ecb = AES.new(self.__master_key, AES.MODE_ECB)
- self.__auth_key = bytes_to_long(self.__aes_ecb.encrypt(b'\x00' * 16))
- print("\nEK0", self.__auth_key)
- table = [] # for 8-bit
- for i in range(16):
- row = []
- for j in range(256):
- row.append(gf_2_128_mul(self.__auth_key, j << (8 * i)))
- table.append(tuple(row))
- self.__pre_table = tuple(table)
- self.prev_init_value = None # reset
- def __times_auth_key(self, val):
- res = 0
- for i in range(16):
- res ^= self.__pre_table[i][val & 0xFF]
- val >>= 8
- return res
- def __ghash(self, aad, txt):
- len_aad = len(aad)
- len_txt = len(txt)
- if 0 == len_aad % 16:
- data = aad
- else:
- data = aad + b'\x00' * (16 - len_aad % 16)
- if 0 == len_txt % 16:
- data += txt
- else:
- data += txt + b'\x00' * (16 - len_txt % 16)
- tag = 0
- assert len(data) % 16 == 0
- for i in range(len(data) // 16):
- tag ^= bytes_to_long(data[i * 16: (i + 1) * 16])
- tag = self.__times_auth_key(tag)
- tag ^= ((8 * len_aad) << 64) | (8 * len_txt)
- tag = self.__times_auth_key(tag)
- return tag
- def getNounce(self):
- print(self.init_value)
- n = self.init_value ^ self.seq
- print(n)
- self.seq += 1
- return n
- def decrypt(self, c):
- t = []
- while len(c) > 0:
- assert(c[:3] == b'\x17\x03\x03')
- a = bytes_to_long(c[3:5])
- auth_data = c[:5]
- cipher = c[5:5+a-16]
- auth_tag = c[5+a-16:5+a]
- bPrint("auth_data", auth_data.hex())
- bPrint("cipher", cipher.hex())
- bPrint("auth_tag", auth_tag.hex())
- t.append(self._decrypt(cipher, auth_tag, auth_data))
- c = c[5+a:]
- return t
- def _decrypt(self, ciphertext, auth_tag, auth_data=b''):
- init_value = self.getNounce()
- # if auth_tag != self.__ghash(auth_data, ciphertext) ^\
- # bytes_to_long(self.__aes_ecb.encrypt(
- # long_to_bytes((init_value << 32) | 1, 16))):
- # raise InvalidTagException
- # TODO CHECK TAG
- len_ciphertext = len(ciphertext)
- if len_ciphertext > 0:
- counter = Counter.new(
- nbits=32,
- prefix=long_to_bytes(init_value, 12),
- initial_value=2,
- allow_wraparound=True)
- aes_ctr = AES.new(self.__master_key, AES.MODE_CTR, counter=counter)
- if 0 != len_ciphertext % 16:
- padded_ciphertext = ciphertext + \
- b'\x00' * (16 - len_ciphertext % 16)
- else:
- padded_ciphertext = ciphertext
- plaintext = aes_ctr.decrypt(padded_ciphertext)[:len_ciphertext]
- else:
- plaintext = b''
- self.b += plaintext
- return plaintext
|