(** * n - queens problem *)

Set Implicit Arguments.
Require Import Arith.
Require Import Omega.
Require Import Dec.
Require Import List.
Require Import Sets.
Require Import Coq.Program.Program.

Section Nqueens.

(** ** General definitions *)
(** The parametric number of queens on the board *)
Variable N : nat.

(** A board is represented as a function from nat to nat *)
Definition board := list nat.

(** falist P k [a0;..;an] holds when 
   P k a0 [a1;..an] /\ P (k+1) a1 [a2;..;an] /\ P (k+n) an []
*)
Inductive falist A (P:nat -> A -> Prop) (k : nat) : list A -> Prop :=
   fanil : falist P k nil
 | facons : forall a l, P k a -> falist P (S k) l -> falist P k (cons a l).
Hint Constructors falist.

Lemma falist_dec : forall A (P:nat -> A -> Prop), 
      (forall n a, {P n a}+{~ P n a}) -> 
      forall l k, {falist P k l}+{~ falist P k l}.
induction l; auto; intros.
destruct (X k a) as [HP|HnP].
destruct (IHl (S k)) as [Hl|Hnl]; auto.
right; intros Hf; inversion Hf; auto.
right; intros Hf; inversion Hf; auto.
Defined.

(** The condition for a new value in [k] to be compatible with the board after [k] *)
Definition ok (k qk:nat) : board -> Prop := 
   falist (fun i qi => qk <> qi /\ k + qk <> i + qi /\ N - k + qk <> N - i + qi)
          (S k).

(** This condition is decidable *)
Lemma ok_dec : forall (k qk:nat) (q: board), {ok k qk q} + {~ ok k qk q}.
unfold ok; intros. 
apply falist_dec; intros.
apply and_dec.
apply not_dec; apply eq_nat_dec.
apply and_dec.
apply not_dec; apply eq_nat_dec.
apply not_dec; apply eq_nat_dec.
Defined.

Definition psol k (p : board) : Prop := 
   length p = N - k /\ 
   forall i, i < N - k -> let qi:= nth i p 0 in 
                          (qi < N /\ ok (i+k) qi (skipn (S i) p)).
Hint Unfold ok psol.


(** Correctness when extending the board *)
Lemma psol_ext : forall (q: board) (k qk:nat),
  k < N -> qk < N -> psol (S k) q -> ok k qk q -> psol k (qk::q).
intros q k qk Hk Hqk (H1,H2) H3; split; intros.
simpl; omega.
destruct i.
simpl; firstorder.
simpl nth in qi; simpl skipn.
destruct (H2 i); intuition.
replace (S i + k) with (i + S k); intuition.
Save.
Hint Resolve psol_ext.

Lemma psol_nil : psol N nil.
red; simpl; intuition.
Save.
Hint Resolve psol_nil.

(** ** Naive solutions *)

(** Pure functional version, we start from a partial solution and search for an extension *)
Fixpoint sol k (q:board) {struct k} (* psol k q *) 
   : option board := 
    match k with O => Some q 
       | S m => 
         let fix search i {struct i} (* no sol found with (m,j) i < j < N *) 
                : option board := 
         match i with O => None 
                   | S j => if ok_dec m j q then 
                     match sol m (j::q) with 
                        | None   => search j
                        | Some b => Some b
                      end
                            else search j
         end
         in search N
     end.

Notation "! P " := {_:unit | P } (at level 35) : type_scope.

Local Obligation Tactic := program_simpl; try omega.

(** Solution with dependent types *)
Program Fixpoint solr k (q:board) (_:!(psol k q /\ k <= N)) {struct k}  
   : option {p : board | psol 0 p} := 
    match k with O => Some q 
       | S m => 
         let fix search i (_:!(i <=N))  {struct i} (* no sol found with (m,j) i < j < N *)  
                : option {b: board | psol 0 b} := 
         match i with O => None 
                   | S j => if ok_dec m j q then 
                     match solr (k:=m) (q:=j::q) tt with 
                        | None   => search j tt
                        | Some b => Some b
                      end
                            else search j tt
         end
         in search N tt
     end.
Next Obligation.
intuition.
Defined.


Program Definition one_sol : option {p : board | psol 0 p} := solr (k:=N) (q:=nil) tt.


(** Computing all solutions *)
Fixpoint all_solr k (q:board) {struct k} (* psol q k *) 
   : list board := 
    match k with O => q :: nil
       | S m => 
         let fix search i  {struct i} 
                (* explores sols extending q with (m,j) j < i *) 
                : list board := 
         match i with O => nil 
                   | S j => if ok_dec m j q then 
                     all_solr m (j::q) ++ search j 
                     else search j
         end
         in search N
     end.

Definition all_sol := all_solr N nil.

(** Computing the number of solutions *)
Fixpoint nb_solr k (q:board) {struct k} (* psol q k *) 
   : nat := 
    match k with O => 1
       | S m => 
         let fix search i {struct i} (* counts sol extending q with (m,j) j < i *) 
                : nat := 
         match i with O => 0 
                   | S j => if ok_dec m j q then 
                     nb_solr m (j::q) + search j 
                     else search j
         end
         in search N
     end.
Definition nb_sol := nb_solr N nil.

(** Specifying the computation of the number of solutions *)


Definition ext_board (p q : board) k := 
           forall i, i < length q -> nth i q 0 = nth (k + i) p 0.

Lemma ext_board_refl : forall p, ext_board p p 0.
red; auto.
Save.
Hint Resolve ext_board_refl.

Lemma ext_board_prev : forall a p q k, ext_board p (a::q) k -> ext_board p q (S k).
intros a p q k H i Hi.
transitivity (nth (S i) (a::q) 0); auto.
rewrite (H (S i)); simpl; intuition.
replace (k + S i) with (S (k + i)); auto.
Save.

Lemma ext_board_nil : forall p, ext_board p nil N.
intros p i; simpl; omega.
Save.

Hint Resolve ext_board_prev ext_board_nil.

Lemma psol_length : forall p q m, psol 0 p -> psol m q -> 
  m <= N -> length q = length p - m. 
intros p q m (Lp,_) (Lq,_) Hm; omega.
Save.
Hint Immediate psol_length.

Lemma psol0_eq : forall p q, length p = length q -> ext_board p q 0 -> p = q.
induction p; intros.
destruct q; auto; simpl in *; discriminate.
destruct q; simpl in *; try discriminate.
assert (H1:=H0 0); simpl in H1.
rewrite H1; auto; f_equal; auto with arith.
apply IHp; auto with arith.
intros i Hi; simpl.
apply (H0 (S i)); simpl; auto with arith.
Save.
Hint Resolve psol0_eq.

Lemma psol_skip : 
  forall k p q, length p = k + length q -> ext_board p q k -> skipn k p = q.
induction k; intros.
simpl in *; auto with arith.
destruct p; try discriminate H. 
simpl in *.
apply IHk; try omega.
intros i Hi; apply (H0 i); try omega.
Save.

Lemma psol_nth : forall p q m j, 
      ext_board p (j::q) m -> nth m p 0 = j.
intros p q m j H.
transitivity (nth 0 (j :: q) 0); auto.
rewrite (H 0); simpl; auto with arith.
replace (m+0) with m; auto.
Save.

Lemma ext_board_nth : forall p q m j, 
      ext_board p q (S m) ->  nth m p 0 = j -> ext_board p (j::q) m.
intros p q m j He Hj i Hi.
destruct i.
simpl; replace (m+0) with m; auto with arith.
replace (m+S i) with ((S m)+i); try omega.
apply (He i); simpl in *; omega.
Save.
Hint Resolve ext_board_nth.

(* Developing the program for counting as a proof *)
Lemma nb_solp : {n : nat  | card (fun p => psol 0 p) n}.
(* Generalisation *)
assert (nb_solr: forall k (q:board), 
   psol k q -> k <= N -> {n : nat  | card (fun p => psol 0 p /\ ext_board p q k) n}).
induction k; intros.
(* case k=0, one solution found *)
exists 1.
apply cardS with q (empty board); auto.
intro p; simpl.
unfold add; intuition; subst; auto.
destruct H; destruct H2; assert (length p = length q); try omega; auto.
(* case S k=0, add solution for all (k,j) *)
assert (search:forall i, i <= N -> {n:nat | card (fun p => psol 0 p /\ ext_board p q (S k) /\ nth k p 0 < i) n}).
induction i; intros.
(* case i=0, no solutions *)
exists 0.
constructor.
intro p; simpl.
intuition; omega.
(* case (S i) : computes smaller solutions *)
destruct IHi as (ni,Hni); auto with arith.
(* tests if (k,j) is possible *)
destruct (ok_dec k i q) as [Hok|Hnok].
(* case ok *)
destruct (IHk (i::q)) as (nk,Hnk); auto with arith.
exists (nk+ni).
apply card_eq_compat 
  with (union (fun p : board => psol 0 p /\ ext_board p (i :: q) k)
              (fun p : board => psol 0 p /\ ext_board p q (S k) /\ nth k p 0 < i)).
apply card_union; auto.
intro p; simpl.
intros (_,He) (Hp,(_,Hi)).
assert (nth k p 0 = i).
apply psol_nth with q; auto.
omega.
unfold union; intro p; simpl; intuition.
apply ext_board_prev with i; auto.
assert (nth k p 0 = i); try omega.
apply psol_nth with q; auto.
assert (nth k p 0 < i \/ nth k p 0 = i) as [Hi|He]; try omega; auto.
(* case not ok *)
exists ni.
apply card_eq_compat with (1:=Hni).
intros p; intuition.
assert (nth k p 0 <> i); try omega.
intro Hp.
apply Hnok.
destruct H; destruct H3.
replace q with (skipn (S k) p).
destruct (H6 k); try omega.
subst i.
replace (k+0) with k in H8; auto with arith.
apply psol_skip; auto with arith; try omega.
(* final result for nb_solr from search N *)
destruct (search N) as (n,Hcn); clear search; auto with arith.
exists n; apply card_eq_compat with (1:=Hcn).
intros p; intuition.
destruct H2 as (_,Hpk).
destruct (Hpk k); auto with arith; try omega.
(* final count from nb_solr N nil *)
destruct (nb_solr N nil) as (n,Hcn); auto.
exists n; apply card_eq_compat with (1:=Hcn).
intros p; intuition.
Defined.

End Nqueens.

(** Extracting the functions *)

Extraction "nqueens" all_sol nb_sol nb_solp.

(** Tests for computing the solution *)
(*
Eval compute in (all_sol 0).
Eval compute in (all_sol 1).
Eval compute in (all_sol 2).
Eval compute in (all_sol 3).
Eval compute in (all_sol 4).
Eval compute in (all_sol 5).
Eval compute in (all_sol 6).
Time Eval vm_compute in (all_sol 7).
Eval compute in (one_sol 0).
Eval compute in (one_sol 1).
Eval compute in (one_sol 2).
Eval compute in (one_sol 3).
Eval compute in (one_sol 4).
Eval compute in (one_sol 5).
Eval compute in (one_sol 6).
Eval compute in (one_sol 7).
Time Eval compute in (one_sol 8).
Time Eval vm_compute in (one_sol 8).
Time Eval vm_compute in (one_sol 10).
Time Eval vm_compute in (one_sol 12).
Time Eval vm_compute in (one_solp 12).
*)
