aes_m.mpc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. from copy import copy
  2. rcon_raw = [
  3. 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a,
  4. 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39,
  5. 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a,
  6. 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8,
  7. 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef,
  8. 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc,
  9. 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b,
  10. 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3,
  11. 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94,
  12. 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20,
  13. 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35,
  14. 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f,
  15. 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04,
  16. 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63,
  17. 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd,
  18. 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb
  19. ]
  20. nparallel = 1
  21. noutput = 1
  22. nthreads = 1
  23. rcon = VectorArray(len(rcon_raw), cgf2n, nparallel)
  24. for idx in range(len(rcon_raw)):
  25. rcon[idx] = cgf2n(rcon_raw[idx],size=nparallel)
  26. powers2 = VectorArray(8, cgf2n, nparallel)
  27. for idx in range(8):
  28. powers2[idx] = cgf2n(2,size=nparallel) ** (5 * idx)
  29. @vectorize
  30. def ApplyEmbedding(x):
  31. in_bytes = x.bit_decompose(8)
  32. out_bytes = [cgf2n(0) for _ in range(8)]
  33. out_bytes[0] = sum(in_bytes[0:8])
  34. out_bytes[1] = sum(in_bytes[idx] for idx in range(1, 8, 2))
  35. out_bytes[2] = in_bytes[2] + in_bytes[3] + in_bytes[6] + in_bytes[7]
  36. out_bytes[3] = in_bytes[3] + in_bytes[7]
  37. out_bytes[4] = in_bytes[4] + in_bytes[5] + in_bytes[6] + in_bytes[7]
  38. out_bytes[5] = in_bytes[5] + in_bytes[7]
  39. out_bytes[6] = in_bytes[6] + in_bytes[7]
  40. out_bytes[7] = in_bytes[7]
  41. return sum(powers2[idx] * out_bytes[idx] for idx in range(8))
  42. def embed_helper(in_bytes):
  43. out_bytes = [None] * 8
  44. out_bytes[0] = sum(in_bytes[0:8])
  45. out_bytes[1] = sum(in_bytes[idx] for idx in range(1, 8, 2))
  46. out_bytes[2] = in_bytes[2] + in_bytes[3] + in_bytes[6] + in_bytes[7]
  47. out_bytes[3] = in_bytes[3] + in_bytes[7]
  48. out_bytes[4] = in_bytes[4] + in_bytes[5] + in_bytes[6] + in_bytes[7]
  49. out_bytes[5] = in_bytes[5] + in_bytes[7]
  50. out_bytes[6] = in_bytes[6] + in_bytes[7]
  51. out_bytes[7] = in_bytes[7]
  52. return out_bytes
  53. @vectorize
  54. def ApplyBDEmbedding(x):
  55. entire_sequence_bits = copy(x)
  56. while len(entire_sequence_bits) < 8:
  57. entire_sequence_bits.append(0)
  58. in_bytes = entire_sequence_bits
  59. out_bytes = embed_helper(in_bytes)
  60. return sum(powers2[idx] * out_bytes[idx] for idx in range(8))
  61. def PreprocInverseEmbedding(x):
  62. in_bytes = x.bit_decompose_embedding()
  63. out_bytes = [cgf2n(0) for _ in range(8)]
  64. out_bytes[7] = in_bytes[7]
  65. out_bytes[6] = in_bytes[6] + out_bytes[7]
  66. out_bytes[5] = in_bytes[5] + out_bytes[7]
  67. out_bytes[4] = in_bytes[4] + out_bytes[5] + out_bytes[6] + out_bytes[7]
  68. out_bytes[3] = in_bytes[3] + out_bytes[7]
  69. out_bytes[2] = in_bytes[2] + out_bytes[3] + out_bytes[6] + out_bytes[7]
  70. out_bytes[1] = in_bytes[1] + out_bytes[3] + out_bytes[5] + out_bytes[7]
  71. out_bytes[0] = in_bytes[0] + sum(out_bytes[1:8])
  72. return out_bytes
  73. @vectorize
  74. def InverseEmbedding(x):
  75. out_bytes = PreprocInverseEmbedding(x)
  76. ret = cgf2n(0)
  77. for idx in range(7, -1, -1):
  78. ret = ret + (cgf2n(2) ** idx) * out_bytes[idx]
  79. return ret
  80. def InverseBDEmbedding(x):
  81. return PreprocInverseEmbedding(x)
  82. def expandAESKey(cipherKey, Nr = 10, Nb = 4, Nk = 4):
  83. #cipherkey should be in hex
  84. cipherKeySize = len(cipherKey)
  85. round_key = [sgf2n(0,size=nparallel)] * 176
  86. temp = [cgf2n(0,size=nparallel)] * 4
  87. for i in range(Nk):
  88. for j in range(4):
  89. round_key[4 * i + j] = cipherKey[4 * i + j]
  90. for i in range(Nk, Nb * (Nr + 1)):
  91. for j in range(4):
  92. temp[j] = round_key[(i-1) * 4 + j]
  93. if i % Nk == 0:
  94. #rotate the 4 bytes word to the left
  95. k = temp[0]
  96. temp[0] = temp[1]
  97. temp[1] = temp[2]
  98. temp[2] = temp[3]
  99. temp[3] = k
  100. #now substitute word
  101. temp[0] = box.apply_sbox(temp[0])
  102. temp[1] = box.apply_sbox(temp[1])
  103. temp[2] = box.apply_sbox(temp[2])
  104. temp[3] = box.apply_sbox(temp[3])
  105. temp[0] = temp[0] + ApplyEmbedding(rcon[int(i//Nk)])
  106. for j in range(4):
  107. round_key[4 * i + j] = round_key[4 * (i - Nk) + j] + temp[j]
  108. return round_key
  109. #Nr = 10 -> The number of rounds in AES Cipher.
  110. #Nb = 4 -> The number of columns of the AES state
  111. #Nk = 4 -> The number of words of a AES key
  112. def SecretArrayEmbedd(byte_array):
  113. return [ApplyEmbedding(_) for _ in byte_array]
  114. @vectorize
  115. def subBytes(state):
  116. for i in range(len(state)):
  117. state[i] = box.apply_sbox(state[i])
  118. def addRoundKey(roundKey):
  119. @vectorize
  120. def inner(state):
  121. for i in range(len(state)):
  122. state[i] = state[i] + roundKey[i]
  123. return inner
  124. # mixColumn takes a column and does stuff
  125. Kv = VectorArray(4, cgf2n, nparallel)
  126. Kv[1] = ApplyEmbedding(cgf2n(1,size=nparallel))
  127. Kv[2] = ApplyEmbedding(cgf2n(2,size=nparallel))
  128. Kv[3] = ApplyEmbedding(cgf2n(3,size=nparallel))
  129. Kv[4] = ApplyEmbedding(cgf2n(4,size=nparallel))
  130. @vectorize
  131. def mixColumn(column):
  132. temp = copy(column)
  133. v1 = Kv[1]
  134. v2 = Kv[2]
  135. v3 = Kv[3]
  136. v4 = Kv[4]
  137. # no multiplication
  138. doubles = [Kv[2] * t for t in temp]
  139. column[0] = doubles[0] + (temp[1] + doubles[1]) + temp[2] + temp[3]
  140. column[1] = temp[0] + doubles[1] + (temp[2] + doubles[2]) + temp[3]
  141. column[2] = temp[0] + temp[1] + doubles[2] + (temp[3] + doubles[3])
  142. column[3] = (temp[0] + doubles[0]) + temp[1] + temp[2] + doubles[3]
  143. @vectorize
  144. def mixColumns(state):
  145. for i in range(4):
  146. column = []
  147. for j in range(4):
  148. column.append(state[i*4+j])
  149. mixColumn(column)
  150. for j in range(4):
  151. state[i*4+j] = column[j]
  152. def rotate(word, n):
  153. return word[n:]+word[0:n]
  154. def shiftRows(state):
  155. for i in range(4):
  156. state[i::4] = rotate(state[i::4],i)
  157. @vectorize
  158. def state_collapse(state):
  159. return [InverseEmbedding(_) for _ in state]
  160. # such constants. very wow.
  161. _embedded_powers = [
  162. [0x1,0x2,0x4,0x8,0x10,0x20,0x40,0x80,0x100,0x200,0x400,0x800,0x1000,0x2000,0x4000,0x8000,0x10000,0x20000,0x40000,0x80000,0x100000,0x200000,0x400000,0x800000,0x1000000,0x2000000,0x4000000,0x8000000,0x10000000,0x20000000,0x40000000,0x80000000,0x100000000,0x200000000,0x400000000,0x800000000,0x1000000000,0x2000000000,0x4000000000,0x8000000000],
  163. [0x1,0x4,0x10,0x40,0x100,0x400,0x1000,0x4000,0x10000,0x40000,0x100000,0x400000,0x1000000,0x4000000,0x10000000,0x40000000,0x100000000,0x400000000,0x1000000000,0x4000000000,0x108401,0x421004,0x1084010,0x4210040,0x10840100,0x42100400,0x108401000,0x421004000,0x1084010000,0x4210040000,0x840008401,0x2100021004,0x8400084010,0x1000000842,0x4000002108,0x100021,0x400084,0x1000210,0x4000840,0x10002100],
  164. [0x1,0x10,0x100,0x1000,0x10000,0x100000,0x1000000,0x10000000,0x100000000,0x1000000000,0x108401,0x1084010,0x10840100,0x108401000,0x1084010000,0x840008401,0x8400084010,0x4000002108,0x400084,0x4000840,0x40008400,0x400084000,0x4000840000,0x8021004,0x80210040,0x802100400,0x8021004000,0x210802008,0x2108020080,0x1080010002,0x800008421,0x8000084210,0x108,0x1080,0x10800,0x108000,0x1080000,0x10800000,0x108000000,0x1080000000],
  165. [0x1,0x100,0x10000,0x1000000,0x100000000,0x108401,0x10840100,0x1084010000,0x8400084010,0x400084,0x40008400,0x4000840000,0x80210040,0x8021004000,0x2108020080,0x800008421,0x108,0x10800,0x1080000,0x108000000,0x800108401,0x10002108,0x1000210800,0x20004010,0x2000401000,0x42008020,0x4200802000,0x84200842,0x8420084200,0x2000421084,0x40000420,0x4000042000,0x10040,0x1004000,0x100400000,0x40108401,0x4010840100,0x1080200040,0x8021080010,0x2100421080],
  166. [0x1,0x10000,0x100000000,0x10840100,0x8400084010,0x40008400,0x80210040,0x2108020080,0x108,0x1080000,0x800108401,0x1000210800,0x2000401000,0x4200802000,0x8420084200,0x40000420,0x10040,0x100400000,0x4010840100,0x8021080010,0x40108421,0x1080000040,0x100421080,0x4200040100,0x1084200,0x842108401,0x1004210042,0x2008400004,0x4210000008,0x401080210,0x840108001,0x1000000840,0x100001000,0x840100,0x8401000000,0x800000001,0x84210800,0x2100001084,0x210802100,0x8001004210],
  167. [0x1,0x100000000,0x8400084010,0x80210040,0x108,0x800108401,0x2000401000,0x8420084200,0x10040,0x4010840100,0x40108421,0x100421080,0x1084200,0x1004210042,0x4210000008,0x840108001,0x100001000,0x8401000000,0x84210800,0x210802100,0x800000401,0x2100420080,0x8000004000,0x4010002,0x4000800100,0x842000420,0x8421084,0x421080210,0x80010042,0x10802108,0x800000020,0x1084,0x8401084010,0x1004200040,0x4000840108,0x100020,0x2108401000,0x8400080210,0x84210802,0x10802100],
  168. [0x1,0x8400084010,0x108,0x2000401000,0x10040,0x40108421,0x1084200,0x4210000008,0x100001000,0x84210800,0x800000401,0x8000004000,0x4000800100,0x8421084,0x80010042,0x800000020,0x8401084010,0x4000840108,0x2108401000,0x84210802,0x20,0x8000004210,0x2100,0x8401004,0x200800,0x802108420,0x21084000,0x4200842108,0x2000020000,0x1084210000,0x100421,0x1004010,0x10840008,0x108421080,0x1000200840,0x108001,0x8020004210,0x10040108,0x2108401004,0x1084210040],
  169. [0x1,0x108,0x10040,0x1084200,0x100001000,0x800000401,0x4000800100,0x80010042,0x8401084010,0x2108401000,0x20,0x2100,0x200800,0x21084000,0x2000020000,0x100421,0x10840008,0x1000200840,0x8020004210,0x2108401004,0x400,0x42000,0x4010000,0x421080000,0x21004,0x2008420,0x210800100,0x4200002,0x401000210,0x2108401084,0x8000,0x840000,0x80200000,0x8421000000,0x420080,0x40108400,0x4210002000,0x84000040,0x8020004200,0x2108400084]
  170. ]
  171. enum_squarings = VectorArray(8 * 40, cgf2n, nparallel)
  172. for i,_list in enumerate(_embedded_powers):
  173. for j,x in enumerate(_list):
  174. enum_squarings[40 * i + j] = cgf2n(x, size=nparallel)
  175. @vectorize
  176. def fancy_squaring(bd_val, exponent):
  177. #This is even more fancy; it performs directly on bit dec values
  178. #returns x ** (2 ** exp) from a bit decomposed value
  179. return sum(enum_squarings[exponent * 40 + idx] * bd_val[idx]
  180. for idx in range(len(bd_val)))
  181. def inverseMod(val):
  182. #embedded now!
  183. #returns x ** 254 using offline squaring
  184. #returns an embedded result
  185. raw_bit_dec = val.bit_decompose_embedding()
  186. bd_val = [cgf2n(0,size=nparallel)] * 40
  187. for idx in range(40):
  188. if idx % 5 == 0:
  189. bd_val[idx] = raw_bit_dec[idx // 5]
  190. bd_squared = bd_val
  191. squared_index = 2
  192. mapper = [0] * 129
  193. for idx in range(1, 8):
  194. bd_squared = fancy_squaring(bd_val, idx)
  195. mapper[squared_index] = bd_squared
  196. squared_index *= 2
  197. enum_powers = [
  198. 2, 4, 8, 16, 32, 64, 128
  199. ]
  200. inverted_product = \
  201. ((mapper[2] * mapper[4]) * (mapper[8] * mapper[16])) * ((mapper[32] * mapper[64]) * mapper[128])
  202. return inverted_product
  203. K01 = VectorArray(8, cgf2n, nparallel)
  204. for idx in range(8):
  205. K01[idx] = ApplyBDEmbedding([0,1]) ** idx
  206. class SpdzBox(object):
  207. def init_matrices(self):
  208. self.matrix_inv = [
  209. [0,0,1,0,0,1,0,1],
  210. [1,0,0,1,0,0,1,0],
  211. [0,1,0,0,1,0,0,1],
  212. [1,0,1,0,0,1,0,0],
  213. [0,1,0,1,0,0,1,0],
  214. [0,0,1,0,1,0,0,1],
  215. [1,0,0,1,0,1,0,0],
  216. [0,1,0,0,1,0,1,0]
  217. ]
  218. to_add = [1,0,1,0,0,0,0,0]
  219. self.addition_inv = [cgf2n(_,size=nparallel) for _ in to_add]
  220. self.forward_matrix = [
  221. [1,0,0,0,1,1,1,1],
  222. [1,1,0,0,0,1,1,1],
  223. [1,1,1,0,0,0,1,1],
  224. [1,1,1,1,0,0,0,1],
  225. [1,1,1,1,1,0,0,0],
  226. [0,1,1,1,1,1,0,0],
  227. [0,0,1,1,1,1,1,0],
  228. [0,0,0,1,1,1,1,1]
  229. ]
  230. forward_add = [1,1,0,0,0,1,1,0]
  231. self.forward_add = VectorArray(len(forward_add), cgf2n, nparallel)
  232. for i,x in enumerate(forward_add):
  233. self.forward_add[i] = cgf2n(x, size=nparallel)
  234. def __init__(self):
  235. constants = [
  236. 0x63, 0x8F, 0xB5, 0x01, 0xF4, 0x25, 0xF9, 0x09, 0x05
  237. ]
  238. self.powers = [
  239. 0, 127, 191, 223, 239, 247, 251, 253, 254
  240. ]
  241. self.constants = [ApplyEmbedding(cgf2n(_,size=nparallel)) for _ in constants]
  242. self.init_matrices()
  243. def forward_bit_sbox(self, emb_byte):
  244. emb_byte_inverse = inverseMod(emb_byte)
  245. unembedded_x = InverseBDEmbedding(emb_byte_inverse)
  246. linear_transform = list()
  247. for row in self.forward_matrix:
  248. result = cgf2n(0, size=nparallel)
  249. for idx in range(len(row)):
  250. result = result + unembedded_x[idx] * row[idx]
  251. linear_transform.append(result)
  252. #do the sum(linear_transfor + additive_layer)
  253. summation_bd = [0 for _ in range(8)]
  254. for idx in range(8):
  255. summation_bd[idx] = linear_transform[idx] + self.forward_add[idx]
  256. #Now raise this to power of 254
  257. result = cgf2n(0,size=nparallel)
  258. for idx in range(8):
  259. result += ApplyBDEmbedding([summation_bd[idx]]) * K01[idx];
  260. return result
  261. def apply_sbox(self, what):
  262. #applying with the multiplicative chain
  263. return self.forward_bit_sbox(what)
  264. box = SpdzBox()
  265. def aesRound(roundKey):
  266. @vectorize
  267. def inner(state):
  268. subBytes(state)
  269. shiftRows(state)
  270. mixColumns(state)
  271. addRoundKey(roundKey)(state)
  272. return inner
  273. # returns a 16-byte round key based on an expanded key and round number
  274. def createRoundKey(expandedKey, n):
  275. return expandedKey[(n*16):(n*16+16)]
  276. # wrapper function for 10 rounds of AES since we're using a 128-bit key
  277. def aesMain(expandedKey, numRounds=10):
  278. @vectorize
  279. def inner(state):
  280. roundKey = createRoundKey(expandedKey, 0)
  281. addRoundKey(roundKey)(state)
  282. for i in range(1, numRounds):
  283. roundKey = createRoundKey(expandedKey, i)
  284. aesRound(roundKey)(state)
  285. roundKey = createRoundKey(expandedKey, numRounds)
  286. subBytes(state)
  287. shiftRows(state)
  288. addRoundKey(roundKey)(state)
  289. return inner
  290. def encrypt_without_key_schedule(expandedKey):
  291. @vectorize
  292. def encrypt(plaintext):
  293. plaintext = SecretArrayEmbedd(plaintext)
  294. aesMain(expandedKey)(plaintext)
  295. return state_collapse(plaintext)
  296. return encrypt;
  297. """
  298. Test Vectors:
  299. plaintext:
  300. 6bc1bee22e409f96e93d7e117393172a
  301. key:
  302. 2b7e151628aed2a6abf7158809cf4f3c
  303. resulting cipher
  304. 3ad77bb40d7a3660a89ecaf32466ef97
  305. """
  306. def single_encryption():
  307. key = [sgf2n.get_raw_input_from(1) for _ in range(16)]
  308. message = [sgf2n.get_raw_input_from(2) for _ in range(16)]
  309. key = [ApplyEmbedding(_) for _ in key]
  310. expanded_key = expandAESKey(key)
  311. AES = encrypt_without_key_schedule(expanded_key)
  312. ciphertext = AES(message)
  313. for block in ciphertext:
  314. print_ln('%s', block.reveal())
  315. single_encryption()