(* Title: Provers/Arith/assoc_fold.ML ID: $Id: assoc_fold.ML,v 1.11 2005/08/01 17:20:26 wenzelm Exp $ Author: Lawrence C Paulson, Cambridge University Computer Laboratory Copyright 1999 University of Cambridge Simplification procedure for associative operators + and * on numeric types Performs constant folding when the literals are separated, as in 3+n+4. *) signature ASSOC_FOLD_DATA = sig val ss : simpset (*basic simpset of object-logtic*) val eq_reflection : thm (*object-equality to meta-equality*) val thy_ref : theory_ref (*the operator's signature*) val add_ac : thm list (*AC-rewrites for plus*) end; functor Assoc_Fold (Data: ASSOC_FOLD_DATA) = struct val assoc_ss = Data.ss addsimps Data.add_ac; exception Assoc_fail; fun mk_sum plus [] = raise Assoc_fail | mk_sum plus tms = foldr1 (fn (x,y) => plus $ x $ y) tms; (*Separate the literals from the other terms being combined*) fun sift_terms plus (t, (lits,others)) = case t of Const("Numeral.number_of", _) $ _ => (t::lits, others) (*new literal*) | (f as Const _) $ x $ y => if f = plus then sift_terms plus (x, sift_terms plus (y, (lits,others))) else (lits, t::others) (*arbitrary summand*) | _ => (lits, t::others); val trace = ref false; (*Make a simproc to combine all literals in a associative nest*) fun proc thy ss lhs = let fun show t = string_of_cterm (Thm.cterm_of thy t) val _ = if !trace then tracing ("assoc_fold simproc: LHS = " ^ show lhs) else () val plus = (case lhs of f $ _ $ _ => f | _ => error "Assoc_fold: bad pattern") val (lits,others) = sift_terms plus (lhs, ([],[])) val _ = if length lits < 2 then raise Assoc_fail (*we can't reduce the number of terms*) else () val rhs = plus $ mk_sum plus lits $ mk_sum plus others val _ = if !trace then tracing ("RHS = " ^ show rhs) else () val th = Tactic.prove thy [] [] (Logic.mk_equals (lhs, rhs)) (fn _ => rtac Data.eq_reflection 1 THEN simp_tac (Simplifier.inherit_bounds ss assoc_ss) 1) in SOME th end handle Assoc_fail => NONE; end; (*test data: set timing; Goal "(#3 * (a * #34)) * (#2 * b * #9) = (x::int)"; Goal "a + b + c + d + e + f + g + h + i + j + k + l + m + n + oo + p + q + r + s + t + u + v + (w + x + y + z + a + #2 + b + #2 + c + #2 + d + #2 + e) + #2 + f + (#2 + g + #2 + h + #2 + i) + #2 + (j + #2 + k + #2 + l + #2 + m + #2) + n + #2 + (oo + #2 + p + #2 + q + #2 + r) + #2 + s + #2 + t + #2 + u + #2 + v + #2 + w + #2 + x + #2 + y + #2 + z + #2 = (uu::nat)"; *)