diff --git a/src/dtls.rs b/src/dtls.rs index 47208ff..8a7e456 100644 --- a/src/dtls.rs +++ b/src/dtls.rs @@ -1,5 +1,7 @@ //! Datagram Transport Layer Security Version 1.2 (RFC 6347) +use core::cmp::min; + use alloc::vec::Vec; use nom::bytes::streaming::take; use nom::combinator::{complete, cond, map, map_parser, opt, verify}; @@ -54,6 +56,8 @@ pub struct DTLSClientHello<'a> { /// A list of compression methods supported by client pub comp: Vec, pub ext: Option<&'a [u8]>, + /// Whether this ClientHello message is complete or a fragment + pub is_parsing_complete: bool, } impl<'a> ClientHello<'a> for DTLSClientHello<'a> { @@ -99,6 +103,24 @@ pub struct DTLSMessageHandshake<'a> { pub body: DTLSMessageHandshakeBody<'a>, } +impl DTLSMessageHandshake<'_> { + pub fn is_parsing_complete(&self) -> bool { + match &self.body { + DTLSMessageHandshakeBody::ClientHello(ch) => ch.is_parsing_complete, + DTLSMessageHandshakeBody::Certificate(certs) => certs.is_parsing_complete, + _ => true, + } + } + + pub fn update_parsing_status(&mut self, complete: bool) { + match &mut self.body { + DTLSMessageHandshakeBody::ClientHello(ch) => ch.is_parsing_complete = complete, + DTLSMessageHandshakeBody::Certificate(certs) => certs.is_parsing_complete = complete, + _ => {} + } + } +} + /// DTLS Generic handshake message #[derive(Debug, PartialEq)] pub enum DTLSMessageHandshakeBody<'a> { @@ -171,7 +193,10 @@ fn parse_dtls_fragment(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody<'_>> /// DTLS Client Hello // Section 4.2 of RFC6347 -fn parse_dtls_client_hello(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody<'_>> { +fn parse_dtls_client_hello_inner( + i: &[u8], + allow_partial: bool, +) -> IResult<&[u8], DTLSMessageHandshakeBody<'_>> { let (i, version) = TlsVersion::parse(i)?; let (i, random) = take(32usize)(i)?; let (i, sidlen) = verify(be_u8, |&n| n <= 32)(i)?; @@ -181,7 +206,13 @@ fn parse_dtls_client_hello(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody< let (i, ciphers) = parse_cipher_suites(i, ciphers_len as usize)?; let (i, comp_len) = be_u8(i)?; let (i, comp) = parse_compressions_algs(i, comp_len as usize)?; - let (i, ext) = opt(complete(length_data(be_u16)))(i)?; + let (i, ext_len) = be_u16(i)?; + let is_parsing_complete = ext_len as usize <= i.len(); + let ext_len = match allow_partial { + false => ext_len as usize, + true => min(ext_len as usize, i.len()), + }; + let (i, ext) = opt(complete(take(ext_len)))(i)?; let content = DTLSClientHello { version, random, @@ -190,6 +221,7 @@ fn parse_dtls_client_hello(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody< ciphers, comp, ext, + is_parsing_complete, }; Ok((i, DTLSMessageHandshakeBody::ClientHello(content))) } @@ -232,28 +264,43 @@ fn parse_dtls_handshake_msg_clientkeyexchange( )(i) } -fn parse_dtls_handshake_msg_certificate(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody<'_>> { - map(parse_tls_certificate, DTLSMessageHandshakeBody::Certificate)(i) -} - /// Parse a DTLS handshake message pub fn parse_dtls_message_handshake(i: &[u8]) -> IResult<&[u8], DTLSMessage<'_>> { + parse_dtls_message_handshake_inner(i, false) +} + +/// Parse a partial DTLS handshake message +pub fn parse_partial_dtls_message_handshake(i: &[u8]) -> IResult<&[u8], DTLSMessage<'_>> { + parse_dtls_message_handshake_inner(i, true) +} + +fn parse_dtls_message_handshake_inner( + i: &[u8], + allow_partial: bool, +) -> IResult<&[u8], DTLSMessage<'_>> { let (i, msg_type) = map(be_u8, TlsHandshakeType)(i)?; let (i, length) = be_u24(i)?; let (i, message_seq) = be_u16(i)?; let (i, fragment_offset) = be_u24(i)?; let (i, fragment_length) = be_u24(i)?; - // This packet contains fragment_length (which is less than length for fragmentation) - let (i, raw_msg) = take(fragment_length)(i)?; // Handshake messages can be fragmented over multiple packets. When fragmented, the user // needs the fragment_offset, fragment_length and length to determine whether they received // all the fragments. The DTLS spec allows for overlapping and duplicated fragments. let is_fragment = fragment_offset > 0 || fragment_length < length; + let is_data_complete = fragment_length <= i.len() as u32; + + let fragment_length = match allow_partial { + false => fragment_length, + true => min(fragment_length, i.len() as u32), + }; + // This packet contains fragment_length (which is less than length for fragmentation) + let (i, raw_msg) = take(fragment_length)(i)?; + let (_, body) = match msg_type { - _ if is_fragment => parse_dtls_fragment(raw_msg), - TlsHandshakeType::ClientHello => parse_dtls_client_hello(raw_msg), + _ if (is_fragment && !allow_partial) => parse_dtls_fragment(raw_msg), + TlsHandshakeType::ClientHello => parse_dtls_client_hello_inner(raw_msg, allow_partial), TlsHandshakeType::HelloVerifyRequest => parse_dtls_hello_verify_request(raw_msg), TlsHandshakeType::ServerHello => parse_dtls_handshake_msg_server_hello_tlsv12(raw_msg), TlsHandshakeType::ServerDone => { @@ -262,13 +309,14 @@ pub fn parse_dtls_message_handshake(i: &[u8]) -> IResult<&[u8], DTLSMessage<'_>> TlsHandshakeType::ClientKeyExchange => { parse_dtls_handshake_msg_clientkeyexchange(raw_msg, length as usize) } - TlsHandshakeType::Certificate => parse_dtls_handshake_msg_certificate(raw_msg), + TlsHandshakeType::Certificate => parse_tls_certificate_inner(raw_msg, allow_partial) + .map(|(rem, certs)| (rem, DTLSMessageHandshakeBody::Certificate(certs))), _ => { // eprintln!("Unsupported message type {:?}", msg_type); Err(Err::Error(make_error(i, ErrorKind::Switch))) } }?; - let msg = DTLSMessageHandshake { + let mut msg = DTLSMessageHandshake { msg_type, length, message_seq, @@ -276,6 +324,9 @@ pub fn parse_dtls_message_handshake(i: &[u8]) -> IResult<&[u8], DTLSMessage<'_>> fragment_length, body, }; + if msg.is_parsing_complete() && !is_data_complete { + msg.update_parsing_status(is_data_complete); + } Ok((i, DTLSMessage::Handshake(msg))) } @@ -296,11 +347,29 @@ pub fn parse_dtls_message_alert(i: &[u8]) -> IResult<&[u8], DTLSMessage<'_>> { pub fn parse_dtls_record_with_header<'i>( i: &'i [u8], hdr: &DTLSRecordHeader, +) -> IResult<&'i [u8], Vec>> { + parse_dtls_record_with_header_inner(i, hdr, false) +} + +pub fn parse_partial_dtls_record_with_header<'i>( + i: &'i [u8], + hdr: &DTLSRecordHeader, +) -> IResult<&'i [u8], Vec>> { + parse_dtls_record_with_header_inner(i, hdr, true) +} + +fn parse_dtls_record_with_header_inner<'i>( + i: &'i [u8], + hdr: &DTLSRecordHeader, + allow_partial: bool, ) -> IResult<&'i [u8], Vec>> { match hdr.content_type { TlsRecordType::ChangeCipherSpec => many1(complete(parse_dtls_message_changecipherspec))(i), TlsRecordType::Alert => many1(complete(parse_dtls_message_alert))(i), - TlsRecordType::Handshake => many1(complete(parse_dtls_message_handshake))(i), + TlsRecordType::Handshake => match allow_partial { + false => many1(complete(parse_dtls_message_handshake))(i), + true => many1(complete(parse_partial_dtls_message_handshake))(i), + }, // TlsRecordType::ApplicationData => many1(complete(parse_tls_message_applicationdata))(i), // TlsRecordType::Heartbeat => parse_tls_message_heartbeat(i, hdr.length), _ => { @@ -313,13 +382,31 @@ pub fn parse_dtls_record_with_header<'i>( /// Parse one DTLS plaintext record // Section 4.1 of RFC6347 pub fn parse_dtls_plaintext_record(i: &[u8]) -> IResult<&[u8], DTLSPlaintext<'_>> { + parse_dtls_plaintext_record_inner(i, false) +} + +/// Parse one partial DTLS plaintext record +// Section 4.1 of RFC6347 +pub fn parse_partial_dtls_plaintext_record(i: &[u8]) -> IResult<&[u8], DTLSPlaintext<'_>> { + parse_dtls_plaintext_record_inner(i, true) +} + +fn parse_dtls_plaintext_record_inner( + i: &[u8], + allow_partial: bool, +) -> IResult<&[u8], DTLSPlaintext<'_>> { let (i, header) = parse_dtls_record_header(i)?; // As in TLS 1.2, the length should not exceed 2^14. if header.length > MAX_RECORD_LEN { return Err(Err::Error(make_error(i, ErrorKind::TooLarge))); } - let (i, messages) = map_parser(take(header.length as usize), |i| { - parse_dtls_record_with_header(i, &header) + + let data_len = match allow_partial { + false => header.length as usize, + true => min(header.length as usize, i.len()), + }; + let (i, messages) = map_parser(take(data_len), |i| { + parse_dtls_record_with_header_inner(i, &header, allow_partial) })(i)?; Ok((i, DTLSPlaintext { header, messages })) } @@ -329,3 +416,9 @@ pub fn parse_dtls_plaintext_record(i: &[u8]) -> IResult<&[u8], DTLSPlaintext<'_> pub fn parse_dtls_plaintext_records(i: &[u8]) -> IResult<&[u8], Vec>> { many1(complete(parse_dtls_plaintext_record))(i) } + +/// Parse multiple DTLS plaintext record +// Section 4.1 of RFC6347 +pub fn parse_partial_dtls_plaintext_records(i: &[u8]) -> IResult<&[u8], Vec>> { + many1(complete(parse_partial_dtls_plaintext_record))(i) +} diff --git a/src/tls_debug.rs b/src/tls_debug.rs index d102484..859c15c 100644 --- a/src/tls_debug.rs +++ b/src/tls_debug.rs @@ -22,6 +22,7 @@ impl<'a> fmt::Debug for TlsClientHelloContents<'a> { .field("ciphers", &self.ciphers) .field("comp", &self.comp) .field("ext", &self.ext.map(HexSlice)) + .field("is_parsing_complete", &self.is_parsing_complete) .finish() } } diff --git a/src/tls_handshake.rs b/src/tls_handshake.rs index c3515f4..7430d84 100644 --- a/src/tls_handshake.rs +++ b/src/tls_handshake.rs @@ -1,4 +1,5 @@ use alloc::vec::Vec; +use core::cmp::min; use core::convert::TryInto; use core::fmt; use core::ops::Deref; @@ -243,6 +244,9 @@ pub struct TlsClientHelloContents<'a> { pub comp: Vec, pub ext: Option<&'a [u8]>, + + /// Whether this ClientHello message is complete or a fragment + pub is_parsing_complete: bool, } impl<'a> TlsClientHelloContents<'a> { @@ -253,6 +257,7 @@ impl<'a> TlsClientHelloContents<'a> { c: Vec, co: Vec, e: Option<&'a [u8]>, + complete: bool, ) -> Self { TlsClientHelloContents { version: TlsVersion(v), @@ -261,6 +266,7 @@ impl<'a> TlsClientHelloContents<'a> { ciphers: c, comp: co, ext: e, + is_parsing_complete: complete, } } @@ -378,6 +384,8 @@ pub struct RawCertificate<'a> { #[derive(Clone, Debug, PartialEq)] pub struct TlsCertificateContents<'a> { pub cert_chain: Vec>, + /// Indicates whether the certificate message is complete or a fragment + pub is_parsing_complete: bool, } /// Certificate request, as defined in [RFC5246](https://tools.ietf.org/html/rfc5246) section 7.4.4 @@ -459,6 +467,24 @@ pub enum TlsMessageHandshake<'a> { KeyUpdate(u8), } +impl TlsMessageHandshake<'_> { + pub fn is_parsing_complete(&self) -> bool { + match self { + TlsMessageHandshake::ClientHello(ch) => ch.is_parsing_complete, + TlsMessageHandshake::Certificate(certs) => certs.is_parsing_complete, + _ => true, + } + } + + pub fn update_parsing_status(&mut self, complete: bool) { + match self { + TlsMessageHandshake::ClientHello(ch) => ch.is_parsing_complete = complete, + TlsMessageHandshake::Certificate(certs) => certs.is_parsing_complete = complete, + _ => {} + } + } +} + /// Parse a HelloRequest handshake message pub fn parse_tls_handshake_msg_hello_request(i: &[u8]) -> IResult<&[u8], TlsMessageHandshake<'_>> { Ok((i, TlsMessageHandshake::HelloRequest)) @@ -477,6 +503,13 @@ pub fn parse_tls_handshake_msg_hello_request(i: &[u8]) -> IResult<&[u8], TlsMess /// # } /// ``` pub fn parse_tls_handshake_client_hello(i: &[u8]) -> IResult<&[u8], TlsClientHelloContents<'_>> { + parse_tls_handshake_client_hello_inner(i, false) +} + +fn parse_tls_handshake_client_hello_inner( + i: &[u8], + allow_partial: bool, +) -> IResult<&[u8], TlsClientHelloContents<'_>> { let (i, version) = be_u16(i)?; let (i, random) = take(32usize)(i)?; let (i, sidlen) = verify(be_u8, |&n| n <= 32)(i)?; @@ -485,11 +518,43 @@ pub fn parse_tls_handshake_client_hello(i: &[u8]) -> IResult<&[u8], TlsClientHel let (i, ciphers) = parse_cipher_suites(i, ciphers_len as usize)?; let (i, comp_len) = be_u8(i)?; let (i, comp) = parse_compressions_algs(i, comp_len as usize)?; - let (i, ext) = opt(complete(length_data(be_u16)))(i)?; - let content = TlsClientHelloContents::new(version, random, sid, ciphers, comp, ext); + let (i, ext_len) = be_u16(i)?; + let is_parsing_complete = ext_len as usize <= i.len(); + let ext_len = match allow_partial { + false => ext_len as usize, + true => min(ext_len as usize, i.len()), + }; + let (i, ext) = opt(complete(take(ext_len)))(i)?; + let content = TlsClientHelloContents::new( + version, + random, + sid, + ciphers, + comp, + ext, + is_parsing_complete, + ); Ok((i, content)) } +/// Parse frag handshake message contents for ClientHello +/// +/// ```rust +/// use tls_parser::*; +/// +/// # pub fn do_stuff(bytes: &[u8]) { +/// if let Ok((_, ch)) = parse_partial_tls_handshake_client_hello(bytes) { +/// println!("ClientHello TLS version: {}", ch.version); +/// println!(" number of proposed ciphersuites: {}", ch.ciphers.len()); +/// } +/// # } +/// ``` +pub fn parse_partial_tls_handshake_client_hello( + i: &[u8], +) -> IResult<&[u8], TlsClientHelloContents<'_>> { + parse_tls_handshake_client_hello_inner(i, true) +} + /// Parse a ClientHello handshake message /// /// This function returns a [TlsMessageHandshake]. To get only the `ClientHello` contents, use the @@ -513,6 +578,16 @@ pub fn parse_tls_handshake_msg_client_hello(i: &[u8]) -> IResult<&[u8], TlsMessa )(i) } +/// Parse a partial ClientHello handshake message +pub fn parse_partial_tls_handshake_msg_client_hello( + i: &[u8], +) -> IResult<&[u8], TlsMessageHandshake<'_>> { + map( + parse_partial_tls_handshake_client_hello, + TlsMessageHandshake::ClientHello, + )(i) +} + pub(crate) fn parse_cipher_suites(i: &[u8], len: usize) -> IResult<&[u8], Vec> { if len == 0 { return Ok((i, Vec::new())); @@ -699,9 +774,31 @@ pub fn parse_tls_handshake_msg_hello_retry_request( } pub(crate) fn parse_tls_certificate(i: &[u8]) -> IResult<&[u8], TlsCertificateContents<'_>> { + parse_tls_certificate_inner(i, false) +} + +pub(crate) fn parse_partial_tls_certificate( + i: &[u8], +) -> IResult<&[u8], TlsCertificateContents<'_>> { + parse_tls_certificate_inner(i, true) +} + +pub(crate) fn parse_tls_certificate_inner( + i: &[u8], + allow_partial: bool, +) -> IResult<&[u8], TlsCertificateContents<'_>> { let (i, cert_len) = be_u24(i)?; + let cert_len = match allow_partial { + false => cert_len as usize, + true => min(cert_len as usize, i.len()), + }; + + let is_parsing_complete = cert_len as usize <= i.len(); let (i, cert_chain) = map_parser(take(cert_len as usize), parse_certs)(i)?; - let content = TlsCertificateContents { cert_chain }; + let content = TlsCertificateContents { + cert_chain, + is_parsing_complete, + }; Ok((i, content)) } @@ -710,6 +807,16 @@ pub fn parse_tls_handshake_msg_certificate(i: &[u8]) -> IResult<&[u8], TlsMessag map(parse_tls_certificate, TlsMessageHandshake::Certificate)(i) } +/// Parse a partial Certificate handshake message +pub fn parse_partial_tls_handshake_msg_certificate( + i: &[u8], +) -> IResult<&[u8], TlsMessageHandshake<'_>> { + map( + parse_partial_tls_certificate, + TlsMessageHandshake::Certificate, + )(i) +} + /// Parse a ServerKeyExchange handshake message pub fn parse_tls_handshake_msg_serverkeyexchange( i: &[u8], @@ -876,36 +983,60 @@ pub fn parse_tls_handshake_msg_key_update(i: &[u8]) -> IResult<&[u8], TlsMessage /// Parse a TLS handshake message pub fn parse_tls_message_handshake(i: &[u8]) -> IResult<&[u8], TlsMessage<'_>> { + parse_tls_message_handshake_inner(i, false) +} + +/// Parse a partial TLS handshake message +pub fn parse_partial_tls_message_handshake(i: &[u8]) -> IResult<&[u8], TlsMessage<'_>> { + parse_tls_message_handshake_inner(i, true) +} + +/// Parse a TLS handshake message +fn parse_tls_message_handshake_inner( + i: &[u8], + allow_partial: bool, +) -> IResult<&[u8], TlsMessage<'_>> { let (i, ht) = be_u8(i)?; let (i, hl) = be_u24(i)?; + let hl = match allow_partial { + false => hl as usize, + true => i.len(), + }; + let is_data_complete = hl <= i.len(); let (i, raw_msg) = take(hl)(i)?; - let (_, msg) = match TlsHandshakeType(ht) { + let (_, mut msg) = match TlsHandshakeType(ht) { TlsHandshakeType::HelloRequest => parse_tls_handshake_msg_hello_request(raw_msg), - TlsHandshakeType::ClientHello => parse_tls_handshake_msg_client_hello(raw_msg), - TlsHandshakeType::ServerHello => parse_tls_handshake_msg_server_hello(raw_msg), - TlsHandshakeType::NewSessionTicket => { - parse_tls_handshake_msg_newsessionticket(raw_msg, hl as usize) + TlsHandshakeType::ClientHello => { + parse_tls_handshake_client_hello_inner(raw_msg, allow_partial) + .map(|(rem, hello)| (rem, TlsMessageHandshake::ClientHello(hello))) } + TlsHandshakeType::ServerHello => parse_tls_handshake_msg_server_hello(raw_msg), + TlsHandshakeType::NewSessionTicket => parse_tls_handshake_msg_newsessionticket(raw_msg, hl), TlsHandshakeType::EndOfEarlyData => Ok((raw_msg, TlsMessageHandshake::EndOfEarlyData)), TlsHandshakeType::HelloRetryRequest => parse_tls_handshake_msg_hello_retry_request(raw_msg), - TlsHandshakeType::Certificate => parse_tls_handshake_msg_certificate(raw_msg), + TlsHandshakeType::Certificate => parse_tls_certificate_inner(raw_msg, allow_partial) + .map(|(rem, cert)| (rem, TlsMessageHandshake::Certificate(cert))), TlsHandshakeType::ServerKeyExchange => { - parse_tls_handshake_msg_serverkeyexchange(raw_msg, hl as usize) + parse_tls_handshake_msg_serverkeyexchange(raw_msg, hl) } TlsHandshakeType::CertificateRequest => parse_tls_handshake_msg_certificaterequest(raw_msg), - TlsHandshakeType::ServerDone => parse_tls_handshake_msg_serverdone(raw_msg, hl as usize), + TlsHandshakeType::ServerDone => parse_tls_handshake_msg_serverdone(raw_msg, hl), TlsHandshakeType::CertificateVerify => { - parse_tls_handshake_msg_certificateverify(raw_msg, hl as usize) + parse_tls_handshake_msg_certificateverify(raw_msg, hl) } TlsHandshakeType::ClientKeyExchange => { - parse_tls_handshake_msg_clientkeyexchange(raw_msg, hl as usize) + parse_tls_handshake_msg_clientkeyexchange(raw_msg, hl) } - TlsHandshakeType::Finished => parse_tls_handshake_msg_finished(raw_msg, hl as usize), + TlsHandshakeType::Finished => parse_tls_handshake_msg_finished(raw_msg, hl), // TlsHandshakeType::CertificateURL => parse_tls_handshake_msg_certificateurl(raw_msg), TlsHandshakeType::CertificateStatus => parse_tls_handshake_msg_certificatestatus(raw_msg), TlsHandshakeType::KeyUpdate => parse_tls_handshake_msg_key_update(raw_msg), TlsHandshakeType::NextProtocol => parse_tls_handshake_msg_next_protocol(raw_msg), _ => Err(Err::Error(make_error(i, ErrorKind::Switch))), }?; + + if msg.is_parsing_complete() && !is_data_complete { + msg.update_parsing_status(is_data_complete); + } Ok((i, TlsMessage::Handshake(msg))) } diff --git a/src/tls_message.rs b/src/tls_message.rs index ed78633..3ca1134 100644 --- a/src/tls_message.rs +++ b/src/tls_message.rs @@ -21,6 +21,15 @@ pub enum TlsMessage<'a> { Heartbeat(TlsMessageHeartbeat<'a>), } +impl TlsMessage<'_> { + pub fn is_parsing_complete(&self) -> bool { + match self { + TlsMessage::Handshake(hs) => hs.is_parsing_complete(), + _ => true, + } + } +} + /// TLS application data /// /// Since this message can only be sent after the handshake, data is diff --git a/src/tls_record.rs b/src/tls_record.rs index 55fceee..1596f82 100644 --- a/src/tls_record.rs +++ b/src/tls_record.rs @@ -1,3 +1,5 @@ +use core::cmp::min; + use alloc::vec::Vec; use nom::bytes::streaming::take; use nom::combinator::{complete, map_parser}; @@ -94,29 +96,66 @@ pub fn parse_tls_record_header(i: &[u8]) -> IResult<&[u8], TlsRecordHeader> { /// strict protocol conformance). /// /// This function will fail on fragmented records. To support fragmented records, use -/// [crate::TlsRecordsParser]]. +/// [crate::TlsRecordsParser]], or use [crate::parse_partial_tls_plaintext] to parse +/// partial plaintext records. #[rustfmt::skip] #[allow(clippy::trivially_copy_pass_by_ref)] // TlsRecordHeader is only 6 bytes, but we prefer not breaking current API pub fn parse_tls_record_with_header<'i>(i:&'i [u8], hdr:&TlsRecordHeader ) -> IResult<&'i [u8], Vec>> { + parse_tls_record_with_header_inner(i, hdr, false) +} + +/// Given partial data and a TLS record header, parse content. +/// +/// A record can contain multiple messages (with the same type). +/// +#[rustfmt::skip] +#[allow(clippy::trivially_copy_pass_by_ref)] // TlsRecordHeader is only 6 bytes, but we prefer not breaking current API +pub fn parse_partial_tls_record_with_header<'i>(i:&'i [u8], hdr:&TlsRecordHeader ) -> IResult<&'i [u8], Vec>> { + parse_tls_record_with_header_inner(i, hdr, true) +} + +fn parse_tls_record_with_header_inner<'i>( + i: &'i [u8], + hdr: &TlsRecordHeader, + allow_partial: bool, +) -> IResult<&'i [u8], Vec>> { match hdr.record_type { TlsRecordType::ChangeCipherSpec => many1(complete(parse_tls_message_changecipherspec))(i), - TlsRecordType::Alert => many1(complete(parse_tls_message_alert))(i), - TlsRecordType::Handshake => many1(complete(parse_tls_message_handshake))(i), - TlsRecordType::ApplicationData => many1(complete(parse_tls_message_applicationdata))(i), - TlsRecordType::Heartbeat => parse_tls_message_heartbeat(i, hdr.len), - _ => Err(Err::Error(make_error(i, ErrorKind::Switch))) + TlsRecordType::Alert => many1(complete(parse_tls_message_alert))(i), + TlsRecordType::Handshake => match allow_partial { + false => many1(complete(parse_tls_message_handshake))(i), + true => many1(complete(parse_partial_tls_message_handshake))(i), + }, + TlsRecordType::ApplicationData => many1(complete(parse_tls_message_applicationdata))(i), + TlsRecordType::Heartbeat => parse_tls_message_heartbeat(i, hdr.len), + _ => Err(Err::Error(make_error(i, ErrorKind::Switch))), } } /// Parse one packet only, as plaintext /// A single record can contain multiple messages, they must share the same record type pub fn parse_tls_plaintext(i: &[u8]) -> IResult<&[u8], TlsPlaintext<'_>> { + parse_tls_plaintext_inner(i, false) +} + +/// Parse one partial packet only, as plaintext +/// A single record can contain multiple messages, they must share the same record type +pub fn parse_partial_tls_plaintext(i: &[u8]) -> IResult<&[u8], TlsPlaintext<'_>> { + parse_tls_plaintext_inner(i, true) +} + +fn parse_tls_plaintext_inner(i: &[u8], allow_partial: bool) -> IResult<&[u8], TlsPlaintext<'_>> { let (i, hdr) = parse_tls_record_header(i)?; if hdr.len > MAX_RECORD_LEN { return Err(Err::Error(make_error(i, ErrorKind::TooLarge))); } - let (i, msg) = map_parser(take(hdr.len as usize), |i| { - parse_tls_record_with_header(i, &hdr) + + let data_len = match allow_partial { + false => hdr.len as usize, + true => min(hdr.len as usize, i.len()), + }; + let (i, msg) = map_parser(take(data_len), |i| { + parse_tls_record_with_header_inner(i, &hdr, allow_partial) })(i)?; Ok((i, TlsPlaintext { hdr, msg })) } diff --git a/tests/tls_extensions.rs b/tests/tls_extensions.rs index d64171c..46ae26b 100644 --- a/tests/tls_extensions.rs +++ b/tests/tls_extensions.rs @@ -50,6 +50,22 @@ static CLIENT_EXTENSIONS1: &[u8] = &[ assert_eq!(res, expected); } + #[test] + fn test_fragmented_tls_extensions() { + let bytes = &CLIENT_EXTENSIONS1[..29]; + let expected = Ok(( + &CLIENT_EXTENSIONS1[23..29], + vec![TlsExtension::SNI(vec![( + SNIType::HostName, + b"www.google.com", + )])], + )); + + let res = parse_tls_extensions(bytes); + + assert_eq!(res, expected); + } + #[test] fn test_tls_extension_max_fragment_length() { let empty = &b""[..]; diff --git a/tests/tls_handshake.rs b/tests/tls_handshake.rs index bf5c6cc..aa93775 100644 --- a/tests/tls_handshake.rs +++ b/tests/tls_handshake.rs @@ -5,6 +5,8 @@ extern crate nom; extern crate tls_parser; mod tls_handshake { + use std::num::NonZero; + use nom::{Err, Needed}; use tls_parser::*; @@ -71,6 +73,7 @@ static CH : &[u8] = &[ ciphers: ciphers.iter().map(|&x| TlsCipherSuiteID(x)).collect(), comp, ext: Some(&CH[220..]), + is_parsing_complete: true, }, ))], }; @@ -78,6 +81,52 @@ static CH : &[u8] = &[ assert_eq!(res, Ok((empty, expected))); } + #[test] + fn test_partial_tls_record_clienthello() { + let empty = &b""[..]; + let rand_data = [ + 0xb2, 0x9d, 0xd7, 0x87, 0xff, 0x21, 0xeb, 0x04, 0xc8, 0xa5, 0x38, 0x39, 0x9a, 0xcf, + 0xb7, 0xa3, 0x82, 0x1f, 0x82, 0x6c, 0x49, 0xbc, 0x8b, 0xb8, 0xa9, 0x03, 0x0a, 0x2d, + 0xce, 0x38, 0x0b, 0xf4, + ]; + let ciphers = &[ + 0xc030, 0xc02c, 0xc028, 0xc024, 0xc014, 0xc00a, 0x00a5, 0x00a3, 0x00a1, 0x009f, 0x006b, + 0x006a, 0x0069, 0x0068, 0x0039, 0x0038, 0x0037, 0x0036, 0x0088, 0x0087, 0x0086, 0x0085, + 0xc032, 0xc02e, 0xc02a, 0xc026, 0xc00f, 0xc005, 0x009d, 0x003d, 0x0035, 0x0084, 0xc02f, + 0xc02b, 0xc027, 0xc023, 0xc013, 0xc009, 0x00a4, 0x00a2, 0x00a0, 0x009e, 0x0067, 0x0040, + 0x003f, 0x003e, 0x0033, 0x0032, 0x0031, 0x0030, 0x009a, 0x0099, 0x0098, 0x0097, 0x0045, + 0x0044, 0x0043, 0x0042, 0xc031, 0xc02d, 0xc029, 0xc025, 0xc00e, 0xc004, 0x009c, 0x003c, + 0x002f, 0x0096, 0x0041, 0xc011, 0xc007, 0xc00c, 0xc002, 0x0005, 0x0004, 0xc012, 0xc008, + 0x0016, 0x0013, 0x0010, 0x000d, 0xc00d, 0xc003, 0x000a, 0x00ff, + ]; + let comp = vec![TlsCompressionID(0x00)]; + let expected = TlsPlaintext { + hdr: TlsRecordHeader { + record_type: TlsRecordType::Handshake, + version: TlsVersion::Tls10, + len: 300, + }, + msg: vec![TlsMessage::Handshake(TlsMessageHandshake::ClientHello( + TlsClientHelloContents { + version: TlsVersion::Tls12, + random: &rand_data, + session_id: None, + ciphers: ciphers.iter().map(|&x| TlsCipherSuiteID(x)).collect(), + comp, + ext: Some(&CH[220..CH.len() - 11]), + is_parsing_complete: false, + }, + ))], + }; + let res = parse_tls_plaintext(&CH[..CH.len() - 11]); + assert_eq!( + res, + Err(Err::Incomplete(Needed::Size(NonZero::new(11).unwrap()))) + ); + let res = parse_partial_tls_plaintext(&CH[..CH.len() - 11]); + assert_eq!(res, Ok((empty, expected))); + } + // tls response, composed of 4 records: // - Server Hello // - Server Certificate @@ -465,12 +514,51 @@ static SERVER_REPLY1: &[u8] = &[ len: 3081, }, msg: vec![TlsMessage::Handshake(TlsMessageHandshake::Certificate( - TlsCertificateContents { cert_chain: chain }, + TlsCertificateContents { + cert_chain: chain, + is_parsing_complete: true, + }, ))], }; assert_eq!(parse_tls_plaintext(bytes), Ok((empty, expected))); } + #[test] + fn test_partial_tls_record_certificate() { + let empty = &b""[..]; + let bytes = &SERVER_REPLY1[64..3000]; + let chain = vec![ + RawCertificate { + data: &bytes[15..1171], + }, + RawCertificate { + data: &bytes[1174..2186], + }, + // &bytes[2189..2936], incomplete cert data + ]; + for cert in &chain { + println!("cert len: {}", cert.data.len()); + } + let expected = TlsPlaintext { + hdr: TlsRecordHeader { + record_type: TlsRecordType::Handshake, + version: TlsVersion::Tls12, + len: 3081, + }, + msg: vec![TlsMessage::Handshake(TlsMessageHandshake::Certificate( + TlsCertificateContents { + cert_chain: chain, + is_parsing_complete: true, + }, + ))], + }; + + assert!(parse_tls_plaintext(bytes).is_err()); + + let result = parse_partial_tls_plaintext(bytes); + assert_eq!(result, Ok((empty, expected))); + } + #[test] fn test_tls_record_serverkeyexchange() { let empty = &b""[..]; diff --git a/tests/tls_tls13.rs b/tests/tls_tls13.rs index d75a3fb..8712c53 100644 --- a/tests/tls_tls13.rs +++ b/tests/tls_tls13.rs @@ -75,6 +75,7 @@ static TV_SERVER_HELLO_1: &[u8] = &[ ciphers: ciphers.iter().map(|&x| TlsCipherSuiteID(x)).collect(), comp: vec![TlsCompressionID(0)], ext: Some(&bytes[112..]), + is_parsing_complete: true, }, ))], };