Skip to content

Commit

Permalink
Add specification for Advanced Encryption Standard (AES)
Browse files Browse the repository at this point in the history
  • Loading branch information
pennyannn committed Mar 27, 2024
1 parent b148912 commit cecc746
Show file tree
Hide file tree
Showing 5 changed files with 374 additions and 148 deletions.
101 changes: 3 additions & 98 deletions Arm/Insts/DPSFP/Crypto_aes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,59 +8,14 @@ Author(s): Shilpi Goel, Yan Peng
import Arm.Decode
import Arm.Insts.Common
import Arm.BitVec
import Specs.AESCommon

----------------------------------------------------------------------

namespace DPSFP

open BitVec

def SBox :=
-- F E D C B A 9 8 7 6 5 4 3 2 1 0
[ 0x16bb54b00f2d99416842e6bf0d89a18c#128, -- F
0xdf2855cee9871e9b948ed9691198f8e1#128, -- E
0x9e1dc186b95735610ef6034866b53e70#128, -- D
0x8a8bbd4b1f74dde8c6b4a61c2e2578ba#128, -- C
0x08ae7a65eaf4566ca94ed58d6d37c8e7#128, -- B
0x79e4959162acd3c25c2406490a3a32e0#128, -- A
0xdb0b5ede14b8ee4688902a22dc4f8160#128, -- 9
0x73195d643d7ea7c41744975fec130ccd#128, -- 8
0xd2f3ff1021dab6bcf5389d928f40a351#128, -- 7
0xa89f3c507f02f94585334d43fbaaefd0#128, -- 6
0xcf584c4a39becb6a5bb1fc20ed00d153#128, -- 5
0x842fe329b3d63b52a05a6e1b1a2c8309#128, -- 4
0x75b227ebe28012079a059618c323c704#128, -- 3
0x1531d871f1e5a534ccf73f362693fdb7#128, -- 2
0xc072a49cafa2d4adf04759fa7dc982ca#128, -- 1
0x76abd7fe2b670130c56f6bf27b777c63#128 -- 0
]

def AESShiftRows (op : BitVec 128) : BitVec 128 :=
extractLsb 95 88 op ++ extractLsb 55 48 op ++
extractLsb 15 8 op ++ extractLsb 103 96 op ++
extractLsb 63 56 op ++ extractLsb 23 16 op ++
extractLsb 111 104 op ++ extractLsb 71 64 op ++
extractLsb 31 24 op ++ extractLsb 119 112 op ++
extractLsb 79 72 op ++ extractLsb 39 32 op ++
extractLsb 127 120 op ++ extractLsb 87 80 op ++
extractLsb 47 40 op ++ extractLsb 7 0 op

def AESSubBytes_aux (i : Nat) (op : BitVec 128) (out : BitVec 128)
: BitVec 128 :=
if h₀ : 16 <= i then
out
else
let idx := (extractLsb (i * 8 + 7) (i * 8) op).toNat
let val := extractLsb (idx * 8 + 7) (idx * 8) $ BitVec.flatten SBox
have h₁ : idx * 8 + 7 - idx * 8 = i * 8 + 7 - i * 8 := by omega
let out := BitVec.partInstall (i * 8 + 7) (i * 8) (h₁ ▸ val) out
have _ : 15 - i < 16 - i := by omega
AESSubBytes_aux (i + 1) op out
termination_by (16 - i)

def AESSubBytes (op : BitVec 128) : BitVec 128 :=
AESSubBytes_aux 0 op (BitVec.zero 128)

@[state_simp_rules]
def exec_aese
(inst : Crypto_aes_cls) (s : ArmState) : ArmState :=
Expand All @@ -69,7 +24,7 @@ def exec_aese
let operand1 := read_sfp 128 inst.Rd s
let operand2 := read_sfp 128 inst.Rn s
let result := operand1 ^^^ operand2
let result := AESSubBytes $ AESShiftRows result
let result := AESCommon.SubBytes $ AESCommon.ShiftRows result
-- State Updates
let s := write_sfp 128 inst.Rd result s
let s := write_pc ((read_pc s) + 4#64) s
Expand Down Expand Up @@ -125,58 +80,8 @@ def FFmul03 (b : BitVec 8) : BitVec 8 :=
have h : hi - lo + 1 = 8 := by omega
h ▸ extractLsb hi lo $ BitVec.flatten FFmul_03

def AESMixColumns_aux (c : Nat)
(in0 : BitVec 32) (in1 : BitVec 32) (in2 : BitVec 32) (in3 : BitVec 32)
(out0 : BitVec 32) (out1 : BitVec 32) (out2 : BitVec 32) (out3 : BitVec 32)
: BitVec 32 × BitVec 32 × BitVec 32 × BitVec 32 :=
if h₀ : 4 <= c then
(out0, out1, out2, out3)
else
let lo := c * 8
let hi := lo + 7
have h₁ : hi - lo + 1 = 8 := by omega
let in0_byte := h₁ ▸ extractLsb hi lo in0
let in1_byte := h₁ ▸ extractLsb hi lo in1
let in2_byte := h₁ ▸ extractLsb hi lo in2
let in3_byte := h₁ ▸ extractLsb hi lo in3
let val0 := h₁.symm ▸ (FFmul02 in0_byte ^^^ FFmul03 in1_byte ^^^ in2_byte ^^^ in3_byte)
let out0 := BitVec.partInstall hi lo val0 out0
let val1 := h₁.symm ▸ (FFmul02 in1_byte ^^^ FFmul03 in2_byte ^^^ in3_byte ^^^ in0_byte)
let out1 := BitVec.partInstall hi lo val1 out1
let val2 := h₁.symm ▸ (FFmul02 in2_byte ^^^ FFmul03 in3_byte ^^^ in0_byte ^^^ in1_byte)
let out2 := BitVec.partInstall hi lo val2 out2
let val3 := h₁.symm ▸ (FFmul02 in3_byte ^^^ FFmul03 in0_byte ^^^ in1_byte ^^^ in2_byte)
let out3 := BitVec.partInstall hi lo val3 out3
have _ : 3 - c < 4 - c := by omega
AESMixColumns_aux (c + 1) in0 in1 in2 in3 out0 out1 out2 out3
termination_by (4 - c)

def AESMixColumns (op : BitVec 128) : BitVec 128 :=
let in0 :=
extractLsb 103 96 op ++ extractLsb 71 64 op ++
extractLsb 39 32 op ++ extractLsb 7 0 op
let in1 :=
extractLsb 111 104 op ++ extractLsb 79 72 op ++
extractLsb 47 40 op ++ extractLsb 15 8 op
let in2 :=
extractLsb 119 112 op ++ extractLsb 87 80 op ++
extractLsb 55 48 op ++ extractLsb 23 16 op
let in3 :=
extractLsb 127 120 op ++ extractLsb 95 88 op ++
extractLsb 63 56 op ++ extractLsb 31 24 op
let (out0, out1, out2, out3) :=
(BitVec.zero 32, BitVec.zero 32,
BitVec.zero 32, BitVec.zero 32)
let (out0, out1, out2, out3) :=
AESMixColumns_aux 0 in0 in1 in2 in3 out0 out1 out2 out3
extractLsb 31 24 out3 ++ extractLsb 31 24 out2 ++
extractLsb 31 24 out1 ++ extractLsb 31 24 out0 ++
extractLsb 23 16 out3 ++ extractLsb 23 16 out2 ++
extractLsb 23 16 out1 ++ extractLsb 23 16 out0 ++
extractLsb 15 8 out3 ++ extractLsb 15 8 out2 ++
extractLsb 15 8 out1 ++ extractLsb 15 8 out0 ++
extractLsb 7 0 out3 ++ extractLsb 7 0 out2 ++
extractLsb 7 0 out1 ++ extractLsb 7 0 out0
AESCommon.MixColumns op FFmul02 FFmul03

@[state_simp_rules]
def exec_aesmc
Expand Down
226 changes: 226 additions & 0 deletions Specs/AES.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Yan Peng
-/
import Arm.BitVec
import Arm.Insts.DPSFP.Crypto_aes
import Specs.AESCommon

-- References : https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.197-upd1.pdf
-- https://csrc.nist.gov/csrc/media/projects/cryptographic-standards-and-guidelines/documents/aes-development/rijndael-ammended.pdf
--
--------------------------------------------------
-- The NIST specification has the following rounds:
--
-- AddRoundKey key0
-- for k in key1 to key9
-- SubBytes
-- ShiftRows
-- MixColumns
-- AddRoundKey
-- SubBytes
-- ShiftRows
-- AddRoundKey key10
--
-- The Arm implementation has an optimization that shifts the rounds:
--
-- for k in key0 to key8
-- AddRoundKey + ShiftRows + SubBytes (AESE k)
-- MixColumns (AESMC)
-- AddRoundKey + ShiftRows + SubBytes (AESE key9)
-- AddRoundKey key10
--
-- Note: SubBytes and ShiftRows are commutative because
-- SubBytes is a byte-wise operation
--
--------------------------------------------------

namespace AES

open BitVec

def WordSize := 32
def BlockSize := 128

-- Maybe consider Lists vs Vectors?
-- https://github.com/joehendrix/lean-crypto/blob/323ee9b1323deed5240762f4029700a246ecd9d5/lib/Crypto/Vector.lean#L96
def Rcon : List (BitVec WordSize) :=
[ 0x00000001#32,
0x00000002#32,
0x00000004#32,
0x00000008#32,
0x00000010#32,
0x00000020#32,
0x00000040#32,
0x00000080#32,
0x0000001b#32,
0x00000036#32 ]

-------------------------------------------------------
-- types

-- Key-Block-Round Combinations
structure KBR where
key_len : Nat
block_size : Nat
Nk := key_len / 32
Nb := block_size / 32
Nr : Nat
h : block_size = BlockSize
deriving DecidableEq, Repr

def AES128KBR : KBR :=
{key_len := 128, block_size := BlockSize, Nr := 10, h := by decide}
def AES192KBR : KBR :=
{key_len := 192, block_size := BlockSize, Nr := 12, h := by decide}
def AES256KBR : KBR :=
{key_len := 256, block_size := BlockSize, Nr := 14, h := by decide}

def KeySchedule : Type := List (BitVec WordSize)

-- Declare KeySchedule to be an instance HAppend
-- so we can apply `++` to KeySchedules propertly
instance : HAppend KeySchedule KeySchedule KeySchedule where
hAppend := List.append

-------------------------------------------------------

def sbox (ind : BitVec 8) : BitVec 8 :=
match_bv ind with
| [x:4, y:4] =>
have h : (x.toNat * 128 + y.toNat * 8 + 7) - (x.toNat * 128 + y.toNat * 8) + 1 = 8 :=
by omega
h ▸ extractLsb
(x.toNat * 128 + y.toNat * 8 + 7)
(x.toNat * 128 + y.toNat * 8) $ BitVec.flatten AESCommon.SBOX
| _ => ind -- unreachable case

-- Little endian
def RotWord (w : BitVec WordSize) : BitVec WordSize :=
match_bv w with
| [a3:8, a2:8, a1:8, a0:8] => a0 ++ a3 ++ a2 ++ a1
| _ => w -- unreachable case

def SubWord (w : BitVec WordSize) : BitVec WordSize :=
match_bv w with
| [a3:8, a2:8, a1:8, a0:8] => (sbox a3) ++ (sbox a2) ++ (sbox a1) ++ (sbox a0)
| _ => w -- unreachable case

protected def InitKey {Param : KBR} (i : Nat) (key : BitVec Param.key_len)
(acc : KeySchedule) : KeySchedule :=
if h₀ : Param.Nk ≤ i then acc
else
have h₁ : i * 32 + 32 - 1 - i * 32 + 1 = WordSize := by
simp only [WordSize]; omega
let wd := h₁ ▸ extractLsb (i * 32 + 32 - 1) (i * 32) key
let (x:KeySchedule) := [wd]
have _ : Param.Nk - (i + 1) < Param.Nk - i := by omega
AES.InitKey (Param := Param) (i + 1) key (acc ++ x)
termination_by (Param.Nk - i)

protected def KeyExpansion_helper {Param : KBR} (i : Nat) (ks : KeySchedule)
: KeySchedule :=
if h : 4 * Param.Nr + 4 ≤ i then
ks
else
let tmp := List.get! ks (i - 1)
let tmp :=
if i % Param.Nk == 0 then
(SubWord (RotWord tmp)) ^^^ (List.get! Rcon $ (i / Param.Nk) - 1)
else if Param.Nk > 6 && i % Param.Nk == 4 then
SubWord tmp
else
tmp
let res := (List.get! ks (i - Param.Nk)) ^^^ tmp
let ks := List.append ks [ res ]
have _ : 4 * Param.Nr + 4 - (i + 1) < 4 * Param.Nr + 4 - i := by omega
AES.KeyExpansion_helper (Param := Param) (i + 1) ks
termination_by (4 * Param.Nr + 4 - i)

def KeyExpansion {Param : KBR} (key : BitVec Param.key_len)
: KeySchedule :=
let seeded := AES.InitKey (Param := Param) 0 key []
AES.KeyExpansion_helper (Param := Param) Param.Nk seeded

def SubBytes {Param : KBR} (state : BitVec Param.block_size)
: BitVec Param.block_size :=
have h : Param.block_size = 128 := by simp only [Param.h, BlockSize]
h ▸ AESCommon.SubBytes (h ▸ state)

def ShiftRows {Param : KBR} (state : BitVec Param.block_size)
: BitVec Param.block_size :=
have h : Param.block_size = 128 := by simp only [Param.h, BlockSize]
h ▸ AESCommon.ShiftRows (h ▸ state)

def XTimes (bv : BitVec 8) : BitVec 8 :=
let res := extractLsb 6 0 bv ++ 0b0#1
if extractLsb 7 7 bv == 0b0#1 then res else res ^^^ 0b00011011#8

def MixColumns {Param : KBR} (state : BitVec Param.block_size)
: BitVec Param.block_size :=
have h : Param.block_size = 128 := by simp only [Param.h, BlockSize]
let FFmul02 := fun (x : BitVec 8) => XTimes x
let FFmul03 := fun (x : BitVec 8) => x ^^^ XTimes x
h ▸ AESCommon.MixColumns (h ▸ state) FFmul02 FFmul03

-- TODO : Prove the following lemma
theorem MixColumns_table_lookup_equiv {Param : KBR}
(state : BitVec Param.block_size):
have h : Param.block_size = 128 := by simp only [Param.h, BlockSize]
MixColumns (Param := Param) state = h ▸ DPSFP.AESMixColumns (h ▸ state) := by
simp only [MixColumns, DPSFP.AESMixColumns]
have h₀ : (fun x => XTimes x) = DPSFP.FFmul02 := by
funext x
simp only [XTimes, DPSFP.FFmul02]
simp only [Nat.reduceSub, Nat.reduceAdd, beq_iff_eq, Nat.sub_zero, List.length_cons, List.length_nil,
Nat.reduceSucc, Nat.reduceMul]
sorry -- looks like a sat problem
have h₁ : (fun x => x ^^^ XTimes x) = DPSFP.FFmul03 := by
funext x
simp only [XTimes, DPSFP.FFmul03]
simp only [Nat.reduceSub, Nat.reduceAdd, beq_iff_eq, Nat.sub_zero, List.length_cons, List.length_nil,
Nat.reduceSucc, Nat.reduceMul]
sorry -- looks like a sat problem
rw [h₀, h₁]

def AddRoundKey {Param : KBR} (state : BitVec Param.block_size)
(roundKey : BitVec Param.block_size) : BitVec Param.block_size :=
state ^^^ roundKey

protected def getKey {Param : KBR} (n : Nat) (w : KeySchedule) : BitVec Param.block_size :=
let ind := 4 * n
have h : WordSize + WordSize + WordSize + WordSize = Param.block_size := by
simp only [WordSize, BlockSize, Param.h]
h ▸ ((List.get! w (ind + 3)) ++ (List.get! w (ind + 2)) ++
(List.get! w (ind + 1)) ++ (List.get! w ind))

protected def AES_encrypt_with_ks_loop {Param : KBR} (round : Nat)
(state : BitVec Param.block_size) (w : KeySchedule)
: BitVec Param.block_size :=
if Param.Nr ≤ round then
state
else
let state := SubBytes state
let state := ShiftRows state
let state := MixColumns state
let state := AddRoundKey state $ AES.getKey round w
AES.AES_encrypt_with_ks_loop (Param := Param) (round + 1) state w
termination_by (Param.Nr - round)

def AES_encrypt_with_ks {Param : KBR} (input : BitVec Param.block_size)
(w : KeySchedule) : BitVec Param.block_size :=
have h₀ : WordSize + WordSize + WordSize + WordSize = Param.block_size := by
simp only [WordSize, BlockSize, Param.h]
let state := AddRoundKey input $ (h₀ ▸ AES.getKey 0 w)
let state := AES.AES_encrypt_with_ks_loop (Param := Param) 1 state w
let state := SubBytes (Param := Param) state
let state := ShiftRows (Param := Param) state
AddRoundKey state $ h₀ ▸ AES.getKey Param.Nr w

def AES_encrypt {Param : KBR} (input : BitVec Param.block_size)
(key : BitVec Param.key_len) : BitVec Param.block_size :=
let ks := KeyExpansion (Param := Param) key
AES_encrypt_with_ks (Param := Param) input ks

end AES
Loading

0 comments on commit cecc746

Please sign in to comment.