key_schedule.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import hmac
  2. import hkdf
  3. import hashlib
  4. class TlsHash:
  5. def __init__(self, hashmod=hashlib.sha256):
  6. self.hashmod = hashmod
  7. self.hash_len = hashmod().digest_size
  8. def hkdf_extract(self, salt: bytes, input_key_material: bytes) -> bytes:
  9. if input_key_material is None:
  10. input_key_material = b"\x00" * self.hash_len
  11. return hkdf.hkdf_extract(salt, input_key_material, self.hashmod)
  12. def hkdf_label(self, label: bytes, context: bytes, length: int) -> bytes:
  13. label = b"tls13 " + label
  14. return (
  15. length.to_bytes(2, "big")
  16. + len(label).to_bytes(1, "big")
  17. + label
  18. + len(context).to_bytes(1, "big")
  19. + context
  20. )
  21. def hkdf_expand_label(
  22. self, secret: bytes, label: bytes, context: bytes, length: int
  23. ) -> bytes:
  24. hkdf_label = self.hkdf_label(label, context, length)
  25. return hkdf.hkdf_expand(secret, hkdf_label, length, self.hashmod)
  26. def derive_secret(self, secret: bytes, label: bytes, messages) -> bytes:
  27. if type(messages) == list:
  28. messages = b"".join(messages)
  29. return self.hkdf_expand_label(
  30. secret, label, self.hashmod(messages).digest(), self.hash_len
  31. )
  32. def transcript_hash(self, *msgs):
  33. return self.hashmod(b"".join(msgs)).digest()
  34. # def transcript_hash(self, client_hello_data, *others):
  35. # digest = self.hashmod(client_hello_data).digest()
  36. # return self.hashmod(
  37. # b"\xfe\x00\x00"
  38. # + self.hash_len.to_bytes(1, "big")
  39. # + digest
  40. # + b"".join(others)
  41. # ).digest()
  42. def derive_key(self, secret: bytes, key_length: int) -> bytes:
  43. return self.hkdf_expand_label(secret, b"key", b"", key_length)
  44. def derive_iv(self, secret: bytes, iv_length: int) -> bytes:
  45. return self.hkdf_expand_label(secret, b"iv", b"", iv_length)
  46. def finished_key(self, base_key: bytes) -> bytes:
  47. return self.hkdf_expand_label(base_key, b"finished", b"", self.hash_len)
  48. def verify_data(self, secret: bytes, msg: bytes) -> bytes:
  49. return hmac.new(
  50. self.finished_key(secret), self.transcript_hash(msg), self.hashmod
  51. ).digest()
  52. def scheduler(self, ecdhe: bytes, psk: bytes = None):
  53. return KeyScheduler(self, ecdhe, psk)
  54. tls_sha256 = TlsHash()
  55. tls_sha384 = TlsHash(hashlib.sha384)
  56. class PSKWrapper:
  57. def __init__(self, psk: bytes, tls_hash=tls_sha256, is_ext: bool = True):
  58. self.tls_hash = tls_hash
  59. self.early_secret = self.tls_hash.hkdf_extract(None, psk)
  60. self.is_ext = is_ext
  61. def ext_binder_key(self) -> bytes:
  62. return self.tls_hash.derive_secret(self.early_secret, b"ext binder", b"")
  63. def res_binder_key(self) -> bytes:
  64. return self.tls_hash.derive_secret(self.early_secret, b"res binder", b"")
  65. def binder_key(self) -> bytes:
  66. return self.ext_binder_key() if self.is_ext else self.res_binder_key()
  67. def client_early_traffic_secret(self, messages) -> bytes:
  68. return self.tls_hash.derive_secret(self.early_secret, b"c e traffic", messages)
  69. def early_exporter_master_secret(self, messages) -> bytes:
  70. return self.tls_hash.derive_secret(self.early_secret, b"e exp master", messages)
  71. class KeyScheduler:
  72. def __init__(self, tls_hash, ecdhe: bytes, psk: bytes = None):
  73. self.tls_hash = tls_hash
  74. self.ecdhe = ecdhe
  75. self.early_secret = self.tls_hash.hkdf_extract(None, psk)
  76. self.first_salt = self.tls_hash.derive_secret(
  77. self.early_secret, b"derived", b""
  78. )
  79. self.handshake_secret = self.tls_hash.hkdf_extract(self.first_salt, self.ecdhe)
  80. self.second_salt = self.tls_hash.derive_secret(
  81. self.handshake_secret, b"derived", b""
  82. )
  83. self.master_secret = self.tls_hash.hkdf_extract(self.second_salt, None)
  84. def client_handshake_traffic_secret(self, messages) -> bytes:
  85. return self.tls_hash.derive_secret(
  86. self.handshake_secret, b"c hs traffic", messages
  87. )
  88. def server_handshake_traffic_secret(self, messages) -> bytes:
  89. return self.tls_hash.derive_secret(
  90. self.handshake_secret, b"s hs traffic", messages
  91. )
  92. def client_application_traffic_secret_0(self, messages) -> bytes:
  93. return self.tls_hash.derive_secret(
  94. self.master_secret, b"c ap traffic", messages
  95. )
  96. def server_application_traffic_secret_0(self, messages) -> bytes:
  97. return self.tls_hash.derive_secret(
  98. self.master_secret, b"s ap traffic", messages
  99. )
  100. def application_traffic_secret_N(self, last_secret) -> bytes:
  101. return self.tls_hash.hkdf_expand_label(
  102. last_secret, b"traffic upd", b"", self.tls_hash.hash_len
  103. )
  104. def exporter_master_secret(self, messages) -> bytes:
  105. return self.tls_hash.derive_secret(self.master_secret, b"exp master", messages)
  106. def resumption_master_secret(self, messages) -> bytes:
  107. return self.tls_hash.derive_secret(self.master_secret, b"res master", messages)
  108. def resumption_psk(self, messages, ticket_nonce: bytes) -> bytes:
  109. secret = self.resumption_master_secret(messages)
  110. return self.tls_hash.hkdf_expand_label(
  111. secret, b"resumption", ticket_nonce, self.tls_hash.hash_len
  112. )