models.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. import enum
  2. import time
  3. import random
  4. import iofree
  5. from . import ciphers
  6. from iofree import schema
  7. MAX_LIFETIME = 24 * 3600 * 7
  8. AGE_MOD = 2 ** 32
  9. class AlertLevel(enum.IntEnum):
  10. warning = 1
  11. fatal = 2
  12. class AlertDescription(enum.IntEnum):
  13. close_notify = 0
  14. unexpected_message = 10
  15. bad_record_mac = 20
  16. record_overflow = 22
  17. handshake_failure = 40
  18. bad_certificate = 42
  19. unsupported_certificate = 43
  20. certificate_revoked = 44
  21. certificate_expired = 45
  22. certificate_unknown = 46
  23. illegal_parameter = 47
  24. unknown_ca = 48
  25. access_denied = 49
  26. decode_error = 50
  27. decrypt_error = 51
  28. protocol_version = 70
  29. insufficient_security = 71
  30. internal_error = 80
  31. inappropriate_fallback = 86
  32. user_canceled = 90
  33. missing_extension = 109
  34. unsupported_extension = 110
  35. unrecognized_name = 112
  36. bad_certificate_status_response = 113
  37. unknown_psk_identity = 115
  38. certificate_required = 116
  39. no_application_protocol = 120
  40. class KeyUpdateRequest(enum.IntEnum):
  41. update_not_requested = 0
  42. update_requested = 1
  43. class ExtensionType(enum.IntEnum):
  44. server_name = 0
  45. max_fragment_length = 1
  46. status_request = 5
  47. supported_groups = 10
  48. signature_algorithms = 13
  49. use_srtp = 14
  50. heartbeat = 15
  51. application_layer_protocol_negotiation = 16
  52. signed_certificate_timestamp = 18
  53. client_certificate_type = 19
  54. server_certificate_type = 20
  55. padding = 21
  56. pre_shared_key = 41
  57. early_data = 42
  58. supported_versions = 43
  59. cookie = 44
  60. psk_key_exchange_modes = 45
  61. certificate_authorities = 47
  62. oid_filters = 48
  63. post_handshake_auth = 49
  64. signature_algorithms_cert = 50
  65. key_share = 51
  66. class HandshakeType(enum.IntEnum):
  67. client_hello = 1
  68. server_hello = 2
  69. new_session_ticket = 4
  70. end_of_early_data = 5
  71. encrypted_extensions = 8
  72. certificate = 11
  73. certificate_request = 13
  74. certificate_verify = 15
  75. finished = 20
  76. key_update = 24
  77. message_hash = 254
  78. class NameType(enum.IntEnum):
  79. host_name = 0
  80. class SignatureScheme(enum.IntEnum):
  81. # RSASSA-PKCS1-v1_5 algorithms
  82. # rsa_pkcs1_sha256 = 0x0401
  83. # rsa_pkcs1_sha384 = 0x0501
  84. # rsa_pkcs1_sha512 = 0x0601
  85. # # ECDSA algorithms
  86. # ecdsa_secp256r1_sha256 = 0x0403
  87. # ecdsa_secp384r1_sha384 = 0x0503
  88. # ecdsa_secp521r1_sha512 = 0x0603
  89. # # RSASSA-PSS algorithms with public key OID rsaEncryption
  90. # rsa_pss_rsae_sha256 = 0x0804
  91. # rsa_pss_rsae_sha384 = 0x0805
  92. # rsa_pss_rsae_sha512 = 0x0806
  93. # EdDSA algorithms
  94. ed25519 = 0x0807
  95. # ed448 = 0x0808
  96. # # RSASSA-PSS algorithms with public key OID RSASSA-PSS
  97. # rsa_pss_pss_sha256 = 0x0809
  98. # rsa_pss_pss_sha384 = 0x080a
  99. # rsa_pss_pss_sha512 = 0x080b
  100. # # Legacy algorithms
  101. # rsa_pkcs1_sha1 = 0x0201
  102. # ecdsa_sha1 = 0x0203
  103. # # Reserved Code Points
  104. # # private_use(0xFE00..0xFFFF)
  105. class NamedGroup(enum.IntEnum):
  106. # Elliptic Curve Groups (ECDHE)
  107. # secp256r1 = 0x0017
  108. # secp384r1 = 0x0018
  109. # secp521r1 = 0x0019
  110. x25519 = 0x001D
  111. # x448 = 0x001E
  112. # # Finite Field Groups (DHE)
  113. # ffdhe2048 = 0x0100
  114. # ffdhe3072 = 0x0101
  115. # ffdhe4096 = 0x0102
  116. # ffdhe6144 = 0x0103
  117. # ffdhe8192 = 0x0104
  118. # Reserved Code Points
  119. # ffdhe_private_use(0x01FC..0x01FF)
  120. # ecdhe_private_use(0xFE00..0xFEFF)
  121. class PskKeyExchangeMode(enum.IntEnum):
  122. psk_ke = 0
  123. psk_dhe_ke = 1
  124. class CipherSuite(enum.IntEnum):
  125. TLS_AES_128_GCM_SHA256 = 0x1301
  126. # TLS_AES_256_GCM_SHA384 = 0x1302
  127. # TLS_CHACHA20_POLY1305_SHA256 = 0x1303
  128. # TLS_AES_128_CCM_SHA256 = 0x1304
  129. # TLS_AES_128_CCM_8_SHA256 = 0x1305
  130. class ContentType(enum.IntEnum):
  131. invalid = 0
  132. change_cipher_spec = 20
  133. alert = 21
  134. handshake = 22
  135. application_data = 23
  136. heartbeat = 24
  137. def tls_plaintext(self, payload):
  138. return TLSPlaintext.pack(self, payload)
  139. class TLSCiphertext(schema.BinarySchema):
  140. opaque_type = schema.MustEqual(
  141. schema.SizedIntEnum(schema.uint8, ContentType), ContentType.application_data
  142. )
  143. legacy_record_version = schema.MustEqual(schema.Bytes(2), b"\x03\x03")
  144. encrypted_record = schema.LengthPrefixedBytes(schema.uint16be)
  145. class ServerName(schema.BinarySchema):
  146. name_type = schema.MustEqual(
  147. schema.SizedIntEnum(schema.uint8, NameType), NameType.host_name
  148. )
  149. name = schema.Switch(
  150. "name_type", {NameType.host_name: schema.LengthPrefixedString(schema.uint16be)}
  151. )
  152. class PskIdentity(schema.BinarySchema):
  153. identity = schema.LengthPrefixedBytes(schema.uint16be)
  154. obfuscated_ticket_age = schema.uint32be
  155. class OfferedPsks(schema.BinarySchema):
  156. identities = schema.LengthPrefixedObjectList(schema.uint16be, PskIdentity)
  157. binders = schema.LengthPrefixedObjectList(
  158. schema.uint16be, schema.LengthPrefixedBytes(schema.uint8)
  159. )
  160. class KeyShareEntry(schema.BinarySchema):
  161. group = schema.SizedIntEnum(schema.uint16be, NamedGroup)
  162. key_exchange = schema.LengthPrefixedBytes(schema.uint16be)
  163. extensions = {
  164. ExtensionType.server_name: schema.LengthPrefixedObject(
  165. schema.uint16be, schema.LengthPrefixedObjectList(schema.uint16be, ServerName)
  166. ),
  167. ExtensionType.signature_algorithms: schema.LengthPrefixedObject(
  168. schema.uint16be,
  169. schema.LengthPrefixedObjectList(
  170. schema.uint16be, schema.SizedIntEnum(schema.uint16be, SignatureScheme)
  171. ),
  172. ),
  173. ExtensionType.supported_groups: schema.LengthPrefixedObject(
  174. schema.uint16be,
  175. schema.LengthPrefixedObjectList(
  176. schema.uint16be, schema.SizedIntEnum(schema.uint16be, NamedGroup)
  177. ),
  178. ),
  179. ExtensionType.psk_key_exchange_modes: schema.LengthPrefixedObject(
  180. schema.uint16be,
  181. schema.LengthPrefixedObjectList(
  182. schema.uint8, schema.SizedIntEnum(schema.uint8, PskKeyExchangeMode)
  183. ),
  184. ),
  185. ExtensionType.early_data: schema.LengthPrefixedBytes(schema.uint16be),
  186. ExtensionType.pre_shared_key: schema.LengthPrefixedObject(
  187. schema.uint16be, OfferedPsks
  188. ),
  189. }
  190. client_extensions = extensions.copy()
  191. client_extensions[ExtensionType.supported_versions] = schema.LengthPrefixedObject(
  192. schema.uint16be, schema.LengthPrefixedObjectList(schema.uint8, schema.Bytes(2))
  193. )
  194. client_extensions[ExtensionType.key_share] = schema.LengthPrefixedObject(
  195. schema.uint16be, schema.LengthPrefixedObjectList(schema.uint16be, KeyShareEntry)
  196. )
  197. server_extensions = extensions.copy()
  198. server_extensions[ExtensionType.supported_versions] = schema.LengthPrefixedBytes(
  199. schema.uint16be
  200. )
  201. server_extensions[ExtensionType.key_share] = schema.LengthPrefixedObject(
  202. schema.uint16be, KeyShareEntry
  203. )
  204. class Extension(schema.BinarySchema):
  205. @classmethod
  206. def server_names(cls, names):
  207. return cls(ExtensionType.server_name, [ServerName(..., name) for name in names])
  208. @classmethod
  209. def supported_versions(cls, versions):
  210. return cls(ExtensionType.supported_versions, versions)
  211. @classmethod
  212. def selected_version(cls, version):
  213. return cls(ExtensionType.supported_versions, version)
  214. @classmethod
  215. def signature_algorithms(cls, schemes):
  216. return cls(ExtensionType.signature_algorithms, schemes)
  217. @classmethod
  218. def supported_groups(cls, groups):
  219. return cls(ExtensionType.supported_groups, groups)
  220. @classmethod
  221. def key_share(cls, key_share_entries):
  222. return cls(ExtensionType.key_share, key_share_entries)
  223. @classmethod
  224. def psk_key_exchange_modes(cls, modes):
  225. return cls(ExtensionType.psk_key_exchange_modes, modes)
  226. @classmethod
  227. def early_data(cls, data):
  228. return cls(ExtensionType.early_data, data)
  229. @classmethod
  230. def pre_shared_key(cls, offered_psks: OfferedPsks):
  231. return cls(ExtensionType.pre_shared_key, offered_psks)
  232. class ServerExtension(Extension):
  233. ext_type = schema.SizedIntEnum(schema.uint16be, ExtensionType)
  234. ext_data = schema.Switch("ext_type", server_extensions)
  235. class ClientExtension(Extension):
  236. ext_type = schema.SizedIntEnum(schema.uint16be, ExtensionType)
  237. ext_data = schema.Switch("ext_type", client_extensions)
  238. class ClientHello(schema.BinarySchema):
  239. legacy_version = schema.MustEqual(schema.Bytes(2), b"\x03\x03")
  240. rand = schema.Bytes(32)
  241. legacy_session_id = schema.LengthPrefixedBytes(schema.uint8)
  242. cipher_suites = schema.LengthPrefixedObjectList(
  243. schema.uint16be, schema.SizedIntEnum(schema.uint16be, CipherSuite)
  244. )
  245. legacy_compression_methods = schema.MustEqual(schema.Bytes(2), b"\x01\x00")
  246. extensions = schema.LengthPrefixedObjectList(schema.uint16be, ClientExtension)
  247. class ServerHello(schema.BinarySchema):
  248. legacy_version = schema.MustEqual(schema.Bytes(2), b"\x03\x03")
  249. rand = schema.Bytes(32)
  250. legacy_session_id_echo = schema.LengthPrefixedBytes(schema.uint8)
  251. cipher_suite = schema.SizedIntEnum(schema.uint16be, CipherSuite)
  252. legacy_compression_method = schema.MustEqual(schema.uint8, 0)
  253. extensions = schema.LengthPrefixedObjectList(schema.uint16be, ServerExtension)
  254. def get_cipher(self):
  255. if self.cipher_suite == CipherSuite.TLS_AES_128_GCM_SHA256:
  256. return ciphers.TLS_AES_128_GCM_SHA256
  257. elif self.cipher_suite == CipherSuite.TLS_AES_256_GCM_SHA384:
  258. return ciphers.TLS_AES_256_GCM_SHA384
  259. elif self.cipher_suite == CipherSuite.TLS_AES_128_CCM_SHA256:
  260. return ciphers.TLS_AES_128_CCM_SHA256
  261. elif self.cipher_suite == CipherSuite.TLS_AES_128_CCM_8_SHA256:
  262. return ciphers.TLS_AES_128_CCM_8_SHA256
  263. elif self.cipher_suite == CipherSuite.TLS_CHACHA20_POLY1305_SHA256:
  264. return ciphers.TLS_CHACHA20_POLY1305_SHA256
  265. else:
  266. raise Exception("bad cipher suite")
  267. @property
  268. def extensions_dict(self):
  269. return {ext.ext_type: ext.ext_data for ext in self.extensions}
  270. class CertificateEntry(schema.BinarySchema):
  271. cert_data = schema.LengthPrefixedBytes(schema.uint24be)
  272. extensions = schema.LengthPrefixedObjectList(schema.uint16be, ServerExtension)
  273. # # x = x509.load_der_x509_certificate(data=cert_data, backend=backend)
  274. # # x = load_certificate(FILETYPE_ASN1, cert_data)
  275. # from tls.Crypto.PublicKey import RSA
  276. # key = RSA.import_key(cert_data)
  277. class Certificate(schema.BinarySchema):
  278. certificate_request_context = schema.LengthPrefixedBytes(schema.uint8)
  279. certificate_list = schema.LengthPrefixedObjectList(
  280. schema.uint24be, CertificateEntry
  281. )
  282. class CertificateVerify(schema.BinarySchema):
  283. algorithm = schema.SizedIntEnum(schema.uint16be, SignatureScheme)
  284. signature = schema.LengthPrefixedBytes(schema.uint16be)
  285. class NewSessionTicket(schema.BinarySchema):
  286. ticket_lifetime = schema.uint32be
  287. ticket_age_add = schema.uint32be
  288. ticket_nonce = schema.LengthPrefixedBytes(schema.uint8)
  289. ticket = schema.LengthPrefixedBytes(schema.uint16be)
  290. extensions = schema.LengthPrefixedObjectList(schema.uint16be, ServerExtension)
  291. def __post_init__(self):
  292. self.outdated_time = time.time() + min(self.ticket_lifetime, MAX_LIFETIME)
  293. self.obfuscated_ticket_age = (
  294. (self.ticket_lifetime * 1000) + self.ticket_age_add
  295. ) % AGE_MOD
  296. def is_outdated(self):
  297. return time.time() >= self.outdated_time
  298. def to_psk_identity(self):
  299. return PskIdentity(self.ticket, self.obfuscated_ticket_age)
  300. class Handshake(schema.BinarySchema):
  301. msg_type = schema.SizedIntEnum(schema.uint8, HandshakeType)
  302. msg = schema.LengthPrefixedObject(
  303. schema.uint24be,
  304. schema.Switch(
  305. "msg_type",
  306. {
  307. HandshakeType.client_hello: ClientHello,
  308. HandshakeType.server_hello: ServerHello,
  309. HandshakeType.encrypted_extensions: schema.LengthPrefixedObjectList(
  310. schema.uint16be, ServerExtension
  311. ),
  312. HandshakeType.certificate: Certificate,
  313. HandshakeType.certificate_verify: CertificateVerify,
  314. HandshakeType.finished: schema.Bytes(32),
  315. HandshakeType.new_session_ticket: NewSessionTicket,
  316. HandshakeType.end_of_early_data: schema.MustEqual(
  317. schema.Bytes(-1), b""
  318. ),
  319. HandshakeType.key_update: schema.SizedIntEnum(
  320. schema.uint8, KeyUpdateRequest
  321. ),
  322. },
  323. ),
  324. )
  325. class Alert(schema.BinarySchema):
  326. level = schema.SizedIntEnum(schema.uint8, AlertLevel)
  327. description = schema.SizedIntEnum(schema.uint8, AlertDescription)
  328. conten_type_cases = {
  329. ContentType.handshake: Handshake,
  330. ContentType.application_data: schema.Bytes(-1),
  331. ContentType.alert: Alert,
  332. ContentType.change_cipher_spec: schema.MustEqual(schema.Bytes(1), b"\x01"),
  333. }
  334. class TLSPlaintext(schema.BinarySchema):
  335. content_type = schema.SizedIntEnum(schema.uint8, ContentType)
  336. legacy_record_version = schema.Bytes(2)
  337. fragment = schema.LengthPrefixedBytes(schema.uint16be)
  338. # @classmethod
  339. # def get_handshake(cls, content_type: ContentType):
  340. # plaintext = yield from cls.get_value()
  341. # return Handshake.parse(plaintext.fragment)
  342. @classmethod
  343. def pack(cls, content_type: ContentType, data: bytes) -> bytes:
  344. assert len(data) > 0, "need data"
  345. data = memoryview(data)
  346. fragments = []
  347. while True:
  348. if len(data) > 16384:
  349. fragments.append(data[:16384])
  350. data = data[16384:]
  351. else:
  352. fragments.append(data)
  353. break
  354. is_handshake = content_type is ContentType.handshake
  355. return b"".join(
  356. cls(
  357. content_type,
  358. b"\x03\x01" if i == 0 and is_handshake else b"\x03\x03",
  359. bytes(frg),
  360. ).binary
  361. for i, frg in enumerate(fragments)
  362. )
  363. def is_overflow(self):
  364. return (
  365. self.content_type is ContentType.application_data
  366. and len(self.fragment) > (16384 + 256)
  367. ) or (
  368. self.content_type is not ContentType.application_data
  369. and len(self.fragment) > 16384
  370. )
  371. class TLSInnerPlaintext(schema.BinarySchema):
  372. content = schema.Bytes(-1)
  373. content_type = schema.SizedIntEnum(schema.uint8, ContentType)
  374. padding = schema.Bytes(-1)
  375. def tls_ciphertext(self, cipher):
  376. return cipher.tls_ciphertext(self.binary)
  377. @classmethod
  378. def pack(cls, content, content_type):
  379. padding = b"\x00" * random.randint(0, 10)
  380. return cls(content, content_type, padding)
  381. @classmethod
  382. def from_alert(cls, alert: Alert):
  383. padding = b"\x00" * random.randint(0, 10)
  384. return cls(alert.binary, ContentType.alert, padding)
  385. @classmethod
  386. def from_handshake(cls, handshake: Handshake):
  387. padding = b"\x00" * random.randint(0, 10)
  388. return cls(handshake.binary, ContentType.handshake, padding)
  389. @classmethod
  390. def from_application_data(cls, payload: bytes):
  391. padding = b"\x00" * random.randint(0, 10)
  392. return cls(payload, ContentType.application_data, padding)
  393. @classmethod
  394. def get_value(cls):
  395. yield from iofree.wait()
  396. bytes_ = yield from iofree.read()
  397. bytes_without_padding = bytes_.rstrip(b"\x00")
  398. padding_len = len(bytes_) - len(bytes_without_padding)
  399. content = bytes_without_padding[:-1]
  400. content_type = bytes_without_padding[-1]
  401. return cls(content, ContentType(content_type), b"\x00" * padding_len)