use super::{nonce::Nonce, overlapping, quic::Sample, NONCE_LEN};
use crate::{
constant_time,
cpu::{self, GetFeature as _},
error,
polyfill::unwrap_const,
};
use cfg_if::cfg_if;
use core::num::NonZeroU32;
pub(super) use ffi::Counter;
#[macro_use]
mod ffi;
mod bs;
pub(super) mod fallback;
pub(super) mod hw;
pub(super) mod vp;
pub type Overlapping<'o> = overlapping::Overlapping<'o, u8>;
pub type OverlappingPartialBlock<'o> = overlapping::PartialBlock<'o, u8, BLOCK_LEN>;
cfg_if! {
if #[cfg(any(all(target_arch = "aarch64", target_endian = "little"), target_arch = "x86_64"))] {
pub(super) use ffi::AES_KEY;
} else {
use ffi::AES_KEY;
}
}
#[derive(Clone)]
pub(super) enum Key {
#[cfg(any(
all(target_arch = "aarch64", target_endian = "little"),
target_arch = "x86_64",
target_arch = "x86"
))]
Hw(hw::Key),
#[cfg(any(
all(target_arch = "aarch64", target_endian = "little"),
all(target_arch = "arm", target_endian = "little"),
target_arch = "x86",
target_arch = "x86_64"
))]
Vp(vp::Key),
Fallback(fallback::Key),
}
impl Key {
#[inline]
pub fn new(
bytes: KeyBytes<'_>,
cpu_features: cpu::Features,
) -> Result<Self, error::Unspecified> {
#[cfg(any(
all(target_arch = "aarch64", target_endian = "little"),
target_arch = "x86",
target_arch = "x86_64"
))]
if let Some(hw_features) = cpu_features.get_feature() {
return Ok(Self::Hw(hw::Key::new(
bytes,
hw_features,
cpu_features.get_feature(),
)?));
}
#[cfg(any(
all(target_arch = "aarch64", target_endian = "little"),
all(target_arch = "arm", target_endian = "little"),
target_arch = "x86_64",
target_arch = "x86"
))]
if let Some(vp_features) = cpu_features.get_feature() {
return Ok(Self::Vp(vp::Key::new(bytes, vp_features)?));
}
let _ = cpu_features;
Ok(Self::Fallback(fallback::Key::new(bytes)?))
}
#[inline]
fn encrypt_block(&self, a: Block) -> Block {
match self {
#[cfg(any(
all(target_arch = "aarch64", target_endian = "little"),
target_arch = "x86_64",
target_arch = "x86"
))]
Key::Hw(inner) => inner.encrypt_block(a),
#[cfg(any(
all(target_arch = "aarch64", target_endian = "little"),
all(target_arch = "arm", target_endian = "little"),
target_arch = "x86",
target_arch = "x86_64"
))]
Key::Vp(inner) => inner.encrypt_block(a),
Key::Fallback(inner) => inner.encrypt_block(a),
}
}
pub fn new_mask(&self, sample: Sample) -> [u8; 5] {
let [b0, b1, b2, b3, b4, ..] = self.encrypt_block(sample);
[b0, b1, b2, b3, b4]
}
}
pub const AES_128_KEY_LEN: usize = 128 / 8;
pub const AES_256_KEY_LEN: usize = 256 / 8;
pub enum KeyBytes<'a> {
AES_128(&'a [u8; AES_128_KEY_LEN]),
AES_256(&'a [u8; AES_256_KEY_LEN]),
}
impl Counter {
pub fn one(nonce: Nonce) -> Self {
let mut value = [0u8; BLOCK_LEN];
value[..NONCE_LEN].copy_from_slice(nonce.as_ref());
value[BLOCK_LEN - 1] = 1;
Self(value)
}
pub fn increment(&mut self) -> Iv {
const ONE: NonZeroU32 = unwrap_const(NonZeroU32::new(1));
let iv = Iv(self.0);
self.increment_by_less_safe(ONE);
iv
}
fn increment_by_less_safe(&mut self, increment_by: NonZeroU32) {
let [.., c0, c1, c2, c3] = &mut self.0;
let old_value: u32 = u32::from_be_bytes([*c0, *c1, *c2, *c3]);
let new_value = old_value + increment_by.get();
[*c0, *c1, *c2, *c3] = u32::to_be_bytes(new_value);
}
}
pub struct Iv(Block);
impl From<Counter> for Iv {
fn from(counter: Counter) -> Self {
Self(counter.0)
}
}
pub(super) type Block = [u8; BLOCK_LEN];
pub(super) const BLOCK_LEN: usize = 16;
pub(super) const ZERO_BLOCK: Block = [0u8; BLOCK_LEN];
pub(super) trait EncryptBlock {
fn encrypt_block(&self, block: Block) -> Block;
fn encrypt_iv_xor_block(&self, iv: Iv, block: Block) -> Block;
}
pub(super) trait EncryptCtr32 {
fn ctr32_encrypt_within(&self, in_out: Overlapping<'_>, ctr: &mut Counter);
}
#[allow(dead_code)]
fn encrypt_block_using_encrypt_iv_xor_block(key: &impl EncryptBlock, block: Block) -> Block {
key.encrypt_iv_xor_block(Iv(block), ZERO_BLOCK)
}
fn encrypt_iv_xor_block_using_encrypt_block(
key: &impl EncryptBlock,
iv: Iv,
block: Block,
) -> Block {
let encrypted_iv = key.encrypt_block(iv.0);
constant_time::xor_16(encrypted_iv, block)
}
#[allow(dead_code)]
fn encrypt_iv_xor_block_using_ctr32(key: &impl EncryptCtr32, iv: Iv, mut block: Block) -> Block {
let mut ctr = Counter(iv.0); key.ctr32_encrypt_within(block.as_mut().into(), &mut ctr);
block
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test;
#[test]
pub fn test_aes() {
test::run(test_file!("aes_tests.txt"), |section, test_case| {
assert_eq!(section, "");
let key = consume_key(test_case, "Key");
let input = test_case.consume_bytes("Input");
let block: Block = input.as_slice().try_into()?;
let expected_output = test_case.consume_bytes("Output");
let output = key.encrypt_block(block);
assert_eq!(output.as_ref(), &expected_output[..]);
Ok(())
})
}
fn consume_key(test_case: &mut test::TestCase, name: &str) -> Key {
let key = test_case.consume_bytes(name);
let key = &key[..];
let key = match key.len() {
16 => KeyBytes::AES_128(key.try_into().unwrap()),
32 => KeyBytes::AES_256(key.try_into().unwrap()),
_ => unreachable!(),
};
Key::new(key, cpu::features()).unwrap()
}
}