playground/coq/ohe.v

244 lines
6.1 KiB
Coq

(*
One hot encoding
Requirements:
- Exactly one marked bit
- Size = n
*)
From mathcomp Require Import all_ssreflect.
Set Bullet Behavior "Strict Subproofs".
(* Set Implicit Arguments. *)
Require Import Bvector.
(* Require Import Arith. *)
(* Parameters: size and a flag indicating whether the mark has been made *)
(* Valid values have flag set to true *)
Inductive t: nat -> bool -> Type :=
| Nil: t 0 false
| Pad: forall {n: nat} {b: bool},
t n b -> t (S n) b
| Mark: forall {n: nat},
t n false -> t (S n) true.
Fixpoint eqb {b1 b2: bool} {n1 n2: nat} (o1: t n1 b1) (o2: t n2 b2): bool.
case (Nat.eqb n1 n2) eqn:Hneq.
have Hn := iffLR (PeanoNat.Nat.eqb_eq n1 n2) Hneq.
clear Hneq; rename Hn into Hneq.
- refine (
match o1, o2 with
| Nil, Nil => true
| Pad n1' b1' o1', Pad n2' b2' o2' => _
| Mark n1' o1', Mark n2' o2' => _
| _, _ => false
end).
+ case (Nat.eqb n1' n2') eqn:Hneq'.
* have Hn := iffLR (PeanoNat.Nat.eqb_eq n1' n2') Hneq'.
clear Hneq'; rename Hn into Hneq'.
rewrite -Hneq' in o2'.
exact: eqb _ _ _ _ o1' o2'.
* exact: false.
+ case (Nat.eqb n1' n2') eqn:Hneq'.
* have Hn := iffLR (PeanoNat.Nat.eqb_eq n1' n2') Hneq'.
clear Hneq'; rename Hn into Hneq'.
rewrite -Hneq' in o2'.
exact: eqb _ _ _ _ o1' o2'.
* exact: false.
- exact: false.
Defined.
Fixpoint eqb' {b1 b2: bool} {n1 n2: nat} (o1: t n1 b1) (o2: t n2 b2): bool.
case (Nat.eqb n1 n2) eqn:Hneq.
move/PeanoNat.Nat.eqb_spec: Hneq => Hneq.
refine (
match o1, o2 with
| Nil, Nil => true
| Pad n1' b1' o1', Pad n2' b2' o2' => _
| Mark n1' o1', Mark n2' o2' => _
| _, _ => false
end).
- move: n1' n2' o1' o2' => n1' n2'.
case (Nat.eqb n1' n2') eqn:Hneq'.
+ move/PeanoNat.Nat.eqb_spec: Hneq' => Hneq'.
rewrite -Hneq'.
move => o1' o2'.
exact: eqb' _ _ _ _ o1' o2'.
+ exact: (fun _ _ => false).
- move: n1' n2' o1' o2' => n1' n2'.
case (Nat.eqb n1' n2') eqn:Hneq'.
+ move/PeanoNat.Nat.eqb_spec: Hneq' => Hneq'.
rewrite -Hneq'.
move => o1' o2'.
exact: eqb' _ _ _ _ o1' o2'.
+ exact: (fun _ _ => false).
- exact: false.
Defined.
(* move/PeanoNat.Nat.eqb_spec: Hneq => Hneq. *)
(* refine ( *)
(* match o1, o2 with *)
(* | Nil, Nil => true *)
(* | Pad n1' b1' o1', Pad n2' b2' o2' => _ *)
(* | Mark n1' o1', Mark n2' o2' => _ *)
(* | _, _ => false *)
(* end). *)
(* - move: n1' n2' o1' o2' => n1' n2'. *)
(* case (Nat.eqb n1' n2') eqn:Hneq'. *)
(* + Search Nat.eqb. *)
(* (1* PeanoNat.Nat.eqb_eq *1) *)
(* Check (EqNat.beq_nat_true_stt n1' n2'). *)
(* Search Nat.eqb Nat.eqb. *)
(* move: Hneq'. *)
(* Check Nat.eqb_eq n1' n2'. *)
(* Check (EqNat.beq_nat_true_stt n1' n2'). *)
(* rewrite (EqNat.beq_nat_true_stt n1' n2') in Hneq'. *)
(* move => o1' o2'. *)
(* exact: eqb _ _ _ _ o1' o2'. *)
(* Search (_ = _) Nat.eqb. *)
(* Check PeanoNat.Nat.eqb_eq n1' n2'. *)
(* rewrite (EqNat.beq_nat_true_stt n1' n2') in Hneq'. *)
(* rewrite -(EqNat.beq_nat_true_stt n1' n2'). *)
(* S *)
(* + exact: (fun _ _ => false). *)
(* - move: n1' n2' o1' o2' => n1' n2'. *)
(* case (Nat.eqb n1' n2') eqn:Hneq'. *)
(* + move/PeanoNat.Nat.eqb_spec: Hneq' => Hneq'. *)
(* rewrite -Hneq'. *)
(* move => o1' o2'. *)
(* exact: eqb _ _ _ _ o1' o2'. *)
(* + exact: (fun _ _ => false). *)
(* - exact: false. *)
(* Defined. *)
Fixpoint markCount_aux {n: nat} {b: bool} (a: t n b): nat :=
match a with
| Nil => 0
| Pad _ _ a' => markCount_aux a'
| Mark _ a' => S (markCount_aux a')
end.
Definition markCount {n: nat} (a: t n true): nat :=
markCount_aux a.
Definition to_bv {n: nat} (a: t n true): Bvector n.
move: n a => n.
elim.
- exact: [].
- move => n' b a' bv'.
exact: (false :: bv').
- move => n' a' bv'.
exact: (true :: bv').
Defined.
Lemma lmCountMark: forall (n: nat) (b: bool) (a: t n b),
markCount_aux a = if b then 1 else 0.
Proof.
move => n b a.
elim: a => //=.
by move => n' a' ->.
Qed.
Lemma lmCountMarkTrue: forall (n: nat) (a: t n true),
markCount a = 1.
Proof.
move => n a.
exact: (lmCountMark n true a).
Qed.
Fixpoint genPad (z: nat): t z false :=
match z with
| O => Nil
| S z' => Pad (genPad z')
end.
Fixpoint padL {n: nat} (a: t n true) (z: nat): t (z+n) true.
refine (
match z with
| O => _
| S z' => Pad (padL n a z')
end).
by [].
Defined.
Fixpoint padR {n: nat} {b: bool} (a: t n b) (z: nat): t (n+z) b.
refine (
match a with
| Nil => _
| Pad _ _ a' =>
match z with
| O => _
| S z' => _
end
| Mark _ a' =>
match z with
| O => _
| S z' => _
end
end).
- rewrite add0n.
exact: genPad z.
- rewrite addn0.
exact: (Pad a'). (* No change *)
- rewrite -addSnnS.
exact: padR _ _ (Pad (Pad a')) z'. (* an extra padding *)
- rewrite addn0.
exact: (Mark a'). (* No change *)
- rewrite -addSnnS.
exact: padR _ _ (Pad (Mark a')) z'.
Defined.
(**************************************)
Fail Check Mark (Pad (Mark (Pad Nil))).
Example eg1: t 0 false := Nil.
Example eg2: t 2 true := Pad (Mark Nil).
Example eg3: t 2 true := Mark (Pad Nil).
Example eg4: t 3 true := Pad (Mark (Pad Nil)).
Example eg5: t 4 true := Pad (Pad (Mark (Pad Nil))).
Compute eqb Nil Nil.
Compute eqb (Mark Nil) Nil.
Compute eqb (Pad (Mark Nil)) (Mark (Pad Nil)).
Compute eqb (Pad (Mark Nil)) (Pad (Mark Nil)).
Compute markCount eg2.
Compute markCount eg3.
Compute markCount eg4.
Compute markCount eg5.
Compute to_bv eg5.
(* = [false; false; true; false] : Bvector 4 *)
Compute padL (Pad (Mark Nil)) 3.
Compute padR (Pad (Mark Nil)) 3.
Example eg8 := Eval compute in padR (Pad (Mark Nil)) 3.
Example eg := Eval compute in padR (Pad (Mark Nil)) 3.
Require Import Extraction.
Extraction Language Haskell.
Recursive Extraction eg.
Recursive Extraction eqb.
Recursive Extraction eqb'.
Recursive Extraction eg8.
(*
data T =
Nil
| Pad Nat Bool T
| Mark Nat T
eg :: T
eg =
Pad (S (S (S (S O)))) True (Pad (S (S (S O))) True (Pad (S (S O)) True (Pad
(S O) True (Mark O Nil))))
*)