common.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. # -*- coding: utf-8 -*-
  2. #
  3. # SelfTest/Hash/common.py: Common code for Crypto.SelfTest.Hash
  4. #
  5. # Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
  6. #
  7. # ===================================================================
  8. # The contents of this file are dedicated to the public domain. To
  9. # the extent that dedication to the public domain is not available,
  10. # everyone is granted a worldwide, perpetual, royalty-free,
  11. # non-exclusive license to exercise all rights associated with the
  12. # contents of this file for any purpose whatsoever.
  13. # No rights are reserved.
  14. #
  15. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  16. # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  17. # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  18. # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  19. # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  20. # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  21. # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  22. # SOFTWARE.
  23. # ===================================================================
  24. """Self-testing for PyCrypto hash modules"""
  25. import unittest
  26. from binascii import a2b_hex, b2a_hex, hexlify
  27. from tls.Crypto.Util.py3compat import b, _memoryview
  28. from tls.Crypto.Util.strxor import strxor_c
  29. class _NoDefault: pass # sentinel object
  30. def _extract(d, k, default=_NoDefault):
  31. """Get an item from a dictionary, and remove it from the dictionary."""
  32. try:
  33. retval = d[k]
  34. except KeyError:
  35. if default is _NoDefault:
  36. raise
  37. return default
  38. del d[k]
  39. return retval
  40. # Generic cipher test case
  41. class CipherSelfTest(unittest.TestCase):
  42. def __init__(self, module, params):
  43. unittest.TestCase.__init__(self)
  44. self.module = module
  45. # Extract the parameters
  46. params = params.copy()
  47. self.description = _extract(params, 'description')
  48. self.key = b(_extract(params, 'key'))
  49. self.plaintext = b(_extract(params, 'plaintext'))
  50. self.ciphertext = b(_extract(params, 'ciphertext'))
  51. self.module_name = _extract(params, 'module_name', None)
  52. self.assoc_data = _extract(params, 'assoc_data', None)
  53. self.mac = _extract(params, 'mac', None)
  54. if self.assoc_data:
  55. self.mac = b(self.mac)
  56. mode = _extract(params, 'mode', None)
  57. self.mode_name = str(mode)
  58. if mode is not None:
  59. # Block cipher
  60. self.mode = getattr(self.module, "MODE_" + mode)
  61. self.iv = _extract(params, 'iv', None)
  62. if self.iv is None:
  63. self.iv = _extract(params, 'nonce', None)
  64. if self.iv is not None:
  65. self.iv = b(self.iv)
  66. else:
  67. # Stream cipher
  68. self.mode = None
  69. self.iv = _extract(params, 'iv', None)
  70. if self.iv is not None:
  71. self.iv = b(self.iv)
  72. self.extra_params = params
  73. def shortDescription(self):
  74. return self.description
  75. def _new(self):
  76. params = self.extra_params.copy()
  77. key = a2b_hex(self.key)
  78. old_style = []
  79. if self.mode is not None:
  80. old_style = [ self.mode ]
  81. if self.iv is not None:
  82. old_style += [ a2b_hex(self.iv) ]
  83. return self.module.new(key, *old_style, **params)
  84. def isMode(self, name):
  85. if not hasattr(self.module, "MODE_"+name):
  86. return False
  87. return self.mode == getattr(self.module, "MODE_"+name)
  88. def runTest(self):
  89. plaintext = a2b_hex(self.plaintext)
  90. ciphertext = a2b_hex(self.ciphertext)
  91. assoc_data = []
  92. if self.assoc_data:
  93. assoc_data = [ a2b_hex(b(x)) for x in self.assoc_data]
  94. ct = None
  95. pt = None
  96. #
  97. # Repeat the same encryption or decryption twice and verify
  98. # that the result is always the same
  99. #
  100. for i in range(2):
  101. cipher = self._new()
  102. decipher = self._new()
  103. # Only AEAD modes
  104. for comp in assoc_data:
  105. cipher.update(comp)
  106. decipher.update(comp)
  107. ctX = b2a_hex(cipher.encrypt(plaintext))
  108. ptX = b2a_hex(decipher.decrypt(ciphertext))
  109. if ct:
  110. self.assertEqual(ct, ctX)
  111. self.assertEqual(pt, ptX)
  112. ct, pt = ctX, ptX
  113. self.assertEqual(self.ciphertext, ct) # encrypt
  114. self.assertEqual(self.plaintext, pt) # decrypt
  115. if self.mac:
  116. mac = b2a_hex(cipher.digest())
  117. self.assertEqual(self.mac, mac)
  118. decipher.verify(a2b_hex(self.mac))
  119. class CipherStreamingSelfTest(CipherSelfTest):
  120. def shortDescription(self):
  121. desc = self.module_name
  122. if self.mode is not None:
  123. desc += " in %s mode" % (self.mode_name,)
  124. return "%s should behave like a stream cipher" % (desc,)
  125. def runTest(self):
  126. plaintext = a2b_hex(self.plaintext)
  127. ciphertext = a2b_hex(self.ciphertext)
  128. # The cipher should work like a stream cipher
  129. # Test counter mode encryption, 3 bytes at a time
  130. ct3 = []
  131. cipher = self._new()
  132. for i in range(0, len(plaintext), 3):
  133. ct3.append(cipher.encrypt(plaintext[i:i+3]))
  134. ct3 = b2a_hex(b("").join(ct3))
  135. self.assertEqual(self.ciphertext, ct3) # encryption (3 bytes at a time)
  136. # Test counter mode decryption, 3 bytes at a time
  137. pt3 = []
  138. cipher = self._new()
  139. for i in range(0, len(ciphertext), 3):
  140. pt3.append(cipher.encrypt(ciphertext[i:i+3]))
  141. # PY3K: This is meant to be text, do not change to bytes (data)
  142. pt3 = b2a_hex(b("").join(pt3))
  143. self.assertEqual(self.plaintext, pt3) # decryption (3 bytes at a time)
  144. class RoundtripTest(unittest.TestCase):
  145. def __init__(self, module, params):
  146. from tls.Crypto import Random
  147. unittest.TestCase.__init__(self)
  148. self.module = module
  149. self.iv = Random.get_random_bytes(module.block_size)
  150. self.key = b(params['key'])
  151. self.plaintext = 100 * b(params['plaintext'])
  152. self.module_name = params.get('module_name', None)
  153. def shortDescription(self):
  154. return """%s .decrypt() output of .encrypt() should not be garbled""" % (self.module_name,)
  155. def runTest(self):
  156. ## ECB mode
  157. mode = self.module.MODE_ECB
  158. encryption_cipher = self.module.new(a2b_hex(self.key), mode)
  159. ciphertext = encryption_cipher.encrypt(self.plaintext)
  160. decryption_cipher = self.module.new(a2b_hex(self.key), mode)
  161. decrypted_plaintext = decryption_cipher.decrypt(ciphertext)
  162. self.assertEqual(self.plaintext, decrypted_plaintext)
  163. class IVLengthTest(unittest.TestCase):
  164. def __init__(self, module, params):
  165. unittest.TestCase.__init__(self)
  166. self.module = module
  167. self.key = b(params['key'])
  168. def shortDescription(self):
  169. return "Check that all modes except MODE_ECB and MODE_CTR require an IV of the proper length"
  170. def runTest(self):
  171. self.assertRaises(TypeError, self.module.new, a2b_hex(self.key),
  172. self.module.MODE_ECB, b(""))
  173. def _dummy_counter(self):
  174. return "\0" * self.module.block_size
  175. class NoDefaultECBTest(unittest.TestCase):
  176. def __init__(self, module, params):
  177. unittest.TestCase.__init__(self)
  178. self.module = module
  179. self.key = b(params['key'])
  180. def runTest(self):
  181. self.assertRaises(TypeError, self.module.new, a2b_hex(self.key))
  182. class ByteArrayTest(unittest.TestCase):
  183. """Verify we can use bytearray's for encrypting and decrypting"""
  184. def __init__(self, module, params):
  185. unittest.TestCase.__init__(self)
  186. self.module = module
  187. # Extract the parameters
  188. params = params.copy()
  189. self.description = _extract(params, 'description')
  190. self.key = b(_extract(params, 'key'))
  191. self.plaintext = b(_extract(params, 'plaintext'))
  192. self.ciphertext = b(_extract(params, 'ciphertext'))
  193. self.module_name = _extract(params, 'module_name', None)
  194. self.assoc_data = _extract(params, 'assoc_data', None)
  195. self.mac = _extract(params, 'mac', None)
  196. if self.assoc_data:
  197. self.mac = b(self.mac)
  198. mode = _extract(params, 'mode', None)
  199. self.mode_name = str(mode)
  200. if mode is not None:
  201. # Block cipher
  202. self.mode = getattr(self.module, "MODE_" + mode)
  203. self.iv = _extract(params, 'iv', None)
  204. if self.iv is None:
  205. self.iv = _extract(params, 'nonce', None)
  206. if self.iv is not None:
  207. self.iv = b(self.iv)
  208. else:
  209. # Stream cipher
  210. self.mode = None
  211. self.iv = _extract(params, 'iv', None)
  212. if self.iv is not None:
  213. self.iv = b(self.iv)
  214. self.extra_params = params
  215. def _new(self):
  216. params = self.extra_params.copy()
  217. key = a2b_hex(self.key)
  218. old_style = []
  219. if self.mode is not None:
  220. old_style = [ self.mode ]
  221. if self.iv is not None:
  222. old_style += [ a2b_hex(self.iv) ]
  223. return self.module.new(key, *old_style, **params)
  224. def runTest(self):
  225. plaintext = a2b_hex(self.plaintext)
  226. ciphertext = a2b_hex(self.ciphertext)
  227. assoc_data = []
  228. if self.assoc_data:
  229. assoc_data = [ bytearray(a2b_hex(b(x))) for x in self.assoc_data]
  230. cipher = self._new()
  231. decipher = self._new()
  232. # Only AEAD modes
  233. for comp in assoc_data:
  234. cipher.update(comp)
  235. decipher.update(comp)
  236. ct = b2a_hex(cipher.encrypt(bytearray(plaintext)))
  237. pt = b2a_hex(decipher.decrypt(bytearray(ciphertext)))
  238. self.assertEqual(self.ciphertext, ct) # encrypt
  239. self.assertEqual(self.plaintext, pt) # decrypt
  240. if self.mac:
  241. mac = b2a_hex(cipher.digest())
  242. self.assertEqual(self.mac, mac)
  243. decipher.verify(bytearray(a2b_hex(self.mac)))
  244. class MemoryviewTest(unittest.TestCase):
  245. """Verify we can use memoryviews for encrypting and decrypting"""
  246. def __init__(self, module, params):
  247. unittest.TestCase.__init__(self)
  248. self.module = module
  249. # Extract the parameters
  250. params = params.copy()
  251. self.description = _extract(params, 'description')
  252. self.key = b(_extract(params, 'key'))
  253. self.plaintext = b(_extract(params, 'plaintext'))
  254. self.ciphertext = b(_extract(params, 'ciphertext'))
  255. self.module_name = _extract(params, 'module_name', None)
  256. self.assoc_data = _extract(params, 'assoc_data', None)
  257. self.mac = _extract(params, 'mac', None)
  258. if self.assoc_data:
  259. self.mac = b(self.mac)
  260. mode = _extract(params, 'mode', None)
  261. self.mode_name = str(mode)
  262. if mode is not None:
  263. # Block cipher
  264. self.mode = getattr(self.module, "MODE_" + mode)
  265. self.iv = _extract(params, 'iv', None)
  266. if self.iv is None:
  267. self.iv = _extract(params, 'nonce', None)
  268. if self.iv is not None:
  269. self.iv = b(self.iv)
  270. else:
  271. # Stream cipher
  272. self.mode = None
  273. self.iv = _extract(params, 'iv', None)
  274. if self.iv is not None:
  275. self.iv = b(self.iv)
  276. self.extra_params = params
  277. def _new(self):
  278. params = self.extra_params.copy()
  279. key = a2b_hex(self.key)
  280. old_style = []
  281. if self.mode is not None:
  282. old_style = [ self.mode ]
  283. if self.iv is not None:
  284. old_style += [ a2b_hex(self.iv) ]
  285. return self.module.new(key, *old_style, **params)
  286. def runTest(self):
  287. plaintext = a2b_hex(self.plaintext)
  288. ciphertext = a2b_hex(self.ciphertext)
  289. assoc_data = []
  290. if self.assoc_data:
  291. assoc_data = [ memoryview(a2b_hex(b(x))) for x in self.assoc_data]
  292. cipher = self._new()
  293. decipher = self._new()
  294. # Only AEAD modes
  295. for comp in assoc_data:
  296. cipher.update(comp)
  297. decipher.update(comp)
  298. ct = b2a_hex(cipher.encrypt(memoryview(plaintext)))
  299. pt = b2a_hex(decipher.decrypt(memoryview(ciphertext)))
  300. self.assertEqual(self.ciphertext, ct) # encrypt
  301. self.assertEqual(self.plaintext, pt) # decrypt
  302. if self.mac:
  303. mac = b2a_hex(cipher.digest())
  304. self.assertEqual(self.mac, mac)
  305. decipher.verify(memoryview(a2b_hex(self.mac)))
  306. def make_block_tests(module, module_name, test_data, additional_params=dict()):
  307. tests = []
  308. extra_tests_added = False
  309. for i in range(len(test_data)):
  310. row = test_data[i]
  311. # Build the "params" dictionary with
  312. # - plaintext
  313. # - ciphertext
  314. # - key
  315. # - mode (default is ECB)
  316. # - (optionally) description
  317. # - (optionally) any other parameter that this cipher mode requires
  318. params = {}
  319. if len(row) == 3:
  320. (params['plaintext'], params['ciphertext'], params['key']) = row
  321. elif len(row) == 4:
  322. (params['plaintext'], params['ciphertext'], params['key'], params['description']) = row
  323. elif len(row) == 5:
  324. (params['plaintext'], params['ciphertext'], params['key'], params['description'], extra_params) = row
  325. params.update(extra_params)
  326. else:
  327. raise AssertionError("Unsupported tuple size %d" % (len(row),))
  328. if not "mode" in params:
  329. params["mode"] = "ECB"
  330. # Build the display-name for the test
  331. p2 = params.copy()
  332. p_key = _extract(p2, 'key')
  333. p_plaintext = _extract(p2, 'plaintext')
  334. p_ciphertext = _extract(p2, 'ciphertext')
  335. p_mode = _extract(p2, 'mode')
  336. p_description = _extract(p2, 'description', None)
  337. if p_description is not None:
  338. description = p_description
  339. elif p_mode == 'ECB' and not p2:
  340. description = "p=%s, k=%s" % (p_plaintext, p_key)
  341. else:
  342. description = "p=%s, k=%s, %r" % (p_plaintext, p_key, p2)
  343. name = "%s #%d: %s" % (module_name, i+1, description)
  344. params['description'] = name
  345. params['module_name'] = module_name
  346. params.update(additional_params)
  347. # Add extra test(s) to the test suite before the current test
  348. if not extra_tests_added:
  349. tests += [
  350. RoundtripTest(module, params),
  351. IVLengthTest(module, params),
  352. NoDefaultECBTest(module, params),
  353. ByteArrayTest(module, params),
  354. ]
  355. extra_tests_added = True
  356. # Add the current test to the test suite
  357. tests.append(CipherSelfTest(module, params))
  358. return tests
  359. def make_stream_tests(module, module_name, test_data):
  360. tests = []
  361. extra_tests_added = False
  362. for i in range(len(test_data)):
  363. row = test_data[i]
  364. # Build the "params" dictionary
  365. params = {}
  366. if len(row) == 3:
  367. (params['plaintext'], params['ciphertext'], params['key']) = row
  368. elif len(row) == 4:
  369. (params['plaintext'], params['ciphertext'], params['key'], params['description']) = row
  370. elif len(row) == 5:
  371. (params['plaintext'], params['ciphertext'], params['key'], params['description'], extra_params) = row
  372. params.update(extra_params)
  373. else:
  374. raise AssertionError("Unsupported tuple size %d" % (len(row),))
  375. # Build the display-name for the test
  376. p2 = params.copy()
  377. p_key = _extract(p2, 'key')
  378. p_plaintext = _extract(p2, 'plaintext')
  379. p_ciphertext = _extract(p2, 'ciphertext')
  380. p_description = _extract(p2, 'description', None)
  381. if p_description is not None:
  382. description = p_description
  383. elif not p2:
  384. description = "p=%s, k=%s" % (p_plaintext, p_key)
  385. else:
  386. description = "p=%s, k=%s, %r" % (p_plaintext, p_key, p2)
  387. name = "%s #%d: %s" % (module_name, i+1, description)
  388. params['description'] = name
  389. params['module_name'] = module_name
  390. # Add extra test(s) to the test suite before the current test
  391. if not extra_tests_added:
  392. tests += [
  393. ByteArrayTest(module, params),
  394. ]
  395. import sys
  396. if sys.version[:3] != '2.6':
  397. tests.append(MemoryviewTest(module, params))
  398. extra_tests_added = True
  399. # Add the test to the test suite
  400. tests.append(CipherSelfTest(module, params))
  401. tests.append(CipherStreamingSelfTest(module, params))
  402. return tests
  403. # vim:set ts=4 sw=4 sts=4 expandtab: