1#[cfg(feature = "std")]
36use alloc::collections::VecDeque;
37use alloc::vec::Vec;
38use core::fmt::Debug;
39#[cfg(feature = "std")]
40use std::sync::Mutex;
41
42use crate::enums::CertificateCompressionAlgorithm;
43use crate::msgs::base::{Payload, PayloadU24};
44use crate::msgs::codec::Codec;
45use crate::msgs::handshake::{CertificatePayloadTls13, CompressedCertificatePayload};
46use crate::sync::Arc;
47
48pub fn default_cert_decompressors() -> &'static [&'static dyn CertDecompressor] {
51 &[
52 #[cfg(feature = "brotli")]
53 BROTLI_DECOMPRESSOR,
54 #[cfg(feature = "zlib")]
55 ZLIB_DECOMPRESSOR,
56 ]
57}
58
59pub trait CertDecompressor: Debug + Send + Sync {
61 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed>;
68
69 fn algorithm(&self) -> CertificateCompressionAlgorithm;
71}
72
73pub fn default_cert_compressors() -> &'static [&'static dyn CertCompressor] {
76 &[
77 #[cfg(feature = "brotli")]
78 BROTLI_COMPRESSOR,
79 #[cfg(feature = "zlib")]
80 ZLIB_COMPRESSOR,
81 ]
82}
83
84pub trait CertCompressor: Debug + Send + Sync {
86 fn compress(
95 &self,
96 input: Vec<u8>,
97 level: CompressionLevel,
98 ) -> Result<Vec<u8>, CompressionFailed>;
99
100 fn algorithm(&self) -> CertificateCompressionAlgorithm;
102}
103
104#[derive(Debug, Copy, Clone, Eq, PartialEq)]
106pub enum CompressionLevel {
107 Interactive,
111
112 Amortized,
116}
117
118#[derive(Debug)]
120pub struct DecompressionFailed;
121
122#[derive(Debug)]
124pub struct CompressionFailed;
125
126#[cfg(feature = "zlib")]
127mod feat_zlib_rs {
128 use zlib_rs::{
129 DeflateConfig, InflateConfig, ReturnCode, compress_bound, compress_slice, decompress_slice,
130 };
131
132 use super::*;
133
134 pub const ZLIB_DECOMPRESSOR: &dyn CertDecompressor = &ZlibRsDecompressor;
136
137 #[derive(Debug)]
138 struct ZlibRsDecompressor;
139
140 impl CertDecompressor for ZlibRsDecompressor {
141 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
142 let output_len = output.len();
143 match decompress_slice(output, input, InflateConfig::default()) {
144 (output_filled, ReturnCode::Ok) if output_filled.len() == output_len => Ok(()),
145 (_, _) => Err(DecompressionFailed),
146 }
147 }
148
149 fn algorithm(&self) -> CertificateCompressionAlgorithm {
150 CertificateCompressionAlgorithm::Zlib
151 }
152 }
153
154 pub const ZLIB_COMPRESSOR: &dyn CertCompressor = &ZlibRsCompressor;
156
157 #[derive(Debug)]
158 struct ZlibRsCompressor;
159
160 impl CertCompressor for ZlibRsCompressor {
161 fn compress(
162 &self,
163 input: Vec<u8>,
164 level: CompressionLevel,
165 ) -> Result<Vec<u8>, CompressionFailed> {
166 let mut output = alloc::vec![0u8; compress_bound(input.len())];
167 let config = match level {
168 CompressionLevel::Interactive => DeflateConfig::default(),
169 CompressionLevel::Amortized => DeflateConfig::best_compression(),
170 };
171 let (output_filled, rc) = compress_slice(&mut output, &input, config);
172 if rc != ReturnCode::Ok {
173 return Err(CompressionFailed);
174 }
175
176 let used = output_filled.len();
177 output.truncate(used);
178 Ok(output)
179 }
180
181 fn algorithm(&self) -> CertificateCompressionAlgorithm {
182 CertificateCompressionAlgorithm::Zlib
183 }
184 }
185}
186
187#[cfg(feature = "zlib")]
188pub use feat_zlib_rs::{ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR};
189
190#[cfg(feature = "brotli")]
191mod feat_brotli {
192 use std::io::{Cursor, Write};
193
194 use super::*;
195
196 pub const BROTLI_DECOMPRESSOR: &dyn CertDecompressor = &BrotliDecompressor;
198
199 #[derive(Debug)]
200 struct BrotliDecompressor;
201
202 impl CertDecompressor for BrotliDecompressor {
203 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
204 let mut in_cursor = Cursor::new(input);
205 let mut out_cursor = Cursor::new(output);
206
207 brotli::BrotliDecompress(&mut in_cursor, &mut out_cursor)
208 .map_err(|_| DecompressionFailed)?;
209
210 if out_cursor.position() as usize != out_cursor.into_inner().len() {
211 return Err(DecompressionFailed);
212 }
213
214 Ok(())
215 }
216
217 fn algorithm(&self) -> CertificateCompressionAlgorithm {
218 CertificateCompressionAlgorithm::Brotli
219 }
220 }
221
222 pub const BROTLI_COMPRESSOR: &dyn CertCompressor = &BrotliCompressor;
224
225 #[derive(Debug)]
226 struct BrotliCompressor;
227
228 impl CertCompressor for BrotliCompressor {
229 fn compress(
230 &self,
231 input: Vec<u8>,
232 level: CompressionLevel,
233 ) -> Result<Vec<u8>, CompressionFailed> {
234 let quality = match level {
235 CompressionLevel::Interactive => QUALITY_FAST,
236 CompressionLevel::Amortized => QUALITY_SLOW,
237 };
238 let output = Cursor::new(Vec::with_capacity(input.len() / 2));
239 let mut compressor = brotli::CompressorWriter::new(output, BUFFER_SIZE, quality, LGWIN);
240 compressor
241 .write_all(&input)
242 .map_err(|_| CompressionFailed)?;
243 Ok(compressor.into_inner().into_inner())
244 }
245
246 fn algorithm(&self) -> CertificateCompressionAlgorithm {
247 CertificateCompressionAlgorithm::Brotli
248 }
249 }
250
251 const BUFFER_SIZE: usize = 4096;
255
256 const LGWIN: u32 = 22;
258
259 const QUALITY_FAST: u32 = 4;
262
263 const QUALITY_SLOW: u32 = 11;
265}
266
267#[cfg(feature = "brotli")]
268pub use feat_brotli::{BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR};
269
270#[derive(Debug)]
276pub enum CompressionCache {
277 Disabled,
280
281 #[cfg(feature = "std")]
283 Enabled(CompressionCacheInner),
284}
285
286#[cfg(feature = "std")]
290#[derive(Debug)]
291pub struct CompressionCacheInner {
292 size: usize,
294
295 entries: Mutex<VecDeque<Arc<CompressionCacheEntry>>>,
299}
300
301impl CompressionCache {
302 #[cfg(feature = "std")]
305 pub fn new(size: usize) -> Self {
306 if size == 0 {
307 return Self::Disabled;
308 }
309
310 Self::Enabled(CompressionCacheInner {
311 size,
312 entries: Mutex::new(VecDeque::with_capacity(size)),
313 })
314 }
315
316 pub(crate) fn compression_for(
322 &self,
323 compressor: &dyn CertCompressor,
324 original: &CertificatePayloadTls13<'_>,
325 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
326 match self {
327 Self::Disabled => Self::uncached_compression(compressor, original),
328
329 #[cfg(feature = "std")]
330 Self::Enabled(_) => self.compression_for_impl(compressor, original),
331 }
332 }
333
334 #[cfg(feature = "std")]
335 fn compression_for_impl(
336 &self,
337 compressor: &dyn CertCompressor,
338 original: &CertificatePayloadTls13<'_>,
339 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
340 let (max_size, entries) = match self {
341 Self::Enabled(CompressionCacheInner { size, entries }) => (*size, entries),
342 _ => unreachable!(),
343 };
344
345 if !original.context.0.is_empty() {
348 return Self::uncached_compression(compressor, original);
349 }
350
351 let encoding = original.get_encoding();
353 let algorithm = compressor.algorithm();
354
355 let mut cache = entries
356 .lock()
357 .map_err(|_| CompressionFailed)?;
358 for (i, item) in cache.iter().enumerate() {
359 if item.algorithm == algorithm && item.original == encoding {
360 let item = cache.remove(i).unwrap();
362 cache.push_back(item.clone());
363 return Ok(item);
364 }
365 }
366 drop(cache);
367
368 let uncompressed_len = encoding.len() as u32;
370 let compressed = compressor.compress(encoding.clone(), CompressionLevel::Amortized)?;
371 let new_entry = Arc::new(CompressionCacheEntry {
372 algorithm,
373 original: encoding,
374 compressed: CompressedCertificatePayload {
375 alg: algorithm,
376 uncompressed_len,
377 compressed: PayloadU24(Payload::new(compressed)),
378 },
379 });
380
381 let mut cache = entries
383 .lock()
384 .map_err(|_| CompressionFailed)?;
385 if cache.len() == max_size {
386 cache.pop_front();
387 }
388 cache.push_back(new_entry.clone());
389 Ok(new_entry)
390 }
391
392 fn uncached_compression(
394 compressor: &dyn CertCompressor,
395 original: &CertificatePayloadTls13<'_>,
396 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
397 let algorithm = compressor.algorithm();
398 let encoding = original.get_encoding();
399 let uncompressed_len = encoding.len() as u32;
400 let compressed = compressor.compress(encoding, CompressionLevel::Interactive)?;
401
402 Ok(Arc::new(CompressionCacheEntry {
405 algorithm,
406 original: Vec::new(),
407 compressed: CompressedCertificatePayload {
408 alg: algorithm,
409 uncompressed_len,
410 compressed: PayloadU24(Payload::new(compressed)),
411 },
412 }))
413 }
414}
415
416impl Default for CompressionCache {
417 fn default() -> Self {
418 #[cfg(feature = "std")]
419 {
420 Self::new(4)
422 }
423
424 #[cfg(not(feature = "std"))]
425 {
426 Self::Disabled
427 }
428 }
429}
430
431#[cfg_attr(not(feature = "std"), allow(dead_code))]
432#[derive(Debug)]
433pub(crate) struct CompressionCacheEntry {
434 algorithm: CertificateCompressionAlgorithm,
436 original: Vec<u8>,
437
438 compressed: CompressedCertificatePayload<'static>,
440}
441
442impl CompressionCacheEntry {
443 pub(crate) fn compressed_cert_payload(&self) -> CompressedCertificatePayload<'_> {
444 self.compressed.as_borrowed()
445 }
446}
447
448#[cfg(all(test, any(feature = "brotli", feature = "zlib")))]
449mod tests {
450 use std::{println, vec};
451
452 use super::*;
453
454 #[test]
455 #[cfg(feature = "zlib")]
456 fn test_zlib() {
457 test_compressor(ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR);
458 }
459
460 #[test]
461 #[cfg(feature = "brotli")]
462 fn test_brotli() {
463 test_compressor(BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR);
464 }
465
466 fn test_compressor(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
467 assert_eq!(comp.algorithm(), decomp.algorithm());
468 for sz in [16, 64, 512, 2048, 8192, 16384] {
469 test_trivial_pairwise(comp, decomp, sz);
470 }
471 test_decompress_wrong_len(comp, decomp);
472 test_decompress_garbage(decomp);
473 }
474
475 fn test_trivial_pairwise(
476 comp: &dyn CertCompressor,
477 decomp: &dyn CertDecompressor,
478 plain_len: usize,
479 ) {
480 let original = vec![0u8; plain_len];
481
482 for level in [CompressionLevel::Interactive, CompressionLevel::Amortized] {
483 let compressed = comp
484 .compress(original.clone(), level)
485 .unwrap();
486 println!(
487 "{:?} compressed trivial {} -> {} using {:?} level",
488 comp.algorithm(),
489 original.len(),
490 compressed.len(),
491 level
492 );
493 let mut recovered = vec![0xffu8; plain_len];
494 decomp
495 .decompress(&compressed, &mut recovered)
496 .unwrap();
497 assert_eq!(original, recovered);
498 }
499 }
500
501 fn test_decompress_wrong_len(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
502 let original = vec![0u8; 2048];
503 let compressed = comp
504 .compress(original.clone(), CompressionLevel::Interactive)
505 .unwrap();
506 println!("{compressed:?}");
507
508 let mut recovered = vec![0xffu8; original.len() + 1];
510 decomp
511 .decompress(&compressed, &mut recovered)
512 .unwrap_err();
513
514 let mut recovered = vec![0xffu8; original.len() - 1];
516 decomp
517 .decompress(&compressed, &mut recovered)
518 .unwrap_err();
519 }
520
521 fn test_decompress_garbage(decomp: &dyn CertDecompressor) {
522 let junk = [0u8; 1024];
523 let mut recovered = vec![0u8; 512];
524 decomp
525 .decompress(&junk, &mut recovered)
526 .unwrap_err();
527 }
528
529 #[test]
530 #[cfg(all(feature = "brotli", feature = "zlib"))]
531 fn test_cache_evicts_lru() {
532 use core::sync::atomic::{AtomicBool, Ordering};
533
534 use pki_types::CertificateDer;
535
536 let cache = CompressionCache::default();
537
538 let cert = CertificateDer::from(vec![1]);
539
540 let cert1 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"1"));
541 let cert2 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"2"));
542 let cert3 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"3"));
543 let cert4 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"4"));
544
545 cache
548 .compression_for(
549 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
550 &cert1,
551 )
552 .unwrap();
553 cache
554 .compression_for(
555 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
556 &cert2,
557 )
558 .unwrap();
559 cache
560 .compression_for(
561 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
562 &cert3,
563 )
564 .unwrap();
565 cache
566 .compression_for(
567 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
568 &cert4,
569 )
570 .unwrap();
571
572 cache
576 .compression_for(
577 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
578 &cert4,
579 )
580 .unwrap();
581
582 cache
584 .compression_for(
585 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
586 &cert2,
587 )
588 .unwrap();
589 cache
590 .compression_for(
591 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
592 &cert3,
593 )
594 .unwrap();
595 cache
596 .compression_for(
597 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
598 &cert4,
599 )
600 .unwrap();
601 cache
602 .compression_for(
603 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), false),
604 &cert4,
605 )
606 .unwrap();
607
608 cache
610 .compression_for(
611 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
612 &cert1,
613 )
614 .unwrap();
615
616 cache
619 .compression_for(
620 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
621 &cert4,
622 )
623 .unwrap();
624 cache
625 .compression_for(
626 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
627 &cert3,
628 )
629 .unwrap();
630 cache
631 .compression_for(
632 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
633 &cert1,
634 )
635 .unwrap();
636
637 cache
640 .compression_for(
641 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
642 &cert1,
643 )
644 .unwrap();
645
646 cache
648 .compression_for(
649 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
650 &cert4,
651 )
652 .unwrap();
653
654 #[derive(Debug)]
655 struct RequireCompress(&'static dyn CertCompressor, AtomicBool, bool);
656
657 impl CertCompressor for RequireCompress {
658 fn compress(
659 &self,
660 input: Vec<u8>,
661 level: CompressionLevel,
662 ) -> Result<Vec<u8>, CompressionFailed> {
663 self.1.store(true, Ordering::SeqCst);
664 self.0.compress(input, level)
665 }
666
667 fn algorithm(&self) -> CertificateCompressionAlgorithm {
668 self.0.algorithm()
669 }
670 }
671
672 impl Drop for RequireCompress {
673 fn drop(&mut self) {
674 assert_eq!(self.1.load(Ordering::SeqCst), self.2);
675 }
676 }
677 }
678}