// Reference: // - https://blog.0x7d0.dev/education/how-aes-is-implemented/ // - https://github.com/boppreh/aes/ // TODO: Salt and stuff use crate::hash::Hashable; pub type Key = [u8; 128 / 8]; type Word = [u8; 4]; type Block = [Word; 4]; type Array16 = [u8; 16]; pub type Result = std::result::Result; #[derive(Debug)] pub enum AesError { MissingCiphertext, InvalidPadding, InvalidRoundKeys, } const SUBSTITUTION_BOX: [u8; 256] = [ 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, ]; const INVERSE_SUBSTITUTION_BOX: [u8; 256] = [ 0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB, 0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB, 0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E, 0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25, 0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92, 0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84, 0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06, 0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B, 0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73, 0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E, 0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B, 0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4, 0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F, 0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF, 0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61, 0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D, ]; const ROUNDS: usize = 10; // const ROUND_CONSTANT: [u8; 16] = [ // 0x7d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, // ]; const ROUND_CONSTANT: [u8; 32] = [ 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, 0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A, 0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39, ]; pub fn encrypt_cbc(plaintext: &mut Vec, key: &H) where H: Hashable, { let key = key.hash().as_bytes_capped::<16>(); let round_keys = expand_key(&key); pad(plaintext); // TODO: Derive this properly let mut previous_block = empty_block(); for ch in plaintext.chunks_mut(16) { let mut buf = empty_array_16(); buf.copy_from_slice(ch); let mut block = array_to_block(&buf); xor_blocks(&mut block, &previous_block); encrypt_block(&mut block, &round_keys) .expect("expand_key() should always return the correct amount of round keys"); previous_block = block.clone(); let mut chunk_ciphertext = block_to_array(&block); ch.swap_with_slice(&mut chunk_ciphertext); } } pub fn decrypt_cbc(ciphertext: &mut Vec, key: &H) -> Result<()> where H: Hashable, { if ciphertext.is_empty() { return Err(AesError::MissingCiphertext); } let key = key.hash().as_bytes_capped::<16>(); let round_keys = expand_key(&key); let mut previous_block = empty_block(); for ch in ciphertext.chunks_mut(16) { let mut buf = empty_array_16(); buf.copy_from_slice(ch); let mut block = array_to_block(&buf); let prev_temp = block.clone(); decrypt_block(&mut block, &round_keys) .expect("expand_key() should always return the correct amount of round keys"); xor_blocks(&mut block, &previous_block); previous_block = prev_temp; let mut chunk_plaintext = block_to_array(&block); ch.swap_with_slice(&mut chunk_plaintext); } unpad(ciphertext)?; Ok(()) } fn encrypt_block(block: &mut Block, round_keys: &[Block]) -> Result<()> { // NOTE: Only works for 128 bit keys. if round_keys.len() != ROUNDS + 1 { return Err(AesError::InvalidRoundKeys); } add_round_key(block, &round_keys[0]); for round in 1..=ROUNDS { substitute_block(block); shift_rows_left(block); if round != ROUNDS { mix_words(block); } add_round_key(block, &round_keys[round]); } Ok(()) } fn decrypt_block(block: &mut Block, round_keys: &[Block]) -> Result<()> { // NOTE: Only works for 128 bit keys. if round_keys.len() != ROUNDS + 1 { return Err(AesError::InvalidRoundKeys); } for round in (1..=ROUNDS).rev() { add_round_key(block, &round_keys[round]); if round != ROUNDS { inverse_mix_words(block); } shift_rows_right(block); inverse_substitute_block(block); } add_round_key(block, &round_keys[0]); Ok(()) } fn pad(plaintext: &mut Vec) { let padding_len = 16 - (plaintext.len() % 16) as u8; let padding = [padding_len].repeat(padding_len as usize); plaintext.extend(padding); } fn unpad(ciphertext: &mut Vec) -> Result<()> { const E: AesError = AesError::InvalidPadding; let m1 = ciphertext.len().checked_sub(1).ok_or(E)?; let padding_len = *ciphertext.get(m1).ok_or(E)?; let start = ciphertext .len() .checked_sub(padding_len as usize) .ok_or(E)?; let padding = ciphertext.get(start..).ok_or(E)?.to_vec(); if !padding.iter().all(|x| x == &padding_len) { return Err(E); } ciphertext.resize(start, 0); Ok(()) } #[inline] fn empty_array_16() -> Array16 { [0; 16] } #[inline] fn empty_word() -> Word { [0; 4] } #[inline] fn empty_block() -> Block { [empty_word(); 4] } fn array_to_block(array: &[u8; 16]) -> Block { let mut block: Block = empty_block(); for (idx, value) in array.iter().enumerate() { let idx_b = idx % 4; let idx_a = (idx - idx_b) / 4; block[idx_a][idx_b] = *value; } block } fn block_to_array(block: &Block) -> [u8; 16] { let mut array = empty_array_16(); let mut ida = 0; for idx in 0..4 { for idy in 0..4 { array[ida] = block[idx][idy]; ida += 1; } } array } fn substitute_block(block: &mut Block) { for idx in 0..4 { substitute_word(&mut block[idx]); } } fn substitute_word(word: &mut Word) { for idx in 0..4 { word[idx] = SUBSTITUTION_BOX[word[idx] as usize]; } } fn inverse_substitute_block(block: &mut Block) { for idx in 0..4 { inverse_substitute_word(&mut block[idx]); } } fn inverse_substitute_word(word: &mut Word) { for idx in 0..4 { word[idx] = INVERSE_SUBSTITUTION_BOX[word[idx] as usize]; } } fn shift_rows_left(block: &mut Block) { for idx in 1..4 { let temp = block[idx]; for idy in 0..4 { let shifted_idy = (idy + idx) % 4; block[idx][idy] = temp[shifted_idy]; } } } fn shift_rows_right(block: &mut Block) { for idx in 1..4 { let temp = block[idx]; for idy in 0..4 { let shifted_idy = (idy + idx) % 4; block[idx][shifted_idy] = temp[idy]; } } } fn xtime(byte: u8) -> u8 { if byte & 0x80 > 0 { return (byte << 1) ^ 0x1B; } else { return byte << 1; } } fn mix_words(block: &mut Block) { for idx in 0..4 { let xor = block[idx][0] ^ block[idx][1] ^ block[idx][2] ^ block[idx][3]; let first = block[idx][0]; block[idx][0] ^= xtime(block[idx][0] ^ block[idx][1]) ^ xor; block[idx][1] ^= xtime(block[idx][1] ^ block[idx][2]) ^ xor; block[idx][2] ^= xtime(block[idx][2] ^ block[idx][3]) ^ xor; block[idx][3] ^= xtime(block[idx][3] ^ first) ^ xor; } } fn inverse_mix_words(block: &mut Block) { for idx in 0..4 { let a = xtime(xtime(block[idx][0] ^ block[idx][2])); let b = xtime(xtime(block[idx][1] ^ block[idx][3])); block[idx][0] ^= a; block[idx][1] ^= b; block[idx][2] ^= a; block[idx][3] ^= b; } mix_words(block); } fn add_round_key(block: &mut Block, round_key: &Block) { for idx in 0..4 { for idy in 0..4 { block[idx][idy] ^= round_key[idx][idy]; } } } fn xor_words(target: &mut Word, modifier: &Word) { for idx in 0..4 { target[idx] ^= modifier[idx]; } } fn xor_blocks(target: &mut Block, modifier: &Block) { for idx in 0..4 { xor_words(&mut target[idx], &modifier[idx]); } } fn expand_key(key: &Key) -> Vec { let mut key_words = array_to_block(key).to_vec(); let key_initial_word_count = key_words.len(); let mut idx = 1; while key_words.len() < (ROUNDS + 1) * 4 { let mut word = key_words[key_words.len() - 1]; if key_words.len() % key_initial_word_count == 0 { let first = word[0]; word[0] = word[1]; word[1] = word[2]; word[2] = word[3]; word[3] = first; substitute_word(&mut word); word[0] ^= ROUND_CONSTANT[idx]; idx += 1; } // TODO: >128 bit keys xor_words( &mut word, &key_words[key_words.len() - key_initial_word_count], ); key_words.push(word); } let mut expanded_keys = Vec::new(); let full_key_count = (key_words.len() - (key_words.len() % 4)) / 4; for idx in 0..full_key_count { let mut block = empty_block(); block.copy_from_slice(&key_words[idx * 4..(idx + 1) * 4]); expanded_keys.push(block); } expanded_keys } #[cfg(test)] mod test { use super::*; const TEST_KEY: Key = [ 130, 191, 5, 162, 175, 104, 200, 14, 32, 0, 97, 170, 10, 83, 159, 90, ]; const TEST_ARRAY_16: [u8; 16] = [ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, ]; const TEST_BLOCK: Block = [ [0x00, 0x01, 0x02, 0x03], [0x04, 0x05, 0x06, 0x07], [0x08, 0x09, 0x0A, 0x0B], [0x0C, 0x0D, 0x0E, 0x0F], ]; const TEST_SUBSTITUTED_BLOCK: Block = [ [0x63, 0x7c, 0x77, 0x7b], [0xf2, 0x6b, 0x6f, 0xc5], [0x30, 0x01, 0x67, 0x2b], [0xfe, 0xd7, 0xab, 0x76], ]; const TEST_SHIFTED_BLOCK: Block = [ [0x00, 0x01, 0x02, 0x03], [0x05, 0x06, 0x07, 0x04], [0x0A, 0x0B, 0x08, 0x09], [0x0F, 0x0C, 0x0D, 0x0E], ]; #[test] fn test_array_to_block() { let block = array_to_block(&TEST_ARRAY_16); assert_eq!(block, TEST_BLOCK); let array = block_to_array(&block); assert_eq!(array, TEST_ARRAY_16); } #[test] fn test_substitute() { let mut input = TEST_BLOCK.clone(); substitute_block(&mut input); assert_eq!(input, TEST_SUBSTITUTED_BLOCK); inverse_substitute_block(&mut input); assert_eq!(input, TEST_BLOCK); } #[test] fn test_shift_rows() { let mut input = TEST_BLOCK.clone(); shift_rows_left(&mut input); assert_eq!(input, TEST_SHIFTED_BLOCK); shift_rows_right(&mut input); assert_eq!(input, TEST_BLOCK); } fn _test_pad(input: Vec, expected: Vec) { let mut buf = input.clone(); pad(&mut buf); assert_eq!(buf, expected); unpad(&mut buf).unwrap(); assert_eq!(buf, input); } #[test] fn test_pad() { let input = vec![1, 2, 3, 4, 5, 6, 7]; let expected = vec![1, 2, 3, 4, 5, 6, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9]; _test_pad(input, expected); let input = vec![10].repeat(100); let mut expected = input.clone(); expected.resize(112, 12); _test_pad(input, expected); } #[test] fn test_xtime() { let tests = &[(84, 168), (255, 229), (0, 0), (240, 251)]; for (a, b) in tests { assert_eq!(xtime(*a), *b); } } #[test] fn test_mix_words() { let mut input = TEST_BLOCK; mix_words(&mut input); inverse_mix_words(&mut input); assert_eq!(input, TEST_BLOCK); } #[test] fn test_xor() { let mut input = TEST_BLOCK; let modifier = TEST_SHIFTED_BLOCK; xor_blocks(&mut input, &modifier); xor_blocks(&mut input, &modifier); assert_eq!(input, TEST_BLOCK); } #[test] fn test_expand_key() { let round_keys = expand_key(&TEST_KEY); let expected: &[Block] = &[ [ [0b10000010, 0b10111111, 0b101, 0b10100010], [0b10101111, 0b1101000, 0b11001000, 0b1110], [0b100000, 0b0, 0b1100001, 0b10101010], [0b1010, 0b1010011, 0b10011111, 0b1011010], ], [ [0b1101110, 0b1100100, 0b10111011, 0b11000101], [0b11000001, 0b1100, 0b1110011, 0b11001011], [0b11100001, 0b1100, 0b10010, 0b1100001], [0b11101011, 0b1011111, 0b10001101, 0b111011], ], [ [0b10100011, 0b111001, 0b1011001, 0b101100], [0b1100010, 0b110101, 0b101010, 0b11100111], [0b10000011, 0b111001, 0b111000, 0b10000110], [0b1101000, 0b1100110, 0b10110101, 0b10111101], ], [ [0b10010100, 0b11101100, 0b100011, 0b1101001], [0b11110110, 0b11011001, 0b1001, 0b10001110], [0b1110101, 0b11100000, 0b110001, 0b1000], [0b11101, 0b10000110, 0b10000100, 0b10110101], ], [ [0b11011000, 0b10110011, 0b11110110, 0b11001101], [0b101110, 0b1101010, 0b11111111, 0b1000011], [0b1011011, 0b10001010, 0b11001110, 0b1001011], [0b1000110, 0b1100, 0b1001010, 0b11111110], ], [ [0b110110, 0b1100101, 0b1001101, 0b10010111], [0b11000, 0b1111, 0b10110010, 0b11010100], [0b1000011, 0b10000101, 0b1111100, 0b10011111], [0b101, 0b10001001, 0b110110, 0b1100001], ], [ [0b10110001, 0b1100000, 0b10100010, 0b11111100], [0b10101001, 0b1101111, 0b10000, 0b101000], [0b11101010, 0b11101010, 0b1101100, 0b10110111], [0b11101111, 0b1100011, 0b1011010, 0b11010110], ], [ [0b1010, 0b11011110, 0b1010100, 0b100011], [0b10100011, 0b10110001, 0b1000100, 0b1011], [0b1001001, 0b1011011, 0b101000, 0b10111100], [0b10100110, 0b111000, 0b1110010, 0b1101010], ], [ [0b10001101, 0b10011110, 0b1010110, 0b111], [0b101110, 0b101111, 0b10010, 0b1100], [0b1100111, 0b1110100, 0b111010, 0b10110000], [0b11000001, 0b1001100, 0b1001000, 0b11011010], ], [ [0b10111111, 0b11001100, 0b1, 0b1111111], [0b10010001, 0b11100011, 0b10011, 0b1110011], [0b11110110, 0b10010111, 0b101001, 0b11000011], [0b110111, 0b11011011, 0b1100001, 0b11001], ], [ [0b110000, 0b100011, 0b11010101, 0b11100101], [0b10100001, 0b11000000, 0b11000110, 0b10010110], [0b1010111, 0b1010111, 0b11101111, 0b1010101], [0b1100000, 0b10001100, 0b10001110, 0b1001100], ], ]; assert_eq!(round_keys, expected); } #[test] fn test_encrypt_block() { let round_keys = expand_key(&TEST_KEY); let input = TEST_BLOCK.clone(); let mut buf = input.clone(); let _ = encrypt_block(&mut buf, &round_keys); let _ = decrypt_block(&mut buf, &round_keys); assert_eq!(buf, input); } #[test] fn test_aes() { let input = vec![1, 2, 3, 4, 5]; let mut buf = input.clone(); encrypt_cbc(&mut buf, &"password".to_string()); println!("Ciphertext: {:?}", buf); decrypt_cbc(&mut buf, &"password").unwrap(); assert_eq!(buf, input); let input = vec![10].repeat(100); let mut buf = input.clone(); encrypt_cbc(&mut buf, &"password".to_string()); decrypt_cbc(&mut buf, &"password").unwrap(); assert_eq!(buf, input); } #[test] fn test_empty_indexing() { encrypt_cbc(&mut vec![], &""); let _ = decrypt_cbc(&mut vec![], &""); let _ = encrypt_block(&mut empty_block(), &[]); let _ = decrypt_block(&mut empty_block(), &[]); pad(&mut vec![]); let _ = unpad(&mut vec![]); } #[test] fn test_empty_password() { let mut input = TEST_ARRAY_16.to_vec(); encrypt_cbc(&mut input, &""); decrypt_cbc(&mut input, &"").unwrap(); assert_eq!(input, TEST_ARRAY_16); } }