key_schedule.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import hmac
  2. import hkdf
  3. import hashlib
  4. from Utils import *
  5. class TlsHash:
  6. def __init__(self, hashmod=hashlib.sha256):
  7. self.hashmod = hashmod
  8. self.hash_len = hashmod().digest_size
  9. def hkdf_extract(self, salt: bytes, input_key_material: bytes) -> bytes:
  10. if input_key_material is None:
  11. input_key_material = b"\x00" * self.hash_len
  12. return hkdf.hkdf_extract(salt, input_key_material, self.hashmod)
  13. def hkdf_label(self, label: bytes, context: bytes, length: int) -> bytes:
  14. label = b"tls13 " + label
  15. return (
  16. length.to_bytes(2, "big")
  17. + len(label).to_bytes(1, "big")
  18. + label
  19. + len(context).to_bytes(1, "big")
  20. + context
  21. )
  22. def hkdf_expand_label(
  23. self, secret: bytes, label: bytes, context: bytes, length: int
  24. ) -> bytes:
  25. hkdf_label = self.hkdf_label(label, context, length)
  26. return hkdf.hkdf_expand(secret, hkdf_label, length, self.hashmod)
  27. def derive_secret(self, secret: bytes, label: bytes, messages) -> bytes:
  28. if type(messages) == list:
  29. messages = b"".join(messages)
  30. simplePrint("\n\tHash", label, "\t", bytes(
  31. self.hashmod(messages).digest()).hex())
  32. return self.hkdf_expand_label(
  33. secret, label, self.hashmod(messages).digest(), self.hash_len
  34. )
  35. def transcript_hash(self, msgs):
  36. return self.hashmod(msgs).digest()
  37. # def transcript_hash(self, client_hello_data, *others):
  38. # digest = self.hashmod(client_hello_data).digest()
  39. # return self.hashmod(
  40. # b"\xfe\x00\x00"
  41. # + self.hash_len.to_bytes(1, "big")
  42. # + digest
  43. # + b"".join(others)
  44. # ).digest()
  45. def derive_key(self, secret: bytes, key_length: int) -> bytes:
  46. return self.hkdf_expand_label(secret, b"key", b"", key_length)
  47. def derive_iv(self, secret: bytes, iv_length: int) -> bytes:
  48. return self.hkdf_expand_label(secret, b"iv", b"", iv_length)
  49. def finished_key(self, base_key: bytes) -> bytes:
  50. return self.hkdf_expand_label(base_key, b"finished", b"", self.hash_len)
  51. def verify_data(self, secret: bytes, msg: bytes) -> bytes:
  52. a = hmac.new(
  53. self.finished_key(secret), self.transcript_hash(msg), self.hashmod
  54. ).digest()
  55. return a
  56. def scheduler(self, ecdhe: bytes, psk: bytes = None):
  57. return KeyScheduler(self, ecdhe, psk)
  58. tls_sha256 = TlsHash()
  59. tls_sha384 = TlsHash(hashlib.sha384)
  60. class PSKWrapper:
  61. def __init__(self, psk: bytes, tls_hash=tls_sha256, is_ext: bool = True):
  62. self.tls_hash = tls_hash
  63. self.early_secret = self.tls_hash.hkdf_extract(None, psk)
  64. self.is_ext = is_ext
  65. def ext_binder_key(self) -> bytes:
  66. return self.tls_hash.derive_secret(self.early_secret, b"ext binder", b"")
  67. def res_binder_key(self) -> bytes:
  68. return self.tls_hash.derive_secret(self.early_secret, b"res binder", b"")
  69. def binder_key(self) -> bytes:
  70. return self.ext_binder_key() if self.is_ext else self.res_binder_key()
  71. def client_early_traffic_secret(self, messages) -> bytes:
  72. return self.tls_hash.derive_secret(self.early_secret, b"c e traffic", messages)
  73. def early_exporter_master_secret(self, messages) -> bytes:
  74. return self.tls_hash.derive_secret(self.early_secret, b"e exp master", messages)
  75. class KeyScheduler:
  76. def __init__(self, tls_hash, ecdhe: bytes, psk: bytes = None):
  77. self.tls_hash = tls_hash
  78. self.ecdhe = ecdhe
  79. self.early_secret = self.tls_hash.hkdf_extract(None, psk)
  80. simplePrint("\n\tearly_secret\t", self.early_secret.hex())
  81. self.first_salt = self.tls_hash.derive_secret(
  82. self.early_secret, b"derived", b""
  83. )
  84. simplePrint("\n\tderived_secret\t", self.first_salt.hex())
  85. self.handshake_secret = self.tls_hash.hkdf_extract(
  86. self.first_salt, self.ecdhe)
  87. simplePrint("\n\thandshake_sec\t", self.handshake_secret.hex())
  88. self.second_salt = self.tls_hash.derive_secret(
  89. self.handshake_secret, b"derived", b""
  90. )
  91. simplePrint("\n\tsecond_salt\t", self.second_salt.hex())
  92. self.master_secret = self.tls_hash.hkdf_extract(self.second_salt, None)
  93. simplePrint("\n\tmaster_secret\t", self.master_secret.hex())
  94. def client_handshake_traffic_secret(self, messages) -> bytes:
  95. return self.tls_hash.derive_secret(
  96. self.handshake_secret, b"c hs traffic", messages
  97. )
  98. def server_handshake_traffic_secret(self, messages) -> bytes:
  99. return self.tls_hash.derive_secret(
  100. self.handshake_secret, b"s hs traffic", messages
  101. )
  102. def client_application_traffic_secret_0(self, messages) -> bytes:
  103. return self.tls_hash.derive_secret(
  104. self.master_secret, b"c ap traffic", messages
  105. )
  106. def server_application_traffic_secret_0(self, messages) -> bytes:
  107. return self.tls_hash.derive_secret(
  108. self.master_secret, b"s ap traffic", messages
  109. )
  110. def application_traffic_secret_N(self, last_secret) -> bytes:
  111. return self.tls_hash.hkdf_expand_label(
  112. last_secret, b"traffic upd", b"", self.tls_hash.hash_len
  113. )
  114. def exporter_master_secret(self, messages) -> bytes:
  115. return self.tls_hash.derive_secret(self.master_secret, b"exp master", messages)
  116. def resumption_master_secret(self, messages) -> bytes:
  117. return self.tls_hash.derive_secret(self.master_secret, b"res master", messages)
  118. def resumption_psk(self, messages, ticket_nonce: bytes) -> bytes:
  119. secret = self.resumption_master_secret(messages)
  120. return self.tls_hash.hkdf_expand_label(
  121. secret, b"resumption", ticket_nonce, self.tls_hash.hash_len
  122. )