Theory SimplifyRCRS

theory SimplifyRCRS
imports Simulink
subsection{*Automated Simplification*}

theory SimplifyRCRS imports Simulink
keywords "simplify_RCRS" "simplify_RCRS_f" :: thy_decl
begin

thm update_assert_comp

ML 
{*
  fun term_to_string ctx t = Pretty.string_of (Syntax.pretty_term ctx t)

  fun cterm_to_string ctx t = Pretty.string_of (Syntax.pretty_term ctx (Thm.term_of t))
  fun type_to_string ctx t = Pretty.string_of (Syntax.pretty_typ ctx t)
  fun thm_to_string ctxt thm = term_to_string ctxt (Thm.prop_of thm);

  fun pwriteln t = Pretty.writeln (Pretty.str t);

  val DEBUG_ASSERT = 1;
  val DEBUG_UPDATE = 2;
  val DEBUG_SERIAL = 3;
  val DEBUG_PROD = 4;
  val DEBUG_FEEDBACK = 5;
  val DEBUG_PROD_PREC = 6;
  val DEBUG_PROD_FUN = 7;
  val DEBUG_ASSERT_UPDATE = 8;
  fun DEBUG n = (n = DEBUG_ASSERT);
  fun DEBUG n = (n = DEBUG_UPDATE);
  fun DEBUG n = (n = DEBUG_SERIAL);
  fun DEBUG n = (n = DEBUG_PROD);
  fun DEBUG n = (n = DEBUG_FEEDBACK);
  fun DEBUG n = (n = DEBUG_PROD_PREC);
  fun DEBUG n = (n = DEBUG_PROD_FUN);
  fun DEBUG _ = false

  fun debug n str = (if DEBUG n then writeln str else ());

*}

ML{*
  val T = Time.now();

  fun time n f x = 
    if DEBUG n then
      let
        val t = Time.now()
        val y = f x
        val t' = t - Time.now()
        val _ = writeln ("Time: " ^ (Time.toString t'))
      in
        y
      end
    else f x


  fun print_term_to_update_or_delete (Const (c, _)) = "C(" ^ c ^ ") "
  | print_term_to_update_or_delete (Free (x, _)) = "(F(" ^ x ^ ")) "
  | print_term_to_update_or_delete (Var ((x, i), _))  = "V(" ^ x ^ (Int.toString i) ^ ") "
  | print_term_to_update_or_delete (Bound i) =  "B" ^ (Int.toString i) ^ " "
  | print_term_to_update_or_delete (Abs (x, _, b)) =
    let
      val x = "A(" ^ x ^ ") ";
      val y = print_term_to_update_or_delete b
    in "(" ^ x ^ y ^ ") " end
  | print_term_to_update_or_delete (t $ u) =  
    let
      val x = print_term_to_update_or_delete t;
      val y = print_term_to_update_or_delete u;
    in "(" ^ x ^ y ^ ") " end;
*}

definition "prod_fun f g  = (λ (x, y) . (f x, g y))"
definition "prod_prec p q  = (λ (x, y) . p x ∧ q y)"

lemma asseert_update_comp: "(⋀ x . let y = f x in p'' x = (p x ∧ p' y) ∧ f'' x = f' y) ⟹ ({.p.} o [-f-]) o ({.p'.} o [-f'-]) = {.p''.} o [-f''-]"
  by (simp add: update_def fun_eq_iff demonic_def assert_def le_fun_def Let_def)

lemma asseert_update_comp_abs_aux: "p'' = p ⊓ (p' o f) ⟹  f'' = f' o f ⟹ ({.p.} o [-f-]) o ({.p'.} o [-f'-]) = {.p''.} o [-f''-]"
  by (simp add: update_def fun_eq_iff demonic_def assert_def le_fun_def Let_def)

lemma asseert_update_comp_abs: "p ⊓ (p' o f) ≡ p'' ⟹  f' o f ≡ f'' ⟹ ({.p.} o [-f-]) o ({.p'.} o [-f'-]) = {.p''.} o [-f''-]"
  by (rule asseert_update_comp_abs_aux, auto)

lemma asseert_update_prod_abs: "prod_prec p p' ≡ p'' ⟹  prod_fun f f' ≡ f'' ⟹ ({.p.} o [-f-]) ** ({.p'.} o [-f'-]) = {.p''.} o [-f''-]"
  apply (simp add: prod_assert_update)
  by (simp add: prod_prec_def prod_fun_def assert_def update_def demonic_def fun_eq_iff le_fun_def)


thm If_prod

term "Product_Type.prod.case_prod"

lemma "case_prod f (a, b) = f a b"
  by simp

thm  Product_Type.case_prod_conv

declare [[show_sorts]]

lemma case_prod_eta_eq_sym: "f ≡ (λ (x, y) . f (x, y))"
  by simp

thm Product_Type.case_prod_eta 



term "T ((x,y) ,z) = (x+y,x+z)"


definition "TtestTerm x ≡ x + 3"

definition "TTtestTerm ≡ (λ (x, (u,v), y) . (x, x+y, u+v))"


lemma TT_simp: "TTtestTerm (x, (u,v), y) ≡ (x, x + y, u+v)"
  by (simp add: TTtestTerm_def)

lemma TTa_simp: "(G ≡ TTtestTerm) ⟹ (G (x, (u,v), y) ≡ (x, x + y, u+v))"
  by (simp add: TTtestTerm_def)

thm TtestTerm_def [of "x"]

lemmas T_inst = TtestTerm_def [of "x"]

ML{*
    fun simp_only_cterm_thm ctxt thms cterm =
      let
        val ctxt' = (Raw_Simplifier.clear_simpset ctxt) addsimps thms;
        val th_exp = Simplifier.rewrite ctxt' (cterm)
      in 
         th_exp
      end;

    fun simp_only_term ctxt thms term =
      let
        val ctxt' = (Raw_Simplifier.clear_simpset ctxt) addsimps thms;
        val th_exp = Simplifier.rewrite ctxt' (Thm.cterm_of ctxt' term)
      in 
         Thm.term_of (Thm.rhs_of th_exp)
      end;

    fun simp_term ctxt term =
      let
        val th_exp = Simplifier.rewrite ctxt (Thm.cterm_of ctxt term)
      in 
         Thm.term_of (Thm.rhs_of th_exp)
      end;

    fun simplify_only_term ctxt simps_add term = 
      let
        val ctxt_a = (Raw_Simplifier.clear_simpset ctxt) addsimps simps_add;
      in  Simplifier.rewrite ctxt_a (Thm.cterm_of ctxt term) end;

    fun simplify_term ctxt simps_add term = 
      let
        val ctxt_a = ctxt addsimps simps_add;
      in  Simplifier.rewrite ctxt_a (Thm.cterm_of ctxt term) end;
    fun simplify_cterm ctxt simps_add cterm = 
      let
        val ctxt_a = ctxt addsimps simps_add;
      in  Simplifier.rewrite ctxt_a cterm end;
*}

ML{*

val TT = Drule.instantiate_normalize;

val case_prod_eta_eq_sym = @{thm case_prod_eta_eq_sym}
fun case_prod_eta_sym_conversion ctxt cterm =
  let
    val 
      [FunVar] = (Thm.fold_terms Term.add_vars case_prod_eta_eq_sym []);
    in
      (*Drule.infer_instantiate' ctxt [some cterm] case_prod_eta_eq_sym*)
      Drule.infer_instantiate ctxt [(fst FunVar, cterm)] case_prod_eta_eq_sym
    end;
*}

ML{*



val T = case_prod_eta_sym_conversion @{context} @{cterm "FF::('a::{plus,minus}) × 'b ⇒ 'c"}

*}

declare [[show_sorts = false]]

thm cond_case_prod_eta

thm case_prod_eta


thm eta_contract_eq


lemma remove_aux_var: "(⋀ X . X ≡ A ⟹ X ≡ B) ⟹ (A ≡ B)"
  by auto

thm Product_Type.case_prod_eta
  
thm cond_case_prod_eta

ML{*

fun tupled_abs_aux ctxt (x as Free (name, _)) thm = 
  let
    val t = Thm.cterm_of ctxt x;
    val th = Thm.abstract_rule name t thm
    val th' = Thm.symmetric (Thm.eta_conversion (Thm.lhs_of th));
    val th'' = Thm.transitive th' th;
  in
   th''
  end | 
    tupled_abs_aux ctxt (x as Var ((name,_), _)) thm = 
      let
        val t = Thm.cterm_of ctxt x;
        val th = Thm.abstract_rule name t thm
        val th' = Thm.symmetric (Thm.eta_conversion (Thm.lhs_of th));
        val th'' = Thm.transitive th' th;
      in
       th''
      end |
   tupled_abs_aux ctxt (Const ("Product_Type.Pair", _) $ u $ v) thm =
    let
      val th = tupled_abs_aux ctxt u (tupled_abs_aux ctxt v thm)
      val T = fastype_of (Thm.term_of (Thm.rhs_of th))
      val (Type ("fun", [A, Type ("fun", [B, C])])) =  fastype_of (Thm.term_of (Thm.rhs_of th))
      val case_prod = Thm.cterm_of ctxt (Const ("Product_Type.prod.case_prod", T --> HOLogic.mk_prodT (A, B) --> C))
      val th' = Thm.reflexive case_prod;
      val th'' = Thm.combination th' th
(*
      val _ = writeln (thm_to_string ctxt th'') 
      val _ = writeln (thm_to_string ctxt thm)
*)
      val TT = simp_only_cterm_thm ctxt [@{thm cond_case_prod_eta}] ((Thm.lhs_of th''))
      val TT2 = Thm.eta_conversion ((Thm.rhs_of TT))
      val TT3 = Thm.transitive TT TT2;
(*
      val _ = writeln (thm_to_string ctxt TT3)
*)
      val th_res = Thm.transitive (Thm.symmetric TT3) th'';
    in
     th_res
    end
  | tupled_abs_aux ctxt (Const ("Product_Type.Unity", _)) th =
    let
      val th' = Thm.abstract_rule "x"  @{cterm "x::unit"} th
      (*
      val _ = writeln "EMPTY TUPLE"
      *)
    in
    (*      Abs ("x", HOLogic.unitT, b) *)
      th'
    end
  | tupled_abs_aux ctxt t _ = raise TERM ("tupled_abs: bad tuple", [t]);
*}


ML{*


fun tupled_abs ctxt thm =
  let
    val (cf, ct) = (Thm.dest_comb ( Thm.lhs_of thm));
    val t = Thm.term_of ct;
    val f = Thm.term_of cf;

    val ZZa = Free("GGAux__", fastype_of f);

    val ZZ =  (Thm.cterm_of ctxt ZZa);

    val UU = Thm.cterm_of ctxt ( Logic.mk_equals (ZZa, f));
    val UU_th =  Thm.assume UU;
    val UUa_th = Thm.combination UU_th (Thm.reflexive ct);

    val UUb_th = Thm.transitive UUa_th thm
    val tt = tupled_abs_aux ctxt t UUb_th;
    val tt_a = Thm.implies_intr UU tt;
    val UUx_thm = Thm.forall_intr ZZ tt_a;
    val UUy_th = Thm.forall_elim cf UUx_thm;
    val LL = Thm.implies_elim UUy_th (Thm.reflexive cf) ;
  in
    LL
  end;

*}


ML{*
  datatype pt = Func of term * term * term | Rel of term * term * term * term

  fun tuple (Func (xs, _, _)) = (xs) |
    tuple (Rel (xs, _, _, _)) = (xs);
  fun out_tuple (Rel (_, ys, _, _)) = (ys);
  fun prec (Func (xs, p, _)) = HOLogic.tupled_lambda (xs) p |
    prec (Rel (xs, _, p, _)) = HOLogic.tupled_lambda (xs) p
  fun func (Func (xs, _, es)) = HOLogic.tupled_lambda (xs) (es);

  fun rel (Rel (xs, ys, _, r)) = HOLogic.tupled_lambda (xs) (HOLogic.tupled_lambda (ys) r);  

  fun assert_prec rep = 
    let 
      val p = prec rep;
      val t = fastype_of p; 
      in 
        Const ("Refinement.assert", t --> t --> t) $ p 
    end;

  fun update_func rep = 
    let 
      val f = func rep;
      val tf = fastype_of f;
      val (t, t') = dest_funT tf;
      in 
        Const ("Refinement.update", tf --> (t' --> HOLogic.boolT) --> (t --> HOLogic.boolT)) $ f 
    end;

  fun demonic_rel rep = 
    let 
      val r = rel rep;
      val tr = fastype_of r;
      val (t, t1) = dest_funT tr;
      val (t', _) = dest_funT t1;
      in 
        Const ("Refinement.demonic", tr --> (t' --> HOLogic.boolT) --> (t --> HOLogic.boolT)) $ r 
    end;

*}

declare [[eta_contract=false]]

lemma "({.(x,y). y≠0.} o [-λ(x,y). x/y-]) o ({.z. z≥0.} o [-λz. sqrt z-]) = {. (λ(x, y). y ≠ 0) ⊓ ((λz. z≥0) ∘ (λ(x, y). x / y)) .} ∘ [-(λz. sqrt z) o (λ(x, y). x / y)-]"
  apply(simp only: comp_assoc  update_assert_comp)
  apply(simp only: comp_assoc[THEN sym]  update_assert_comp)
  apply(simp only: assert_assert_comp update_comp)
  by(simp only: comp_assoc update_comp)


ML{*

  val test_rep = Func (@{term "(x::'a::plus, y::'a::plus)"}, @{term "(x::'a::plus) = y"}, @{term "(y + y::'a::plus, x + y + a)"});

  val ctxt = @{context};

  val TT = Thm.cterm_of ctxt (update_func test_rep);

  fun prec_type rep = fastype_of (prec rep);

  fun mk_comp t t' =
    let 
      val tp = fastype_of t;
      val tp' = fastype_of t';
      val r = range_type tp;
      val d = domain_type tp'
      in
        Const ("Fun.comp", tp --> tp' --> d --> r) $ t $ t'
    end;
    

  fun ptran rep = mk_comp (assert_prec rep) (update_func rep);

  val t = fastype_of (assert_prec test_rep);

  val TT = Thm.cterm_of ctxt (assert_prec test_rep);
  val TT = Thm.cterm_of ctxt (func test_rep);
  val t = @{term "f o g"};
  val TT = Thm.cterm_of ctxt (ptran test_rep);

  fun Lambda [] e = e | 
    Lambda (x :: xs) e = Lambda xs (lambda x e)

  fun Apply e [] = e |
    Apply e (e' :: es) = (Apply e es) $ e'

  fun Subst xs es e = Apply (Lambda xs e) es;

  fun flatten_tuple (Const ("Product_Type.Pair", _) $ t1 $ t2) = (flatten_tuple t1) @ (flatten_tuple t2)
    | flatten_tuple t = [t];
  val ctxt = @{context};

  val asseert_update_comp_th = @{thm "asseert_update_comp"};

  val comp_unfold = @{thms split_tupled_all} @[@{thm case_prod_conv}, @{thm If_prod}];

  val serial_prove_thms = [@{thm "split_paired_all"}, @{thm "comp_def"}, @{thm "case_prod_conv"}, 
      @{thm simp_thms(6)}, @{thm simp_thms(21)}, @{thm  simp_thms(22)}, @{thm "triv_forall_equality"}, @{thm If_prod}];

  val serial_simp_thms = [@{thm "comp_def"}, @{thm "case_prod_conv"}, 
      @{thm simp_thms(6)}, @{thm simp_thms(21)}, @{thm  simp_thms(22)}, @{thm If_prod}];


  fun mk_id typ = Const ("Fun.id",  typ);
  fun mk_prod_typ T1 T2 = Type ("Product_Type.prod", [T1, T2]);
  fun mk_prod_prec p q =
    let
      val _ = debug DEBUG_PROD_PREC " - prod_prec const begin"
      val ti = Time.now()
      val t = fastype_of p;
      val t' = fastype_of q;
      val tp = t --> t' --> (mk_prod_typ (domain_type t) (domain_type t')) --> HOLogic.boolT;
      val _=  debug DEBUG_PROD_PREC ("- prod_prec end  - time proof: " ^ (Time.toString (ti - Time.now())));
    in
      Const ("SimplifyRCRS.prod_prec", tp) $ p $ q
    end;
  fun mk_prod_fun f g =
    let
      val _ = debug DEBUG_PROD_FUN " - prod_fun const begin"
      val ti = Time.now()
      val t = fastype_of f;
      val t' = fastype_of g;
      val tp = t --> t' --> (mk_prod_typ (domain_type t) (domain_type t')) --> (mk_prod_typ (range_type t) (range_type t'));
      val _=  debug DEBUG_PROD_FUN ("- prod_fun end  - time proof: " ^ (Time.toString (ti - Time.now())));
    in
      Const ("SimplifyRCRS.prod_fun", tp) $ f $ g
    end;
  fun mk_dup typ = Const ("SimplifyRCRS.dup", typ);
  fun mk_fst typ =  Const ("Product_Type.prod.fst", typ);
  fun mk_snd typ =  Const ("Product_Type.prod.snd", typ);

  fun pt_comp ctxt (Func (xs, p, es)) (Func (xs', p', es')) = 
    let
      val xs_flat = flatten_tuple xs';
      val es_simp = simp_term ctxt es;
      val es_flat = flatten_tuple es_simp;
      (*val es'' = simp_term (ctxt addsimps [@{thm If_prod}]) (Subst xs_flat es_flat es');*)
      val time_start = Time.now();
      val es'' = simp_term ctxt (simp_only_term ctxt serial_simp_thms (Subst xs_flat es_flat es'));
      val _ = debug DEBUG_SERIAL ("  - time simp assert: " ^ (Time.toString (Time.now() - time_start)));
      val time_start = Time.now();
      val p'' = simp_term ctxt (simp_only_term ctxt serial_simp_thms (Const ("HOL.conj", HOLogic.boolT --> HOLogic.boolT --> HOLogic.boolT) $ p $ (Subst xs_flat es_flat) p'));
      val _ = debug DEBUG_SERIAL ("  - time simp update: " ^ (Time.toString (Time.now() - time_start)));
    in
      Func (xs, p'', es'')
    end;

    val a = @{term "x ∧ y"};
    val b= @{term "(λ(x, y). y ≠ 0) ⊓ ((λz. z≥0) ∘ (λ(x, y). x / y))"}
    val c = @{term "(λz. sqrt z) o (λ(x, y). x / y)"}

    val asseert_update_comp_abs = @{thm asseert_update_comp_abs};
    val  [Var_f'', Var_f', Var_p'', Var_f, Var_p', Var_p]  = (Thm.fold_terms Term.add_vars asseert_update_comp_abs []);

    fun pt_comp_th ctxt (rep as Func(x, _, _)) (rep' as Func(_, _, _)) =
      let
        val _ = debug DEBUG_SERIAL " - comp const begin"
        val t = Time.now()
        val p = prec rep;
        val p' = prec rep';

        val f = func rep;
        val f' = func rep';

        val p'' = (Const ("Lattices.inf_class.inf", ((fastype_of x) --> HOLogic.boolT) --> 
                    ((fastype_of x) --> HOLogic.boolT) --> ((fastype_of x) --> HOLogic.boolT)) 
                    $ p $ (mk_comp p' f)) $ x;
        val p''_simp_thm = simplify_term ctxt [] p'';
        val p''_simp_abs_thm = tupled_abs ctxt p''_simp_thm;

        val f'' = (mk_comp f' f) $ x;
        val f''_simp_thm = simplify_term ctxt [] f'';
        val f''_simp_abs_thm = tupled_abs ctxt f''_simp_thm;

        val asseert_update_comp_abs_inst = Drule.infer_instantiate ctxt 
        [(fst Var_f'', Thm.rhs_of f''_simp_abs_thm), (fst Var_f', Thm.cterm_of ctxt f'), 
        (fst Var_p'', Thm.rhs_of p''_simp_abs_thm), (fst Var_f,  Thm.cterm_of ctxt f), 
        (fst Var_p', Thm.cterm_of ctxt p'), (fst Var_p, Thm.cterm_of ctxt p)] asseert_update_comp_abs;

        val res_thm = Thm.implies_elim asseert_update_comp_abs_inst p''_simp_abs_thm;
        val th = Thm.implies_elim res_thm f''_simp_abs_thm;

        val comp_term = Func (x, Thm.term_of (Thm.rhs_of p''_simp_thm), Thm.term_of (Thm.rhs_of f''_simp_thm)) 

        val _=  debug DEBUG_SERIAL ("- comp proof end  - time proof: " ^ (Time.toString (t - Time.now())));

      in (comp_term,th) end;

  val TTa = pt_comp_th ctxt test_rep test_rep;

*}

definition "dup y = (y,y)"

lemma "(snd o f o Pair (g x y)) y = (snd o f o (prod_fun (g x) id) o dup) y"
  by (simp add: prod_fun_def dup_def)

  
lemma feedback_asseert_update_abs_aux: "g = (λ x . fst o f o Pair x) ⟹ (⋀ x x' . g x = g x') ⟹  snd o (f o (prod_fun (g x) id o dup)) = f' ⟹ 
  p o (prod_fun (g x) id o dup) = p' ⟹ feedback ({.p.} o [-f-]) = {.p'.} o [-f'-]"
  apply (subst feedback_simp_c)
  apply simp
  apply (simp add: fun_eq_iff update_def demonic_def le_fun_def prod_fun_def dup_def assert_def, safe, auto)
  apply metis
  apply metis
  apply metis
  by metis
  
lemma feedback_asseert_update_abs: "(λ x . fst o f o Pair x) ≡ g ⟹ (⋀ x x' . g x ≡ g x') ⟹  snd o (f o (prod_fun (g x) id o dup)) ≡ f' ⟹ 
  p o (prod_fun (g x) id o dup) ≡ p' ⟹ feedback ({.p.} o [-f-]) = {.p'.} o [-f'-]"
  by (rule_tac g = g and x = x in feedback_asseert_update_abs_aux, simp_all)



declare [[eta_contract = false]]

thm eta_contract_eq

thm transitive

ML{*
  val BB = Thm.beta_conversion
  val feedback_asseert_update_abs = @{thm feedback_asseert_update_abs}
  val  [Var_p', Var_p, Var_f', Var_x, Var_g, Var_f]  = (Thm.fold_terms Term.add_vars feedback_asseert_update_abs []);

  val feedback_unfold = @{thms split_tupled_all} @ [@{thm case_prod_conv}, @{thm comp_def}] @ @{thms prod.sel};

  fun pt_feedback_th ctxt (rep as Func (z,_,_)) =
    let
      (*
      val _ = writeln "FEEDBACK" 
      *)
      val t = Time.now()

      val p = prec rep
      val f = func rep
      val (Pair_x as Const ("Product_Type.Pair", T) $ x) $ y = z;

      val Tf as Type ("Product_Type.prod", [Tf1, Tf2]) = range_type (fastype_of f);

      val g = (mk_comp (mk_comp (mk_fst (Tf --> Tf1)) f) Pair_x) $ y;

      val g_simp_thm_a = simplify_only_term ctxt feedback_unfold g;

      val g_simp_beta = Thm.beta_conversion true (Thm.rhs_of g_simp_thm_a)
      val g_simp_thm1 = Thm.transitive g_simp_thm_a g_simp_beta


      val g_simp_abs_thm_aux = tupled_abs ctxt g_simp_thm1

      (*
      val _ = writeln(print_term (Thm.term_of (Thm.rhs_of g_simp_abs_thm_aux)));
      *)
      val Free (name_x, _) = x;

      val g_simp_abs_thm = Thm.abstract_rule name_x (Thm.cterm_of ctxt x) g_simp_abs_thm_aux
      (*
      val _ = writeln(print_term (Thm.term_of (Thm.rhs_of g_simp_abs_thm)));
      *)
      val _=  debug DEBUG_FEEDBACK ("  - time simp g: " ^ (Time.toString (t - Time.now())));
      val t = Time.now()


      val g_simp = Thm.rhs_of g_simp_abs_thm;
      val _=  debug DEBUG_FEEDBACK ("  - time abs 1 g: " ^ (Time.toString (t - Time.now())));
      val t = Time.now()

      val g_simp_term = Thm.term_of g_simp;

      val _=  debug DEBUG_FEEDBACK ("  - time abs 2 g: " ^ (Time.toString (t - Time.now())));
      val t = Time.now()

      val x' = Free ("var__fb", fastype_of x);
(*
      val g1 = Thm.term_of (Thm.rhs_of (Thm.beta_conversion true (Thm.cterm_of ctxt (g_simp_term $ x')))) $ y
      val g1_simp_thm = simplify_term ctxt [] g1;
      val g1_simp_abs_thm_aux = tupled_abs ctxt g1_simp_thm

      val g_x = Thm.term_of (Thm.rhs_of (Thm.beta_conversion true (Thm.cterm_of ctxt (g_simp_term $ x))));
      val g2 = g_x $ y
      val g2_simp_thm = simplify_term ctxt [] g2;
      val g2_simp_abs_thm_aux = tupled_abs ctxt g2_simp_thm
      val g_eq = Thm.transitive g2_simp_abs_thm_aux (Thm.symmetric g1_simp_abs_thm_aux)
*)
      val g1 = Thm.term_of (Thm.rhs_of (Thm.beta_conversion true (Thm.cterm_of ctxt (g_simp_term $ x'))))
(*
      val g1_simp_thm = simplify_term ctxt [@{thm case_prod_unfold}] g1;
*)
      val g1_simp_thm = simplify_term ctxt [] g1;

      val g_x = Thm.term_of (Thm.rhs_of (Thm.beta_conversion true (Thm.cterm_of ctxt (g_simp_term $ x))));
      val g2 = g_x 
(*
      val g2_simp_thm = simplify_term ctxt [@{thm case_prod_unfold}] g2;
*)
      val g2_simp_thm = simplify_term ctxt [] g2;

      val g_eq = @{thm transitive} OF [g2_simp_thm, Thm.symmetric g1_simp_thm]
(*
      val g_eq = Thm.transitive g2_simp_thm (Thm.symmetric g1_simp_thm)
*)
      val g_eq_all = Thm.forall_intr (Thm.cterm_of ctxt x) (Thm.forall_intr (Thm.cterm_of ctxt x') g_eq)

      val _=  debug DEBUG_FEEDBACK ("  - time g x = g x': " ^ (Time.toString (t - Time.now())));
      val t = Time.now()


  
      val Ty = fastype_of y;

      val T = mk_comp (mk_prod_fun (g_x) (mk_id (Ty --> Ty))) (mk_dup (Ty --> (mk_prod_typ Ty Ty)))

      val p' = (mk_comp p T) $ y;

(*
      val p'_simp_thm = simplify_term ctxt [@{thm dup_def}, @{thm prod_fun_def}] p';
*)
      val p'_simp_thm = simplify_only_term ctxt (feedback_unfold @ [@{thm dup_def}, @{thm prod_fun_def}, @{thm id_def}])  p';

      val p'_simp_abs_thm = tupled_abs ctxt p'_simp_thm

      val _=  debug DEBUG_FEEDBACK ("  - time prec: " ^ (Time.toString (t - Time.now())));
      val t = Time.now()


      val f' = mk_comp (mk_snd (Tf --> Tf2)) (mk_comp f T) $ y;

(*
      val f'_simp_thm = simplify_term ctxt [@{thm dup_def}, @{thm prod_fun_def}] f';
*)
      val f'_simp_thm = simplify_only_term ctxt (feedback_unfold @ [@{thm dup_def}, @{thm prod_fun_def}, @{thm id_def}]) f';


      val f'_simp_abs_thm = tupled_abs ctxt f'_simp_thm

      val _=  debug DEBUG_FEEDBACK ("  - time func: " ^ (Time.toString (t - Time.now())));
      val t = Time.now()


      val feedback_asseert_update_abs_inst = Drule.infer_instantiate ctxt 
        [(fst Var_p', Thm.rhs_of p'_simp_abs_thm), (fst Var_p, Thm.cterm_of ctxt p), 
        (fst Var_f',  Thm.rhs_of f'_simp_abs_thm), (fst Var_f, Thm.cterm_of ctxt f),
        (fst Var_g, g_simp)] feedback_asseert_update_abs

      (*
      val _ = writeln("A")
      val _ = writeln(print_term (Thm.term_of (Thm.rhs_of g_simp_abs_thm)));

      val _ = map (writeln o (term_to_string ctxt)) (Thm.prems_of feedback_asseert_update_abs_inst)
      *)
      
      val X :: _ = (Thm.prems_of feedback_asseert_update_abs_inst)
      val _ $ Y $ Z = X
      (*
      val _ = writeln(print_term (Z));
      *)
      val res_thm = Thm.implies_elim feedback_asseert_update_abs_inst g_simp_abs_thm
      (*
      val _ = writeln("B")
      *)

      val res_thm1 = Thm.implies_elim res_thm g_eq_all
      val res_thm2 = Thm.implies_elim res_thm1 f'_simp_abs_thm
      val res_thm3 = Thm.implies_elim res_thm2 p'_simp_abs_thm

      val th = res_thm3;

      val _=  debug DEBUG_FEEDBACK ("  - time inst: " ^ (Time.toString (t - Time.now())));
      val t = Time.now()

      val feedback_term = Func (y, Thm.term_of (Thm.rhs_of p'_simp_thm), Thm.term_of (Thm.rhs_of f'_simp_thm)) 

      (*val _=  debug DEBUG_FEEDBACK ("  - time proof: " ^ (PolyML.makestring (t - Time.now())));*)

    in (feedback_term, th) end;

  val ctxt = @{context};

  val TT =  (pt_feedback_th ctxt test_rep);
*}

ML{*
  val test_rep' = Func (@{term "(x'::'a::plus, y'::'a::plus)"}, @{term "(x'::'a::plus) = y'"}, @{term "(y' + y'::'a::plus, x' + y')"});

  fun mk_prod t t' = 
    let
      val tp = fastype_of t;
      val tp' = fastype_of t';
      val tp_a  = (domain_type (domain_type tp));
      val tp_a'  = (domain_type (range_type tp));
      val tp_b = (domain_type (domain_type tp'));
      val tp_b' = (domain_type (range_type tp'));
      val t'' = Const ("Refinement.Prod", tp --> tp' --> (HOLogic.mk_prodT (tp_a, tp_b) --> HOLogic.boolT) --> (HOLogic.mk_prodT (tp_a', tp_b') --> HOLogic.boolT)) $ t $ t';
    in t'' end;

  val tupled_lambda_thms = @{thms split_tupled_all} @ [@{thm case_prod_conv}, @{thm HOL.simp_thms(21)}, @{thm HOL.simp_thms(22)}]

  val asseert_update_prod_abs = @{thm asseert_update_prod_abs};
  val  [Var_f'', Var_f', Var_f, Var_p'', Var_p', Var_p]  = (Thm.fold_terms Term.add_vars asseert_update_prod_abs []);

  fun pt_prod_th ctxt (rep as Func(v, _, _)) (rep' as Func(v', _, _)) =
    let
       val _ = debug DEBUG_PROD " - prod const begin"
      val t = Time.now()

      val p = prec rep;
      val p' = prec rep';

      val T = (fastype_of v) --> (fastype_of v') --> (mk_prod_typ (fastype_of v) (fastype_of v'));
      val y = Const ("Product_Type.Pair", T) $ v $ v'

      val p'' = (mk_prod_prec p p') $ y;
      val p''_simp_thm = simplify_term ctxt [@{thm prod_prec_def}] p'';
      val p''_simp_abs_thm = tupled_abs ctxt p''_simp_thm;

      val f = func rep;
      val f' = func rep';
 
      val f'' = (mk_prod_fun f f') $ y;
      val f''_simp_thm = simplify_term ctxt [@{thm prod_fun_def}] f'';
      val f''_simp_abs_thm = tupled_abs ctxt f''_simp_thm; 


      val asseert_update_prod_abs_inst = Drule.infer_instantiate ctxt 
        [(fst Var_f'', Thm.rhs_of f''_simp_abs_thm), (fst Var_f', Thm.cterm_of ctxt f'), 
        (fst Var_f,  Thm.cterm_of ctxt f), (fst Var_p'', Thm.rhs_of p''_simp_abs_thm),
        (fst Var_p', Thm.cterm_of ctxt p'), (fst Var_p, Thm.cterm_of ctxt p)] asseert_update_prod_abs;

      val res_thm = Thm.implies_elim asseert_update_prod_abs_inst p''_simp_abs_thm;
      val th = Thm.implies_elim res_thm f''_simp_abs_thm;

      val prod_term = Func (y, Thm.term_of (Thm.rhs_of p''_simp_thm), Thm.term_of (Thm.rhs_of f''_simp_thm)) 

      val _=  debug DEBUG_PROD ("- prod proof end  - time proof: " ^ (Time.toString (t - Time.now())));

    in (prod_term, th)  end;

  val ctxt = @{context};

  val TT = (pt_prod_th ctxt ( test_rep) ( test_rep'));

*}


ML{*
  fun is_prodT (Type ("Product_Type.prod", _)) = true
    | is_prodT _ = false;

  fun (*strip_tupleT (Type ("Product_Type.unit", [])) = [] |*)
    strip_tupleT_a (Type ("Product_Type.prod", [T1, T2])) = T1 :: strip_tupleT_a T2
    | strip_tupleT_a T = [T];


  fun create_var_tuple_list [] n = ([], n) |
    create_var_tuple_list (T::Ts) n = 
      let
        val (term, k) = create_var_tuple T n
        val (terms, m) = create_var_tuple_list Ts k;
      in
        ((term::terms), m)
      end
   and

   create_var_tuple T n = 
    (if (is_prodT T) then
      let
        val Ts = strip_tupleT_a T;
        val (terms, k) = create_var_tuple_list Ts n;
      in
       (HOLogic.mk_tuple terms, k)
      end
    else 
      (Free("v__" ^ (string_of_int n) ^ "_", T), n + 1));
*}

ML{*
  fun create_undefined_tuple_list [] = [] |
    create_undefined_tuple_list (T::Ts) = 
      let
        val term = create_undefined_tuple T
        val terms = create_undefined_tuple_list Ts;
      in
        term::terms
      end
   and
    create_undefined_tuple T = 
    (if (is_prodT T) then
      let
        val Ts = strip_tupleT_a T;
        val terms = create_undefined_tuple_list Ts;
      in
       HOLogic.mk_tuple terms
      end
    else 
      (Const("HOL.undefined", T)));

*}

  ML{*

  fun zip [] l = (case l of [] => [] | _ => error "lists have different sizes") |
    zip (h::t) l = case l of h1::t1 => (h,h1)::(zip t t1) | _ => error "lists have different sizes";

  fun change_type (Free (x, _)) typ = Free (x, typ) |
    change_type y _ = error (String.concat["Variable ", term_to_string @{context} y, " is not free"]);

  fun is_tuple (Const ("Product_Type.Pair", _) $ _ $ _) = true
    | is_tuple _ = false;

  fun (*strip_tupleT (Type ("Product_Type.unit", [])) = [] |*)
    strip_tupleT_a (Type ("Product_Type.prod", [T1, T2])) = T1 :: strip_tupleT_a T2
    | strip_tupleT_a T = [T];

  fun (*strip_tuple (Const ("Product_Type.Unity", _)) = [] | *)
    strip_tuple_a (Const ("Product_Type.Pair", _) $ t1 $ t2) = t1 :: strip_tuple_a t2
    | strip_tuple_a t = [t];

  fun distribute_types
    (Const ("Product_Type.Pair", _) $ t1 $ t2) T =
    let
      val (Type ("Product_Type.prod", [T1, T2])) = T
      val t1' = distribute_types t1 T1
      val t2' = distribute_types t2 T2
    in
      HOLogic.mk_tuple [t1', t2']
    end |
    distribute_types t T = change_type t T

  

  fun distribute_types_a (trm, typ) = 
    if (is_tuple trm) then 
       HOLogic.mk_tuple (List.map distribute_types_a (zip (strip_tuple_a trm) ((strip_tupleT_a typ))))
    else change_type trm typ;

  *}

  lemma Skip_th: "⊤ ≡ p ⟹ id ≡ f ⟹ Skip = {.p.} o [-f-]"
    by (simp add: Skip_def fun_eq_iff assert_def update_def demonic_def le_fun_def)

ML{*
    val ctxt = @{context};

   fun mk_skip typ = Const ("Refinement.Skip", typ);
   val Skip_th = @{thm "Skip_th"};

   fun pt_skip_th ctxt invar =
     let
      (*
      val _ = writeln "SKIP" 
      *)
      val rep = Func (invar, @{term "True"}, invar);
      val t = fastype_of invar

      val func_simp_th = (simplify_only_term ctxt [@{thm id_def}] (mk_id (t --> t) $ invar))
      val assert_simp_th = (simplify_only_term ctxt [@{thm "top_fun_def"}, @{thm top_bool_def}] (Const ("Orderings.top_class.top", t --> HOLogic.boolT) $ invar))

      val func_simp_abs_th = tupled_abs ctxt  func_simp_th
      val assert_simp_abs_th = tupled_abs ctxt assert_simp_th

      val th = Drule.compose (assert_simp_abs_th, 1, Skip_th);
      val th' = Drule.compose (func_simp_abs_th, 1, th);

    in (rep, th') end;

  val T = pt_skip_th @{context} @{term "(a,(u,v,x), b,d)"}
*}

  lemma Fail_th: "⊥ ≡ p ⟹ f ≡ f ⟹ ⊥ = {.p.} o [-f-]"
    by (simp add: Fail_def fun_eq_iff assert_def update_def demonic_def le_fun_def)



ML{*
   fun mk_bottom typ = Const ("Orderings.bot_class.bot", typ);
   val Fail_th = @{thm "Fail_th"};


  fun is_unit (Const ("Product_Type.Unity", _)) = true |
    is_unit _ = false;


  fun fail_pt_th ctxt inpt_vars n typ = 
      let
        val typ' = domain_type (domain_type typ);
        val typ'' = domain_type (range_type typ)
        val (invar, n') = if is_unit inpt_vars then create_var_tuple typ' n else (distribute_types inpt_vars typ'', n) ;
        val undefined_res = create_undefined_tuple typ'
        val rep = Func (invar, @{term "False"}, undefined_res);

        val t = fastype_of invar
        val assert_simp_th = (simplify_only_term ctxt [@{thm "bot_fun_def"}, @{thm bot_bool_def}] (mk_bottom (t --> HOLogic.boolT) $ invar))
        val assert_simp_abs_th = tupled_abs ctxt assert_simp_th
        val func_simp_abs_th = Thm.reflexive (Thm.cterm_of ctxt (func rep))

        val th = Drule.compose (assert_simp_abs_th, 1, Fail_th);
        val th' = Drule.compose (func_simp_abs_th, 1, th);
      in (rep, th', n') end

   val t = @{term "(a,(u,v,x), b,d)"};
   val tt =  fastype_of t;
   val T = fail_pt_th @{context} t 1 ((@{typ "'a ×( 'c × 'a)× 'v"} --> HOLogic.boolT) --> tt  --> HOLogic.boolT)
*}

  lemma assert_th: "p ≡ p' ⟹ id ≡ f ⟹ {.p.} = {.p'.} o [-f-]"
    by (simp add: Skip_def fun_eq_iff assert_def update_def demonic_def le_fun_def)

ML{*

    val assert_th = @{thm "assert_th"};

    val prove_assert_thms = [@{thm "split_paired_all"}, @{thm "comp_def"}, @{thm "case_prod_conv"}, 
      @{thm simp_thms(6)}, @{thm simp_thms(21)}, @{thm  simp_thms(22)}, @{thm "triv_forall_equality"}, @{thm If_prod}];

    val prove_assert_thms = [@{thm "split_paired_all"}, @{thm "comp_def"}, @{thm "case_prod_conv"}, 
      @{thm "triv_forall_equality"}, @{thm If_prod}] @ @{thms simp_thms};

    fun pt_assert_th ctxt vars p =
      let
        val p' = (p $ vars);
        (*
        val _ = writeln "ASSERT"
        *)
        val assert_simp_th = (simplify_only_term ctxt prove_assert_thms p')
        val assert_simp_abs_th = tupled_abs ctxt assert_simp_th

        val p'' = Thm.term_of (Thm.rhs_of assert_simp_th);

        val rep = Func (vars, p'', vars);

        val t = fastype_of vars
  
        val func_simp_th = (simplify_only_term ctxt [@{thm id_def}] (mk_id (t --> t) $ vars))
        val func_simp_abs_th = tupled_abs ctxt  func_simp_th


        val th = Drule.compose (assert_simp_abs_th, 1, assert_th);
        val th' = Drule.compose (func_simp_abs_th, 1, th);
    in (rep, th') end;

    val T = pt_assert_th @{context} @{term "(x::'a,y::'a)"} @{term "λ (x,y) . x = y ∧ x ≠ x"}

*}



ML{*


    val update_simp_thms = prove_assert_thms;


    fun pretty_thm ctxt thm = Syntax.pretty_term ctxt (Thm.prop_of thm)

    fun pretty_hyps_thm ctxt thm =  Pretty.block (Pretty.commas (map (Syntax.pretty_term ctxt) (Thm.hyps_of thm)))
    fun pretty_prems_thm ctxt thm =  Pretty.block (Pretty.commas (map (Syntax.pretty_term ctxt) (Thm.prems_of thm)))

    fun pretty_thm_no_vars ctxt thm = 
      let
        val ctxt' = Config.put show_question_marks false ctxt
      in
        pretty_thm ctxt' thm
      end



    fun pretty_thms ctxt thms = 
      Pretty.block (Pretty.commas (map (pretty_thm ctxt) thms))


    fun pretty_thms_no_vars ctxt thms = 
      Pretty.block (Pretty.commas (map (pretty_thm_no_vars ctxt) thms))

    fun simp_term ctxt term =
      let
        val th_exp = Simplifier.rewrite ctxt (Thm.cterm_of ctxt term)
      in 
         Thm.term_of (Thm.rhs_of th_exp)
      end;

    fun simp_only_term ctxt thms term =
      let
        val ctxt' = (Raw_Simplifier.clear_simpset ctxt) addsimps thms;
        val th_exp = Simplifier.rewrite ctxt' (Thm.cterm_of ctxt' term)
      in 
         Thm.term_of (Thm.rhs_of th_exp)
      end;
*}
 
  
lemma update_eq: "⊤ ≡ p ⟹ f ≡ g ⟹ [-f-] =  {.p.} o [-g-]"
  by (simp add: fun_eq_iff update_def assert_def le_fun_def)

lemma demonic_eq: "⊤ ≡ p ⟹ r ≡ r' ⟹ [:r:] =  {.p.} o [:r':]"
  by (simp add: fun_eq_iff assert_def le_fun_def)

ML{*

    fun simp_update_thm ctxt term = 
      let
        val ctxt' = (clear_simpset ctxt) addsimps update_simp_thms;
        val th1 = Simplifier.rewrite ctxt' (Thm.cterm_of ctxt' term)
        val ctxt'' = Raw_Simplifier.del_cong @{thm "if_weak_cong"} ctxt
        val th2 = Simplifier.rewrite ctxt'' (Thm.rhs_of th1)
        val th = Thm.transitive th1 th2;
        val t = Thm.term_of (Thm.lhs_of th)

        val _ = writeln (if t = term then "" else "ERROR")
      in
        th
      end

    fun simp_update ctxt term = Thm.term_of (Thm.rhs_of (simp_update_thm ctxt term))

    fun print_tac ctxt thm = 
      let
       val _ = tracing (Pretty.string_of (pretty_prems_thm ctxt thm))
      in
        Seq.single thm

      end;


    fun pt_update_th ctxt vars f =
      let
        (*
        val _ = writeln "UPDATE" 
        *)
        val func_simp_th = (simp_update_thm ctxt (f $ vars))
        val assert_simp_th = (simplify_only_term ctxt [@{thm "top_fun_def"}, @{thm top_bool_def}] (Const ("Orderings.top_class.top", fastype_of vars --> HOLogic.boolT) $ vars))
(*
        val assert_simp_th = (simp_update_thm ctxt ( Const ("Orderings.top_class.top", fastype_of vars --> HOLogic.boolT) $ vars))
*)
        val f' = Thm.term_of (Thm.rhs_of func_simp_th)

        val func_simp_abs_th = tupled_abs ctxt  func_simp_th
        val assert_simp_abs_th = tupled_abs ctxt assert_simp_th

        val th = Drule.compose (assert_simp_abs_th, 1, @{thm update_eq});
        val th' = Drule.compose (func_simp_abs_th, 1, th);

        val rep = Func (vars, @{term True}, f');
      in (rep, th') end;

   val T = pt_update_th @{context} @{term "(x::nat,(a,b),y)"} @{term "λ (x::nat,(a,b),y) . (a,b,y,x+1)"}
*}
ML{*
    fun pt_demonic_th ctxt vars outvars r =
      let
        (*
        val _ = writeln "UPDATE" 
        *)
        val rel_simp_th = (simp_update_thm ctxt (r $ vars $ outvars))

        (*
        val _ = writeln ("simplified relation: " ^ thm_to_string ctxt rel_simp_th)
        *)

        val assert_simp_th = (simplify_only_term ctxt [@{thm "top_fun_def"}, @{thm top_bool_def}] 
            (Const ("Orderings.top_class.top", fastype_of vars --> HOLogic.boolT) $ vars))
(*
        val assert_simp_th = (simp_update_thm ctxt ( Const ("Orderings.top_class.top", fastype_of vars --> HOLogic.boolT) $ vars))
*)
        val f' = Thm.term_of (Thm.rhs_of rel_simp_th)

        val rel_simp_abs_th = tupled_abs ctxt (tupled_abs ctxt  rel_simp_th)
        val assert_simp_abs_th = tupled_abs ctxt assert_simp_th

        val th = Drule.compose (assert_simp_abs_th, 1, @{thm demonic_eq});
        val th' = Drule.compose (rel_simp_abs_th, 1, th);
        
        (*
        val _ = writeln ("simplified relation: " ^ thm_to_string ctxt th')
        *)

        val rep = Rel (vars, outvars, @{term True}, f');
      in (rep, th') end;

   val T = pt_update_th @{context} @{term "(x::nat,(a,b),y)"} @{term "λ (x::nat,(a,b),y) . (a,b,y,x+1)"}
*}

  
  
lemma assert_update_eq: "p ≡ q ⟹ f ≡ g ⟹ {.p.} o [-f-] =  {.q.} o [-g-]"
  by simp

lemma assert_demonic_eq: "p ≡ q ⟹ r ≡ r' ⟹ {.p.} o [:r:] =  {.q.} o [:r':]"
  by simp


ML{*
    (*val assert_update_th = @{thm assert_update_th}*)

    fun pt_assert_update_th ctxt vars  p f =
      let
        (*
        val _ = writeln "ASSERT UPDATE" 
        *)
        val func_simp_th = (simp_update_thm ctxt (f $ vars))
        val assert_simp_th = (simp_update_thm ctxt (p $ vars))

        val f' = Thm.term_of (Thm.rhs_of func_simp_th)
        val p' = Thm.term_of (Thm.rhs_of assert_simp_th)

        val func_simp_abs_th = tupled_abs ctxt  func_simp_th
        val assert_simp_abs_th = tupled_abs ctxt assert_simp_th

        val th = Drule.compose (assert_simp_abs_th, 1, @{thm assert_update_eq});
        val th' = Drule.compose (func_simp_abs_th, 1, th);

        val rep = Func (vars, p', f');
      in (rep, th') end;
*}

ML{*
  val Z = simp_term @{context} (simp_only_term @{context} update_simp_thms @{term "(λ B . (if (1::real) ≤ 2 then B else C)) (X + Y)"})
*}
 
ML{*
    fun simp_asm_thm ctxt asm term = 
      let
        val ctxt' = (clear_simpset ctxt) addsimps update_simp_thms;
        val t = Logic.implies $ (HOLogic.mk_Trueprop asm) $ (HOLogic.mk_Trueprop term)
        val th1 = Simplifier.asm_rewrite ctxt' (Thm.cterm_of ctxt' t)

        val ctxt'' = Raw_Simplifier.del_cong @{thm "if_weak_cong"} ctxt
        val th2 = Simplifier.asm_rewrite ctxt'' (Thm.rhs_of th1)
        val th = Thm.transitive th1 th2;
        val t = Thm.term_of (Thm.lhs_of th)

        val _ = writeln (if t = term then "" else "ERROR")
      in
        (th1, th2, th)
      end
*}
  
lemma prec_simp_rel: "((p ⟹ r) ≡ (p ⟹ r')) ⟹ p ∧ r ≡ p ∧ r'"
  apply (subgoal_tac "((p ⟶ r) = (p ⟶ r'))")
   apply (rule eq_reflection)
   apply auto [1]
  apply auto
    by (drule symmetric, simp)

lemma "((p ⟹ r) ≡ Trueprop True) ⟹ p ∧ r ≡ p"
  apply (subgoal_tac "((p ⟶ r) = True)")
   apply (rule eq_reflection)
   by auto

definition "inter_pre_rel p r x y = (p x ∧ r x y)"
  
   
lemma prop_eq_true: "X ≡ True ⟹ X"
  by auto
    
    
    
ML{*
fun beta_conv_tuple ctxt x term = 
    let
      val t = (HOLogic.mk_eq (((HOLogic.tupled_lambda x term) $ x), term))
      val TT = simp_only_cterm_thm ctxt [@{thm cond_case_prod_eta}, @{thm case_prod_conv}] (Thm.cterm_of ctxt t)
      val TT = Simplifier.rewrite ctxt (Thm.cterm_of ctxt t)
    in  
     TT
    end;

fun beta_conv_tuple_rel ctxt x y term = 
    let
      val t = HOLogic.mk_eq ( (HOLogic.tupled_lambda x (HOLogic.tupled_lambda y term)) $ x $ y, term)
      val TT = @{thm prop_eq_true} OF [Simplifier.rewrite ctxt (Thm.cterm_of ctxt t)]
    in  
     TT
    end;

*}
  
ML{*
val ctxt = @{context}
  val (invar, outvar, exp) = (@{term "(x::'a::plus,y::'a::plus,z::'a::plus)"}, 
    @{term "(a::'a::plus, b::'a::plus)"}, @{term "(a = (x::'a::plus) + y + z) ∧ (b + y = z)"})

  val (invar, outvar, exp) = (@{term "(x::'a::plus, y::'a::plus)"},
    @{term "(a::'a::plus, b::'a::plus)"}, @{term "(a = (x::'a::plus) + y + z) ∧ (b + y = z)"})

val beta = beta_conv_tuple_rel ctxt invar outvar exp;

*}

lemma inter_pre_rel_sym: "(p x ∧ r x y) = inter_pre_rel p r x y"
  by (simp add: inter_pre_rel_def)

  
ML{*
fun prec_rel_abs ctxt invar outvar p r =
  let

    val (t1, t2, t) = simp_asm_thm ctxt (p $ invar) (r $ invar $ outvar)

    (*
    val _ = writeln ("simplified asm: " ^ thm_to_string ctxt t)
    *)
  
    val th1 = @{thm "prec_simp_rel"} OF [t];

    (*
    val _ = writeln ("simplified eq aaaaaaaaa: " ^ thm_to_string ctxt th1)
    *)

    val conj $ (pre ) $ rel' = Thm.term_of (Thm.rhs_of th1);

    val th2 = @{thm "eq_reflection"} OF [beta_conv_tuple_rel ctxt invar outvar rel']

    (*
    val _ = writeln ("simplified beta: " ^ thm_to_string ctxt th2)
    *)

    val th3 = Thm.combination (Thm.reflexive (Thm.cterm_of ctxt (conj $ (pre)))) th2
    val th4  = Thm.transitive th1 (Thm.symmetric th3);
(*
    val th5 = simp_only_cterm_thm ctxt [@{thm inter_pre_rel_sym}] (Thm.rhs_of th4)
*)
    val th5 = simp_only_cterm_thm ctxt [@{thm inter_pre_rel_sym}] (Thm.lhs_of th3)

    (*
    val _ = writeln ("simplified inter_pre_rel_sym 1: " ^ thm_to_string ctxt th5)
   *)

    val th6 = Thm.transitive th4 th5;
(*
    val th7 = simp_only_cterm_thm ctxt [@{thm inter_pre_rel_sym}] (Thm.lhs_of th4)
*)
    val th7 = simp_only_cterm_thm ctxt [@{thm inter_pre_rel_sym}] (Thm.lhs_of th1)

    (*
    val _ = writeln ("simplified inter_pre_rel_sym 2: " ^ thm_to_string ctxt th7)
    *)

    val th8 = Thm.transitive (Thm.symmetric th7) th6;
  in
    (th8, rel')
  end;

*}
  
   

ML{*
  val ctxt = @{context};
  val th = Simplifier.asm_rewrite ctxt (Thm.cterm_of ctxt @{term "x = 2 ⟹ True"});
  val t = @{term "p ⟹ T"}
  val M =  HOLogic.mk_imp (HOLogic.mk_Trueprop @{term p}, HOLogic.mk_Trueprop @{term T});
  val N =   HOLogic.mk_eq(HOLogic.mk_Trueprop @{term p}, HOLogic.mk_Trueprop @{term T});
  val Z = Logic.implies $ M $ N;

  val oupt = @{term "b::nat"}

  val (t1, t2, t) = simp_asm_thm @{context} (@{term "(λ (x, z::nat) . (x::nat) = 0)"} $ @{term "(a::nat, c::nat)"}) 
      (@{term "(λ (x,z::nat) y . (x::nat) + 2 = y)"} $ @{term "(a::nat, c::nat)"} $ @{term "b::nat"});

  val (th, t) = prec_rel_abs ctxt @{term "(a::nat, c::nat)"} @{term "b::nat"}
    @{term "(λ (x, z::nat) . (x::nat) = 0)"} @{term "(λ (x,z::nat) y . (x::nat) + 2 = y)"}

  val t = Thm.cterm_of ctxt t;

  val (th, t) = prec_rel_abs ctxt @{term "(a::nat)"} @{term "b::nat"}
    @{term "(λ x . (x::nat) = 0)"} @{term "(λ x y . (x::nat) + 2 = y)"}

(*
  val th1 = @{thm "prec_simp_rel"} OF [t];
  val conj $ (pre $ inpt) $ rel' = Thm.term_of (Thm.rhs_of th1);
  val th2 = @{thm "eq_reflection"} OF [beta_conv_tuple_rel ctxt inpt oupt rel']
  val th3 = Thm.combination (Thm.reflexive (Thm.cterm_of ctxt (conj $ (pre $ inpt)))) th2
  val th4  = Thm.transitive th1 (Thm.symmetric th3);
  val th5 = simp_only_cterm_thm ctxt [@{thm inter_pre_rel_sym}] (Thm.rhs_of th4)
  val th6 = Thm.transitive th4 th5;
  val th7 = simp_only_cterm_thm ctxt [@{thm inter_pre_rel_sym}] (Thm.lhs_of th4)
  val t8 = Thm.transitive (Thm.symmetric th7) th6;
*)

*}

theorem assert_simp_demonic_eq: "p ≡ p' ⟹ inter_pre_rel p' r ≡ inter_pre_rel p' r' ⟹ {.p.} o [:r:] = {.p'.} o [:r':]"
  apply (subgoal_tac "inter_pre_rel p' r = inter_pre_rel p' r'")
   apply (simp add: inter_pre_rel_def fun_eq_iff demonic_def assert_def le_fun_def)
  by auto
    
ML{*
(*with simplification of relation based on precondition*)

fun pt_assert_demonic_thA ctxt vars outvars p r =
      let
        (*
        val _ = writeln "ASSERT DEMONIC" 
        *)
        val assert_simp_th = (simp_update_thm ctxt (p $ vars))
        (*
        val _ = writeln ("simplified precondition: " ^ thm_to_string ctxt assert_simp_th)
        *)
        val p'a = HOLogic.tupled_lambda vars (Thm.term_of (Thm.rhs_of assert_simp_th))

        (*
        val _ = writeln ("Test: " ^ thm_to_string ctxt assert_simp_th)
        *)

        val (rel_simp_prec_thm, r') = prec_rel_abs ctxt vars outvars  p'a r

        (*
        val _ = writeln ("simplified relation: " ^ thm_to_string ctxt rel_simp_prec_thm)
        *)

        val p' = Thm.term_of (Thm.rhs_of assert_simp_th)

        val assert_simp_abs_th = tupled_abs ctxt assert_simp_th

        (*
        val _ = writeln ("assert simp abstract: " ^ thm_to_string ctxt assert_simp_abs_th)
        *)

        val th = Drule.compose (assert_simp_abs_th, 1, @{thm assert_simp_demonic_eq});

        val th' = Drule.compose (rel_simp_prec_thm, 1, th);

        (*
        val _ = writeln ("final thm asert demonic: " ^ thm_to_string ctxt th')
        *)

        val rep = Rel (vars, outvars, p', r');
      in (rep, th') end;

(*without simplification of relation based on precondition*)
fun pt_assert_demonic_th ctxt vars outvars p r =
      let
        (*
        val _ = writeln "ASSERT DEMONIC" 
        *)
        val assert_simp_th = (simp_update_thm ctxt (p $ vars))
        (*
        val _ = writeln ("simplified precondition: " ^ thm_to_string ctxt assert_simp_th)
        *)
        val rel_simp_th = (simp_update_thm ctxt (r $ vars $ outvars))
        (*
        val _ = writeln ("simplified relation: " ^ thm_to_string ctxt rel_simp_th)
        *)

        val r' = Thm.term_of (Thm.rhs_of rel_simp_th)
        val p' = Thm.term_of (Thm.rhs_of assert_simp_th)

        val rel_simp_abs_th = tupled_abs ctxt (tupled_abs ctxt rel_simp_th)

        (*
        val _ = writeln ("relation simp abstract: " ^ thm_to_string ctxt rel_simp_abs_th)
        *)

        val assert_simp_abs_th = tupled_abs ctxt assert_simp_th

        (*
        val _ = writeln ("assert simp abstract: " ^ thm_to_string ctxt assert_simp_abs_th)
        *)

        val th = Drule.compose (assert_simp_abs_th, 1, @{thm assert_demonic_eq});
        val th' = Drule.compose (rel_simp_abs_th, 1, th);

        (*
        val _ = writeln ("final thm asert demonic: " ^ thm_to_string ctxt th')
        *)

        val rep = Rel (vars, outvars, p', r');
      in (rep, th') end;

*}


  lemma feedback_cong: "B = A ⟹ feedback A = F ⟹ feedback B = F"
    by simp

  lemma comp_cong: "S = A ⟹ T = B ⟹ A o B = F ⟹ S o T = F"
    by simp

  lemma prod_cong: "S = A ⟹ T = B ⟹ A ** B = F ⟹ S ** T = F"
    by simp


ML{*
  fun is_prodT (Type ("Product_Type.prod", _)) = true
    | is_prodT _ = false;

  fun (*strip_tupleT (Type ("Product_Type.unit", [])) = [] |*)
    strip_tupleT_a (Type ("Product_Type.prod", [T1, T2])) = T1 :: strip_tupleT_a T2
    | strip_tupleT_a T = [T];


  fun create_var_tuple_list [] n = ([], n) |
    create_var_tuple_list (T::Ts) n = 
      let
        val (term, k) = create_var_tuple T n
        val (terms, m) = create_var_tuple_list Ts k;
      in
        ((term::terms), m)
      end
   and

   create_var_tuple T n = 
    (if (is_prodT T) then
      let
        val Ts = strip_tupleT_a T;
        val (terms, k) = create_var_tuple_list Ts n;
      in
       (HOLogic.mk_tuple terms, k)
      end
    else 
      (Free("v__" ^ (string_of_int n) ^ "_", T), n + 1));
*}

ML{*
  fun create_undefined_tuple_list [] = [] |
    create_undefined_tuple_list (T::Ts) = 
      let
        val term = create_undefined_tuple T
        val terms = create_undefined_tuple_list Ts;
      in
        term::terms
      end
   and
    create_undefined_tuple T = 
    (if (is_prodT T) then
      let
        val Ts = strip_tupleT_a T;
        val terms = create_undefined_tuple_list Ts;
      in
       HOLogic.mk_tuple terms
      end
    else 
      (Const("HOL.undefined", T)));

*}


ML{*

  val T = @{term "⊥"}

  (*version with initial list of variable names*)
  fun SimpPredTran inpt_vars out_vars n ctxt (Const ("Refinement.Skip", typ)) = 
      let
        val typ' = domain_type (domain_type typ);
        val (invar, n') = if is_unit inpt_vars then create_var_tuple typ' n else (distribute_types inpt_vars typ', n) ;
        val (rep, th) = pt_skip_th ctxt invar;
      in (rep, th, n') end |
      SimpPredTran inpt_vars out_vars n ctxt (Const ("Refinement.Fail", typ)) = fail_pt_th ctxt inpt_vars n typ |
      SimpPredTran inpt_vars out_vars n ctxt (Const ("Orderings.bot_class.bot", typ)) = fail_pt_th ctxt inpt_vars n typ |
      SimpPredTran inpt_vars out_vars n ctxt (Const ("Refinement.assert", typ) $ p) = 
        let
          val _ = debug DEBUG_ASSERT ("assert: " ^ (string_of_int n));
          val t = Time.now();
          val typ' = domain_type (domain_type typ);
          val (invar, n') = if is_unit inpt_vars then create_var_tuple typ' n else (distribute_types inpt_vars typ', n) ;
          val (rep, th) = pt_assert_th ctxt invar p;
          val _=  debug DEBUG_ASSERT ("  - time: " ^ (Time.toString (t - Time.now())));
        in (rep, th, n') end |
      SimpPredTran inpt_vars out_vars n ctxt (Const ("Refinement.update", typ) $ f) =
        let
          val _ = debug DEBUG_UPDATE ("update: " ^ (string_of_int n));
          val typ' = domain_type (domain_type typ);
          val t = Time.now();
          val (invar, n') = if is_unit inpt_vars then create_var_tuple typ' n else (distribute_types inpt_vars typ', n) ;
          val (rep, th) = pt_update_th ctxt invar f;
          val _=  debug DEBUG_UPDATE ("  - time: " ^ (Time.toString (t - Time.now())));
        in (rep, th, n') end |

      SimpPredTran inpt_vars out_vars n ctxt (Const ("Refinement.demonic", typ) $ r) =
        let
          val _ = debug DEBUG_UPDATE ("demonic: " ^ (string_of_int n));
          (*
          val _ = writeln ("relation: 1") 
          *)
          val typ' = domain_type (domain_type typ);
          val typ'' = domain_type (range_type (domain_type typ));
          (*
          val _ = writeln ("relation: 2")
          *)
          val t = Time.now();
          val (invar, n') = if is_unit inpt_vars then create_var_tuple typ' n else (distribute_types inpt_vars typ', n) ;
          val (outvar, n'') = if is_unit out_vars then create_var_tuple typ'' n' else (distribute_types out_vars typ'', n') ;
          (*
          val _ = writeln ("relation: 3")
          *)
          val (rep, th) = pt_demonic_th ctxt invar outvar r;
          
          val _=  debug DEBUG_UPDATE ("  - time: " ^ (Time.toString (t - Time.now())));
        in (rep, th, n'') end |

      SimpPredTran inpt_vars out_vars n ctxt ((Const ("Fun.comp", _) 
          $ ((Const ("Refinement.assert", _) $ p)) $ (Const ("Refinement.update", typ_update) $ f))) = 
        let
          val _ = debug DEBUG_ASSERT_UPDATE ("assert update: " ^ (string_of_int n));
          val t = Time.now();
          val typ' = domain_type (domain_type typ_update);
          val (invar, n') = if is_unit inpt_vars then create_var_tuple typ' n else (distribute_types inpt_vars typ', n) ;
          val (rep, th) = pt_assert_update_th ctxt invar p f;
          val _=  debug DEBUG_ASSERT_UPDATE ("  - time: " ^ (Time.toString (t - Time.now())));
        in
          (rep, th, n')
        end |
      SimpPredTran inpt_vars out_vars n ctxt ((Const ("Fun.comp", _) 
          $ ((Const ("Refinement.assert", _) $ p)) $ (Const ("Refinement.demonic", typ_demonic) $ r))) = 
        let
          val _ = debug DEBUG_ASSERT_UPDATE ("assert update: " ^ (string_of_int n));
          val t = Time.now();
          val typ' = domain_type (domain_type typ_demonic);
          val typ'' = domain_type (range_type (domain_type typ_demonic));
          (*
          val _ = writeln ("assert demonic: 2")
          *)

          val (invar, n') = if is_unit inpt_vars then create_var_tuple typ' n else (distribute_types inpt_vars typ', n) ;
          val (outvar, n'') = if is_unit out_vars then create_var_tuple typ'' n' else (distribute_types out_vars typ'', n') ;

          val (rep, th) = pt_assert_demonic_th ctxt invar outvar p r;
          val _=  debug DEBUG_ASSERT_UPDATE ("  - time: " ^ (Time.toString (t - Time.now())));
        in
          (rep, th, n'')
        end |
      SimpPredTran inpt_vars out_vars n ctxt (Const ("Fun.comp", _) $ S $ T) = 
        let
          val _ = debug DEBUG_SERIAL ("serial: " ^ (string_of_int n));

          val (rep, th, n') = SimpPredTran inpt_vars  out_vars n ctxt S;
          val (rep', th', n'') = SimpPredTran (@{term "()"})  out_vars n' ctxt T;

          val t = Time.now();
          val (comp_rep, comp_th) = pt_comp_th ctxt rep rep';
          val t' = Time.now();
          val res_th = @{thm comp_cong} OF [th, th', comp_th]
          val _=  debug DEBUG_SERIAL ("  - time OF: " ^ (Time.toString (t' - Time.now())));

          val _=  debug DEBUG_SERIAL ("  - time: " ^ (Time.toString (t - Time.now())));
        in
          (comp_rep, res_th, n'')
        end |
      SimpPredTran inpt_vars out_vars n ctxt (Const ("Refinement.Prod", _) $ S $ T) = 
        let
          (*val _ = writeln "this is product"*)
          val _ = debug DEBUG_PROD ("serial: " ^ (string_of_int n));
          val (inpt_vars_a, inpt_vars_b) = if is_unit inpt_vars then (inpt_vars, inpt_vars) else HOLogic.dest_prod inpt_vars;
          val (rep, th, n') = SimpPredTran inpt_vars_a  out_vars n ctxt S;
          val (rep', th', n'') = SimpPredTran inpt_vars_b  out_vars n' ctxt T;

          val t = Time.now();

          val (prod_rep, prod_th) = pt_prod_th ctxt rep rep';

          val res_th = @{thm prod_cong} OF [th, th', prod_th]

          val _=  debug DEBUG_PROD ("  - time: " ^ (Time.toString (t - Time.now())));
        in
          (prod_rep, res_th, n'')
        end |
      SimpPredTran inpt_vars out_vars n ctxt (Const ("TransitionFeedback.feedback", _) $ S) = 
        let
          val (inpt_vars', n') = if is_unit inpt_vars then (inpt_vars, n) else  (HOLogic.mk_tuple [Free("v__" ^ (string_of_int n) ^ "_", @{typ "'a"}), inpt_vars] , n + 1)
          val (rep, th, n') = SimpPredTran inpt_vars' out_vars n' ctxt S;

          val _ = debug DEBUG_FEEDBACK ("feedback: " ^ (string_of_int n'));
          val t = Time.now();

          val (fb_rep, fb_th) = pt_feedback_th ctxt rep;

          val _ = debug DEBUG_FEEDBACK ("  - time thm:" ^ (Time.toString (t - Time.now())));

          val res_th = @{thm feedback_cong} OF [th, fb_th]

          val _=  debug DEBUG_FEEDBACK ("  - time: " ^ (Time.toString (t - Time.now())));
        in
          (fb_rep, res_th, n')
        end |
      SimpPredTran inpt_vars out_vars n ctxt S = let
        val _ = writeln "No match";
        val _ = writeln (term_to_string ctxt S);
        val _ = raise (ERROR "No match")
        in
          (Func (@{term "()"}, @{term "True"}, @{term "()"}), @{thm prod_cong}, n)
        end;
*}
  
ML{*
(*examples of simplifications*)
    val update = @{term "[- x,y  ↝ x + y, x + 123 -]"};
    val X = SimpPredTran @{term "(x,y)"} @{term "(z,z')"} 10 @{context} update;

    val demonic = @{term "[: x,y ↝ z . x < z ∧ z < y :]"};
    val X = SimpPredTran @{term "(x,y)"} @{term "z"} 10 @{context} update;

(*
    val skip = @{term Skip};
    val X = SimpPredTran 10 ctxt skip;

    val assert = @{term "{.x,y.(x + y = 123).}"};
    val X = SimpPredTran 10 ctxt assert;

    val update = @{term "[-λ (x,y) . (x + y, x + 123)-]"};
    val X = SimpPredTran 10 ctxt update;

    val (update_rep, thm, _)  = SimpPredTran 1 ctxt @{term " [-λ (x,y) . (y, x + 123)-]"};


    val fb_term = @{term "feedback [-λ (x,y) . (y, x + 123)-]"};
    val X = SimpPredTran 10 ctxt fb_term;

    val comp_term = @{term "[-λ (x,y) . (y, x + 123)-] o [-(λ (u,v) . (u*u,v*v,u+v))-]"};
    val X = SimpPredTran 10 ctxt comp_term;

    val fb_comp_term = @{term "feedback ([-λ (x,y) . (y, x + 123)-] o [-(λ (u,v) . (u*u,v*v,u+v))-])"};
    val X = SimpPredTran 10 ctxt fb_comp_term;

    val prod_term = @{term "[-λ (x,y) . (y, x + 123)-] ** [-(λ (u,v) . (u*u,v*v,u+v))-]"};
    val X = SimpPredTran 10 ctxt prod_term;
*)
   *}


lemma eq_eq_tran: "a = b ⟹ b ≡ c ⟹ c = d ⟹ a = d"
  by simp

lemma rename_vars: "Skip = A ⟹ A o B = C ⟹ M = B ⟹ M = C"
    by auto

lemma simp_to_fail: "A = {.p.} o T ⟹ (⋀ x . p x = False) ⟹ A = ⊥"
  by (metis assert_false_fail bot.extremum_uniqueI fail_comp predicate1I)


lemma assert_true_comp: "A = {.p.} o T ⟹ (⋀ x . p x = True) ⟹ A = T"
  apply (subgoal_tac "p = ⊤")
  apply (simp add: assert_true_skip skip_comp)
  by (simp add: fun_eq_iff)

  ML{*
    fun SimpFailSkip ctxt (rep as Func(invar, @{term False}, outexp)) th =
      let
        val _ $ (_ $ A $ _) = Thm.prop_of th;
        val fail = mk_bottom (fastype_of A);
        val th_term =  HOLogic.mk_Trueprop (HOLogic.mk_eq (A, fail));
        val th' = Goal.prove ctxt [] [] th_term (fn _ => resolve_tac ctxt [((@{thm simp_to_fail}) OF [th])] 1
          THEN (unfold_tac ctxt serial_prove_thms)
          THEN (TRY (auto_tac ctxt ))
          );
      in
        th'
      end |
      SimpFailSkip ctxt (rep as Func(invar, @{term True}, outexp)) th =
        let
          (*
          val _ = writeln "simplify assert true"
          *)
          val _ $ (_ $ A $ _) = Thm.prop_of th;
          val update = update_func rep;
          val th_term =  HOLogic.mk_Trueprop (HOLogic.mk_eq (A, update));
          (*
          val _ = writeln ("term: " ^ term_to_string ctxt th_term)
  
          val _ = writeln ("theorem: " ^ thm_to_string ctxt th)
          *)
          val th' = Goal.prove ctxt [] [] th_term (fn _ => resolve_tac ctxt [((@{thm assert_true_comp}) OF [th])] 1
            THEN (unfold_tac ctxt serial_prove_thms)
            THEN (TRY (auto_tac ctxt ))
            );
        in th' end |
    (*demonic*)
      SimpFailSkip ctxt (rep as Rel(invar, outvar, @{term False}, rel)) th =
      let
        val _ $ (_ $ A $ _) = Thm.prop_of th;
        val fail = mk_bottom (fastype_of A);
        val th_term =  HOLogic.mk_Trueprop (HOLogic.mk_eq (A, fail));
        val th' = Goal.prove ctxt [] [] th_term (fn _ => resolve_tac ctxt [((@{thm simp_to_fail}) OF [th])] 1
          THEN (unfold_tac ctxt serial_prove_thms)
          THEN (TRY (auto_tac ctxt ))
          );
      in
        th'
      end |
      SimpFailSkip ctxt (rep as Rel(invar, outvar, @{term True}, rel)) th =
        let
          val _ $ (_ $ A $ _) = Thm.prop_of th;
          val update = demonic_rel rep;
          val th_term =  HOLogic.mk_Trueprop (HOLogic.mk_eq (A, update));
          val th' = Goal.prove ctxt [] [] th_term (fn _ => resolve_tac ctxt [((@{thm assert_true_comp}) OF [th])] 1
            THEN (unfold_tac ctxt serial_prove_thms)
            THEN (TRY (auto_tac ctxt ))
            );
        in th' end |

      SimpFailSkip ctxt _ th = th

  fun thms _ [] = [] |
    thms ctxt (h::t) = (Proof_Context.get_thms ctxt h) @ (thms ctxt t);

val TT = @{term "feedback A"};
fun is_feedback (Const ("TransitionFeedback.feedback", _) $ _) = true | is_feedback _ = false; 

fun SimpFeedback ctxt (rep as Func(invar, _, _)) th term = 
       (*This case for feedback is needed because feedback is simplified using fewer simplification rules.
        This function applies all simplifications to the final result.*)
       (if is_feedback term 
          then
            let
              (*
              val _ = writeln "feedback"
              *)
             val p = prec rep
              val f = func rep
              val (rep', th_b) = pt_assert_update_th ctxt invar p f;
            in
              (rep', @{thm trans} OF [th, th_b])
            end
          else 
            (rep, th)) |
    SimpFeedback _ rep th _ = (rep, th)

    fun SimpSimulinkDef ctxt thmsa inpt out def_th = 
      let 
        val def = def_th
        val _ $ (_ $ _ $ rhs) = Thm.prop_of def;


        val simp_thms = thms ctxt (["basic_simps"]) @ thmsa;
        val th_exp = simplify_only_term ctxt simp_thms rhs;

        val ctxt' = ctxt addsimps simp_thms

        val model_term = Thm.term_of (Thm.rhs_of th_exp);

        val (rep_a, th_a, _) = SimpPredTran inpt out 1 ctxt' model_term;

        (*
        val _ = writeln "done: SimpPredTran"
        *)

        val (rep, th) = SimpFeedback ctxt rep_a th_a model_term;
(*
        val (rep,th) = (rep_a, th_a)
*)
        (*
        val _ = writeln "done: SimpFeedback"
        *)

        val th_a = SimpFailSkip ctxt rep th;

        (*
        val _ = writeln "done: SimpFailSkip"
        *)

        val th' = @{thm eq_eq_tran} OF [def_th, th_exp, th_a]

        val vars = Variable.add_free_names ctxt (Thm.prop_of th') [];
        val th_gen = Thm.generalize ([], vars) ((Thm.maxidx_of th') + 1) th'

(*
        val th_gen = Goal.prove ctxt vars [] (Thm.prop_of th'')(fn _ =>  (unfold_tac ctxt [th''])
          THEN (TRY (auto_tac ctxt ))
          );
*)
      in
        th_gen
      end;
  *}

ML{*

fun replace_tfreeT typ (TFree _) = typ |
  replace_tfreeT typ (TVar _) = typ |
  replace_tfreeT typ (Type (name, typs)) = Type (name, map (replace_tfreeT typ) typs)

val TT = replace_tfreeT @{typ "bool"} @{typ "('a × bool) ⇒ 'b ⇒ 'c"}

fun replace_tfree typ (Const (c, T)) = Const (c, replace_tfreeT typ T)
  | replace_tfree typ (Free (x, T)) = (Free (x, replace_tfreeT typ T))
  | replace_tfree typ (Var (xi, T))  = Var (xi, replace_tfreeT typ T)
  | replace_tfree _ (Bound i) = (Bound i)
  | replace_tfree typ (Abs (x, T, b)) =
      let val b' = replace_tfree typ b
      in Abs (x, replace_tfreeT typ T, b') end
  | replace_tfree typ (t $ u) =
      let
        val (t') = replace_tfree typ t;
        val (u') = replace_tfree typ u;
      in (t' $ u') end;

*}

ML{*

val XX = Local_Defs.derived_def;

fun new_derived_def ctxt get_pos {conditional} prop =
  let
    val ((c, T), rhs) = prop
      |> Thm.cterm_of ctxt
      |> Local_Defs.meta_rewrite_conv ctxt
      |> (snd o Logic.dest_equals o Thm.prop_of)
      |> conditional ? Logic.strip_imp_concl
      |> (Local_Defs.abs_def o #2 o Local_Defs.cert_def ctxt get_pos);
(*
    fun prove ctxt' def =
      Goal.prove ctxt'
        ((not (Variable.is_body ctxt') ? Variable.add_free_names ctxt' prop) []) [] prop
        (fn {context = ctxt'', ...} =>
          ALLGOALS
            (CONVERSION (Local_Defs.meta_rewrite_conv ctxt'') THEN'
              rewrite_goal_tac ctxt'' [def] THEN'
              resolve_tac ctxt'' [Drule.reflexive_thm]))
      handle ERROR msg => cat_error msg "Failed to prove definitional specification";
*)
    fun prove ctxt' vars def =
      Goal.prove ctxt'
        vars [] prop
        (fn {context = ctxt'', ...} =>
          ALLGOALS
            (CONVERSION (Local_Defs.meta_rewrite_conv ctxt'') THEN'
              rewrite_goal_tac ctxt'' [def] THEN'
              resolve_tac ctxt'' [Drule.reflexive_thm]))
      handle ERROR msg => cat_error msg "Failed to prove definitional specification";

  in (((c, T), rhs), prove) end;
*}
  
ML{*
(*to delete next two functions*)
fun DEL_cert_def ctxt  eq =
  let
    fun err msg =
      cat_error msg ("The error(s) above occurred in definition:\n" ^
        quote (Syntax.string_of_term ctxt eq));
    val ((lhs, _), args, eq') = eq
      |> Sign.no_vars ctxt
      |> Primitive_Defs.dest_def ctxt
        {check_head = Term.is_Free,
         check_free_lhs = not o Variable.is_fixed ctxt,
         check_free_rhs = if Variable.is_body ctxt then K true else Variable.is_fixed ctxt,
         check_tfree = K true}
      handle TERM (msg, _) => err msg | ERROR msg => err msg;
(*
    val _ =
      Context_Position.reports ctxt
        (maps (fn Free (x, _) => Syntax_Phases.reports_of_scope (get_pos x) | _ => []) args);
*)
  in (Term.dest_Free (Term.head_of lhs), eq') end;

fun DEL_derived_def ctxt conditional prop =
  let
    val ((c, T), rhs) = prop
      |> Thm.cterm_of ctxt
      |>  Local_Defs.meta_rewrite_conv ctxt
      |> (snd o Logic.dest_equals o Thm.prop_of)
      |> conditional ? Logic.strip_imp_concl
      |> ( Local_Defs.abs_def o #2 o  DEL_cert_def ctxt);
    fun prove ctxt' vars def =
      Goal.prove ctxt' vars [] prop
        (fn {context = ctxt'', ...} =>
          ALLGOALS
            (CONVERSION ( Local_Defs.meta_rewrite_conv ctxt'') THEN'
              rewrite_goal_tac ctxt'' [def] THEN'
              resolve_tac ctxt'' [Drule.reflexive_thm]))
      handle ERROR msg => cat_error msg "Failed to prove definitional specification";
  in (((c, T), rhs), prove) end;
*}

ML{*

fun new_gen_def prep_spec prep_att raw_var raw_params raw_prems ((a, raw_atts), raw_spec) int lthy =
  let
    val atts = map (prep_att lthy) raw_atts;

    val ((vars, xs, get_pos, spec), _) = prep_spec (the_list raw_var) raw_params raw_prems raw_spec lthy;

    val (((x, T), rhs), prove) = Local_Defs.derived_def lthy get_pos {conditional = true} spec;
    val _ = Name.reject_internal (x, []);

    val (b, mx) =
      (case (vars, xs) of
        ([], []) => (Binding.make (x, (case get_pos x of [] => Position.none | p :: _ => p)), NoSyn)
      | ([(b, _, mx)], [y]) =>
          if x = y then (b, mx)
          else
            error ("Head of definition " ^ quote x ^ " differs from declaration " ^ quote y ^
              Position.here (Binding.pos_of b)));
    val name = Thm.def_binding_optional b a;

    val ((lhs, (_, raw_th)), lthy2) = lthy
      |> Local_Theory.define_internal ((b, mx), ((Binding.suffix_name "_raw" name, []), rhs));

    val th = prove lthy2 raw_th;

    val lthy3 = lthy2 |> Spec_Rules.add Spec_Rules.Equational ([lhs], [th]);
(*
    val ([(def_name, [th'])], lthy4) = lthy3
      |> Local_Theory.notes [((name, Code.add_default_eqn_attrib Code.Equation :: atts), [([th], [])])];
*)
    val ([(def_name, [th'])], lthy4) = lthy3
      |> Local_Theory.notes [((name, atts), [([th], [])])];

    val lhs' = Morphism.term (Local_Theory.target_morphism lthy4) lhs;

    val _ =
      Proof_Display.print_consts int (Position.thread_data ()) lthy4
        (member (op =) (Term.add_frees lhs' [])) [(x, T)];

  in ((lhs, (def_name, th')), lthy4) end;

fun def_binding_optional b name = if Binding.is_empty name then Thm.def_binding b else name;

fun simulink_definition_only_aux str_typ (raw_var, raw_spec) int lthy =
  let
    val atts = []; (*map (prep_att lthy) raw_atts;*)

    val ((vars, xs, get_pos, specA), _) = Specification.read_spec_open (the_list raw_var) [] [] raw_spec lthy;

(*
    val ((vars, [((raw_name, DELatts), DELprop)]), DELget_pos) = fst (prep (the_list raw_var) [raw_spec] lthy);
*)
    (*if str_typ is nonempty, replace type variables*)
    (*val _ = writeln ("type: " ^ str_typ); *)
    val prop = if str_typ = "" then specA else
      let
        (*val typ = Type ("custom_t", [Syntax.read_typ lthy str_typ])*)
        val typ = Syntax.read_typ lthy str_typ
      in
        replace_tfree typ specA
      end;
(*
    val (((x, T), rhs), prove) = derived_def lthy true prop;
*)
(*    val (((x, T), rhs), prove) = new_derived_def lthy get_pos {conditional = true} spec;*)
    val (((x, T), rhs), prove) = new_derived_def lthy get_pos {conditional = true} prop;
    val _ = Name.reject_internal (x, []);

(*
    val var as (b, _) =
      (case vars of
        [] => (Binding.make (x, get_pos x), NoSyn)
      | [((b, _), mx)] =>
          let
            val y = Variable.check_name b;
            val _ = x = y orelse
              error ("Head of definition " ^ quote x ^ " differs from declaration " ^ quote y ^
                Position.here (Binding.pos_of b));
          in (b, mx) end);
    val name = Binding.reset_pos (Thm.def_binding_optional b raw_name);
*)
    val (b, mx) =
      (case (vars, xs) of
        ([], []) => (Binding.make (x, (case get_pos x of [] => Position.none | p :: _ => p)), NoSyn)
      | ([(b, _, mx)], [y]) =>
          if x = y then (b, mx)
          else
            error ("Head of definition " ^ quote x ^ " differs from declaration " ^ quote y ^
              Position.here (Binding.pos_of b)));

    (*
    val name = Thm.def_binding_optional b a;
    *)
    val name = Thm.def_binding b;
    val const_name = b;

(*
    val ((lhs, (_, raw_th)), lthy2) = lthy
      |> Local_Theory.define_internal (var, ((Binding.suffix_name "_raw" name, []), rhs));
*)

    val ((lhs, (_, raw_th)), lthy2) = lthy
      |> Local_Theory.define_internal ((b, mx), ((Binding.suffix_name "_raw" name, []), rhs));


    val th = prove lthy2 [] raw_th;

(*
    val th_gen = prove lthy2 (Variable.add_free_names lthy2 prop []) raw_th;
*)
    val vars = Variable.add_free_names lthy2 prop [];
    val th_gen = Thm.generalize ([], vars) ((Thm.maxidx_of th) + 1) th
    val lthy3 = lthy2 |> Spec_Rules.add Spec_Rules.Equational ([lhs], [th_gen]);

(*
    val ([(def_name, [th'])], lthy4) = lthy3
      |> Local_Theory.notes [((name, Code.add_default_eqn_attrib Code.Equation :: atts), [([th], [])])];
*)
(*
    val ([(def_name, [th_gen'])], lthy4) = lthy3
      |> Local_Theory.notes [((name, Code.add_default_eqn_attrib Code.Equation :: atts), [([th_gen], [])])];
*)
    val ([(def_name, [th_gen'])], lthy4) = lthy3
      |> Local_Theory.notes [((name, atts), [([th_gen], [])])];

    val lhs' = Morphism.term (Local_Theory.target_morphism lthy4) lhs;

  in 
    (lthy4, th, th_gen', lhs, prop, def_name, const_name)
  end;
*}
  
ML{*

fun simulink_definition_only str_typ (a, raw_spec) (int:bool) lthy = simulink_definition_only_aux str_typ (NONE, raw_spec) int lthy;

(*
fun simulink_definition_only str_typ raw_spec (int:bool) lthy = simulink_definition_only_aux str_typ Specification.read_free_spec (NONE, raw_spec) int lthy;
*)

fun simulink_def ((((invar, outvar), th_name_list:string list), str_typ), raw_spec) int lthy =
  let

    val (lthy', th, th_gen, lhs, prop, def_name, const_name) = simulink_definition_only str_typ raw_spec int lthy;
    val _ = debug 0 "end definition";
    val simp_thms = thms lthy' th_name_list;
    val inpt = Syntax.read_term lthy' invar
    val out = Syntax.read_term lthy' outvar
    val res_thm = SimpSimulinkDef lthy' simp_thms inpt out th;
    val name_sc_c = (Binding.name ((Binding.name_of const_name) ^ "_simp"));
    val ([(_, [_])], lthy_out) = lthy' |> Local_Theory.notes [((name_sc_c, []), [([res_thm], [])])];
  in ((lhs, (def_name, res_thm)), lthy_out) end;

(*only input variables*)
fun simulinkf_def ((((invar), th_name_list:string list), str_typ), raw_spec) int lthy =
  simulink_def ((((invar, ""), th_name_list:string list), str_typ), raw_spec) int lthy;

*}

lemma test_types: "(a::real) = a ∧ b + 0 = b + 0 ∧ (c :: 'a ⇒ 'b) = c"
  by simp

declare [[show_types]]
(*test type instantiation*)
ML{*
    val asseert_update_comp_abs = @{thm asseert_update_comp_abs};
    val  [Var_f'', Var_f', Var_p'', Var_f, Var_p', Var_p]  = (Thm.fold_terms Term.add_vars asseert_update_comp_abs []);
    val TT = Specification.definition_cmd

    val TT = Term.add_tvarsT;

    fun UU t = Term.add_tvarsT (Thm.typ_of t);

    val  A  = (Thm.fold_atomic_ctyps (Term.add_tvarsT o Thm.typ_of) asseert_update_comp_abs []);
    val th = @{thm "test_types"}
    val  [A_t,B_t,C_t]  = (Thm.fold_atomic_ctyps (Term.add_tvarsT o Thm.typ_of) th []);

    val th' = Drule.instantiate_normalize ([(B_t, @{ctyp "real"}), (A_t, @{ctyp "real"}), (C_t, @{ctyp "real"})],[]) th;

*}

declare [[show_types=false]]
  
ML{*

*}
  
ML{*
(*
fun opt_thm_name s =
  Scan.optional
    ((Parse.binding -- Parse.opt_attribs || Parse.attribs >> pair Binding.empty) --| Parse.$$$ s)
    Attrib.empty_binding;
*)
fun opt_thm_name s =
  Scan.optional ((Parse.binding -- Parse.opt_attribs || Parse.attribs >> pair Binding.empty) --| Parse.$$$ s) (Binding.empty, []);


val parse_use = Scan.optional(Parse.reserved "use" |-- Parse.$$$ "(" |-- (Scan.repeat1 Parse.name) --| Parse.$$$ ")") [];

val parse_type = Scan.optional(Parse.reserved "type" |-- Parse.name) "";

val simulinkdefC = (opt_thm_name ":" -- Parse.prop) -- (Parse.term -- Parse.term -- parse_use -- parse_type);
val simulinkdefC' = (opt_thm_name ":" -- Parse.prop) -- (Parse.term -- Parse.term -- parse_use);

fun simulinkdefB t =
  let
   val (((a, b), c)) = simulinkdefC t
   in ((b, a), c)
  end;

val _ =
  Outer_Syntax.local_theory' @{command_keyword simplify_RCRS} "simulink definition"
    (simulinkdefB >> (fn args => snd oo simulink_def args));

(*Only input variables*)
val simulinkdefD = (opt_thm_name ":" -- Parse.prop) -- (Parse.term -- parse_use -- parse_type);

fun simulinkdefE t =
  let
   val (((a, b), c)) = simulinkdefD t
   in ((b, a), c)
  end;

val _ =
  Outer_Syntax.local_theory' @{command_keyword simplify_RCRS_f} "simulink definition"
    (simulinkdefE >> (fn args => snd oo simulinkf_def args));
*}

  ML{*

  fun get_last s =
    let
      val n = size s in
      if n = 0 then ""
      else
        if String.sub (s, n - 1) = #"." then ""
        else (get_last (String.extract (s, 0, SOME (n - 1)))) ^ String.extract(s, n - 1, NONE)
   end;
  
  fun nat_of_num (Const ("Num.num.One", _)) = 1 |
    nat_of_num (Const ("Num.num.Bit0", _) $ x) = 2 * nat_of_num x | 
    nat_of_num (Const ("Num.num.Bit1", _) $ x) = 2 * nat_of_num x + 1;

  fun 
    term_to_python _ (Const ("Groups.zero_class.zero", _)) = "0"
    | term_to_python _ (Const ("Groups.one_class.one", _)) = "1"
    | term_to_python _ (Const ("Groups.uminus_class.uminus", _)) = "-"
    | term_to_python _ (Const (c, _)) = get_last c
    | term_to_python _ (Free (x, _)) = x
    | term_to_python _ (Var ((x, _), _))  = x
    | term_to_python _ (Const ("Num.numeral_class.numeral", _) $ u) = Int.toString (nat_of_num u)
(*
    | term_to_python _ (Bound i) = "Bound i"
    | term_to_python _ (Abs (x, T, b)) = "Abs"
*)
    | term_to_python par (Const ("Product_Type.Pair", _) $ t $ u) = (if par then "(" else "") ^ (term_to_python true t) ^ ", " ^ (term_to_python false u) ^  (if par then ")" else "")
    | term_to_python _ (t $ u) = (term_to_python true t) ^ "(" ^ (term_to_python false u) ^ ")";
 
    fun write_prec_func_python file vars prec func = 
      let
        val vars_s = "(" ^ (term_to_python false vars) ^ ", dt)";
        val prec_s = (term_to_python true prec);
        val func_s = (term_to_python true func);
        val writestream = TextIO.openOut file;
        val _ = TextIO.output (writestream, "def step " ^ vars_s ^ ":\n");
        val _ = TextIO.output (writestream, "\tassert (" ^ prec_s ^ ")\n");
        val _ = TextIO.output (writestream, "\treturn " ^ func_s ^ "\n");
        val _ = TextIO.closeOut writestream;
      in
        ()
      end;

fun replace_schematicT (TFree a) = TFree a |
  replace_schematicT (TVar ((x,i),s)) = TFree (x,s) |
  replace_schematicT (Type (name, typs)) = Type (name, map (replace_schematicT) typs)

val TT = replace_schematicT @{typ "('a × bool) ⇒ 'b ⇒ 'c"}

fun replace_schematic (Const (c, T)) = Const (c, replace_schematicT T)
  | replace_schematic (Free (x, T)) = (Free (x, replace_schematicT T))
  | replace_schematic (Var ((x,_), T))  = Free (x, replace_schematicT T)
  | replace_schematic (Bound i) = (Bound i)
  | replace_schematic (Abs (x, T, b)) =
      let val b' = replace_schematic b
      in Abs (x, replace_schematicT T, b') end
  | replace_schematic (t $ u) =
      let
        val (t') = replace_schematic t;
        val (u') = replace_schematic u;
      in (t' $ u') end;
*}


ML{*
    fun get_output_file_path ctxt filename =  File.platform_path (Path.append (Resources.master_directory (Proof_Context.theory_of ctxt)) (Path.explode filename));
    fun write_thm_python file ctxt thm invars = 
      let
        val  _ $ (_ $ _ $ rhs) = Thm.prop_of thm;
        val rhs' = replace_schematic rhs;
        val (Func (vars, p, f), _, _) = SimpPredTran invars invars 1 ctxt rhs';
        val (res, _) = (write_prec_func_python (get_output_file_path ctxt file) vars p f,  writeln ("The file " ^ file ^ " was written successfully"))
          handle exn => ((), writeln ("Write operation fails, please try deleting existing copy of the file " ^ file ^", as it may be non-over-writable."));
      in 
        res
      end;
*}

  
  ML{*
    fun all_tac thm = Seq.single thm;
    fun unfold_safe_auto_tac ctxt = (simp_tac ((Raw_Simplifier.clear_simpset ctxt) addsimps (thms ctxt ["equiv_pt_simps"])) 1) THEN (TRY ((resolve_tac ctxt @{thms "spec_eq_iff"} 1) THEN (unfold_tac ctxt [@{thm "fun_eq_iff"}]) THEN (safe_tac ctxt) THEN (auto_tac ctxt)));
     val _ =
      Theory.setup
        ( Method.setup @{binding unfold_safe_auto} (Classical.cla_method (CHANGED_PROP o unfold_safe_auto_tac))
        "classical prover (apply unfold - safe - auto rules)");
  *}  

end