decrypt.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from Crypto.Cipher import AES
  2. from Crypto.Util import Counter
  3. from Crypto.Util.number import long_to_bytes, bytes_to_long
  4. from Utils import *
  5. # GF(2^128) defined by 1 + a + a^2 + a^7 + a^128
  6. # Please note the MSB is x0 and LSB is x127
  7. def gf_2_128_mul(x, y):
  8. assert x < (1 << 128)
  9. assert y < (1 << 128)
  10. res = 0
  11. for i in range(127, -1, -1):
  12. res ^= x * ((y >> i) & 1) # branchless
  13. x = (x >> 1) ^ ((x & 1) * 0xE1000000000000000000000000000000)
  14. assert res < 1 << 128
  15. return res
  16. class InvalidInputException(Exception):
  17. def __init__(self, msg):
  18. self.msg = msg
  19. def __str__(self):
  20. return str(self.msg)
  21. class InvalidTagException(Exception):
  22. def __str__(self):
  23. return 'The authenticaiton tag is invalid.'
  24. # Galois/Counter Mode with AES-128 and 96-bit IV
  25. class AES_GCM:
  26. def __init__(self, master_key, init_value,seq=0):
  27. self.change_key(master_key)
  28. self.seq = seq
  29. self.init_value = init_value
  30. self.b = b''
  31. def change_key(self, master_key):
  32. if master_key >= (1 << 128):
  33. raise InvalidInputException('Master key should be 128-bit')
  34. self.__master_key = long_to_bytes(master_key, 16)
  35. self.__aes_ecb = AES.new(self.__master_key, AES.MODE_ECB)
  36. self.__auth_key = bytes_to_long(self.__aes_ecb.encrypt(b'\x00' * 16))
  37. print("\nEK0", self.__auth_key)
  38. table = [] # for 8-bit
  39. for i in range(16):
  40. row = []
  41. for j in range(256):
  42. row.append(gf_2_128_mul(self.__auth_key, j << (8 * i)))
  43. table.append(tuple(row))
  44. self.__pre_table = tuple(table)
  45. self.prev_init_value = None # reset
  46. def __times_auth_key(self, val):
  47. res = 0
  48. for i in range(16):
  49. res ^= self.__pre_table[i][val & 0xFF]
  50. val >>= 8
  51. return res
  52. def __ghash(self, aad, txt):
  53. len_aad = len(aad)
  54. len_txt = len(txt)
  55. if 0 == len_aad % 16:
  56. data = aad
  57. else:
  58. data = aad + b'\x00' * (16 - len_aad % 16)
  59. if 0 == len_txt % 16:
  60. data += txt
  61. else:
  62. data += txt + b'\x00' * (16 - len_txt % 16)
  63. tag = 0
  64. assert len(data) % 16 == 0
  65. for i in range(len(data) // 16):
  66. tag ^= bytes_to_long(data[i * 16: (i + 1) * 16])
  67. tag = self.__times_auth_key(tag)
  68. tag ^= ((8 * len_aad) << 64) | (8 * len_txt)
  69. tag = self.__times_auth_key(tag)
  70. return tag
  71. def getNounce(self):
  72. print(self.init_value)
  73. n = self.init_value ^ self.seq
  74. print(n)
  75. self.seq += 1
  76. return n
  77. def decrypt(self, c):
  78. t = []
  79. while len(c) > 0:
  80. assert(c[:3] == b'\x17\x03\x03')
  81. a = bytes_to_long(c[3:5])
  82. auth_data = c[:5]
  83. cipher = c[5:5+a-16]
  84. auth_tag = c[5+a-16:5+a]
  85. bPrint("auth_data", auth_data.hex())
  86. bPrint("cipher", cipher.hex())
  87. bPrint("auth_tag", auth_tag.hex())
  88. t.append(self._decrypt(cipher, auth_tag, auth_data))
  89. c = c[5+a:]
  90. return t
  91. def _decrypt(self, ciphertext, auth_tag, auth_data=b''):
  92. init_value = self.getNounce()
  93. # if auth_tag != self.__ghash(auth_data, ciphertext) ^\
  94. # bytes_to_long(self.__aes_ecb.encrypt(
  95. # long_to_bytes((init_value << 32) | 1, 16))):
  96. # raise InvalidTagException
  97. # TODO CHECK TAG
  98. len_ciphertext = len(ciphertext)
  99. if len_ciphertext > 0:
  100. counter = Counter.new(
  101. nbits=32,
  102. prefix=long_to_bytes(init_value, 12),
  103. initial_value=2,
  104. allow_wraparound=True)
  105. aes_ctr = AES.new(self.__master_key, AES.MODE_CTR, counter=counter)
  106. if 0 != len_ciphertext % 16:
  107. padded_ciphertext = ciphertext + \
  108. b'\x00' * (16 - len_ciphertext % 16)
  109. else:
  110. padded_ciphertext = ciphertext
  111. plaintext = aes_ctr.decrypt(padded_ciphertext)[:len_ciphertext]
  112. else:
  113. plaintext = b''
  114. self.b += plaintext
  115. return plaintext