Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 93 additions & 160 deletions DatapathVerification/BitHeap/BitHeap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,101 +6,81 @@ import Mathlib.Tactic.SplitIfs
import Std.Data.HashMap.Lemmas
import Std.Tactic.Do

structure BitHeap where
width : Nat
columns : Std.HashMap Nat BitHeap.Column
structure BitHeap (width : Nat) where
columns : Vector BitHeap.Column width
Comment thread
osmanyasar05 marked this conversation as resolved.
deriving Inhabited

namespace BitHeap

open Circuit
open Column

instance : ToString BitHeap where
instance : ToString (BitHeap w) where
toString h :=
let entries := h.columns.toList.mergeSort (fun a b => a.1 ≤ b.1)
"{" ++ String.intercalate ", " (entries.map (fun (k, v) => s!"{k} ↦ {v}")) ++ "}"
let entries := h.columns.toList.zipIdx
"{" ++ String.intercalate ", " (entries.map (fun (v, k) => s!"{k} ↦ {v}")) ++ "}"

def empty : BitHeap := ⟨0, Std.HashMap.emptyWithCapacity 0
def empty (w : Nat) : BitHeap w := ⟨Vector.replicate w Column.empty

/--
Evaluate a bit-heap, to compute the final sum of all the bits in the heap.
-/
def eval (h : BitHeap) (env : BitEnv) : Int :=
(h.columns.fold (init := 0) (fun acc w col => acc + (2 ^ w) * col.eval env))

def eval' (h : BitHeap) (env : BitEnv) : Int := Id.run do
let mut acc := 0
for (col, val) in h.columns do
acc := acc + (2 ^ col) * val.eval env
return acc
def eval (h : BitHeap w) (env : BitEnv) : Int :=
h.columns.toList.zipIdx.foldl (fun acc (col, idx) => acc + 2^idx * col.eval env) 0

/--
Evaluate a bit-heap modulo 2^width, to compute the final sum of all the bits in the heap.
-/
def evalMod (h : BitHeap) (env : BitEnv) : Int :=
h.eval env % 2^(h.width)

structure AdderResult where
heap : BitHeap
sum : Circuit
carry : Circuit
def evalMod (h : BitHeap w) (env : BitEnv) : Int :=
h.eval env % 2^(w)

def get (h : BitHeap) (column : Nat) : Column :=
def get (h : BitHeap w) (column : Nat) : Column :=
h.columns.getD column (Column.empty)

theorem get_eq_getD (h : BitHeap) (column : Nat) :
h.get column = (h.columns[column]?).getD Column.empty := by
simp [get, Std.HashMap.getD_eq_getD_getElem?]
-- if index is in bounds, getD returns the value
theorem getD_in_bounds (h : BitHeap w) (column : Nat) (h1 : column < w) :
h.get column = h.columns[column] := by
simp [get, Vector.getD]
rw [Vector.getElem_eq_getElem?_get, Vector.getD_getElem?]
simp [h1]

instance : Membership Circuit BitHeap where
instance : Membership Circuit (BitHeap w) where
mem h c :=
∃ (col : Nat), c ∈ h.get col

def removeBit (column : Nat) (c : Circuit) (h : BitHeap) : BitHeap :=
⟨h.width, h.columns.modify column (fun col => col.erase c)⟩
def removeBit (column : Nat) (c : Circuit) (h : BitHeap w) : BitHeap w :=
⟨h.columns.setIfInBounds column ((h.get column).erase c)⟩

-- Maximum height across all columns
def maxHeight (h : BitHeap) : Nat :=
h.columns.fold (init := 0) (fun acc _ col => max acc col.height)
def maxHeight (h : BitHeap w) : Nat :=
h.columns.foldl (fun acc col => max acc col.height) 0

-- Highest column of the BitHeap, return none if BitHeap is empty
def highestColumn (h : BitHeap) : Option Nat :=
def highestColumn (h : BitHeap w) : Option Nat :=
let target := h.maxHeight
h.columns.toList.findSome? (fun (idx, col) => if col.height == target then some idx else none)
if target == 0 then none else
h.columns.toList.zipIdx.findSome?
(fun (col, idx) => if col.height == target then some idx else none)
Comment on lines -68 to +63

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop this, since it is dead code now.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might use this function in the future, so I'd keep it.


/--
Add a bit into the bit heap, returning a new bit heap.
If the bit already exists in the column, remove it and add it to the next column.
Stops carrying when the column exceeds the width of the bit heap.
-/
def addBit (column : Nat) (c : Circuit) (h : BitHeap) : BitHeap :=
if column >= h.width then h else
def addBit (column : Nat) (c : Circuit) (h : BitHeap w) : BitHeap w :=
if h_bounds : column >= w then h else
have h1 : column < w := by omega
let col := h.get column
if !col.contains c then
⟨h.width, h.columns.insert column (col.insert c)⟩
else addBit (column + 1) c (h.removeBit column c)
termination_by h.width - column
decreasing_by
have hw : (removeBit column c h).width = h.width := by rfl
rw [hw]
omega

@[simp]
theorem removeBit_width (column : Nat) (c : Circuit) (h : BitHeap) :
(removeBit column c h).width = h.width := by rfl
if !col.contains c then
⟨h.columns.set column (col.insert c) h1⟩
else addBit (column + 1) c (h.removeBit column c)

@[simp]
theorem addBit_width (column : Nat) (c : Circuit) (h : BitHeap) :
(addBit column c h).width = h.width := by
fun_induction addBit with
| case1 => rfl
| case2 => rfl
| case3 _ _ _ _ _ ih =>
rw [removeBit_width] at ih
rw [ih]
structure AdderResult (w : Nat) where
heap : BitHeap w
sum : Circuit
carry : Circuit

def halfAdder (column : Nat) (i j : Circuit) (h : BitHeap) : AdderResult :=
def halfAdder (column : Nat) (i j : Circuit) (h : BitHeap w) : AdderResult w :=
let h := h.removeBit column i
let h := h.removeBit column j
let sum := Circuit.binop .xor i j
Expand All @@ -109,7 +89,7 @@ def halfAdder (column : Nat) (i j : Circuit) (h : BitHeap) : AdderResult :=
let h := h.addBit (column + 1) carry
⟨h, sum, carry⟩

def fullAdder (column : Nat) (i j k : Circuit) (h : BitHeap) : AdderResult :=
def fullAdder (column : Nat) (i j k : Circuit) (h : BitHeap w) : AdderResult w :=
let h := h.removeBit column i
let h := h.removeBit column j
let h := h.removeBit column k
Expand All @@ -120,76 +100,32 @@ def fullAdder (column : Nat) (i j k : Circuit) (h : BitHeap) : AdderResult :=
⟨h, sum, carry⟩

@[simp]
theorem evalMod_heap_removeBit (column : Nat) (c : Circuit) (h : BitHeap) (env : BitEnv) (h1 : c ∈ h.get column) :
(h.removeBit column c).evalMod env = (h.evalMod env - 2^(column) * (c.eval env).toInt) % 2^(h.width) := by
theorem evalMod_heap_removeBit (column : Nat) (c : Circuit) (h : BitHeap w) (env : BitEnv) (h1 : c ∈ h.get column) :
(h.removeBit column c).evalMod env = (h.evalMod env - 2^(column) * (c.eval env).toInt) % 2^(w) := by
unfold evalMod
rw [removeBit_width]
simp [eval, removeBit]
have : (h.get column |>.erase c).eval env = (h.get column).eval env - 2 ^ column * (c.eval env).toInt := by
sorry
-- have : (h.columns.modify column fun col => col.erase c) = h.columns - 2 ^ column * (c.eval env).toInt := by sorry
repeat rw [Std.HashMap.fold_eq_foldl_toList]
sorry

theorem by_pow2_of_zero_eval (h : BitHeap) (h1 : col ≥ h.width) :
(2 : Int) ^ h.width ∣ (2 : Int) ^ col := by
theorem by_pow2_of_zero_eval (h : BitHeap w) (h1 : col ≥ w) :
(2 : Int) ^ w ∣ (2 : Int) ^ col := by
sorry
-- exact Nat.pow_dvd_pow_iff_le_right'.mpr h1 -> this works for Nat.

/--
Relate BitHeap.env to sum of a list. (Nat x Column) comes from Std.HashMap.toList, since it returns (Key x Value) pairs.
-/
theorem foldl_sum (l : List (Nat × Column)) (env : BitEnv) (a : Int) :
l.foldl (fun acc (p : Nat × Column) => acc + 2 ^ p.1 * (p.2.eval env : Int)) a =
a + (l.map (fun p => 2 ^ p.1 * (p.2.eval env : Int))).sum := by
induction l generalizing a with
| nil => simp
| cons p ps ih =>
grind

@[grind => ]
theorem eval_insertColumn_eq_eval_add (h : BitHeap) (k : Nat) (v : Column) (env : BitEnv) :
(⟨h.width, h.columns.insert k v⟩ : BitHeap).eval env
= h.eval env + 2 ^ k * (v.eval env : Int) - 2 ^ k * ((h.get k).eval env : Int) := by
cases h
case mk width cols =>
simp [BitHeap.eval]
rw [Std.HashMap.fold_eq_foldl_toList]
rw [Std.HashMap.fold_eq_foldl_toList]
rw [foldl_sum]
rw [foldl_sum]
sorry

@[grind => ]
theorem eval_eraseColumn_eq_eval_sub (h : BitHeap) (k : Nat) (env : BitEnv) :
(⟨h.width, h.columns.erase k⟩ : BitHeap).eval env
= h.eval env - 2 ^ k * ((h.get k).eval env : Int) := by
simp [BitHeap.eval]
rw [Std.HashMap.fold_eq_foldl_toList]
rw [Std.HashMap.fold_eq_foldl_toList]

rw?
sorry


theorem eval_insertColumn (h : BitHeap) (k : Nat) (v : Column) (env : BitEnv) :
(⟨h.width, h.columns.insert k v⟩ : BitHeap).eval env
= (⟨h.width, h.columns.erase k⟩ : BitHeap).eval env + 2 ^ k * (v.eval env : Int) := by
sorry


theorem eval_eraseColumn (h : BitHeap) (k : Nat) (env : BitEnv) :
h.eval env
= (⟨h.width, h.columns.erase k⟩ : BitHeap).eval env + 2 ^ k * ((h.get k).eval env : Int) := by
theorem eval_insertColumn (h : BitHeap w) (k : Nat) (col : Column) (env : BitEnv) (h1 : column < w) :
({ columns := h.columns.set column col h1 } : BitHeap w).eval env
= ({ columns := h.columns.set column (Column.empty) h1} : BitHeap w).eval env + 2 ^ column * (col.eval env : Int) := by
sorry

@[simp]
theorem evalMod_heap_addBit (column : Nat) (c : Circuit) (h : BitHeap) (env : BitEnv) :
(h.addBit column c).evalMod env = (h.evalMod env + 2^column * (c.eval env).toInt) % 2^(h.width) := by
theorem evalMod_heap_addBit (column : Nat) (c : Circuit) (h : BitHeap w) (env : BitEnv) :
(h.addBit column c).evalMod env = (h.evalMod env + 2^column * (c.eval env).toInt) % 2^(w) := by
fun_induction addBit with
| case1 col h h1 =>
simp [evalMod]
have h3 : 2 ^ col * (c.eval env).toInt % 2 ^ h.width = 0 := by
have h3 : 2 ^ col * (c.eval env).toInt % 2 ^ w = 0 := by
generalize hvi : c.eval env = vi
rcases vi
· simp
Expand All @@ -198,47 +134,55 @@ theorem evalMod_heap_addBit (column : Nat) (c : Circuit) (h : BitHeap) (env : Bi
apply Int.emod_eq_zero_of_dvd
exact_mod_cast by_pow2_of_zero_eval h h1
simp [Int.add_emod, h3]
| case2 col h h1 =>
| case2 column h h4 h3 col h1 =>
simp only [evalMod, Int.emod_add_emod]
have h3 : (⟨h.width, h.columns.insert col ((h.get col).insert c)⟩ : BitHeap).eval env = h.eval env + 2 ^ col * (c.eval env).toInt := by

rw [eval_insertColumn, eval_eraseColumn h col env]

rw [Column.eval_insert]
· grind
· simp
grind
rw [h3]
| case3 _ _ _ h2 h1 ih =>
simp only [ih, removeBit_width]
have : ({ columns := h.columns.set column (col.insert c) h3 } : BitHeap w).eval env = (h.eval env + 2 ^ column * (c.eval env).toInt) := by
simp_all
rw [eval_insertColumn]
· rw [Column.eval_insert]
· simp
rw [Int.mul_add]
simp [eval]
sorry
· simp
exact h1
· exact h.maxHeight
rw [this]
| case3 col h h4 h3 h2 h1 ih =>
rw [ih]
rw [evalMod_heap_removeBit]
· simp only [Int.emod_add_emod]
grind
· simp at h1
simp [mem_iff_contains]
grind

@[simp]
theorem get_removeBit_of_ne (column : Nat) (h : BitHeap) (i j : Circuit)
(h1 : i ∈ h.get column) (hne : i ≠ j) :
i ∈ (removeBit column j h).get column := by
rw [get_eq_getD] at h1
rw [get_eq_getD]
simp only [removeBit]
rcases hcol : h.columns[column]?
· simp_all only
have : (- 2 ^ col * (c.eval env).toInt + 2 ^ (col + 1) * (c.eval env).toInt) = 2 ^ col * (c.eval env).toInt := by grind
rw [← this]
simp_all -- this looks very ugly
grind
· simp_all only [Option.getD_some, mem_iff_contains, ne_eq, Std.HashMap.getElem?_modify_self,
Option.map_some, Column.erase, Column.contains, Std.HashSet.contains_erase]
simp_all
grind

theorem removeBit_decreases_size (col : Nat) (c : Circuit) (h : BitHeap) (h1: c ∈ h.get col) :
theorem get_removeBit_self (column : Nat) (c : Circuit) (h : BitHeap w) (hb : column < w) :
(removeBit column c h).get column = (h.get column).erase c := by
simp only [removeBit]
rw [getD_in_bounds] <;> grind

@[simp]
theorem get_removeBit_of_ne (column : Nat) (h : BitHeap w) (i j : Circuit)
(h1 : i ∈ h.get column) (hne : i ≠ j) : i ∈ (removeBit column j h).get column := by
by_cases hb : column < w
· rw [get_removeBit_self _ _ _ hb]
exact (erase_eq_erase (h.get column) h1 (id (Ne.symm hne))).mpr h1
· have hr : removeBit column j h = h := by
simp only [removeBit]
rw [Vector.setIfInBounds_eq_of_size_le]
grind
rw [hr]
exact h1

theorem removeBit_decreases_size (col : Nat) (c : Circuit) (h : BitHeap w) (h1: c ∈ h.get col) :
((removeBit col c h).get col).height < (h.get col).height := by
simp only [removeBit, height_eq_size]
simp [erase]
sorry

theorem double_removeBit_decreases (col : Nat) (c₁ c₂ : Circuit) (h : BitHeap)
theorem double_removeBit_decreases (col : Nat) (c₁ c₂ : Circuit) (h : BitHeap w)
(h1 : c₁ ∈ h.get col) (h2 : c₂ ∈ h.get col) (hne : c₁ ≠ c₂) :
((removeBit col c₁ (removeBit col c₂ h)).get col).height < (h.get col).height := by
have h1' : c₁ ∈ (removeBit col c₂ h).get col :=
Expand All @@ -247,7 +191,7 @@ theorem double_removeBit_decreases (col : Nat) (c₁ c₂ : Circuit) (h : BitHea
(removeBit_decreases_size col c₁ (removeBit col c₂ h) h1')
(removeBit_decreases_size col c₂ h h2)

theorem triple_removeBit_decreases (col : Nat) (c₁ c₂ c₃ : Circuit) (h : BitHeap)
theorem triple_removeBit_decreases (col : Nat) (c₁ c₂ c₃ : Circuit) (h : BitHeap w)
(h1 : c₁ ∈ h.get col) (h2 : c₂ ∈ h.get col) (h3 : c₃ ∈ h.get col)
(hne12 : c₁ ≠ c₂) (hne13 : c₁ ≠ c₃) (hne23 : c₂ ≠ c₃) :
((removeBit col c₁ (removeBit col c₂ (removeBit col c₃ h))).get col).height < (h.get col).height := by
Expand All @@ -259,38 +203,27 @@ theorem triple_removeBit_decreases (col : Nat) (c₁ c₂ c₃ : Circuit) (h : B
(double_removeBit_decreases col c₁ c₂ (removeBit col c₃ h) h1' h2' hne12)
(removeBit_decreases_size col c₃ h h3)

@[simp]
theorem halfAdder_preserves_width (column : Nat) (i j : Circuit) (h : BitHeap) :
(h.halfAdder column i j).heap.width = h.width := by
simp [halfAdder, removeBit]

theorem halfAdder_correct_mod (column : Nat) (i j : Circuit) (h : BitHeap)
theorem halfAdder_correct_mod (column : Nat) (i j : Circuit) (h : BitHeap w)
(h1 : i ∈ h.get column) (h2 : j ∈ h.get column) (hne : i ≠ j) :
∀ (env : BitEnv), (h.halfAdder column i j).heap.evalMod env = h.evalMod env := by
intros env
have h3 := get_removeBit_of_ne column h j i h2 hne.symm
simp [halfAdder, evalMod_heap_addBit, addBit_width, removeBit_width]
simp [halfAdder, evalMod_heap_addBit]
simp only [evalMod_heap_removeBit, h1, h3]
simp [evalMod]
generalize hvi : i.eval env = vi
generalize hvj : j.eval env = vj
rcases vi <;> rcases vj <;> simp_all
grind

@[simp]
theorem fullAdder_preserves_width (column : Nat) (i j k : Circuit) (h : BitHeap) :
(h.fullAdder column i j k).heap.width = h.width := by
simp [fullAdder, removeBit]

theorem fullAdder_correct_mod (column : Nat) (i j k : Circuit) (h : BitHeap)
(h1 : i ∈ h.get column) (h2 : j ∈ h.get column) (h3 : k ∈ h.get column)
(hne : i ≠ j) (hne2 : i ≠ k) (hne3 : j ≠ k) :
theorem fullAdder_correct_mod (column : Nat) (i j k : Circuit) (h : BitHeap w)
(h1 : i ∈ h.get column) (h2 : j ∈ h.get column) (h3 : k ∈ h.get column) (hne : i ≠ j) (hne2 : i ≠ k) (hne3 : j ≠ k) :
∀ (env : BitEnv), (h.fullAdder column i j k).heap.evalMod env = h.evalMod env := by
intros env
have h4 := get_removeBit_of_ne column h j i h2 hne.symm
have h5 := get_removeBit_of_ne column (removeBit column i h) k
have h6 := h5 j (get_removeBit_of_ne column h k i h3 hne2.symm) hne3.symm
simp [fullAdder, evalMod_heap_addBit, addBit_width, removeBit_width]
simp [fullAdder, evalMod_heap_addBit]
simp only [evalMod_heap_removeBit, h1, h4, h6]
simp [evalMod]
generalize hvi : i.eval env = vi
Expand Down
Loading