__init__.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820
  1. from .ghash import GhashCry, GhashCon
  2. from Crypto.Util.number import long_to_bytes, bytes_to_long
  3. from Crypto.Util import Counter
  4. from Crypto.Cipher import AES
  5. import os
  6. import time
  7. import struct
  8. import random
  9. import iofree
  10. import typing
  11. from dataclasses import dataclass
  12. from enum import IntEnum
  13. from types import SimpleNamespace
  14. from nacl.public import PrivateKey
  15. from nacl.bindings import crypto_scalarmult
  16. import hashlib
  17. from Crypto.PublicKey import RSA
  18. # from Cryptography import x509
  19. # from OpenSSL.crypto import load_certificate, FILETYPE_ASN1
  20. from .models import *
  21. from . import ciphers
  22. from .key_schedule import PSKWrapper
  23. from .utils import pack_int, pack_list, pack_all
  24. from Utils import *
  25. MAX_LIFETIME = 24 * 3600 * 7
  26. AGE_MOD = 2 ** 32
  27. class Alert(Exception):
  28. def __init__(self, level, description):
  29. self.level = level
  30. self.description = description
  31. class MyIntEnum(IntEnum):
  32. @classmethod
  33. def from_value(cls, value: int):
  34. for e in cls:
  35. if e == value:
  36. return e
  37. raise Exception(f"Known {cls.__name__} type: {value}")
  38. class UInt8Enum(MyIntEnum):
  39. def pack(self) -> bytes:
  40. return self.to_bytes(1, "big")
  41. class UInt16Enum(MyIntEnum):
  42. def pack(self) -> bytes:
  43. return self.to_bytes(2, "big")
  44. class HandshakeType(UInt8Enum):
  45. client_hello = 1
  46. server_hello = 2
  47. new_session_ticket = 4
  48. end_of_early_data = 5
  49. encrypted_extensions = 8
  50. certificate = 11
  51. certificate_request = 13
  52. certificate_verify = 15
  53. finished = 20
  54. key_update = 24
  55. message_hash = 254
  56. def pack_data(self, data: bytes) -> bytes:
  57. return self.pack() + pack_int(3, data)
  58. def tls_inner_plaintext(self, content: bytes) -> bytes:
  59. return (
  60. self.pack_data(content)
  61. + ContentType.handshake.pack()
  62. + (b"\x00" * random.randint(0, 10))
  63. )
  64. class ExtensionType(UInt16Enum):
  65. server_name = 0
  66. max_fragment_length = 1
  67. status_request = 5
  68. supported_groups = 10
  69. signature_algorithms = 13
  70. use_srtp = 14
  71. heartbeat = 15
  72. application_layer_protocol_negotiation = 16
  73. signed_certificate_timestamp = 18
  74. client_certificate_type = 19
  75. server_certificate_type = 20
  76. padding = 21
  77. pre_shared_key = 41
  78. early_data = 42
  79. supported_versions = 43
  80. cookie = 44
  81. psk_key_exchange_modes = 45
  82. certificate_authorities = 47
  83. oid_filters = 48
  84. post_handshake_auth = 49
  85. signature_algorithms_cert = 50
  86. key_share = 51
  87. def pack_data(self, data: bytes) -> bytes:
  88. return self.pack() + pack_int(2, data)
  89. @classmethod
  90. def server_name_list(cls, host_names: list) -> bytes:
  91. return cls.server_name.pack_data(
  92. pack_list(
  93. 2, (NameType.host_name.pack_data(name.encode())
  94. for name in host_names)
  95. )
  96. )
  97. @classmethod
  98. def supported_versions_list(cls) -> bytes:
  99. return cls.supported_versions.pack_data(pack_int(1, b"\x03\x04"))
  100. @classmethod
  101. def supported_groups_list(cls, named_group, *named_groups) -> bytes:
  102. return cls.supported_groups.pack_data(
  103. pack_list(2, (group.pack()
  104. for group in (named_group, *named_groups)))
  105. )
  106. @classmethod
  107. def signature_algorithms_list(cls, algo, *algos) -> bytes:
  108. return cls.signature_algorithms.pack_data(
  109. pack_list(2, (alg.pack() for alg in (algo, *algos)))
  110. )
  111. @classmethod
  112. def unpack_from(cls, mv: memoryview):
  113. extensions = {}
  114. while mv:
  115. type_value = int.from_bytes(mv[:2], "big")
  116. mv = mv[2:]
  117. if mv:
  118. extension_data_lenth = int.from_bytes(mv[:2], "big")
  119. pos = extension_data_lenth + 2
  120. extension_data = mv[2:pos]
  121. assert (
  122. extension_data.nbytes == extension_data_lenth
  123. ), "extension length does not match"
  124. mv = mv[pos:]
  125. else:
  126. extension_data = b""
  127. et = cls.from_value(type_value)
  128. extensions[et] = et.unpack(extension_data)
  129. return extensions
  130. def unpack(self, data: memoryview):
  131. if self == ExtensionType.supported_versions:
  132. return bytes(data)
  133. if self == ExtensionType.key_share:
  134. return NamedGroup.unpack_from(data)
  135. if self == ExtensionType.server_name:
  136. # return data.decode()
  137. return data
  138. if self == ExtensionType.pre_shared_key:
  139. assert len(data) == 2, "invalid length"
  140. return int.from_bytes(data, "big")
  141. if self == ExtensionType.early_data:
  142. if data:
  143. assert len(data) == 4, "expect uint32 max_early_data_size"
  144. return int.from_bytes(data, "big")
  145. return
  146. raise Exception("not support yet")
  147. class ContentType(UInt8Enum):
  148. invalid = 0
  149. change_cipher_spec = 20
  150. alert = 21
  151. handshake = 22
  152. application_data = 23
  153. def tls_plaintext(self, data: bytes) -> bytes:
  154. assert len(data) > 0, "need data"
  155. data = memoryview(data)
  156. fragments = []
  157. while True:
  158. if len(data) > 16384:
  159. fragments.append(data[:16384])
  160. data = data[16384:]
  161. else:
  162. fragments.append(data)
  163. break
  164. return b"".join(
  165. (
  166. self.pack()
  167. + (
  168. b"\x03\x01"
  169. if i == 0 and self is ContentType.handshake
  170. else b"\x03\x03"
  171. )
  172. + pack_int(2, fragment)
  173. for i, fragment in enumerate(fragments)
  174. )
  175. )
  176. def tls_inner_plaintext(self, content: bytes) -> bytes:
  177. return content + self.pack() + (b"\x00" * random.randint(0, 10))
  178. class AlertLevel(UInt8Enum):
  179. warning = 1
  180. fatal = 2
  181. class AlertDescription(UInt8Enum):
  182. close_notify = 0
  183. unexpected_message = 10
  184. bad_record_mac = 20
  185. record_overflow = 22
  186. handshake_failure = 40
  187. bad_certificate = 42
  188. unsupported_certificate = 43
  189. certificate_revoked = 44
  190. certificate_expired = 45
  191. certificate_unknown = 46
  192. illegal_parameter = 47
  193. unknown_ca = 48
  194. access_denied = 49
  195. decode_error = 50
  196. decrypt_error = 51
  197. protocol_version = 70
  198. insufficient_security = 71
  199. internal_error = 80
  200. inappropriate_fallback = 86
  201. user_canceled = 90
  202. missing_extension = 109
  203. unsupported_extension = 110
  204. unrecognized_name = 112
  205. bad_certificate_status_response = 113
  206. unknown_psk_identity = 115
  207. certificate_required = 116
  208. no_application_protocol = 120
  209. class SignatureScheme(UInt16Enum):
  210. # # RSASSA-PKCS1-v1_5 algorithms
  211. # rsa_pkcs1_sha256 = 0x0401
  212. # rsa_pkcs1_sha384 = 0x0501
  213. # rsa_pkcs1_sha512 = 0x0601
  214. # # ECDSA algorithms
  215. # ecdsa_secp256r1_sha256 = 0x0403
  216. # ecdsa_secp384r1_sha384 = 0x0503
  217. # ecdsa_secp521r1_sha512 = 0x0603
  218. # # RSASSA-PSS algorithms with public key OID rsaEncryption
  219. # rsa_pss_rsae_sha256 = 0x0804
  220. # rsa_pss_rsae_sha384 = 0x0805
  221. # rsa_pss_rsae_sha512 = 0x0806
  222. # # EdDSA algorithms
  223. ed25519 = 0x0807
  224. # ed448 = 0x0808
  225. # # RSASSA-PSS algorithms with public key OID RSASSA-PSS
  226. # rsa_pss_pss_sha256 = 0x0809
  227. # rsa_pss_pss_sha384 = 0x080a
  228. # rsa_pss_pss_sha512 = 0x080b
  229. # # Legacy algorithms
  230. # rsa_pkcs1_sha1 = 0x0201
  231. # ecdsa_sha1 = 0x0203
  232. # # Reserved Code Points
  233. # # private_use(0xFE00..0xFFFF)
  234. dh_parameters = {
  235. # "ffdhe2048": dh.generate_parameters(generator=2, key_size=2048, backend=backend),
  236. # "ffdhe3072": dh.generate_parameters(generator=2, key_size=3072, backend=backend),
  237. # "ffdhe4096": dh.generate_parameters(generator=2, key_size=4096, backend=backend),
  238. # "ffdhe8192": dh.generate_parameters(generator=2, key_size=8192, backend=backend),
  239. }
  240. class NamedGroup(UInt16Enum):
  241. # Elliptic Curve Groups (ECDHE)
  242. secp256r1 = 0x0017
  243. secp384r1 = 0x0018
  244. secp521r1 = 0x0019
  245. x25519 = 0x001D
  246. x448 = 0x001E
  247. # Finite Field Groups (DHE)
  248. ffdhe2048 = 0x0100
  249. ffdhe3072 = 0x0101
  250. ffdhe4096 = 0x0102
  251. ffdhe6144 = 0x0103
  252. ffdhe8192 = 0x0104
  253. # Reserved Code Points
  254. # ffdhe_private_use(0x01FC..0x01FF)
  255. # ecdhe_private_use(0xFE00..0xFEFF)
  256. # def dh_key_share_entry(self):
  257. # private_key = dh_parameters[self.name].generate_private_key()
  258. # peer_public_key = private_key.public_key()
  259. # opaque = peer_public_key.public_bytes(
  260. # Encoding.DER, PublicFormat.SubjectPublicKeyInfo
  261. # )
  262. # return private_key, self.pack() + pack_int(2, opaque)
  263. @classmethod
  264. def new_x25519(cls):
  265. private_key = PrivateKey.generate()
  266. key_exchange = bytes(private_key.public_key)
  267. return private_key, cls.x25519.pack() + pack_int(2, key_exchange)
  268. @classmethod
  269. def unpack_from(cls, data: memoryview):
  270. value = int.from_bytes(data[:2], "big")
  271. group_type = cls.from_value(value)
  272. length = int.from_bytes(data[2:4], "big")
  273. assert length == len(data[4:]), "group length does not match"
  274. key_exchange = bytes(data[4:])
  275. return KeyShareEntry(group_type, key_exchange)
  276. @dataclass
  277. class KeyShareEntry:
  278. group: NamedGroup
  279. key_exchange: bytes
  280. __slots__ = ("group", "key_exchange")
  281. def pack(self):
  282. return self.group.pack() + pack_int(2, self.key_exchange)
  283. class CertificateType(UInt8Enum):
  284. X509 = 0
  285. RawPublicKey = 2
  286. @dataclass
  287. class CertificateEntry:
  288. cert_type: CertificateType
  289. cert_data: bytes
  290. extensions: dict
  291. __slots__ = ("cert_type", "cert_data", "extensions")
  292. @classmethod
  293. def unpack_from(cls, data: memoryview):
  294. certificate_request_context_len = data[0]
  295. certificate_request_context = data[1: 1 +
  296. certificate_request_context_len]
  297. certificate_request_context
  298. data = data[1 + certificate_request_context_len:]
  299. certificate_list_len = int.from_bytes(data[:3], "big")
  300. certificate_list = data[3:]
  301. assert certificate_list_len == len(
  302. certificate_list
  303. ), "Certificate length does not match"
  304. certs = []
  305. while certificate_list:
  306. cert_data_len = int.from_bytes(certificate_list[:3], "big")
  307. cert_data = certificate_list[3: 3 + cert_data_len]
  308. cert_type = 0
  309. if cert_type == CertificateType.X509:
  310. # x = x509.load_der_x509_certificate(data=cert_data, backend=backend)
  311. # x = load_certificate(FILETYPE_ASN1, cert_data)
  312. key = RSA.import_key(cert_data)
  313. key
  314. certificate_list = certificate_list[3 + cert_data_len:]
  315. extensions_len = int.from_bytes(certificate_list[:2], "big")
  316. assert extensions_len <= len(
  317. certificate_list[2:]
  318. ), "extensions length does not match"
  319. extensions = ExtensionType.unpack_from(
  320. certificate_list[2: 2 + extensions_len]
  321. )
  322. certificate_list = certificate_list[2 + extensions_len:]
  323. certs.append(cls(cert_type, cert_data, extensions))
  324. return certs
  325. class KeyUpdateRequest(UInt8Enum):
  326. update_not_requested = 0
  327. update_requested = 1
  328. class PskKeyExchangeMode(UInt8Enum):
  329. psk_ke = 0
  330. psk_dhe_ke = 1
  331. def extension(self):
  332. return ExtensionType.psk_key_exchange_modes.pack_data(pack_int(1, self.pack()))
  333. @classmethod
  334. def both_extensions(cls):
  335. return ExtensionType.psk_key_exchange_modes.pack_data(pack_int(1, b"\x00\x01"))
  336. class CipherSuite(UInt16Enum):
  337. TLS_AES_128_GCM_SHA256 = 0x1301
  338. @classmethod
  339. def all(cls) -> set:
  340. if not hasattr(cls, "_all"):
  341. cls._all = {suite.pack() for suite in cls}
  342. return cls._all
  343. @classmethod
  344. def select(cls, data):
  345. data = memoryview(data)
  346. for i in (0, data.nbytes, 2):
  347. if data[i: i + 2] in cls.all():
  348. return data[i: i + 2].tobytes()
  349. @classmethod
  350. def get_cipher(cls, data):
  351. value = int.from_bytes(data, "big")
  352. if value == cls.TLS_AES_128_GCM_SHA256:
  353. return ciphers.TLS_AES_128_GCM_SHA256
  354. elif value == cls.TLS_AES_256_GCM_SHA384:
  355. return ciphers.TLS_AES_256_GCM_SHA384
  356. elif value == cls.TLS_AES_128_CCM_SHA256:
  357. return ciphers.TLS_AES_128_CCM_SHA256
  358. elif value == cls.TLS_AES_128_CCM_8_SHA256:
  359. return ciphers.TLS_AES_128_CCM_8_SHA256
  360. elif value == cls.TLS_CHACHA20_POLY1305_SHA256:
  361. return ciphers.TLS_CHACHA20_POLY1305_SHA256
  362. else:
  363. raise Exception("bad cipher suite")
  364. @classmethod
  365. def pack_all(cls):
  366. return pack_all(
  367. 2,
  368. [
  369. cls.TLS_CHACHA20_POLY1305_SHA256,
  370. cls.TLS_AES_128_GCM_SHA256,
  371. cls.TLS_AES_256_GCM_SHA384,
  372. cls.TLS_AES_128_CCM_SHA256,
  373. cls.TLS_AES_128_CCM_8_SHA256,
  374. ],
  375. )
  376. class NameType(UInt8Enum):
  377. host_name = 0
  378. def pack_data(self, data: bytes) -> bytes:
  379. return self.pack() + pack_int(2, data)
  380. class Const:
  381. all_signature_algorithms = ExtensionType.signature_algorithms.pack_data(
  382. pack_all(2, SignatureScheme)
  383. )
  384. all_supported_groups = ExtensionType.supported_groups.pack_data(
  385. pack_all(2, [NamedGroup.x25519])
  386. )
  387. psk_ke_extension = PskKeyExchangeMode.psk_ke.extension()
  388. psk_dhe_ke_extension = PskKeyExchangeMode.psk_dhe_ke.extension()
  389. psk_both_extensions = PskKeyExchangeMode.both_extensions()
  390. def server_hello_pack(legacy_session_id_echo, cipher_suite, extensions) -> bytes:
  391. legacy_version = b"\x03\x03"
  392. msg = b"".join(
  393. (
  394. legacy_version,
  395. os.urandom(32),
  396. pack_int(1, legacy_session_id_echo),
  397. cipher_suite.pack(),
  398. b"\x00",
  399. )
  400. )
  401. return ContentType.handshake.tls_plaintext(
  402. HandshakeType.server_hello.pack_data(msg)
  403. )
  404. def client_hello_key_share_extension(*key_share_entries):
  405. return ExtensionType.key_share.pack_data(pack_list(2, key_share_entries))
  406. @dataclass
  407. class PskIdentity:
  408. identity: bytes
  409. obfuscated_ticket_age: int
  410. binder_len: int
  411. def client_pre_shared_key_extension(
  412. psk_identities: typing.Iterable
  413. ) -> typing.Tuple[bytes, int]:
  414. binders = pack_psk_binder_entries(
  415. (i.binder_len * b"\x00" for i in psk_identities))
  416. return (
  417. ExtensionType.pre_shared_key.pack_data(
  418. pack_list(
  419. 2,
  420. (
  421. pack_int(2, i.identity) +
  422. i.obfuscated_ticket_age.to_bytes(4, "big")
  423. for i in psk_identities
  424. ),
  425. )
  426. + binders
  427. ),
  428. len(binders),
  429. )
  430. def pack_psk_binder_entries(binder_list: typing.Iterable[bytes]) -> bytes:
  431. return pack_list(2, (pack_int(1, binder) for binder in binder_list))
  432. def unpack_certificate_verify(mv: memoryview):
  433. algorithm = int.from_bytes(mv[:2], "big")
  434. scheme = SignatureScheme.from_value(algorithm)
  435. signature_len = int.from_bytes(mv[2:4], "big")
  436. signature = mv[4: 4 + signature_len]
  437. return SimpleNamespace(algorithm=scheme, signature=signature)
  438. def unpack_new_session_ticket(mv: memoryview):
  439. lifetime, age_add, nonce_len = struct.unpack_from("!IIB", mv)
  440. mv = mv[9:]
  441. nonce = mv[:nonce_len]
  442. mv = mv[nonce_len:]
  443. ticket_len = int.from_bytes(mv[:2], "big")
  444. mv = mv[2:]
  445. ticket = bytes(mv[:ticket_len])
  446. mv = mv[ticket_len:]
  447. ext_len = int.from_bytes(mv[:2], "big")
  448. mv = mv[2:]
  449. assert ext_len == len(mv), "extension length does not match"
  450. extensions = ExtensionType.unpack_from(mv)
  451. return NewSessionTicket(
  452. lifetime=lifetime,
  453. age_add=age_add,
  454. nonce=nonce,
  455. ticket=ticket,
  456. max_early_data_size=extensions.get(ExtensionType.early_data),
  457. )
  458. @dataclass
  459. class NewSessionTicket:
  460. lifetime: int
  461. age_add: int
  462. nonce: bytes
  463. ticket: bytes
  464. max_early_data_size: int
  465. def __post_init__(self):
  466. self.outdated_time = time.time() + min(self.lifetime, MAX_LIFETIME)
  467. self.obfuscated_ticket_age = (
  468. (self.lifetime * 1000) + self.age_add) % AGE_MOD
  469. def is_outdated(self):
  470. return time.time() >= self.outdated_time
  471. def to_psk_identity(self, binder_len: int):
  472. return PskIdentity(self.ticket, self.obfuscated_ticket_age, binder_len)
  473. class TLSClientSession:
  474. def __init__(
  475. self,
  476. private_key: bytes,
  477. key_share_entry: bytes,
  478. client_hello_data: bytearray,
  479. data_callback,
  480. data_callback_done
  481. ):
  482. self.private_key, key_share_entry = private_key, key_share_entry
  483. self.handshake_context = client_hello_data
  484. self.server_finished = False
  485. self.data_callback = data_callback
  486. self.data_callback_done = data_callback_done
  487. self.client = None
  488. self.server = None
  489. self.session_tickets = []
  490. gBackPrint("tls", " started")
  491. def unpack_server_hello(self, mv: memoryview):
  492. assert mv[:2] == b"\x03\x03", "version must be 0x0303"
  493. random = bytes(mv[2:34])
  494. legacy_session_id_echo_length = mv[34]
  495. legacy_session_id_echo = bytes(
  496. mv[35: 35 + legacy_session_id_echo_length])
  497. mv = mv[35 + legacy_session_id_echo_length:]
  498. cipher_suite = CipherSuite.get_cipher(mv[:2])
  499. assert mv[2] == 0, "legacy_compression_method should be 0"
  500. extension_length = int.from_bytes(mv[3:5], "big")
  501. extensions_mv = mv[5:]
  502. assert (
  503. extensions_mv.nbytes == extension_length
  504. ), "extensions length does not match"
  505. extensions = ExtensionType.unpack_from(extensions_mv)
  506. return SimpleNamespace(
  507. handshake_type=HandshakeType.server_hello,
  508. random=random,
  509. legacy_session_id_echo=legacy_session_id_echo,
  510. cipher_suite=cipher_suite,
  511. extensions=extensions,
  512. )
  513. def unpack_handshake(self, mv: memoryview):
  514. while len(mv) > 0:
  515. handshake_type = mv[0]
  516. length = int.from_bytes(mv[1:4], "big")
  517. handshake_data = mv[4:4+length]
  518. simplePrint(" handshake_type", handshake_type)
  519. simplePrint(" length", length)
  520. simplePrint(" handshake_data", handshake_data.tobytes().hex())
  521. if handshake_type == HandshakeType.server_hello:
  522. assert len(
  523. mv[4:]) == length, f"handshake length does not match"
  524. self.handshake_context.extend(
  525. bytearray(mv[:4+length].tobytes().hex(), "utf-8"))
  526. return self.unpack_server_hello(handshake_data)
  527. elif handshake_type == HandshakeType.encrypted_extensions:
  528. self.handshake_context.extend(
  529. bytearray(mv[:4+length].tobytes().hex(), "utf-8"))
  530. ext_len = int.from_bytes(handshake_data[:2], "big")
  531. simplePrint(" handshake_data 1", ext_len)
  532. handshake_data = handshake_data[2:]
  533. simplePrint(" handshake_data 2", ext_len)
  534. assert (
  535. len(handshake_data) == ext_len
  536. ), "encrypted extensions length does not match"
  537. self.encrypted_extensions = ExtensionType.unpack_from(
  538. handshake_data)
  539. simplePrint(" handshake_data 3", ext_len)
  540. elif handshake_type == HandshakeType.certificate_request:
  541. self.handshake_context.extend(
  542. bytearray(mv[:4+length].tobytes().hex(), "utf-8"))
  543. elif handshake_type == HandshakeType.certificate:
  544. self.handshake_context.extend(
  545. bytearray(mv[:4+length].tobytes().hex(), "utf-8"))
  546. # self.certificate_entry = CertificateEntry.unpack_from(
  547. # handshake_data.tobytes())
  548. # TODO Certficate Check
  549. elif handshake_type == HandshakeType.certificate_verify:
  550. self.handshake_context.extend(
  551. bytearray(mv[:4+length].tobytes().hex(), "utf-8"))
  552. # self.certificate_verify = unpack_certificate_verify(
  553. # handshake_data)
  554. # TODO Certficate Check
  555. elif handshake_type == HandshakeType.finished:
  556. simplePrint(" finished", handshake_type)
  557. assert handshake_data == self.peer_cipher.verify_data(
  558. b'' + bytes.fromhex(self.handshake_context.decode())
  559. ), "server handshake finished does not match"
  560. self.handshake_context.extend(
  561. bytearray(mv[:4+length].tobytes().hex(), "utf-8"))
  562. self.server_finished = True
  563. simplePrint("server_finished ", self.server_finished)
  564. elif handshake_type == HandshakeType.new_session_ticket:
  565. # self.session_tickets.append(
  566. # unpack_new_session_ticket(handshake_data))
  567. simplePrint("session_tickets", handshake_data.tobytes().hex())
  568. else:
  569. raise Exception(f"unknown handshake type {handshake_type}")
  570. mv = mv[4+length:]
  571. def tls_response(self, mv: memoryview):
  572. head = memoryview(mv[:5])
  573. assert head[1:
  574. 3] == b"\x03\x03", f"bad legacy_record_version {head[1:3]}"
  575. length = int.from_bytes(head[3:], "big")
  576. if (head[0] == ContentType.application_data and length > (16384 + 256)) or (
  577. head[0] != ContentType.application_data and length > 16384
  578. ):
  579. raise Alert(AlertLevel.fatal, AlertDescription.record_overflow)
  580. content = memoryview(mv[5:])
  581. if head[0] == ContentType.alert:
  582. level = AlertLevel.from_value(content[0])
  583. description = AlertDescription.from_value(content[1])
  584. raise Alert(level, description)
  585. elif head[0] == ContentType.handshake:
  586. self.peer_handshake = self.unpack_handshake(content)
  587. assert (
  588. self.peer_handshake.handshake_type == HandshakeType.server_hello
  589. ), "expect server hello"
  590. simplePrint("handshake_context", self.handshake_context)
  591. peer_pk = self.peer_handshake.extensions[
  592. ExtensionType.key_share
  593. ].key_exchange
  594. simplePrint("peer_pk\t", bytes(peer_pk).hex())
  595. simplePrint("private_key", bytes(self.private_key).hex())
  596. shared_key = crypto_scalarmult(bytes(self.private_key), peer_pk)
  597. simplePrint("shared_key", bytes(shared_key).hex())
  598. TLSCipher = self.peer_handshake.cipher_suite
  599. self.TLSCipher = TLSCipher
  600. simplePrint("TLSCipher", TLSCipher)
  601. key_scheduler = TLSCipher.tls_hash.scheduler(shared_key)
  602. self.key_scheduler = key_scheduler
  603. simplePrint("key_scheduler", self.key_scheduler)
  604. secret = key_scheduler.server_handshake_traffic_secret(
  605. b'' + bytes.fromhex(self.handshake_context.decode())
  606. )
  607. simplePrint("s_hand_traf_s",
  608. secret.hex())
  609. self.peer_cipher = TLSCipher(secret)
  610. simplePrint("peer_cipher", self.peer_cipher)
  611. simplePrint("secret", bytes(secret).hex())
  612. client_handshake_traffic_secret = key_scheduler.client_handshake_traffic_secret(
  613. b'' + bytes.fromhex(self.handshake_context.decode())
  614. )
  615. self.client_handshake_traffic_secret = client_handshake_traffic_secret
  616. simplePrint("c_hand_traf_s",
  617. client_handshake_traffic_secret.hex())
  618. elif head[0] == ContentType.application_data:
  619. plaintext = self.peer_cipher.decrypt(content, head).rstrip(b"\x00")
  620. simplePrint("plaintext", (plaintext.hex(), type(plaintext)))
  621. content_type = ContentType.from_value(plaintext[-1])
  622. simplePrint(" content_type", content_type)
  623. if content_type == ContentType.handshake:
  624. self.unpack_handshake(memoryview(plaintext[:-1]))
  625. if self.server_finished:
  626. cipher = self.TLSCipher(
  627. self.client_handshake_traffic_secret)
  628. client_finished = cipher.verify_data(
  629. b'' + bytes.fromhex(self.handshake_context.decode()))
  630. simplePrint("client_finished", client_finished.hex)
  631. client_finished_data = HandshakeType.finished.pack_data(
  632. client_finished
  633. )
  634. simplePrint("client_finished_data",
  635. client_finished_data.hex())
  636. inner_plaintext = client_finished_data + b'\x16'
  637. simplePrint("inner_plaintext", inner_plaintext.hex())
  638. record = cipher.tls_ciphertext(inner_plaintext)
  639. simplePrint("record", record.hex())
  640. change_cipher_spec = ContentType.change_cipher_spec.tls_plaintext(
  641. b"\x01"
  642. )
  643. simplePrint("change_cipher_spec", change_cipher_spec.hex())
  644. self.change_cipher_spec = change_cipher_spec
  645. self.record = record
  646. simplePrint("record", record.hex())
  647. server_secret = self.key_scheduler.server_application_traffic_secret_0(
  648. b'' + bytes.fromhex(self.handshake_context.decode())
  649. )
  650. simplePrint("server_secret", server_secret.hex())
  651. self.peer_cipher = self.TLSCipher(server_secret)
  652. self.server_finished = False
  653. client_secret = self.key_scheduler.client_application_traffic_secret_0(
  654. b'' + bytes.fromhex(self.handshake_context.decode())
  655. )
  656. simplePrint("client_secret", client_secret)
  657. self.cipher = self.TLSCipher(client_secret)
  658. self.client = self.cipher
  659. self.server = self.peer_cipher
  660. self.handshake_context.extend(
  661. bytearray(client_finished_data.hex(), "utf-8"))
  662. spacePrint()
  663. gPrint("Client key", self.cipher.key.hex())
  664. gPrint("Client iv", self.cipher.ivhex.hex())
  665. gPrint("server key", self.peer_cipher.key.hex())
  666. gPrint("server iv", self.peer_cipher.ivhex.hex())
  667. spacePrint()
  668. self.data_callback_done('140303000101', self.record.hex())
  669. elif content_type == ContentType.application_data:
  670. simplePrint("DATA ISSSSSS", plaintext[:-1])
  671. elif content_type == ContentType.alert:
  672. level = AlertLevel.from_value(plaintext[0])
  673. description = AlertDescription.from_value(plaintext[1])
  674. raise Alert(level, description)
  675. elif content_type == ContentType.invalid:
  676. raise Exception("invalid content type")
  677. else:
  678. raise Exception(f"unexpected content type {content_type}")
  679. elif head[0] == ContentType.change_cipher_spec:
  680. assert content == b"\x01", "change_cipher should be 0x01"
  681. else:
  682. raise Exception(f"Unknown content type: {head[0]}")
  683. def pack_client_hello(self):
  684. data = ContentType.handshake.tls_plaintext(self.client_hello_data)
  685. return data
  686. def pack_application_data(self, payload: bytes) -> bytes:
  687. inner_plaintext = ContentType.application_data.tls_inner_plaintext(
  688. payload)
  689. return self.cipher.tls_ciphertext(inner_plaintext)
  690. def pack_alert(self, description: AlertDescription, level: AlertLevel) -> bytes:
  691. payload = level.pack() + description.pack()
  692. if self.cipher:
  693. inner_plaintext = ContentType.alert.tls_inner_plaintext(payload)
  694. return self.cipher.tls_ciphertext(inner_plaintext)
  695. else:
  696. return ContentType.alert.tls_plaintext(payload)
  697. def pack_warning(self, description: AlertDescription) -> bytes:
  698. return self.pack_alert(description, AlertLevel.warning)
  699. def pack_fatal(self, description: AlertDescription) -> bytes:
  700. return self.pack_alert(description, AlertLevel.fatal)
  701. def pack_close(self) -> bytes:
  702. return self.pack_warning(AlertDescription.close_notify)
  703. def pack_canceled(self) -> bytes:
  704. return self.pack_warning(AlertDescription.user_canceled)