playground/coq/ohe.v

175 lines
3.8 KiB
Coq

(* One hot encoding *)
From mathcomp Require Import all_ssreflect.
Set Bullet Behavior "Strict Subproofs".
(* Set Implicit Arguments. *)
Require Import Bvector.
Search (Vector.t _ _ -> Vector.t _ _ -> bool).
Print BVeq.
Print Vector.eqb.
About Vector.eqb.
(* Parameters: size and a flag indicating whether the mark has been made *)
(* Valid values have flag set to true *)
Inductive t: nat -> bool -> Type :=
| Nil: t 0 false
| Pad: forall {n: nat} {b: bool},
t n b -> t (S n) b
| Mark: forall {n: nat},
t n false -> t (S n) true.
Fixpoint eqb {b1 b2: bool} {n1 n2: nat} (o1: t n1 b1) (o2: t n2 b2): bool.
case (Nat.eqb n1 n2) eqn:Hneq.
move/PeanoNat.Nat.eqb_spec: Hneq => Hneq.
refine (
match o1, o2 with
| Nil, Nil => true
| Pad n1' b1' o1', Pad n2' b2' o2' => _
| Mark n1' o1', Mark n2' o2' => _
| _, _ => false
end).
- move: n1' n2' o1' o2' => n1' n2'.
case (Nat.eqb n1' n2') eqn:Hneq'.
+ move/PeanoNat.Nat.eqb_spec: Hneq' => Hneq'.
rewrite -Hneq'.
move => o1' o2'.
exact: eqb _ _ _ _ o1' o2'.
+ exact: (fun _ _ => false).
- move: n1' n2' o1' o2' => n1' n2'.
case (Nat.eqb n1' n2') eqn:Hneq'.
+ move/PeanoNat.Nat.eqb_spec: Hneq' => Hneq'.
rewrite -Hneq'.
move => o1' o2'.
exact: eqb _ _ _ _ o1' o2'.
+ exact: (fun _ _ => false).
- exact: false.
Defined.
Fixpoint markCount_aux {n: nat} {b: bool} (a: t n b): nat :=
match a with
| Nil => 0
| Pad _ _ a' => markCount_aux a'
| Mark _ a' => S (markCount_aux a')
end.
Definition markCount {n: nat} (a: t n true): nat :=
markCount_aux a.
Definition to_bv {n: nat} (a: t n true): Bvector n.
move: n a => n.
elim.
- exact: [].
- move => n' b a' bv'.
exact: (false :: bv').
- move => n' a' bv'.
exact: (true :: bv').
Defined.
Lemma lmCountMark: forall (n: nat) (b: bool) (a: t n b),
markCount_aux a = if b then 1 else 0.
Proof.
move => n b a.
elim: a => //=.
by move => n' a' ->.
Qed.
Lemma lmCountMarkTrue: forall (n: nat) (a: t n true),
markCount a = 1.
Proof.
move => n a.
exact: (lmCountMark n true a).
Qed.
(*
The command has indeed failed with message:
Illegal application:
The term "@markCount" of type "forall n : nat, t n true -> nat"
cannot be applied to the terms
"n0" : "nat"
"t" : "t n0 b"
The 2nd term has type "t n0 b" which should be a subtype of
"t n0 true".
*)
Fail Check Mark (Pad (Mark (Pad Nil))).
Example eg1: t 0 false := Nil.
Example eg2: t 2 true := Pad (Mark Nil).
Example eg3: t 2 true := Mark (Pad Nil).
Example eg4: t 3 true := Pad (Mark (Pad Nil)).
Example eg5: t 4 true := Pad (Pad (Mark (Pad Nil))).
Compute markCount eg2.
Compute markCount eg3.
Compute markCount eg4.
Compute markCount eg5.
Compute to_bv eg5.
(* = [false; false; true; false] : Bvector 4 *)
Fixpoint genPad (z: nat): t z false :=
match z with
| O => Nil
| S z' => Pad (genPad z')
end.
Fixpoint padL {n: nat} (a: t n true) (z: nat): t (z+n) true.
refine (
match z with
| O => _
| S z' => Pad (padL n a z')
end).
by [].
Defined.
Fixpoint padR {n: nat} {b: bool} (a: t n b) (z: nat): t (z+n) b.
refine (
match a with
| Nil => _
| Pad _ _ a' =>
match z with
| O => _
| S z' => _
end
| Mark _ a' =>
match z with
| O => _
| S z' => _
end
end).
- rewrite addn0.
exact: genPad z.
- rewrite add0n.
exact: (Pad a').
- rewrite addSnnS.
exact: padR _ _ (Pad (Pad a')) z'.
- rewrite add0n.
exact: Mark a'.
- rewrite addSnnS.
exact: padR _ _ (Pad (Mark a')) z'.
Defined.
Compute padL (Pad (Mark Nil)) 3.
Compute padR (Pad (Mark Nil)) 3.
Example eg := Eval compute in padR (Pad (Mark Nil)) 3.
Require Import Extraction.
Extraction Language Haskell.
Recursive Extraction eg.
(*
data T =
Nil
| Pad Nat Bool T
| Mark Nat T
eg :: T
eg =
Pad (S (S (S (S O)))) True (Pad (S (S (S O))) True (Pad (S (S O)) True (Pad
(S O) True (Mark O Nil))))
*)