Skip to main content

iris_core/protocols/stream/tls/
parser.rs

1//! TLS handshake parser.
2//!
3//! The TLS handshake parser uses a [fork](https://github.com/thegwan/tls-parser) of the
4//! [tls-parser](https://docs.rs/tls-parser/latest/tls_parser/) crate to parse the handshake phase
5//! of a TLS connection. It maintains TLS state, stores selected parameters, and handles
6//! defragmentation.
7//!
8//! Adapted from [the Rusticata TLS
9//! parser](https://github.com/rusticata/rusticata/blob/master/src/tls.rs).
10
11use super::handshake::{
12    Certificate, ClientDHParams, ClientECDHParams, ClientHello, ClientKeyExchange, ClientRSAParams,
13    KeyShareEntry, ServerDHParams, ServerECDHParams, ServerHello, ServerKeyExchange,
14    ServerRSAParams,
15};
16use super::Tls;
17use crate::conntrack::pdu::L4Pdu;
18use crate::protocols::stream::{
19    ConnParsable, ParseResult, ParsingState, ProbeResult, Session, SessionData,
20};
21
22use tls_parser::*;
23
24/// Parses a single TLS handshake per connection.
25#[derive(Debug)]
26pub struct TlsParser {
27    /// Handshakes seen. We expect there to only be one.
28    sessions: Vec<Tls>,
29}
30
31impl Default for TlsParser {
32    fn default() -> Self {
33        TlsParser {
34            sessions: vec![Tls::new()],
35        }
36    }
37}
38
39impl ConnParsable for TlsParser {
40    fn parse(&mut self, pdu: &L4Pdu) -> ParseResult {
41        log::debug!("Updating parser tls");
42        let offset = pdu.offset();
43        let length = pdu.length();
44        if length == 0 {
45            return ParseResult::Skipped;
46        }
47
48        if let Ok(data) = (pdu.mbuf_ref()).get_data_slice(offset, length) {
49            self.sessions[0].parse_tcp_level(data, pdu.dir)
50        } else {
51            log::warn!("Malformed packet");
52            ParseResult::Skipped
53        }
54    }
55
56    fn probe(&self, pdu: &L4Pdu) -> ProbeResult {
57        if pdu.length() <= 2 {
58            return ProbeResult::Unsure;
59        }
60
61        let offset = pdu.offset();
62        let length = pdu.length();
63        if let Ok(data) = (pdu.mbuf_ref()).get_data_slice(offset, length) {
64            // First byte is record type (between 0x14 and 0x17, 0x16 is handhake) Second is TLS
65            // version major (0x3) Third is TLS version minor (0x0 for SSLv3, 0x1 for TLSv1.0, etc.)
66            // Does not support versions <= SSLv2
67            match (data[0], data[1], data[2]) {
68                (0x14..=0x17, 0x03, 0..=3) => ProbeResult::Certain,
69                _ => ProbeResult::NotForUs,
70            }
71        } else {
72            log::warn!("Malformed packet");
73            ProbeResult::Error
74        }
75    }
76
77    fn remove_session(&mut self, _session_id: usize) -> Option<Session> {
78        self.sessions.pop().map(|tls| Session {
79            data: SessionData::Tls(Box::new(tls)),
80            id: 0,
81        })
82    }
83
84    fn drain_sessions(&mut self) -> Vec<Session> {
85        self.sessions
86            .drain(..)
87            .map(|tls| Session {
88                data: SessionData::Tls(Box::new(tls)),
89                id: 0,
90            })
91            .collect()
92    }
93
94    fn session_parsed_state(&self) -> ParsingState {
95        ParsingState::Stop
96    }
97
98    fn body_offset(&mut self) -> Option<usize> {
99        match self.sessions.last_mut() {
100            Some(tls) => std::mem::take(&mut tls.last_body_offset),
101            None => None,
102        }
103    }
104}
105
106// ------------------------------------------------------------
107
108impl Tls {
109    /// Allocate a new TLS handshake instance.
110    pub(crate) fn new() -> Tls {
111        Tls {
112            client_hello: None,
113            server_hello: None,
114            server_certificates: vec![],
115            client_certificates: vec![],
116            server_key_exchange: None,
117            client_key_exchange: None,
118            state: TlsState::None,
119            tcp_buffer: vec![],
120            record_buffer: vec![],
121            last_body_offset: None,
122        }
123    }
124
125    /// Parse a ClientHello message.
126    pub(crate) fn parse_handshake_clienthello(&mut self, content: &TlsClientHelloContents) {
127        let mut client_hello = ClientHello {
128            version: content.version,
129            random: content.random.to_vec(),
130            session_id: match content.session_id {
131                Some(v) => v.to_vec(),
132                None => vec![],
133            },
134            cipher_suites: content.ciphers.to_vec(),
135            compression_algs: content.comp.to_vec(),
136            ..ClientHello::default()
137        };
138
139        let ext = parse_tls_client_hello_extensions(content.ext.unwrap_or(b""));
140        log::trace!("client extensions: {:#?}", ext);
141        match &ext {
142            Ok((rem, ref ext_lst)) => {
143                if !rem.is_empty() {
144                    log::debug!("warn: extensions not entirely parsed");
145                }
146                for extension in ext_lst {
147                    client_hello
148                        .extension_list
149                        .push(TlsExtensionType::from(extension));
150                    match *extension {
151                        TlsExtension::SNI(ref v) if !v.is_empty() => {
152                            let sni = v[0].1;
153                            client_hello.server_name = Some(match std::str::from_utf8(sni) {
154                                Ok(name) => name.to_string(),
155                                Err(_) => format!("<Invalid UTF-8: {}>", hex::encode(sni)),
156                            });
157                        }
158                        TlsExtension::SupportedGroups(ref v) => {
159                            client_hello.supported_groups = v.clone();
160                        }
161                        TlsExtension::EcPointFormats(v) => {
162                            client_hello.ec_point_formats = v.to_vec();
163                        }
164                        TlsExtension::SignatureAlgorithms(ref v) => {
165                            client_hello.signature_algs = v.clone();
166                        }
167                        TlsExtension::ALPN(ref v) => {
168                            for proto in v {
169                                client_hello.alpn_protocols.push(
170                                    match std::str::from_utf8(proto) {
171                                        Ok(proto) => proto.to_string(),
172                                        Err(_) => {
173                                            format!("<Invalid UTF-8: {}>", hex::encode(proto))
174                                        }
175                                    },
176                                );
177                            }
178                        }
179                        TlsExtension::KeyShare(ref v) => {
180                            log::debug!("Client Shares: {:?}", v);
181                            client_hello.key_shares = v
182                                .iter()
183                                .map(|k| KeyShareEntry {
184                                    group: k.group,
185                                    kx_data: k.kx.to_vec(),
186                                })
187                                .collect();
188                        }
189                        TlsExtension::SupportedVersions(ref v) => {
190                            client_hello.supported_versions = v.clone();
191                        }
192                        _ => (),
193                    }
194                }
195            }
196            e => log::debug!("Could not parse extensions: {:?}", e),
197        };
198        self.client_hello = Some(client_hello);
199    }
200
201    /// Parse a ServerHello message.
202    fn parse_handshake_serverhello(&mut self, content: &TlsServerHelloContents) {
203        let mut server_hello = ServerHello {
204            version: content.version,
205            random: content.random.to_vec(),
206            session_id: match content.session_id {
207                Some(v) => v.to_vec(),
208                None => vec![],
209            },
210            cipher_suite: content.cipher,
211            compression_alg: content.compression,
212            ..ServerHello::default()
213        };
214
215        let ext = parse_tls_server_hello_extensions(content.ext.unwrap_or(b""));
216        log::debug!("server_hello extensions: {:#?}", ext);
217        match &ext {
218            Ok((rem, ref ext_lst)) => {
219                if !rem.is_empty() {
220                    log::debug!("warn: extensions not entirely parsed");
221                }
222                for extension in ext_lst {
223                    server_hello
224                        .extension_list
225                        .push(TlsExtensionType::from(extension));
226                    match *extension {
227                        TlsExtension::EcPointFormats(v) => {
228                            server_hello.ec_point_formats = v.to_vec();
229                        }
230                        TlsExtension::ALPN(ref v) if !v.is_empty() => {
231                            server_hello.alpn_protocol = Some(match std::str::from_utf8(v[0]) {
232                                Ok(proto) => proto.to_string(),
233                                Err(_) => format!("<Invalid UTF-8: {}>", hex::encode(v[0])),
234                            });
235                        }
236                        TlsExtension::KeyShare(ref v) => {
237                            log::debug!("Server Share: {:?}", v);
238                            if !v.is_empty() {
239                                server_hello.key_share = Some(KeyShareEntry {
240                                    group: v[0].group,
241                                    kx_data: v[0].kx.to_vec(),
242                                });
243                            }
244                        }
245                        TlsExtension::SupportedVersions(ref v) if !v.is_empty() => {
246                            server_hello.selected_version = Some(v[0]);
247                        }
248                        _ => (),
249                    }
250                }
251            }
252            e => log::debug!("Could not parse extensions: {:?}", e),
253        };
254        self.server_hello = Some(server_hello);
255    }
256
257    /// Parse a Certificate message.
258    fn parse_handshake_certificate(&mut self, content: &TlsCertificateContents, direction: bool) {
259        log::trace!("cert chain length: {}", content.cert_chain.len());
260        if direction {
261            // client -> server
262            for cert in &content.cert_chain {
263                self.client_certificates.push(Certificate {
264                    raw: cert.data.to_vec(),
265                })
266            }
267        } else {
268            // server -> client
269            for cert in &content.cert_chain {
270                self.server_certificates.push(Certificate {
271                    raw: cert.data.to_vec(),
272                })
273            }
274        }
275    }
276
277    /// Parse a ServerKeyExchange message.
278    fn parse_handshake_serverkeyexchange(&mut self, content: &TlsServerKeyExchangeContents) {
279        log::trace!("SKE: {:?}", content);
280        if let Some(cipher) = self.cipher_suite() {
281            match &cipher.kx {
282                TlsCipherKx::Ecdhe | TlsCipherKx::Ecdh => {
283                    if let Ok((_sig, ref parsed)) = parse_server_ecdh_params(content.parameters) {
284                        if let ECParametersContent::NamedGroup(curve) =
285                            parsed.curve_params.params_content
286                        {
287                            let ecdh_params = ServerECDHParams {
288                                curve,
289                                kx_data: parsed.public.point.to_vec(),
290                            };
291                            self.server_key_exchange = Some(ServerKeyExchange::Ecdh(ecdh_params));
292                        };
293                    }
294                }
295                TlsCipherKx::Dhe | TlsCipherKx::Dh => {
296                    if let Ok((_sig, ref parsed)) = parse_server_dh_params(content.parameters) {
297                        let dh_params = ServerDHParams {
298                            prime: parsed.dh_p.to_vec(),
299                            generator: parsed.dh_g.to_vec(),
300                            kx_data: parsed.dh_ys.to_vec(),
301                        };
302                        self.server_key_exchange = Some(ServerKeyExchange::Dh(dh_params));
303                    }
304                }
305                TlsCipherKx::Rsa => {
306                    if let Ok((_sig, ref parsed)) = parse_server_rsa_params(content.parameters) {
307                        let rsa_params = ServerRSAParams {
308                            modulus: parsed.modulus.to_vec(),
309                            exponent: parsed.exponent.to_vec(),
310                        };
311                        self.server_key_exchange = Some(ServerKeyExchange::Rsa(rsa_params));
312                    }
313                }
314                _ => {
315                    self.server_key_exchange =
316                        Some(ServerKeyExchange::Unknown(content.parameters.to_vec()))
317                }
318            }
319        }
320    }
321
322    /// Parse a ClientKeyExchange message.
323    fn parse_handshake_clientkeyexchange(&mut self, content: &TlsClientKeyExchangeContents) {
324        log::trace!("CKE: {:?}", content);
325        if let Some(cipher) = self.cipher_suite() {
326            match &cipher.kx {
327                TlsCipherKx::Ecdhe | TlsCipherKx::Ecdh => {
328                    if let Ok((_rem, ref parsed)) = parse_client_ecdh_params(content.parameters) {
329                        let ecdh_params = ClientECDHParams {
330                            kx_data: parsed.ecdh_yc.point.to_vec(),
331                        };
332                        self.client_key_exchange = Some(ClientKeyExchange::Ecdh(ecdh_params));
333                    }
334                }
335                TlsCipherKx::Dhe | TlsCipherKx::Dh => {
336                    if let Ok((_rem, ref parsed)) = parse_client_dh_params(content.parameters) {
337                        let dh_params = ClientDHParams {
338                            kx_data: parsed.dh_yc.to_vec(),
339                        };
340                        self.client_key_exchange = Some(ClientKeyExchange::Dh(dh_params));
341                    }
342                }
343                TlsCipherKx::Rsa => {
344                    if let Ok((_rem, ref parsed)) = parse_client_rsa_params(content.parameters) {
345                        let rsa_params = ClientRSAParams {
346                            encrypted_pms: parsed.data.to_vec(),
347                        };
348                        self.client_key_exchange = Some(ClientKeyExchange::Rsa(rsa_params));
349                    }
350                }
351                _ => {
352                    self.client_key_exchange =
353                        Some(ClientKeyExchange::Unknown(content.parameters.to_vec()))
354                }
355            }
356        }
357        //self.client_key_exchange = Some(client_key_exchange);
358    }
359
360    /// Parse a TLS message.
361    pub(crate) fn parse_message_level(&mut self, msg: &TlsMessage, direction: bool) -> ParseResult {
362        log::trace!("parse_message_level {:?}", msg);
363
364        // do not parse if session is encrypted
365        if self.state == TlsState::ClientChangeCipherSpec {
366            log::trace!("TLS session encrypted, activating bypass");
367            return ParseResult::HeadersDone(0);
368        }
369
370        // update state machine
371        match tls_state_transition(self.state, msg, direction) {
372            Ok(s) => self.state = s,
373            Err(_) => {
374                self.state = TlsState::Invalid;
375            }
376        };
377        log::trace!("TLS new state: {:?}", self.state);
378
379        // extract variables
380        match *msg {
381            TlsMessage::Handshake(ref m) => match *m {
382                TlsMessageHandshake::ClientHello(ref content) => {
383                    self.parse_handshake_clienthello(content);
384                }
385                TlsMessageHandshake::ServerHello(ref content) => {
386                    self.parse_handshake_serverhello(content);
387                }
388                TlsMessageHandshake::Certificate(ref content) => {
389                    self.parse_handshake_certificate(content, direction);
390                }
391                TlsMessageHandshake::ServerKeyExchange(ref content) => {
392                    self.parse_handshake_serverkeyexchange(content);
393                }
394                TlsMessageHandshake::ClientKeyExchange(ref content) => {
395                    self.parse_handshake_clientkeyexchange(content);
396                }
397
398                _ => (),
399            },
400            TlsMessage::Alert(ref a) if a.severity == TlsAlertSeverity::Fatal => {
401                return ParseResult::HeadersDone(0);
402            }
403            _ => (),
404        }
405
406        ParseResult::Continue(0)
407    }
408
409    /// Parse a TLS record.
410    pub(crate) fn parse_record_level(
411        &mut self,
412        record: &TlsRawRecord<'_>,
413        direction: bool,
414        pdu_len: usize,
415    ) -> ParseResult {
416        let mut v: Vec<u8>;
417        let mut status = ParseResult::Continue(0);
418
419        log::trace!("parse_record_level ({} bytes)", record.data.len());
420        log::trace!("{:?}", record.hdr);
421        // log::trace!("{:?}", record.data);
422
423        // do not parse if session is encrypted
424        if self.state == TlsState::ClientChangeCipherSpec {
425            log::trace!("TLS session encrypted, activating bypass");
426            return ParseResult::HeadersDone(0);
427        }
428
429        // only parse some message types (the Content type, first byte of TLS record)
430        match record.hdr.record_type {
431            TlsRecordType::ChangeCipherSpec => (),
432            TlsRecordType::Handshake => (),
433            TlsRecordType::Alert => (),
434            _ => return ParseResult::Continue(0),
435        }
436
437        // Check if a record is being defragmented
438        let record_buffer = match self.record_buffer.len() {
439            0 => record.data,
440            _ => {
441                // sanity check vector length to avoid memory exhaustion maximum length may be 2^24
442                // (handshake message)
443                if self.record_buffer.len() + record.data.len() > 16_777_216 {
444                    return ParseResult::Skipped;
445                };
446                v = self.record_buffer.split_off(0);
447                v.extend_from_slice(record.data);
448                v.as_slice()
449            }
450        };
451
452        // NICE-TO-HAVE: record may be compressed Parse record contents as plaintext
453        match parse_tls_record_with_header(record_buffer, &record.hdr) {
454            Ok((rem, ref msg_list)) => {
455                for msg in msg_list {
456                    status = self.parse_message_level(msg, direction);
457                    if status != ParseResult::Continue(0) {
458                        // Handshake done, but data remaining
459                        let remaining = rem.len();
460                        if matches!(status, ParseResult::HeadersDone(_))
461                            && remaining > 0
462                            && remaining < pdu_len
463                        {
464                            self.last_body_offset = Some(pdu_len - remaining - 1);
465                        }
466                        return status;
467                    }
468                }
469                if !rem.is_empty() {
470                    log::debug!("warn: extra bytes in TLS record: {:?}", rem);
471                };
472            }
473            Err(Err::Incomplete(needed)) => {
474                log::trace!(
475                    "Defragmentation required (TLS record), missing {:?} bytes",
476                    needed
477                );
478                self.record_buffer.extend_from_slice(record.data);
479            }
480            Err(_e) => {
481                log::debug!("warn: parse_tls_record_with_header failed");
482                return ParseResult::Skipped;
483            }
484        };
485
486        status
487    }
488
489    /// Parse a TCP segment, handling TCP chunks fragmentation.
490    pub(crate) fn parse_tcp_level(&mut self, data: &[u8], direction: bool) -> ParseResult {
491        let mut v: Vec<u8>;
492        let mut status = ParseResult::Continue(0);
493        let pdu_len = data.len(); // new data len
494        log::trace!("parse_tcp_level ({} bytes)", data.len());
495        log::trace!("defrag buffer size: {}", self.tcp_buffer.len());
496
497        // do not parse if session is encrypted
498        if self.state == TlsState::ClientChangeCipherSpec {
499            log::trace!("TLS session encrypted, activating bypass");
500            return ParseResult::HeadersDone(0);
501        };
502        // Check if TCP data is being defragmented
503        let tcp_buffer = match self.tcp_buffer.len() {
504            0 => data,
505            _ => {
506                // sanity check vector length to avoid memory exhaustion maximum length may be 2^24
507                // (handshake message)
508                if self.tcp_buffer.len() + data.len() > 16_777_216 {
509                    return ParseResult::Skipped;
510                };
511                v = self.tcp_buffer.split_off(0);
512                v.extend_from_slice(data);
513                v.as_slice()
514            }
515        };
516        let mut cur_data = tcp_buffer;
517        while !cur_data.is_empty() {
518            // parse each TLS record in the TCP segment (there could be multiple)
519            match parse_tls_raw_record(cur_data) {
520                Ok((rem, ref record)) => {
521                    cur_data = rem;
522                    status = self.parse_record_level(record, direction, pdu_len);
523                    if status != ParseResult::Continue(0) {
524                        return status;
525                    }
526                }
527                Err(Err::Incomplete(needed)) => {
528                    log::trace!(
529                        "Defragmentation required (TCP level), missing {:?} bytes",
530                        needed
531                    );
532                    self.tcp_buffer.extend_from_slice(cur_data);
533                    break;
534                }
535                Err(_e) => {
536                    log::debug!("warn: Parsing raw record failed");
537                    break;
538                }
539            }
540        }
541        status
542    }
543}