__init__.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. import os
  2. import time
  3. import struct
  4. import random
  5. import iofree
  6. import typing
  7. from dataclasses import dataclass
  8. from enum import IntEnum
  9. from types import SimpleNamespace
  10. from nacl.public import PrivateKey
  11. from nacl.bindings import crypto_scalarmult
  12. from Crypto.PublicKey import RSA
  13. # from Cryptography import x509
  14. # from OpenSSL.crypto import load_certificate, FILETYPE_ASN1
  15. from . import ciphers
  16. from .key_schedule import PSKWrapper
  17. from .utils import pack_int, pack_list, pack_all
  18. MAX_LIFETIME = 24 * 3600 * 7
  19. AGE_MOD = 2 ** 32
  20. class Alert(Exception):
  21. def __init__(self, level, description):
  22. self.level = level
  23. self.description = description
  24. class MyIntEnum(IntEnum):
  25. @classmethod
  26. def from_value(cls, value: int):
  27. for e in cls:
  28. if e == value:
  29. return e
  30. raise Exception(f"Known {cls.__name__} type: {value}")
  31. class UInt8Enum(MyIntEnum):
  32. def pack(self) -> bytes:
  33. return self.to_bytes(1, "big")
  34. class UInt16Enum(MyIntEnum):
  35. def pack(self) -> bytes:
  36. return self.to_bytes(2, "big")
  37. class HandshakeType(UInt8Enum):
  38. client_hello = 1
  39. server_hello = 2
  40. new_session_ticket = 4
  41. end_of_early_data = 5
  42. encrypted_extensions = 8
  43. certificate = 11
  44. certificate_request = 13
  45. certificate_verify = 15
  46. finished = 20
  47. key_update = 24
  48. message_hash = 254
  49. def pack_data(self, data: bytes) -> bytes:
  50. return self.pack() + pack_int(3, data)
  51. def tls_inner_plaintext(self, content: bytes) -> bytes:
  52. return (
  53. self.pack_data(content)
  54. + ContentType.handshake.pack()
  55. + (b"\x00" * random.randint(0, 10))
  56. )
  57. class ExtensionType(UInt16Enum):
  58. server_name = 0
  59. max_fragment_length = 1
  60. status_request = 5
  61. supported_groups = 10
  62. signature_algorithms = 13
  63. use_srtp = 14
  64. heartbeat = 15
  65. application_layer_protocol_negotiation = 16
  66. signed_certificate_timestamp = 18
  67. client_certificate_type = 19
  68. server_certificate_type = 20
  69. padding = 21
  70. pre_shared_key = 41
  71. early_data = 42
  72. supported_versions = 43
  73. cookie = 44
  74. psk_key_exchange_modes = 45
  75. certificate_authorities = 47
  76. oid_filters = 48
  77. post_handshake_auth = 49
  78. signature_algorithms_cert = 50
  79. key_share = 51
  80. def pack_data(self, data: bytes) -> bytes:
  81. return self.pack() + pack_int(2, data)
  82. @classmethod
  83. def server_name_list(cls, host_names: list) -> bytes:
  84. return cls.server_name.pack_data(
  85. pack_list(
  86. 2, (NameType.host_name.pack_data(name.encode()) for name in host_names)
  87. )
  88. )
  89. @classmethod
  90. def supported_versions_list(cls) -> bytes:
  91. return cls.supported_versions.pack_data(pack_int(1, b"\x03\x04"))
  92. @classmethod
  93. def supported_groups_list(cls, named_group, *named_groups) -> bytes:
  94. return cls.supported_groups.pack_data(
  95. pack_list(2, (group.pack() for group in (named_group, *named_groups)))
  96. )
  97. @classmethod
  98. def signature_algorithms_list(cls, algo, *algos) -> bytes:
  99. return cls.signature_algorithms.pack_data(
  100. pack_list(2, (alg.pack() for alg in (algo, *algos)))
  101. )
  102. @classmethod
  103. def unpack_from(cls, mv: memoryview):
  104. extensions = {}
  105. while mv:
  106. type_value = int.from_bytes(mv[:2], "big")
  107. mv = mv[2:]
  108. if mv:
  109. extension_data_lenth = int.from_bytes(mv[:2], "big")
  110. pos = extension_data_lenth + 2
  111. extension_data = mv[2:pos]
  112. assert (
  113. extension_data.nbytes == extension_data_lenth
  114. ), "extension length does not match"
  115. mv = mv[pos:]
  116. else:
  117. extension_data = b""
  118. et = cls.from_value(type_value)
  119. extensions[et] = et.unpack(extension_data)
  120. return extensions
  121. def unpack(self, data: memoryview):
  122. if self == ExtensionType.supported_versions:
  123. return bytes(data)
  124. if self == ExtensionType.key_share:
  125. return NamedGroup.unpack_from(data)
  126. if self == ExtensionType.server_name:
  127. return data.decode()
  128. if self == ExtensionType.pre_shared_key:
  129. assert len(data) == 2, "invalid length"
  130. return int.from_bytes(data, "big")
  131. if self == ExtensionType.early_data:
  132. if data:
  133. assert len(data) == 4, "expect uint32 max_early_data_size"
  134. return int.from_bytes(data, "big")
  135. return
  136. raise Exception("not support yet")
  137. class ContentType(UInt8Enum):
  138. invalid = 0
  139. change_cipher_spec = 20
  140. alert = 21
  141. handshake = 22
  142. application_data = 23
  143. def tls_plaintext(self, data: bytes) -> bytes:
  144. assert len(data) > 0, "need data"
  145. data = memoryview(data)
  146. fragments = []
  147. while True:
  148. if len(data) > 16384:
  149. fragments.append(data[:16384])
  150. data = data[16384:]
  151. else:
  152. fragments.append(data)
  153. break
  154. return b"".join(
  155. (
  156. self.pack()
  157. + (
  158. b"\x03\x01"
  159. if i == 0 and self is ContentType.handshake
  160. else b"\x03\x03"
  161. )
  162. + pack_int(2, fragment)
  163. for i, fragment in enumerate(fragments)
  164. )
  165. )
  166. def tls_inner_plaintext(self, content: bytes) -> bytes:
  167. return content + self.pack() + (b"\x00" * random.randint(0, 10))
  168. class AlertLevel(UInt8Enum):
  169. warning = 1
  170. fatal = 2
  171. class AlertDescription(UInt8Enum):
  172. close_notify = 0
  173. unexpected_message = 10
  174. bad_record_mac = 20
  175. record_overflow = 22
  176. handshake_failure = 40
  177. bad_certificate = 42
  178. unsupported_certificate = 43
  179. certificate_revoked = 44
  180. certificate_expired = 45
  181. certificate_unknown = 46
  182. illegal_parameter = 47
  183. unknown_ca = 48
  184. access_denied = 49
  185. decode_error = 50
  186. decrypt_error = 51
  187. protocol_version = 70
  188. insufficient_security = 71
  189. internal_error = 80
  190. inappropriate_fallback = 86
  191. user_canceled = 90
  192. missing_extension = 109
  193. unsupported_extension = 110
  194. unrecognized_name = 112
  195. bad_certificate_status_response = 113
  196. unknown_psk_identity = 115
  197. certificate_required = 116
  198. no_application_protocol = 120
  199. class SignatureScheme(UInt16Enum):
  200. # RSASSA-PKCS1-v1_5 algorithms
  201. rsa_pkcs1_sha256 = 0x0401
  202. rsa_pkcs1_sha384 = 0x0501
  203. rsa_pkcs1_sha512 = 0x0601
  204. # ECDSA algorithms
  205. ecdsa_secp256r1_sha256 = 0x0403
  206. ecdsa_secp384r1_sha384 = 0x0503
  207. ecdsa_secp521r1_sha512 = 0x0603
  208. # RSASSA-PSS algorithms with public key OID rsaEncryption
  209. rsa_pss_rsae_sha256 = 0x0804
  210. rsa_pss_rsae_sha384 = 0x0805
  211. rsa_pss_rsae_sha512 = 0x0806
  212. # EdDSA algorithms
  213. ed25519 = 0x0807
  214. ed448 = 0x0808
  215. # RSASSA-PSS algorithms with public key OID RSASSA-PSS
  216. rsa_pss_pss_sha256 = 0x0809
  217. rsa_pss_pss_sha384 = 0x080a
  218. rsa_pss_pss_sha512 = 0x080b
  219. # Legacy algorithms
  220. rsa_pkcs1_sha1 = 0x0201
  221. ecdsa_sha1 = 0x0203
  222. # Reserved Code Points
  223. # private_use(0xFE00..0xFFFF)
  224. dh_parameters = {
  225. # "ffdhe2048": dh.generate_parameters(generator=2, key_size=2048, backend=backend),
  226. # "ffdhe3072": dh.generate_parameters(generator=2, key_size=3072, backend=backend),
  227. # "ffdhe4096": dh.generate_parameters(generator=2, key_size=4096, backend=backend),
  228. # "ffdhe8192": dh.generate_parameters(generator=2, key_size=8192, backend=backend),
  229. }
  230. class NamedGroup(UInt16Enum):
  231. # Elliptic Curve Groups (ECDHE)
  232. secp256r1 = 0x0017
  233. secp384r1 = 0x0018
  234. secp521r1 = 0x0019
  235. x25519 = 0x001D
  236. x448 = 0x001E
  237. # Finite Field Groups (DHE)
  238. ffdhe2048 = 0x0100
  239. ffdhe3072 = 0x0101
  240. ffdhe4096 = 0x0102
  241. ffdhe6144 = 0x0103
  242. ffdhe8192 = 0x0104
  243. # Reserved Code Points
  244. # ffdhe_private_use(0x01FC..0x01FF)
  245. # ecdhe_private_use(0xFE00..0xFEFF)
  246. # def dh_key_share_entry(self):
  247. # private_key = dh_parameters[self.name].generate_private_key()
  248. # peer_public_key = private_key.public_key()
  249. # opaque = peer_public_key.public_bytes(
  250. # Encoding.DER, PublicFormat.SubjectPublicKeyInfo
  251. # )
  252. # return private_key, self.pack() + pack_int(2, opaque)
  253. @classmethod
  254. def new_x25519(cls):
  255. private_key = PrivateKey.generate()
  256. key_exchange = bytes(private_key.public_key)
  257. return private_key, cls.x25519.pack() + pack_int(2, key_exchange)
  258. @classmethod
  259. def unpack_from(cls, data: memoryview):
  260. value = int.from_bytes(data[:2], "big")
  261. group_type = cls.from_value(value)
  262. length = int.from_bytes(data[2:4], "big")
  263. assert length == len(data[4:]), "group length does not match"
  264. key_exchange = bytes(data[4:])
  265. return KeyShareEntry(group_type, key_exchange)
  266. @dataclass
  267. class KeyShareEntry:
  268. group: NamedGroup
  269. key_exchange: bytes
  270. __slots__ = ("group", "key_exchange")
  271. def pack(self):
  272. return self.group.pack() + pack_int(2, self.key_exchange)
  273. class CertificateType(UInt8Enum):
  274. X509 = 0
  275. RawPublicKey = 2
  276. @dataclass
  277. class CertificateEntry:
  278. cert_type: CertificateType
  279. cert_data: bytes
  280. extensions: dict
  281. __slots__ = ("cert_type", "cert_data", "extensions")
  282. @classmethod
  283. def unpack_from(cls, data: memoryview):
  284. certificate_request_context_len = data[0]
  285. certificate_request_context = data[1 : 1 + certificate_request_context_len]
  286. certificate_request_context
  287. data = data[1 + certificate_request_context_len :]
  288. certificate_list_len = int.from_bytes(data[:3], "big")
  289. certificate_list = data[3:]
  290. assert certificate_list_len == len(
  291. certificate_list
  292. ), "Certificate length does not match"
  293. certs = []
  294. while certificate_list:
  295. cert_data_len = int.from_bytes(certificate_list[:3], "big")
  296. cert_data = certificate_list[3 : 3 + cert_data_len]
  297. cert_type = 0
  298. if cert_type == CertificateType.X509:
  299. # x = x509.load_der_x509_certificate(data=cert_data, backend=backend)
  300. # x = load_certificate(FILETYPE_ASN1, cert_data)
  301. key = RSA.import_key(cert_data)
  302. key
  303. certificate_list = certificate_list[3 + cert_data_len :]
  304. extensions_len = int.from_bytes(certificate_list[:2], "big")
  305. assert extensions_len <= len(
  306. certificate_list[2:]
  307. ), "extensions length does not match"
  308. extensions = ExtensionType.unpack_from(
  309. certificate_list[2 : 2 + extensions_len]
  310. )
  311. certificate_list = certificate_list[2 + extensions_len :]
  312. certs.append(cls(cert_type, cert_data, extensions))
  313. return certs
  314. class KeyUpdateRequest(UInt8Enum):
  315. update_not_requested = 0
  316. update_requested = 1
  317. class PskKeyExchangeMode(UInt8Enum):
  318. psk_ke = 0
  319. psk_dhe_ke = 1
  320. def extension(self):
  321. return ExtensionType.psk_key_exchange_modes.pack_data(pack_int(1, self.pack()))
  322. @classmethod
  323. def both_extensions(cls):
  324. return ExtensionType.psk_key_exchange_modes.pack_data(pack_int(1, b"\x00\x01"))
  325. class CipherSuite(UInt16Enum):
  326. TLS_AES_128_GCM_SHA256 = 0x1301
  327. TLS_AES_256_GCM_SHA384 = 0x1302
  328. TLS_CHACHA20_POLY1305_SHA256 = 0x1303
  329. TLS_AES_128_CCM_SHA256 = 0x1304
  330. TLS_AES_128_CCM_8_SHA256 = 0x1305
  331. @classmethod
  332. def all(cls) -> set:
  333. if not hasattr(cls, "_all"):
  334. cls._all = {suite.pack() for suite in cls}
  335. return cls._all
  336. @classmethod
  337. def select(cls, data):
  338. data = memoryview(data)
  339. for i in (0, data.nbytes, 2):
  340. if data[i : i + 2] in cls.all():
  341. return data[i : i + 2].tobytes()
  342. @classmethod
  343. def get_cipher(cls, data):
  344. value = int.from_bytes(data, "big")
  345. if value == cls.TLS_AES_128_GCM_SHA256:
  346. return ciphers.TLS_AES_128_GCM_SHA256
  347. elif value == cls.TLS_AES_256_GCM_SHA384:
  348. return ciphers.TLS_AES_256_GCM_SHA384
  349. elif value == cls.TLS_AES_128_CCM_SHA256:
  350. return ciphers.TLS_AES_128_CCM_SHA256
  351. elif value == cls.TLS_AES_128_CCM_8_SHA256:
  352. return ciphers.TLS_AES_128_CCM_8_SHA256
  353. elif value == cls.TLS_CHACHA20_POLY1305_SHA256:
  354. return ciphers.TLS_CHACHA20_POLY1305_SHA256
  355. else:
  356. raise Exception("bad cipher suite")
  357. @classmethod
  358. def pack_all(cls):
  359. return pack_all(
  360. 2,
  361. [
  362. cls.TLS_CHACHA20_POLY1305_SHA256,
  363. cls.TLS_AES_128_GCM_SHA256,
  364. cls.TLS_AES_256_GCM_SHA384,
  365. cls.TLS_AES_128_CCM_SHA256,
  366. cls.TLS_AES_128_CCM_8_SHA256,
  367. ],
  368. )
  369. class NameType(UInt8Enum):
  370. host_name = 0
  371. def pack_data(self, data: bytes) -> bytes:
  372. return self.pack() + pack_int(2, data)
  373. class Const:
  374. all_signature_algorithms = ExtensionType.signature_algorithms.pack_data(
  375. pack_all(2, SignatureScheme)
  376. )
  377. all_supported_groups = ExtensionType.supported_groups.pack_data(
  378. pack_all(2, [NamedGroup.x25519])
  379. )
  380. psk_ke_extension = PskKeyExchangeMode.psk_ke.extension()
  381. psk_dhe_ke_extension = PskKeyExchangeMode.psk_dhe_ke.extension()
  382. psk_both_extensions = PskKeyExchangeMode.both_extensions()
  383. def server_hello_pack(legacy_session_id_echo, cipher_suite, extensions) -> bytes:
  384. legacy_version = b"\x03\x03"
  385. msg = b"".join(
  386. (
  387. legacy_version,
  388. os.urandom(32),
  389. pack_int(1, legacy_session_id_echo),
  390. cipher_suite.pack(),
  391. b"\x00",
  392. )
  393. )
  394. return ContentType.handshake.tls_plaintext(
  395. HandshakeType.server_hello.pack_data(msg)
  396. )
  397. def client_hello_key_share_extension(*key_share_entries):
  398. return ExtensionType.key_share.pack_data(pack_list(2, key_share_entries))
  399. @dataclass
  400. class PskIdentity:
  401. identity: bytes
  402. obfuscated_ticket_age: int
  403. binder_len: int
  404. def client_pre_shared_key_extension(
  405. psk_identities: typing.Iterable
  406. ) -> typing.Tuple[bytes, int]:
  407. binders = pack_psk_binder_entries((i.binder_len * b"\x00" for i in psk_identities))
  408. return (
  409. ExtensionType.pre_shared_key.pack_data(
  410. pack_list(
  411. 2,
  412. (
  413. pack_int(2, i.identity) + i.obfuscated_ticket_age.to_bytes(4, "big")
  414. for i in psk_identities
  415. ),
  416. )
  417. + binders
  418. ),
  419. len(binders),
  420. )
  421. def pack_psk_binder_entries(binder_list: typing.Iterable[bytes]) -> bytes:
  422. return pack_list(2, (pack_int(1, binder) for binder in binder_list))
  423. def unpack_certificate_verify(mv: memoryview):
  424. algorithm = int.from_bytes(mv[:2], "big")
  425. scheme = SignatureScheme.from_value(algorithm)
  426. signature_len = int.from_bytes(mv[2:4], "big")
  427. signature = mv[4 : 4 + signature_len]
  428. return SimpleNamespace(algorithm=scheme, signature=signature)
  429. def unpack_new_session_ticket(mv: memoryview):
  430. lifetime, age_add, nonce_len = struct.unpack_from("!IIB", mv)
  431. mv = mv[9:]
  432. nonce = mv[:nonce_len]
  433. mv = mv[nonce_len:]
  434. ticket_len = int.from_bytes(mv[:2], "big")
  435. mv = mv[2:]
  436. ticket = bytes(mv[:ticket_len])
  437. mv = mv[ticket_len:]
  438. ext_len = int.from_bytes(mv[:2], "big")
  439. mv = mv[2:]
  440. assert ext_len == len(mv), "extension length does not match"
  441. extensions = ExtensionType.unpack_from(mv)
  442. return NewSessionTicket(
  443. lifetime=lifetime,
  444. age_add=age_add,
  445. nonce=nonce,
  446. ticket=ticket,
  447. max_early_data_size=extensions.get(ExtensionType.early_data),
  448. )
  449. @dataclass
  450. class NewSessionTicket:
  451. lifetime: int
  452. age_add: int
  453. nonce: bytes
  454. ticket: bytes
  455. max_early_data_size: int
  456. def __post_init__(self):
  457. self.outdated_time = time.time() + min(self.lifetime, MAX_LIFETIME)
  458. self.obfuscated_ticket_age = ((self.lifetime * 1000) + self.age_add) % AGE_MOD
  459. def is_outdated(self):
  460. return time.time() >= self.outdated_time
  461. def to_psk_identity(self, binder_len: int):
  462. return PskIdentity(self.ticket, self.obfuscated_ticket_age, binder_len)
  463. class TLSClientSession:
  464. def __init__(
  465. self,
  466. private_key: bytes ,
  467. key_share_entry: bytes ,
  468. client_hello_data: memoryview
  469. ):
  470. self.private_key, key_share_entry = private_key, key_share_entry
  471. self.handshake_context = client_hello_data
  472. self.server_finished = False
  473. def unpack_server_hello(self, mv: memoryview):
  474. assert mv[:2] == b"\x03\x03", "version must be 0x0303"
  475. random = bytes(mv[2:34])
  476. legacy_session_id_echo_length = mv[34]
  477. legacy_session_id_echo = bytes(mv[35 : 35 + legacy_session_id_echo_length])
  478. mv = mv[35 + legacy_session_id_echo_length :]
  479. cipher_suite = CipherSuite.get_cipher(mv[:2])
  480. assert mv[2] == 0, "legacy_compression_method should be 0"
  481. extension_length = int.from_bytes(mv[3:5], "big")
  482. extensions_mv = mv[5:]
  483. assert (
  484. extensions_mv.nbytes == extension_length
  485. ), "extensions length does not match"
  486. extensions = ExtensionType.unpack_from(extensions_mv)
  487. return SimpleNamespace(
  488. handshake_type=HandshakeType.server_hello,
  489. random=random,
  490. legacy_session_id_echo=legacy_session_id_echo,
  491. cipher_suite=cipher_suite,
  492. extensions=extensions,
  493. )
  494. def unpack_handshake(self, mv: memoryview):
  495. handshake_type = mv[0]
  496. length = int.from_bytes(mv[1:4], "big")
  497. assert len(mv[4:]) == length, f"handshake length does not match"
  498. handshake_data = mv[4:]
  499. if handshake_type == HandshakeType.server_hello:
  500. self.handshake_context.extend(mv)
  501. return self.unpack_server_hello(handshake_data)
  502. elif handshake_type == HandshakeType.encrypted_extensions:
  503. self.handshake_context.extend(mv)
  504. ext_len = int.from_bytes(handshake_data[:2], "big")
  505. handshake_data = handshake_data[2:]
  506. assert (
  507. len(handshake_data) == ext_len
  508. ), "encrypted extensions length does not match"
  509. self.encrypted_extensions = ExtensionType.unpack_from(handshake_data)
  510. elif handshake_type == HandshakeType.certificate_request:
  511. self.handshake_context.extend(mv)
  512. elif handshake_type == HandshakeType.certificate:
  513. self.handshake_context.extend(mv)
  514. self.certificate_entry = CertificateEntry.unpack_from(handshake_data)
  515. elif handshake_type == HandshakeType.certificate_verify:
  516. self.handshake_context.extend(mv)
  517. self.certificate_verify = unpack_certificate_verify(handshake_data)
  518. print(self.certificate_verify)
  519. elif handshake_type == HandshakeType.finished:
  520. assert handshake_data == self.peer_cipher.verify_data(
  521. self.handshake_context
  522. ), "server handshake finished does not match"
  523. self.handshake_context.extend(mv)
  524. self.server_finished = True
  525. elif handshake_type == HandshakeType.new_session_ticket:
  526. self.session_tickets.append(unpack_new_session_ticket(handshake_data))
  527. else:
  528. raise Exception(f"unknown handshake type {handshake_type}")
  529. def tls_response(self, mv: memoryview):
  530. head = mv[:5]
  531. assert head[1:3] == b"\x03\x03", f"bad legacy_record_version {head[1:3]}"
  532. length = int.from_bytes(head[3:], "big")
  533. if (head[0] == ContentType.application_data and length > (16384 + 256)) or (
  534. head[0] != ContentType.application_data and length > 16384
  535. ):
  536. raise Alert(AlertLevel.fatal, AlertDescription.record_overflow)
  537. content = mv[5:]
  538. if head[0] == ContentType.alert:
  539. level = AlertLevel.from_value(content[0])
  540. description = AlertDescription.from_value(content[1])
  541. raise Alert(level, description)
  542. elif head[0] == ContentType.handshake:
  543. self.peer_handshake = self.unpack_handshake(content)
  544. assert (
  545. self.peer_handshake.handshake_type == HandshakeType.server_hello
  546. ), "expect server hello"
  547. peer_pk = self.peer_handshake.extensions[
  548. ExtensionType.key_share
  549. ].key_exchange
  550. shared_key = crypto_scalarmult(bytes(self.private_key), peer_pk)
  551. TLSCipher = self.peer_handshake.cipher_suite
  552. self.TLSCipher = TLSCipher
  553. key_scheduler = TLSCipher.tls_hash.scheduler(shared_key, None)
  554. self.key_scheduler = key_scheduler
  555. secret = key_scheduler.server_handshake_traffic_secret(
  556. self.handshake_context
  557. )
  558. # server handshake cipher
  559. self.peer_cipher = TLSCipher(secret)
  560. client_handshake_traffic_secret = key_scheduler.client_handshake_traffic_secret(
  561. self.handshake_context
  562. )
  563. print("\n\tpeer_pk\t",bytes(peer_pk).hex())
  564. print("\n\tprivate_key\t",bytes(self.private_key).hex() )
  565. print("\n\tshared_key\t",bytes(shared_key).hex() )
  566. print("\n\tself.key_scheduler.\t",self.key_scheduler)
  567. print("\n\tself.peer_cipher\t",self.peer_cipher)
  568. print("\n\tself.TLSCipher\t",self.TLSCipher)
  569. print("\n\tsecret\t",bytes(secret).hex())
  570. print("\n\tkey\t",self.peer_cipher.key )
  571. print("\n\tIV\t",self.peer_cipher.iv )
  572. print("\n",)
  573. elif head[0] == ContentType.application_data:
  574. plaintext = self.peer_cipher.decrypt(content, head).rstrip(b"\x00")
  575. content_type = ContentType.from_value(plaintext[-1])
  576. if content_type == ContentType.handshake:
  577. self.unpack_handshake(plaintext[:-1])
  578. if self.server_finished:
  579. # client handshake cipher
  580. cipher = TLSCipher(client_handshake_traffic_secret)
  581. client_finished = cipher.verify_data(self.handshake_context)
  582. client_finished_data = HandshakeType.finished.pack_data(
  583. client_finished
  584. )
  585. inner_plaintext = ContentType.handshake.tls_inner_plaintext(
  586. client_finished_data
  587. )
  588. record = cipher.tls_ciphertext(inner_plaintext)
  589. change_cipher_spec = ContentType.change_cipher_spec.tls_plaintext(
  590. b"\x01"
  591. )
  592. # parser.write(change_cipher_spec + record)
  593. # server application cipher
  594. server_secret = key_scheduler.server_application_traffic_secret_0(
  595. self.handshake_context
  596. )
  597. self.peer_cipher = TLSCipher(server_secret)
  598. self.server_finished = False
  599. # client application cipher
  600. client_secret = key_scheduler.client_application_traffic_secret_0(
  601. self.handshake_context
  602. )
  603. self.cipher = TLSCipher(client_secret)
  604. self.handshake_context.extend(client_finished_data)
  605. elif content_type == ContentType.application_data:
  606. self.data_callback(plaintext[:-1])
  607. elif content_type == ContentType.alert:
  608. level = AlertLevel.from_value(plaintext[0])
  609. description = AlertDescription.from_value(plaintext[1])
  610. raise Alert(level, description)
  611. elif content_type == ContentType.invalid:
  612. raise Exception("invalid content type")
  613. else:
  614. raise Exception(f"unexpected content type {content_type}")
  615. elif head[0] == ContentType.change_cipher_spec:
  616. assert content == b"\x01", "change_cipher should be 0x01"
  617. else:
  618. raise Exception(f"Unknown content type: {head[0]}")
  619. def pack_client_hello(self):
  620. data = ContentType.handshake.tls_plaintext(self.client_hello_data)
  621. return data
  622. def pack_application_data(self, payload: bytes) -> bytes:
  623. inner_plaintext = ContentType.application_data.tls_inner_plaintext(payload)
  624. return self.cipher.tls_ciphertext(inner_plaintext)
  625. def pack_alert(self, description: AlertDescription, level: AlertLevel) -> bytes:
  626. payload = level.pack() + description.pack()
  627. if self.cipher:
  628. inner_plaintext = ContentType.alert.tls_inner_plaintext(payload)
  629. return self.cipher.tls_ciphertext(inner_plaintext)
  630. else:
  631. return ContentType.alert.tls_plaintext(payload)
  632. def pack_warning(self, description: AlertDescription) -> bytes:
  633. return self.pack_alert(description, AlertLevel.warning)
  634. def pack_fatal(self, description: AlertDescription) -> bytes:
  635. return self.pack_alert(description, AlertLevel.fatal)
  636. def pack_close(self) -> bytes:
  637. return self.pack_warning(AlertDescription.close_notify)
  638. def pack_canceled(self) -> bytes:
  639. return self.pack_warning(AlertDescription.user_canceled)