run_loop.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. #!/usr/bin/env python3
  2. """Run the eval + improve loop until all pass or max iterations reached.
  3. Combines run_eval.py and improve_description.py in a loop, tracking history
  4. and returning the best description found. Supports train/test split to prevent
  5. overfitting.
  6. """
  7. import argparse
  8. import json
  9. import random
  10. import sys
  11. import tempfile
  12. import time
  13. import webbrowser
  14. from pathlib import Path
  15. from scripts.generate_report import generate_html
  16. from scripts.improve_description import improve_description
  17. from scripts.run_eval import find_project_root, run_eval
  18. from scripts.utils import parse_skill_md
  19. def split_eval_set(eval_set: list[dict], holdout: float, seed: int = 42) -> tuple[list[dict], list[dict]]:
  20. """Split eval set into train and test sets, stratified by should_trigger."""
  21. random.seed(seed)
  22. # Separate by should_trigger
  23. trigger = [e for e in eval_set if e["should_trigger"]]
  24. no_trigger = [e for e in eval_set if not e["should_trigger"]]
  25. # Shuffle each group
  26. random.shuffle(trigger)
  27. random.shuffle(no_trigger)
  28. # Calculate split points
  29. n_trigger_test = max(1, int(len(trigger) * holdout))
  30. n_no_trigger_test = max(1, int(len(no_trigger) * holdout))
  31. # Split
  32. test_set = trigger[:n_trigger_test] + no_trigger[:n_no_trigger_test]
  33. train_set = trigger[n_trigger_test:] + no_trigger[n_no_trigger_test:]
  34. return train_set, test_set
  35. def run_loop(
  36. eval_set: list[dict],
  37. skill_path: Path,
  38. description_override: str | None,
  39. num_workers: int,
  40. timeout: int,
  41. max_iterations: int,
  42. runs_per_query: int,
  43. trigger_threshold: float,
  44. holdout: float,
  45. model: str,
  46. verbose: bool,
  47. live_report_path: Path | None = None,
  48. log_dir: Path | None = None,
  49. ) -> dict:
  50. """Run the eval + improvement loop."""
  51. project_root = find_project_root()
  52. name, original_description, content = parse_skill_md(skill_path)
  53. current_description = description_override or original_description
  54. # Split into train/test if holdout > 0
  55. if holdout > 0:
  56. train_set, test_set = split_eval_set(eval_set, holdout)
  57. if verbose:
  58. print(f"Split: {len(train_set)} train, {len(test_set)} test (holdout={holdout})", file=sys.stderr)
  59. else:
  60. train_set = eval_set
  61. test_set = []
  62. history = []
  63. exit_reason = "unknown"
  64. for iteration in range(1, max_iterations + 1):
  65. if verbose:
  66. print(f"\n{'='*60}", file=sys.stderr)
  67. print(f"Iteration {iteration}/{max_iterations}", file=sys.stderr)
  68. print(f"Description: {current_description}", file=sys.stderr)
  69. print(f"{'='*60}", file=sys.stderr)
  70. # Evaluate train + test together in one batch for parallelism
  71. all_queries = train_set + test_set
  72. t0 = time.time()
  73. all_results = run_eval(
  74. eval_set=all_queries,
  75. skill_name=name,
  76. description=current_description,
  77. num_workers=num_workers,
  78. timeout=timeout,
  79. project_root=project_root,
  80. runs_per_query=runs_per_query,
  81. trigger_threshold=trigger_threshold,
  82. model=model,
  83. )
  84. eval_elapsed = time.time() - t0
  85. # Split results back into train/test by matching queries
  86. train_queries_set = {q["query"] for q in train_set}
  87. train_result_list = [r for r in all_results["results"] if r["query"] in train_queries_set]
  88. test_result_list = [r for r in all_results["results"] if r["query"] not in train_queries_set]
  89. train_passed = sum(1 for r in train_result_list if r["pass"])
  90. train_total = len(train_result_list)
  91. train_summary = {"passed": train_passed, "failed": train_total - train_passed, "total": train_total}
  92. train_results = {"results": train_result_list, "summary": train_summary}
  93. if test_set:
  94. test_passed = sum(1 for r in test_result_list if r["pass"])
  95. test_total = len(test_result_list)
  96. test_summary = {"passed": test_passed, "failed": test_total - test_passed, "total": test_total}
  97. test_results = {"results": test_result_list, "summary": test_summary}
  98. else:
  99. test_results = None
  100. test_summary = None
  101. history.append({
  102. "iteration": iteration,
  103. "description": current_description,
  104. "train_passed": train_summary["passed"],
  105. "train_failed": train_summary["failed"],
  106. "train_total": train_summary["total"],
  107. "train_results": train_results["results"],
  108. "test_passed": test_summary["passed"] if test_summary else None,
  109. "test_failed": test_summary["failed"] if test_summary else None,
  110. "test_total": test_summary["total"] if test_summary else None,
  111. "test_results": test_results["results"] if test_results else None,
  112. # For backward compat with report generator
  113. "passed": train_summary["passed"],
  114. "failed": train_summary["failed"],
  115. "total": train_summary["total"],
  116. "results": train_results["results"],
  117. })
  118. # Write live report if path provided
  119. if live_report_path:
  120. partial_output = {
  121. "original_description": original_description,
  122. "best_description": current_description,
  123. "best_score": "in progress",
  124. "iterations_run": len(history),
  125. "holdout": holdout,
  126. "train_size": len(train_set),
  127. "test_size": len(test_set),
  128. "history": history,
  129. }
  130. live_report_path.write_text(generate_html(partial_output, auto_refresh=True, skill_name=name))
  131. if verbose:
  132. def print_eval_stats(label, results, elapsed):
  133. pos = [r for r in results if r["should_trigger"]]
  134. neg = [r for r in results if not r["should_trigger"]]
  135. tp = sum(r["triggers"] for r in pos)
  136. pos_runs = sum(r["runs"] for r in pos)
  137. fn = pos_runs - tp
  138. fp = sum(r["triggers"] for r in neg)
  139. neg_runs = sum(r["runs"] for r in neg)
  140. tn = neg_runs - fp
  141. total = tp + tn + fp + fn
  142. precision = tp / (tp + fp) if (tp + fp) > 0 else 1.0
  143. recall = tp / (tp + fn) if (tp + fn) > 0 else 1.0
  144. accuracy = (tp + tn) / total if total > 0 else 0.0
  145. print(f"{label}: {tp+tn}/{total} correct, precision={precision:.0%} recall={recall:.0%} accuracy={accuracy:.0%} ({elapsed:.1f}s)", file=sys.stderr)
  146. for r in results:
  147. status = "PASS" if r["pass"] else "FAIL"
  148. rate_str = f"{r['triggers']}/{r['runs']}"
  149. print(f" [{status}] rate={rate_str} expected={r['should_trigger']}: {r['query'][:60]}", file=sys.stderr)
  150. print_eval_stats("Train", train_results["results"], eval_elapsed)
  151. if test_summary:
  152. print_eval_stats("Test ", test_results["results"], 0)
  153. if train_summary["failed"] == 0:
  154. exit_reason = f"all_passed (iteration {iteration})"
  155. if verbose:
  156. print(f"\nAll train queries passed on iteration {iteration}!", file=sys.stderr)
  157. break
  158. if iteration == max_iterations:
  159. exit_reason = f"max_iterations ({max_iterations})"
  160. if verbose:
  161. print(f"\nMax iterations reached ({max_iterations}).", file=sys.stderr)
  162. break
  163. # Improve the description based on train results
  164. if verbose:
  165. print(f"\nImproving description...", file=sys.stderr)
  166. t0 = time.time()
  167. # Strip test scores from history so improvement model can't see them
  168. blinded_history = [
  169. {k: v for k, v in h.items() if not k.startswith("test_")}
  170. for h in history
  171. ]
  172. new_description = improve_description(
  173. skill_name=name,
  174. skill_content=content,
  175. current_description=current_description,
  176. eval_results=train_results,
  177. history=blinded_history,
  178. model=model,
  179. log_dir=log_dir,
  180. iteration=iteration,
  181. )
  182. improve_elapsed = time.time() - t0
  183. if verbose:
  184. print(f"Proposed ({improve_elapsed:.1f}s): {new_description}", file=sys.stderr)
  185. current_description = new_description
  186. # Find the best iteration by TEST score (or train if no test set)
  187. if test_set:
  188. best = max(history, key=lambda h: h["test_passed"] or 0)
  189. best_score = f"{best['test_passed']}/{best['test_total']}"
  190. else:
  191. best = max(history, key=lambda h: h["train_passed"])
  192. best_score = f"{best['train_passed']}/{best['train_total']}"
  193. if verbose:
  194. print(f"\nExit reason: {exit_reason}", file=sys.stderr)
  195. print(f"Best score: {best_score} (iteration {best['iteration']})", file=sys.stderr)
  196. return {
  197. "exit_reason": exit_reason,
  198. "original_description": original_description,
  199. "best_description": best["description"],
  200. "best_score": best_score,
  201. "best_train_score": f"{best['train_passed']}/{best['train_total']}",
  202. "best_test_score": f"{best['test_passed']}/{best['test_total']}" if test_set else None,
  203. "final_description": current_description,
  204. "iterations_run": len(history),
  205. "holdout": holdout,
  206. "train_size": len(train_set),
  207. "test_size": len(test_set),
  208. "history": history,
  209. }
  210. def main():
  211. parser = argparse.ArgumentParser(description="Run eval + improve loop")
  212. parser.add_argument("--eval-set", required=True, help="Path to eval set JSON file")
  213. parser.add_argument("--skill-path", required=True, help="Path to skill directory")
  214. parser.add_argument("--description", default=None, help="Override starting description")
  215. parser.add_argument("--num-workers", type=int, default=10, help="Number of parallel workers")
  216. parser.add_argument("--timeout", type=int, default=30, help="Timeout per query in seconds")
  217. parser.add_argument("--max-iterations", type=int, default=5, help="Max improvement iterations")
  218. parser.add_argument("--runs-per-query", type=int, default=3, help="Number of runs per query")
  219. parser.add_argument("--trigger-threshold", type=float, default=0.5, help="Trigger rate threshold")
  220. parser.add_argument("--holdout", type=float, default=0.4, help="Fraction of eval set to hold out for testing (0 to disable)")
  221. parser.add_argument("--model", required=True, help="Model for improvement")
  222. parser.add_argument("--verbose", action="store_true", help="Print progress to stderr")
  223. parser.add_argument("--report", default="auto", help="Generate HTML report at this path (default: 'auto' for temp file, 'none' to disable)")
  224. parser.add_argument("--results-dir", default=None, help="Save all outputs (results.json, report.html, log.txt) to a timestamped subdirectory here")
  225. args = parser.parse_args()
  226. eval_set = json.loads(Path(args.eval_set).read_text())
  227. skill_path = Path(args.skill_path)
  228. if not (skill_path / "SKILL.md").exists():
  229. print(f"Error: No SKILL.md found at {skill_path}", file=sys.stderr)
  230. sys.exit(1)
  231. name, _, _ = parse_skill_md(skill_path)
  232. # Set up live report path
  233. if args.report != "none":
  234. if args.report == "auto":
  235. timestamp = time.strftime("%Y%m%d_%H%M%S")
  236. live_report_path = Path(tempfile.gettempdir()) / f"skill_description_report_{skill_path.name}_{timestamp}.html"
  237. else:
  238. live_report_path = Path(args.report)
  239. # Open the report immediately so the user can watch
  240. live_report_path.write_text("<html><body><h1>Starting optimization loop...</h1><meta http-equiv='refresh' content='5'></body></html>")
  241. webbrowser.open(str(live_report_path))
  242. else:
  243. live_report_path = None
  244. # Determine output directory (create before run_loop so logs can be written)
  245. if args.results_dir:
  246. timestamp = time.strftime("%Y-%m-%d_%H%M%S")
  247. results_dir = Path(args.results_dir) / timestamp
  248. results_dir.mkdir(parents=True, exist_ok=True)
  249. else:
  250. results_dir = None
  251. log_dir = results_dir / "logs" if results_dir else None
  252. output = run_loop(
  253. eval_set=eval_set,
  254. skill_path=skill_path,
  255. description_override=args.description,
  256. num_workers=args.num_workers,
  257. timeout=args.timeout,
  258. max_iterations=args.max_iterations,
  259. runs_per_query=args.runs_per_query,
  260. trigger_threshold=args.trigger_threshold,
  261. holdout=args.holdout,
  262. model=args.model,
  263. verbose=args.verbose,
  264. live_report_path=live_report_path,
  265. log_dir=log_dir,
  266. )
  267. # Save JSON output
  268. json_output = json.dumps(output, indent=2)
  269. print(json_output)
  270. if results_dir:
  271. (results_dir / "results.json").write_text(json_output)
  272. # Write final HTML report (without auto-refresh)
  273. if live_report_path:
  274. live_report_path.write_text(generate_html(output, auto_refresh=False, skill_name=name))
  275. print(f"\nReport: {live_report_path}", file=sys.stderr)
  276. if results_dir and live_report_path:
  277. (results_dir / "report.html").write_text(generate_html(output, auto_refresh=False, skill_name=name))
  278. if results_dir:
  279. print(f"Results saved to: {results_dir}", file=sys.stderr)
  280. if __name__ == "__main__":
  281. main()