diff --git a/DatapathVerification/BitHeap/BitHeap.lean b/DatapathVerification/BitHeap/BitHeap.lean index 6f06b04..b69c6f4 100644 --- a/DatapathVerification/BitHeap/BitHeap.lean +++ b/DatapathVerification/BitHeap/BitHeap.lean @@ -6,9 +6,8 @@ import Mathlib.Tactic.SplitIfs import Std.Data.HashMap.Lemmas import Std.Tactic.Do -structure BitHeap where - width : Nat - columns : Std.HashMap Nat BitHeap.Column +structure BitHeap (width : Nat) where + columns : Vector BitHeap.Column width deriving Inhabited namespace BitHeap @@ -16,91 +15,72 @@ namespace BitHeap open Circuit open Column -instance : ToString BitHeap where +instance : ToString (BitHeap w) where toString h := - let entries := h.columns.toList.mergeSort (fun a b => a.1 ≤ b.1) - "{" ++ String.intercalate ", " (entries.map (fun (k, v) => s!"{k} ↦ {v}")) ++ "}" + let entries := h.columns.toList.zipIdx + "{" ++ String.intercalate ", " (entries.map (fun (v, k) => s!"{k} ↦ {v}")) ++ "}" -def empty : BitHeap := ⟨0, Std.HashMap.emptyWithCapacity 0⟩ +def empty (w : Nat) : BitHeap w := ⟨Vector.replicate w Column.empty⟩ /-- Evaluate a bit-heap, to compute the final sum of all the bits in the heap. -/ -def eval (h : BitHeap) (env : BitEnv) : Int := - (h.columns.fold (init := 0) (fun acc w col => acc + (2 ^ w) * col.eval env)) - -def eval' (h : BitHeap) (env : BitEnv) : Int := Id.run do - let mut acc := 0 - for (col, val) in h.columns do - acc := acc + (2 ^ col) * val.eval env - return acc +def eval (h : BitHeap w) (env : BitEnv) : Int := + h.columns.toList.zipIdx.foldl (fun acc (col, idx) => acc + 2^idx * col.eval env) 0 /-- Evaluate a bit-heap modulo 2^width, to compute the final sum of all the bits in the heap. -/ -def evalMod (h : BitHeap) (env : BitEnv) : Int := - h.eval env % 2^(h.width) - -structure AdderResult where - heap : BitHeap - sum : Circuit - carry : Circuit +def evalMod (h : BitHeap w) (env : BitEnv) : Int := + h.eval env % 2^(w) -def get (h : BitHeap) (column : Nat) : Column := +def get (h : BitHeap w) (column : Nat) : Column := h.columns.getD column (Column.empty) -theorem get_eq_getD (h : BitHeap) (column : Nat) : - h.get column = (h.columns[column]?).getD Column.empty := by - simp [get, Std.HashMap.getD_eq_getD_getElem?] +-- if index is in bounds, getD returns the value +theorem getD_in_bounds (h : BitHeap w) (column : Nat) (h1 : column < w) : + h.get column = h.columns[column] := by + simp [get, Vector.getD] + rw [Vector.getElem_eq_getElem?_get, Vector.getD_getElem?] + simp [h1] -instance : Membership Circuit BitHeap where +instance : Membership Circuit (BitHeap w) where mem h c := ∃ (col : Nat), c ∈ h.get col -def removeBit (column : Nat) (c : Circuit) (h : BitHeap) : BitHeap := - ⟨h.width, h.columns.modify column (fun col => col.erase c)⟩ +def removeBit (column : Nat) (c : Circuit) (h : BitHeap w) : BitHeap w := + ⟨h.columns.setIfInBounds column ((h.get column).erase c)⟩ -- Maximum height across all columns -def maxHeight (h : BitHeap) : Nat := - h.columns.fold (init := 0) (fun acc _ col => max acc col.height) +def maxHeight (h : BitHeap w) : Nat := + h.columns.foldl (fun acc col => max acc col.height) 0 -- Highest column of the BitHeap, return none if BitHeap is empty -def highestColumn (h : BitHeap) : Option Nat := +def highestColumn (h : BitHeap w) : Option Nat := let target := h.maxHeight - h.columns.toList.findSome? (fun (idx, col) => if col.height == target then some idx else none) + if target == 0 then none else + h.columns.toList.zipIdx.findSome? + (fun (col, idx) => if col.height == target then some idx else none) /-- Add a bit into the bit heap, returning a new bit heap. If the bit already exists in the column, remove it and add it to the next column. Stops carrying when the column exceeds the width of the bit heap. -/ -def addBit (column : Nat) (c : Circuit) (h : BitHeap) : BitHeap := - if column >= h.width then h else +def addBit (column : Nat) (c : Circuit) (h : BitHeap w) : BitHeap w := + if h_bounds : column >= w then h else + have h1 : column < w := by omega let col := h.get column - if !col.contains c then - ⟨h.width, h.columns.insert column (col.insert c)⟩ - else addBit (column + 1) c (h.removeBit column c) - termination_by h.width - column - decreasing_by - have hw : (removeBit column c h).width = h.width := by rfl - rw [hw] - omega - -@[simp] -theorem removeBit_width (column : Nat) (c : Circuit) (h : BitHeap) : - (removeBit column c h).width = h.width := by rfl + if !col.contains c then + ⟨h.columns.set column (col.insert c) h1⟩ + else addBit (column + 1) c (h.removeBit column c) -@[simp] -theorem addBit_width (column : Nat) (c : Circuit) (h : BitHeap) : - (addBit column c h).width = h.width := by - fun_induction addBit with - | case1 => rfl - | case2 => rfl - | case3 _ _ _ _ _ ih => - rw [removeBit_width] at ih - rw [ih] +structure AdderResult (w : Nat) where + heap : BitHeap w + sum : Circuit + carry : Circuit -def halfAdder (column : Nat) (i j : Circuit) (h : BitHeap) : AdderResult := +def halfAdder (column : Nat) (i j : Circuit) (h : BitHeap w) : AdderResult w := let h := h.removeBit column i let h := h.removeBit column j let sum := Circuit.binop .xor i j @@ -109,7 +89,7 @@ def halfAdder (column : Nat) (i j : Circuit) (h : BitHeap) : AdderResult := let h := h.addBit (column + 1) carry ⟨h, sum, carry⟩ -def fullAdder (column : Nat) (i j k : Circuit) (h : BitHeap) : AdderResult := +def fullAdder (column : Nat) (i j k : Circuit) (h : BitHeap w) : AdderResult w := let h := h.removeBit column i let h := h.removeBit column j let h := h.removeBit column k @@ -120,76 +100,32 @@ def fullAdder (column : Nat) (i j k : Circuit) (h : BitHeap) : AdderResult := ⟨h, sum, carry⟩ @[simp] -theorem evalMod_heap_removeBit (column : Nat) (c : Circuit) (h : BitHeap) (env : BitEnv) (h1 : c ∈ h.get column) : - (h.removeBit column c).evalMod env = (h.evalMod env - 2^(column) * (c.eval env).toInt) % 2^(h.width) := by +theorem evalMod_heap_removeBit (column : Nat) (c : Circuit) (h : BitHeap w) (env : BitEnv) (h1 : c ∈ h.get column) : + (h.removeBit column c).evalMod env = (h.evalMod env - 2^(column) * (c.eval env).toInt) % 2^(w) := by unfold evalMod - rw [removeBit_width] simp [eval, removeBit] have : (h.get column |>.erase c).eval env = (h.get column).eval env - 2 ^ column * (c.eval env).toInt := by sorry -- have : (h.columns.modify column fun col => col.erase c) = h.columns - 2 ^ column * (c.eval env).toInt := by sorry - repeat rw [Std.HashMap.fold_eq_foldl_toList] sorry -theorem by_pow2_of_zero_eval (h : BitHeap) (h1 : col ≥ h.width) : - (2 : Int) ^ h.width ∣ (2 : Int) ^ col := by +theorem by_pow2_of_zero_eval (h : BitHeap w) (h1 : col ≥ w) : + (2 : Int) ^ w ∣ (2 : Int) ^ col := by sorry -- exact Nat.pow_dvd_pow_iff_le_right'.mpr h1 -> this works for Nat. -/-- -Relate BitHeap.env to sum of a list. (Nat x Column) comes from Std.HashMap.toList, since it returns (Key x Value) pairs. --/ -theorem foldl_sum (l : List (Nat × Column)) (env : BitEnv) (a : Int) : - l.foldl (fun acc (p : Nat × Column) => acc + 2 ^ p.1 * (p.2.eval env : Int)) a = - a + (l.map (fun p => 2 ^ p.1 * (p.2.eval env : Int))).sum := by - induction l generalizing a with - | nil => simp - | cons p ps ih => - grind - -@[grind => ] -theorem eval_insertColumn_eq_eval_add (h : BitHeap) (k : Nat) (v : Column) (env : BitEnv) : - (⟨h.width, h.columns.insert k v⟩ : BitHeap).eval env - = h.eval env + 2 ^ k * (v.eval env : Int) - 2 ^ k * ((h.get k).eval env : Int) := by - cases h - case mk width cols => - simp [BitHeap.eval] - rw [Std.HashMap.fold_eq_foldl_toList] - rw [Std.HashMap.fold_eq_foldl_toList] - rw [foldl_sum] - rw [foldl_sum] - sorry - -@[grind => ] -theorem eval_eraseColumn_eq_eval_sub (h : BitHeap) (k : Nat) (env : BitEnv) : - (⟨h.width, h.columns.erase k⟩ : BitHeap).eval env - = h.eval env - 2 ^ k * ((h.get k).eval env : Int) := by - simp [BitHeap.eval] - rw [Std.HashMap.fold_eq_foldl_toList] - rw [Std.HashMap.fold_eq_foldl_toList] - - rw? - sorry - - -theorem eval_insertColumn (h : BitHeap) (k : Nat) (v : Column) (env : BitEnv) : - (⟨h.width, h.columns.insert k v⟩ : BitHeap).eval env - = (⟨h.width, h.columns.erase k⟩ : BitHeap).eval env + 2 ^ k * (v.eval env : Int) := by - sorry - - -theorem eval_eraseColumn (h : BitHeap) (k : Nat) (env : BitEnv) : - h.eval env - = (⟨h.width, h.columns.erase k⟩ : BitHeap).eval env + 2 ^ k * ((h.get k).eval env : Int) := by +theorem eval_insertColumn (h : BitHeap w) (k : Nat) (col : Column) (env : BitEnv) (h1 : column < w) : + ({ columns := h.columns.set column col h1 } : BitHeap w).eval env + = ({ columns := h.columns.set column (Column.empty) h1} : BitHeap w).eval env + 2 ^ column * (col.eval env : Int) := by sorry @[simp] -theorem evalMod_heap_addBit (column : Nat) (c : Circuit) (h : BitHeap) (env : BitEnv) : - (h.addBit column c).evalMod env = (h.evalMod env + 2^column * (c.eval env).toInt) % 2^(h.width) := by +theorem evalMod_heap_addBit (column : Nat) (c : Circuit) (h : BitHeap w) (env : BitEnv) : + (h.addBit column c).evalMod env = (h.evalMod env + 2^column * (c.eval env).toInt) % 2^(w) := by fun_induction addBit with | case1 col h h1 => simp [evalMod] - have h3 : 2 ^ col * (c.eval env).toInt % 2 ^ h.width = 0 := by + have h3 : 2 ^ col * (c.eval env).toInt % 2 ^ w = 0 := by generalize hvi : c.eval env = vi rcases vi · simp @@ -198,47 +134,55 @@ theorem evalMod_heap_addBit (column : Nat) (c : Circuit) (h : BitHeap) (env : Bi apply Int.emod_eq_zero_of_dvd exact_mod_cast by_pow2_of_zero_eval h h1 simp [Int.add_emod, h3] - | case2 col h h1 => + | case2 column h h4 h3 col h1 => simp only [evalMod, Int.emod_add_emod] - have h3 : (⟨h.width, h.columns.insert col ((h.get col).insert c)⟩ : BitHeap).eval env = h.eval env + 2 ^ col * (c.eval env).toInt := by - - rw [eval_insertColumn, eval_eraseColumn h col env] - - rw [Column.eval_insert] - · grind - · simp - grind - rw [h3] - | case3 _ _ _ h2 h1 ih => - simp only [ih, removeBit_width] + have : ({ columns := h.columns.set column (col.insert c) h3 } : BitHeap w).eval env = (h.eval env + 2 ^ column * (c.eval env).toInt) := by + simp_all + rw [eval_insertColumn] + · rw [Column.eval_insert] + · simp + rw [Int.mul_add] + simp [eval] + sorry + · simp + exact h1 + · exact h.maxHeight + rw [this] + | case3 col h h4 h3 h2 h1 ih => + rw [ih] rw [evalMod_heap_removeBit] - · simp only [Int.emod_add_emod] - grind - · simp at h1 - simp [mem_iff_contains] - grind - -@[simp] -theorem get_removeBit_of_ne (column : Nat) (h : BitHeap) (i j : Circuit) - (h1 : i ∈ h.get column) (hne : i ≠ j) : - i ∈ (removeBit column j h).get column := by - rw [get_eq_getD] at h1 - rw [get_eq_getD] - simp only [removeBit] - rcases hcol : h.columns[column]? - · simp_all only + have : (- 2 ^ col * (c.eval env).toInt + 2 ^ (col + 1) * (c.eval env).toInt) = 2 ^ col * (c.eval env).toInt := by grind + rw [← this] + simp_all -- this looks very ugly grind - · simp_all only [Option.getD_some, mem_iff_contains, ne_eq, Std.HashMap.getElem?_modify_self, - Option.map_some, Column.erase, Column.contains, Std.HashSet.contains_erase] + simp_all grind -theorem removeBit_decreases_size (col : Nat) (c : Circuit) (h : BitHeap) (h1: c ∈ h.get col) : +theorem get_removeBit_self (column : Nat) (c : Circuit) (h : BitHeap w) (hb : column < w) : + (removeBit column c h).get column = (h.get column).erase c := by + simp only [removeBit] + rw [getD_in_bounds] <;> grind + +@[simp] +theorem get_removeBit_of_ne (column : Nat) (h : BitHeap w) (i j : Circuit) + (h1 : i ∈ h.get column) (hne : i ≠ j) : i ∈ (removeBit column j h).get column := by + by_cases hb : column < w + · rw [get_removeBit_self _ _ _ hb] + exact (erase_eq_erase (h.get column) h1 (id (Ne.symm hne))).mpr h1 + · have hr : removeBit column j h = h := by + simp only [removeBit] + rw [Vector.setIfInBounds_eq_of_size_le] + grind + rw [hr] + exact h1 + +theorem removeBit_decreases_size (col : Nat) (c : Circuit) (h : BitHeap w) (h1: c ∈ h.get col) : ((removeBit col c h).get col).height < (h.get col).height := by simp only [removeBit, height_eq_size] simp [erase] sorry -theorem double_removeBit_decreases (col : Nat) (c₁ c₂ : Circuit) (h : BitHeap) +theorem double_removeBit_decreases (col : Nat) (c₁ c₂ : Circuit) (h : BitHeap w) (h1 : c₁ ∈ h.get col) (h2 : c₂ ∈ h.get col) (hne : c₁ ≠ c₂) : ((removeBit col c₁ (removeBit col c₂ h)).get col).height < (h.get col).height := by have h1' : c₁ ∈ (removeBit col c₂ h).get col := @@ -247,7 +191,7 @@ theorem double_removeBit_decreases (col : Nat) (c₁ c₂ : Circuit) (h : BitHea (removeBit_decreases_size col c₁ (removeBit col c₂ h) h1') (removeBit_decreases_size col c₂ h h2) -theorem triple_removeBit_decreases (col : Nat) (c₁ c₂ c₃ : Circuit) (h : BitHeap) +theorem triple_removeBit_decreases (col : Nat) (c₁ c₂ c₃ : Circuit) (h : BitHeap w) (h1 : c₁ ∈ h.get col) (h2 : c₂ ∈ h.get col) (h3 : c₃ ∈ h.get col) (hne12 : c₁ ≠ c₂) (hne13 : c₁ ≠ c₃) (hne23 : c₂ ≠ c₃) : ((removeBit col c₁ (removeBit col c₂ (removeBit col c₃ h))).get col).height < (h.get col).height := by @@ -259,17 +203,12 @@ theorem triple_removeBit_decreases (col : Nat) (c₁ c₂ c₃ : Circuit) (h : B (double_removeBit_decreases col c₁ c₂ (removeBit col c₃ h) h1' h2' hne12) (removeBit_decreases_size col c₃ h h3) -@[simp] -theorem halfAdder_preserves_width (column : Nat) (i j : Circuit) (h : BitHeap) : - (h.halfAdder column i j).heap.width = h.width := by - simp [halfAdder, removeBit] - -theorem halfAdder_correct_mod (column : Nat) (i j : Circuit) (h : BitHeap) +theorem halfAdder_correct_mod (column : Nat) (i j : Circuit) (h : BitHeap w) (h1 : i ∈ h.get column) (h2 : j ∈ h.get column) (hne : i ≠ j) : ∀ (env : BitEnv), (h.halfAdder column i j).heap.evalMod env = h.evalMod env := by intros env have h3 := get_removeBit_of_ne column h j i h2 hne.symm - simp [halfAdder, evalMod_heap_addBit, addBit_width, removeBit_width] + simp [halfAdder, evalMod_heap_addBit] simp only [evalMod_heap_removeBit, h1, h3] simp [evalMod] generalize hvi : i.eval env = vi @@ -277,20 +216,14 @@ theorem halfAdder_correct_mod (column : Nat) (i j : Circuit) (h : BitHeap) rcases vi <;> rcases vj <;> simp_all grind -@[simp] -theorem fullAdder_preserves_width (column : Nat) (i j k : Circuit) (h : BitHeap) : - (h.fullAdder column i j k).heap.width = h.width := by - simp [fullAdder, removeBit] - -theorem fullAdder_correct_mod (column : Nat) (i j k : Circuit) (h : BitHeap) - (h1 : i ∈ h.get column) (h2 : j ∈ h.get column) (h3 : k ∈ h.get column) - (hne : i ≠ j) (hne2 : i ≠ k) (hne3 : j ≠ k) : +theorem fullAdder_correct_mod (column : Nat) (i j k : Circuit) (h : BitHeap w) + (h1 : i ∈ h.get column) (h2 : j ∈ h.get column) (h3 : k ∈ h.get column) (hne : i ≠ j) (hne2 : i ≠ k) (hne3 : j ≠ k) : ∀ (env : BitEnv), (h.fullAdder column i j k).heap.evalMod env = h.evalMod env := by intros env have h4 := get_removeBit_of_ne column h j i h2 hne.symm have h5 := get_removeBit_of_ne column (removeBit column i h) k have h6 := h5 j (get_removeBit_of_ne column h k i h3 hne2.symm) hne3.symm - simp [fullAdder, evalMod_heap_addBit, addBit_width, removeBit_width] + simp [fullAdder, evalMod_heap_addBit] simp only [evalMod_heap_removeBit, h1, h4, h6] simp [evalMod] generalize hvi : i.eval env = vi diff --git a/DatapathVerification/BitHeap/Chain.lean b/DatapathVerification/BitHeap/Chain.lean index 63995e9..a13435a 100644 --- a/DatapathVerification/BitHeap/Chain.lean +++ b/DatapathVerification/BitHeap/Chain.lean @@ -28,19 +28,19 @@ def printSummary (adders : List Adder) : String := s!"FAs: {fa}, HAs: {ha}" /-- Apply an adder to a bit heap returning the updated heap -/ -def applyAdder (adder : Adder) (h : BitHeap) : BitHeap := +def applyAdder (adder : Adder) (h : BitHeap w) : BitHeap w := match adder with | .halfAdder column i j => (h.halfAdder column i j).heap | .fullAdder column i j k => (h.fullAdder column i j k).heap /-- Apply a list of adders, from front to the back -/ -def applyChain (adders : List Adder) (h : BitHeap) : BitHeap := +def applyChain (adders : List Adder) (h : BitHeap w) : BitHeap w := match adders with | [] => h | s :: rest => applyChain rest (applyAdder s h) /-- Preconditions for each step in the chain -/ -def ChainPreconditions (steps : List Adder) (h : BitHeap) : Prop := +def ChainPreconditions (steps : List Adder) (h : BitHeap w) : Prop := match steps with | [] => True | s :: rest => @@ -52,21 +52,7 @@ def ChainPreconditions (steps : List Adder) (h : BitHeap) : Prop := ∧ i ≠ j ∧ i ≠ k ∧ j ≠ k) ∧ ChainPreconditions rest (applyAdder s h) -@[simp] -theorem applyChain_preserves_width (steps : List Adder) (h : BitHeap) : (applyChain steps h).width = h.width := by - induction steps generalizing h with - | nil => rfl - | cons s rest ih => - simp [applyChain] - cases s with - | halfAdder => - simp [applyAdder, ih] - | fullAdder => - simp [applyAdder, ih] - -/-- Main correctness theorem for the chain. - Applying the chain preserves the heap's value under all evaluation environments -/ -theorem applyChain_correct_mod (steps : List Adder) (h : BitHeap) +theorem applyChain_correct_mod (steps : List Adder) (h : BitHeap w) (hwf : ChainPreconditions steps h) : ∀ (env : BitEnv), (applyChain steps h).evalMod env = h.evalMod env := by intros env @@ -84,7 +70,7 @@ theorem applyChain_correct_mod (steps : List Adder) (h : BitHeap) · grind /-- Check if a single step of the chain is applicable -/ -def isApplicable (step: Adder) (h : BitHeap) : Bool := +def isApplicable (step: Adder) (h : BitHeap w) : Bool := match step with | .halfAdder col i j => let column := h.get col @@ -94,7 +80,7 @@ def isApplicable (step: Adder) (h : BitHeap) : Bool := column.contains i && column.contains j && column.contains k && i != j && i != k && j != k /-- Apply a chain of adders if they are applicable, otherwise return none -/ -def applyChainSafe (steps : List Adder) (h : BitHeap) : Option BitHeap := +def applyChainSafe (steps : List Adder) (h : BitHeap w) : Option (BitHeap w) := match steps with | [] => some h | s :: rest => @@ -103,31 +89,7 @@ def applyChainSafe (steps : List Adder) (h : BitHeap) : Option BitHeap := else none -@[simp] -theorem applyChainSafe_preserves_width (steps : List Adder) (h h' : BitHeap) - (heq : applyChainSafe steps h = some h') : - h'.width = h.width := by - induction steps generalizing h with - | nil => - simp [applyChainSafe] at heq - rw [heq] - | cons s rest ih => - simp [applyChainSafe] at heq - obtain ⟨hleft, hright⟩ := heq - have ih_applied := ih (applyAdder s h) hright - rw [ih_applied] - simp_all [applyAdder, isApplicable] - cases s with - simp at hleft - | halfAdder => - obtain ⟨⟨i_in_col, j_in_col⟩, not_eq⟩ := hleft - rw [halfAdder_preserves_width _ _ _] - | fullAdder => - obtain ⟨⟨⟨⟨⟨hi, hj⟩, hk⟩, hij⟩, hik⟩, hjk⟩ := hleft - rw [fullAdder_preserves_width _ _ _] - --- /-- If a chain of adders is applicable (it does not return none), then it preserves the heap's value -/ -theorem applyChainSafe_correct_mod (steps : List Adder) (h h' : BitHeap) +theorem applyChainSafe_correct_mod (steps : List Adder) (h h' : BitHeap w) (heq : applyChainSafe steps h = some h') : ∀ (env : BitEnv), h'.evalMod env = h.evalMod env := by intros env @@ -150,7 +112,6 @@ theorem applyChainSafe_correct_mod (steps : List Adder) (h h' : BitHeap) obtain ⟨⟨⟨⟨⟨hi, hj⟩, hk⟩, hij⟩, hik⟩, hjk⟩ := hleft rw [fullAdder_correct_mod _ _ _ _ _ hi hj hk hij hik hjk env] - end Chain end BitHeap diff --git a/DatapathVerification/BitHeap/Column.lean b/DatapathVerification/BitHeap/Column.lean index f61d583..2acc1a4 100644 --- a/DatapathVerification/BitHeap/Column.lean +++ b/DatapathVerification/BitHeap/Column.lean @@ -49,17 +49,41 @@ theorem height_eq_size (col : Column) : col.height = col.elems.size := rfl def toList (col : Column) : List Circuit := col.elems.toList +theorem erase_eq_erase (col : Column) (he : d ∈ col) (hne : c ≠ d) : d ∈ col.erase c ↔ d ∈ col := by + simp_all [erase, contains] + +theorem foldl_sum (l : List Circuit) (env : BitEnv) (a : Nat) : + l.foldl (fun acc (c : Circuit) => acc + (c.eval env).toNat) a = + a + (l.map (fun c => (c.eval env).toNat)).sum := by + induction l generalizing a with + | nil => simp + | cons p ps ih => + grind + @[simp] theorem eval_erase (col : Column) (c : Circuit) (env : BitEnv) (h : c ∈ col) : (col.erase c).eval env = col.eval env - (c.eval env).toInt := by simp [eval, erase] + repeat rw [Std.HashSet.fold_eq_foldl_toList] + rw [eq_comm, Int.sub_eq_iff_eq_add'] + repeat rw [foldl_sum] + simp only [Nat.zero_add] + have : col.elems.toList.Perm (c :: (col.elems.erase c).toList) := by + sorry sorry @[simp] theorem eval_insert (col : Column) (c : Circuit) (env : BitEnv) (h : c ∉ col) : - (col.insert c).eval env = col.eval env + (c.eval env).toInt := by + (col.insert c).eval env = col.eval env + (c.eval env).toNat := by simp [eval, insert] - sorry + repeat rw [Std.HashSet.fold_eq_foldl_toList, foldl_sum] + simp only [Nat.zero_add] + have : ((col.elems.insert c).toList).Perm (c :: col.elems.toList) := by + have key := col.elems.toList_insert_perm (k := c) + have h' : c ∉ col.elems := by exact h + rw [if_neg h'] at key + exact key + grind end Column diff --git a/DatapathVerification/BitHeap/CompressionHelpers.lean b/DatapathVerification/BitHeap/CompressionHelpers.lean index a4c8b94..3f131d1 100644 --- a/DatapathVerification/BitHeap/CompressionHelpers.lean +++ b/DatapathVerification/BitHeap/CompressionHelpers.lean @@ -9,13 +9,13 @@ open Chain namespace Compression -def applyHalfAdder (column : Nat) (i j : Circuit) (h : BitHeap) (acc : BitHeap) : (BitHeap × BitHeap × Adder) := +def applyHalfAdder (column : Nat) (i j : Circuit) (h : BitHeap w) (acc : BitHeap w) : (BitHeap w × BitHeap w × Adder) := let HA := Adder.halfAdder column i j let newAcc := Chain.applyAdder HA acc -- applies a Half Adder, removing compressed bits and adding sum and carry bits to acc. let newOriginal := h.removeBit column i |>.removeBit column j -- removes the compressed bits from the original heap. (newOriginal, newAcc, HA) -def applyFullAdder (column : Nat) (i j k : Circuit) (h : BitHeap) (acc : BitHeap) : (BitHeap × BitHeap × Adder) := +def applyFullAdder (column : Nat) (i j k : Circuit) (h : BitHeap w) (acc : BitHeap w) : (BitHeap w × BitHeap w × Adder) := let FA := Adder.fullAdder column i j k let newAcc := Chain.applyAdder FA acc -- applies a Full Adder, removing compressed bits and adding sum and carry bits to acc. let newOriginal := h.removeBit column i |>.removeBit column j |>.removeBit column k -- removes the compressed bits from the original heap. diff --git a/DatapathVerification/BitHeap/DaddaTree.lean b/DatapathVerification/BitHeap/DaddaTree.lean index e038d54..ce4b595 100644 --- a/DatapathVerification/BitHeap/DaddaTree.lean +++ b/DatapathVerification/BitHeap/DaddaTree.lean @@ -69,8 +69,8 @@ the updated accumulator, and the list of adders used to compress the column. For each column whose height exceeds mₗ₋₁, introduces the fewest number of FAs and at most one HA to bring the height down to mₗ₋₁. -/ -partial def DaddaStageColumn (col : Nat) (h : BitHeap) (acc : BitHeap) (daddaLevel : Nat) - : BitHeap × BitHeap × List Adder := +partial def DaddaStageColumn (col : Nat) (h : BitHeap w) (acc : BitHeap w) (daddaLevel : Nat) + : BitHeap w × BitHeap w × List Adder := let DaddaHeightPrev := DaddaSequence (daddaLevel - 1) -- mₗ₋₁ is the maximum height allowed in the column after this stage of compression. if (acc.get col).height - DaddaHeightPrev ≥ 2 then -- If the column height is more than one above the previous Dadda level, apply a Full Adder to compress it. @@ -96,7 +96,7 @@ One full stage of Dadda tree across all columns of the bit heap. Loops over every column in order, invoking DaddaStageColumn to compress each one and folding the resulting adders into a running list. -/ -partial def DaddaStage (h : BitHeap) (daddaLevel : Nat) : BitHeap × List Adder := +partial def DaddaStage (h : BitHeap w) (daddaLevel : Nat) : BitHeap w × List Adder := let (_, acc, adders) := (List.range h.columns.size).foldl (fun (original, acc, adders) col => @@ -112,9 +112,9 @@ Top-level Dadda Tree function. Repeatedly applies DaddaStage until every column has at most 2 bits. -/ -partial def DaddaTree (h : BitHeap) : BitHeap × List Adder := +partial def DaddaTree (h : BitHeap w) : BitHeap w × List Adder := let daddaLevel := findDaddaLevel h.maxHeight - let rec loop (h : BitHeap) (adders : List Adder) (daddaLevel : Nat) : BitHeap × List Adder := + let rec loop (h : BitHeap w) (adders : List Adder) (daddaLevel : Nat) : BitHeap w × List Adder := if h.maxHeight ≤ 2 then (h, adders) else let (h', newAdders) := DaddaStage h daddaLevel diff --git a/DatapathVerification/BitHeap/Examples/DaddaTree.lean b/DatapathVerification/BitHeap/Examples/DaddaTree.lean index ef5273f..6db1ea7 100644 --- a/DatapathVerification/BitHeap/Examples/DaddaTree.lean +++ b/DatapathVerification/BitHeap/Examples/DaddaTree.lean @@ -32,8 +32,8 @@ info: 4 abbrev BitEnv := Nat → Bool -def pp4 : BitHeap := bitHeapOfPartialProducts 4 -def pp8 : BitHeap := bitHeapOfPartialProducts 8 +def pp4 : BitHeap 8 := bitHeapOfPartialProducts 4 +def pp8 : BitHeap 16:= bitHeapOfPartialProducts 8 ------------- First Dadda stage ------------- @@ -50,7 +50,7 @@ info: [HA(3: b9, b6), HA(4: b10, b13)] #eval (DaddaTree.DaddaStage pp4 2).2 /-- -info: (some {0 ↦ [b0], 1 ↦ [b1, b4], 2 ↦ [(b2 ⊕ b5), b8], 3 ↦ [((b3 ⊕ b12) ⊕ (b9 ⊕ b6)), (b2 ∧ b5)], 4 ↦ [(((b3 ∧ b12) ∨ (b3 ∧ (b9 ⊕ b6))) ∨ (b12 ∧ (b9 ⊕ b6))), ((b7 ⊕ (b10 ⊕ b13)) ⊕ (b9 ∧ b6))], 5 ↦ [(((b7 ∧ (b10 ⊕ b13)) ∨ (b7 ∧ (b9 ∧ b6))) ∨ ((b10 ⊕ b13) ∧ (b9 ∧ b6))), ((b14 ⊕ b11) ⊕ (b10 ∧ b13))], 6 ↦ [b15, (((b14 ∧ b11) ∨ (b14 ∧ (b10 ∧ b13))) ∨ (b11 ∧ (b10 ∧ b13)))]})-/ +info: (some {0 ↦ [b0], 1 ↦ [b1, b4], 2 ↦ [(b2 ⊕ b5), b8], 3 ↦ [((b3 ⊕ b12) ⊕ (b9 ⊕ b6)), (b2 ∧ b5)], 4 ↦ [(((b3 ∧ b12) ∨ (b3 ∧ (b9 ⊕ b6))) ∨ (b12 ∧ (b9 ⊕ b6))), ((b7 ⊕ (b10 ⊕ b13)) ⊕ (b9 ∧ b6))], 5 ↦ [(((b7 ∧ (b10 ⊕ b13)) ∨ (b7 ∧ (b9 ∧ b6))) ∨ ((b10 ⊕ b13) ∧ (b9 ∧ b6))), ((b14 ⊕ b11) ⊕ (b10 ∧ b13))], 6 ↦ [b15, (((b14 ∧ b11) ∨ (b14 ∧ (b10 ∧ b13))) ∨ (b11 ∧ (b10 ∧ b13)))], 7 ↦ []})-/ #guard_msgs in #eval (applyChainSafe (DaddaTree.DaddaTree pp4).2 pp4) diff --git a/DatapathVerification/BitHeap/Examples/EvalModExamples.lean b/DatapathVerification/BitHeap/Examples/EvalModExamples.lean index 81d173a..e0c4c8a 100644 --- a/DatapathVerification/BitHeap/Examples/EvalModExamples.lean +++ b/DatapathVerification/BitHeap/Examples/EvalModExamples.lean @@ -13,8 +13,8 @@ def env1 : BitEnv := fun n => n = 0 || n = 2 || n = 4 -- A heap with width 3: only columns 0, 1, 2 contribute under evalMod. -- Bits in column 3 (and beyond) get truncated away. -def heapWidth3 : BitHeap := - let h : BitHeap := ⟨3, Std.HashMap.emptyWithCapacity 0⟩ +def heapWidth3 : BitHeap 3 := + let h := BitHeap.empty 3 let h := h.addBit 0 (Circuit.bit 0) let h := h.addBit 1 (Circuit.bit 1) let h := h.addBit 2 (Circuit.bit 2) diff --git a/DatapathVerification/BitHeap/Examples/Examples.lean b/DatapathVerification/BitHeap/Examples/Examples.lean index 2af8119..4702bae 100644 --- a/DatapathVerification/BitHeap/Examples/Examples.lean +++ b/DatapathVerification/BitHeap/Examples/Examples.lean @@ -7,8 +7,8 @@ open Chain namespace Examples -def addBitsExample : BitHeap := - let h : BitHeap := ⟨3, Std.HashMap.emptyWithCapacity 0⟩ +def addBitsExample (w : Nat) : BitHeap w := + let h := BitHeap.empty w let h := h.addBit 0 (Circuit.bit 0)-- add a bit in column 0 let h := h.addBit 1 (Circuit.bit 1) -- add a bit in column 1 let h := h.addBit 1 (Circuit.bit 1) -- add another bit in column 1 @@ -20,12 +20,12 @@ def addBitsExample : BitHeap := info: {0 ↦ [], 1 ↦ [b1], 2 ↦ [b1]} -/ #guard_msgs in -#eval addBitsExample +#eval addBitsExample 3 abbrev BitEnv := Nat → Bool -def fourBitsInCol1 : BitHeap := - let h : BitHeap := ⟨3, Std.HashMap.emptyWithCapacity 0⟩ +def fourBitsInCol1 : BitHeap 4 := + let h := BitHeap.empty 4 let h := h.addBit 1 (Circuit.bit 0) let h := h.addBit 1 (Circuit.bit 1) let h := h.addBit 1 (Circuit.bit 2) @@ -49,8 +49,8 @@ info: 6 #guard_msgs in #eval (applyChain compressionChain fourBitsInCol1).eval (show BitEnv from fun n => n = 1 || n = 2 || n = 3) -def exampleHeap : BitHeap := - let h : BitHeap := ⟨3, Std.HashMap.emptyWithCapacity 0⟩ +def exampleHeap : BitHeap 5:= + let h := BitHeap.empty 5 let h := h.addBit 1 (Circuit.bit 0) let h := h.addBit 1 (Circuit.bit 1) let h := h.addBit 2 (Circuit.bit 2) diff --git a/DatapathVerification/BitHeap/Examples/NaiveCompressionExamples.lean b/DatapathVerification/BitHeap/Examples/NaiveCompressionExamples.lean index 6fdd1ed..92e2139 100644 --- a/DatapathVerification/BitHeap/Examples/NaiveCompressionExamples.lean +++ b/DatapathVerification/BitHeap/Examples/NaiveCompressionExamples.lean @@ -8,15 +8,15 @@ open Chain namespace NaiveCompressionExamples -def threeBitsInCol1 : BitHeap := - let h : BitHeap := ⟨32, Std.HashMap.emptyWithCapacity 0⟩ +def threeBitsInCol1 : BitHeap 2 := + let h := BitHeap.empty 2 let h := h.addBit 1 (Circuit.bit 0) let h := h.addBit 1 (Circuit.bit 1) let h := h.addBit 1 (Circuit.bit 2) h -def fiveBitsInCol1 : BitHeap := - let h : BitHeap := ⟨32, Std.HashMap.emptyWithCapacity 0⟩ +def fiveBitsInCol1 : BitHeap 32 := + let h := BitHeap.empty 32 let h := h.addBit 1 (Circuit.bit 0) let h := h.addBit 1 (Circuit.bit 1) let h := h.addBit 1 (Circuit.bit 2) @@ -24,8 +24,8 @@ def fiveBitsInCol1 : BitHeap := let h := h.addBit 1 (Circuit.bit 4) h -def multiColSmall : BitHeap := - let h : BitHeap := ⟨3, Std.HashMap.emptyWithCapacity 0⟩ +def multiColSmall : BitHeap 3 := + let h := BitHeap.empty 3 let h := h.addBit 0 (Circuit.bit 0) let h := h.addBit 0 (Circuit.bit 1) let h := h.addBit 1 (Circuit.bit 2) @@ -85,7 +85,6 @@ info: 6 #guard_msgs in #eval (NaiveCompression.reduceColumn 1 fiveBitsInCol1).1.evalMod (show BitEnv from env1) - ------------- /-- info: 9 diff --git a/DatapathVerification/BitHeap/Examples/PartialProductGenerator.lean b/DatapathVerification/BitHeap/Examples/PartialProductGenerator.lean index 6aeb6c1..8961b36 100644 --- a/DatapathVerification/BitHeap/Examples/PartialProductGenerator.lean +++ b/DatapathVerification/BitHeap/Examples/PartialProductGenerator.lean @@ -9,8 +9,8 @@ open Chain namespace PartialProductGenerator /-- Build a bit heap of partial products for the given bit-width. -/ -def bitHeapOfPartialProducts (width : Nat) : BitHeap := Id.run do - let mut h := ⟨(width*2), Std.HashMap.emptyWithCapacity 0⟩ +def bitHeapOfPartialProducts (width : Nat) : BitHeap (width * 2) := Id.run do + let mut h := BitHeap.empty (width * 2) for i in [0:width] do for j in [i:width+i] do h := h.addBit j (Circuit.bit (i * (width - 1) + j)) @@ -23,7 +23,7 @@ info: [1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1] #eval (List.range 15).map (fun c => ((bitHeapOfPartialProducts 8).get c).height) /-- -info: {0 ↦ [b0], 1 ↦ [b1, b4], 2 ↦ [b2, b5, b8], 3 ↦ [b9, b6, b3, b12], 4 ↦ [b10, b13, b7], 5 ↦ [b11, b14], 6 ↦ [b15]} +info: {0 ↦ [b0], 1 ↦ [b1, b4], 2 ↦ [b2, b5, b8], 3 ↦ [b9, b6, b3, b12], 4 ↦ [b10, b13, b7], 5 ↦ [b11, b14], 6 ↦ [b15], 7 ↦ []} -/ #guard_msgs in #eval bitHeapOfPartialProducts 4 diff --git a/DatapathVerification/BitHeap/Examples/WallaceTree.lean b/DatapathVerification/BitHeap/Examples/WallaceTree.lean index fd7b627..ac0af53 100644 --- a/DatapathVerification/BitHeap/Examples/WallaceTree.lean +++ b/DatapathVerification/BitHeap/Examples/WallaceTree.lean @@ -12,8 +12,8 @@ namespace WallaceTreeExamples abbrev BitEnv := Nat → Bool -def pp4 : BitHeap := bitHeapOfPartialProducts 4 -def pp8 : BitHeap := bitHeapOfPartialProducts 8 +def pp4 : BitHeap 8 := bitHeapOfPartialProducts 4 +def pp8 : BitHeap 16 := bitHeapOfPartialProducts 8 /-- info: [1, 2, 3, 4, 3, 2, 1] -/ diff --git a/DatapathVerification/BitHeap/NaiveCompression.lean b/DatapathVerification/BitHeap/NaiveCompression.lean index ea03707..ad798a2 100644 --- a/DatapathVerification/BitHeap/NaiveCompression.lean +++ b/DatapathVerification/BitHeap/NaiveCompression.lean @@ -11,7 +11,7 @@ namespace NaiveCompression Another difference with the Wallace tree is that this naive approach consumes carries in the same round. -/ -- if height >= 4, apply FA. if height = 3, apply HA. -def reduceColumnStep (col : Nat) (h : BitHeap) : Option (BitHeap × Adder) := +def reduceColumnStep (col : Nat) (h : BitHeap w) : Option (BitHeap w × Adder) := match (h.get col).toList with | a :: b :: c :: _ :: _ => let FA := Adder.fullAdder col a b c @@ -22,14 +22,14 @@ def reduceColumnStep (col : Nat) (h : BitHeap) : Option (BitHeap × Adder) := | _ => none -- termination is guaranteed since each step reduces the height of the column. -partial def reduceColumn (col : Nat) (h : BitHeap) : BitHeap × List Adder := +partial def reduceColumn (col : Nat) (h : BitHeap w) : BitHeap w × List Adder := match reduceColumnStep col h with | none => (h, []) | some (h', adder) => let (h'', adders) := reduceColumn col h' (h'', adder :: adders) -def naiveCompression (h : BitHeap) : BitHeap × List Adder := +def naiveCompression (h : BitHeap w) : BitHeap w × List Adder := (List.range h.columns.size).foldl (fun (heap, adders) col => let (newHeap, newAdders) := reduceColumn col heap (newHeap, adders ++ newAdders)) (h, []) diff --git a/DatapathVerification/BitHeap/WallaceTree.lean b/DatapathVerification/BitHeap/WallaceTree.lean index a5aa8a2..b7302a1 100644 --- a/DatapathVerification/BitHeap/WallaceTree.lean +++ b/DatapathVerification/BitHeap/WallaceTree.lean @@ -26,8 +26,8 @@ If acc.get col has ≥ 4 bits but the original column has only 2 bits left, a half adder is used instead. This is because we cannot consume newly added bits, so we only consume the original bits in the column. If acc.get col has exactly 3 bits, a single half adder is applied and we stop. -/ -partial def WallaceStageColumn (col : Nat) (h : BitHeap) (acc : BitHeap) - : BitHeap × BitHeap × List Adder := +partial def WallaceStageColumn (col : Nat) (h : BitHeap w) (acc : BitHeap w) + : BitHeap w × BitHeap w × List Adder := match (h.get col).toList with | x :: y :: z :: _ => let ⟨newOriginal, newAcc, FA⟩ := Compression.applyFullAdder col x y z h acc @@ -40,8 +40,8 @@ partial def WallaceStageColumn (col : Nat) (h : BitHeap) (acc : BitHeap) | _ => (h, acc, []) -- TODO: This is the same function as WallaceStageColumn, with incomplete termination proof. -def WallaceStageColumnNotPartial (col : Nat) (h : BitHeap) (acc : BitHeap) - : BitHeap × BitHeap × List Adder := +def WallaceStageColumnNotPartial (col : Nat) (h : BitHeap w) (acc : BitHeap w) + : BitHeap w × BitHeap w × List Adder := match (h.get col).toList with | x :: y :: z :: _ => let ⟨newOriginal, newAcc, FA⟩ := Compression.applyFullAdder col x y z h acc @@ -84,7 +84,7 @@ compressed bits are removed from both acc and h (original heap). This separation lets us track which bits are carries, since generated carry bits contribute to the height calculation but are not themselves compressed in the same stage. -/ -def WallaceStage (h : BitHeap) : BitHeap × List Adder := +def WallaceStage (h : BitHeap w) : BitHeap w × List Adder := let (_, acc, adders) := (List.range h.columns.size).foldl (fun (original, acc, adders) col => @@ -100,7 +100,7 @@ Top-level Wallace Tree function. Repeatedly applies WallaceStage until every column has at most 2 bits. -/ -partial def WallaceTree (h : BitHeap) : BitHeap × List Adder := +partial def WallaceTree (h : BitHeap w) : BitHeap w × List Adder := if h.maxHeight ≤ 2 then (h, []) else