run_eval.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. #!/usr/bin/env python3
  2. """Run trigger evaluation for a skill description.
  3. Tests whether a skill's description causes Claude to trigger (read the skill)
  4. for a set of queries. Outputs results as JSON.
  5. """
  6. import argparse
  7. import json
  8. import os
  9. import select
  10. import subprocess
  11. import sys
  12. import time
  13. import uuid
  14. from concurrent.futures import ProcessPoolExecutor, as_completed
  15. from pathlib import Path
  16. from scripts.utils import parse_skill_md
  17. def find_project_root() -> Path:
  18. """Find the project root by walking up from cwd looking for .claude/.
  19. Mimics how Claude Code discovers its project root, so the command file
  20. we create ends up where claude -p will look for it.
  21. """
  22. current = Path.cwd()
  23. for parent in [current, *current.parents]:
  24. if (parent / ".claude").is_dir():
  25. return parent
  26. return current
  27. def run_single_query(
  28. query: str,
  29. skill_name: str,
  30. skill_description: str,
  31. timeout: int,
  32. project_root: str,
  33. model: str | None = None,
  34. ) -> bool:
  35. """Run a single query and return whether the skill was triggered.
  36. Creates a command file in .claude/commands/ so it appears in Claude's
  37. available_skills list, then runs `claude -p` with the raw query.
  38. Uses --include-partial-messages to detect triggering early from
  39. stream events (content_block_start) rather than waiting for the
  40. full assistant message, which only arrives after tool execution.
  41. """
  42. unique_id = uuid.uuid4().hex[:8]
  43. clean_name = f"{skill_name}-skill-{unique_id}"
  44. project_commands_dir = Path(project_root) / ".claude" / "commands"
  45. command_file = project_commands_dir / f"{clean_name}.md"
  46. try:
  47. project_commands_dir.mkdir(parents=True, exist_ok=True)
  48. # Use YAML block scalar to avoid breaking on quotes in description
  49. indented_desc = "\n ".join(skill_description.split("\n"))
  50. command_content = (
  51. f"---\n"
  52. f"description: |\n"
  53. f" {indented_desc}\n"
  54. f"---\n\n"
  55. f"# {skill_name}\n\n"
  56. f"This skill handles: {skill_description}\n"
  57. )
  58. command_file.write_text(command_content)
  59. cmd = [
  60. "claude",
  61. "-p", query,
  62. "--output-format", "stream-json",
  63. "--verbose",
  64. "--include-partial-messages",
  65. ]
  66. if model:
  67. cmd.extend(["--model", model])
  68. # Remove CLAUDECODE env var to allow nesting claude -p inside a
  69. # Claude Code session. The guard is for interactive terminal conflicts;
  70. # programmatic subprocess usage is safe.
  71. env = {k: v for k, v in os.environ.items() if k != "CLAUDECODE"}
  72. process = subprocess.Popen(
  73. cmd,
  74. stdout=subprocess.PIPE,
  75. stderr=subprocess.DEVNULL,
  76. cwd=project_root,
  77. env=env,
  78. )
  79. triggered = False
  80. start_time = time.time()
  81. buffer = ""
  82. # Track state for stream event detection
  83. pending_tool_name = None
  84. accumulated_json = ""
  85. try:
  86. while time.time() - start_time < timeout:
  87. if process.poll() is not None:
  88. remaining = process.stdout.read()
  89. if remaining:
  90. buffer += remaining.decode("utf-8", errors="replace")
  91. break
  92. ready, _, _ = select.select([process.stdout], [], [], 1.0)
  93. if not ready:
  94. continue
  95. chunk = os.read(process.stdout.fileno(), 8192)
  96. if not chunk:
  97. break
  98. buffer += chunk.decode("utf-8", errors="replace")
  99. while "\n" in buffer:
  100. line, buffer = buffer.split("\n", 1)
  101. line = line.strip()
  102. if not line:
  103. continue
  104. try:
  105. event = json.loads(line)
  106. except json.JSONDecodeError:
  107. continue
  108. # Early detection via stream events
  109. if event.get("type") == "stream_event":
  110. se = event.get("event", {})
  111. se_type = se.get("type", "")
  112. if se_type == "content_block_start":
  113. cb = se.get("content_block", {})
  114. if cb.get("type") == "tool_use":
  115. tool_name = cb.get("name", "")
  116. if tool_name in ("Skill", "Read"):
  117. pending_tool_name = tool_name
  118. accumulated_json = ""
  119. else:
  120. return False
  121. elif se_type == "content_block_delta" and pending_tool_name:
  122. delta = se.get("delta", {})
  123. if delta.get("type") == "input_json_delta":
  124. accumulated_json += delta.get("partial_json", "")
  125. if clean_name in accumulated_json:
  126. return True
  127. elif se_type in ("content_block_stop", "message_stop"):
  128. if pending_tool_name:
  129. return clean_name in accumulated_json
  130. if se_type == "message_stop":
  131. return False
  132. # Fallback: full assistant message
  133. elif event.get("type") == "assistant":
  134. message = event.get("message", {})
  135. for content_item in message.get("content", []):
  136. if content_item.get("type") != "tool_use":
  137. continue
  138. tool_name = content_item.get("name", "")
  139. tool_input = content_item.get("input", {})
  140. if tool_name == "Skill" and clean_name in tool_input.get("skill", ""):
  141. triggered = True
  142. elif tool_name == "Read" and clean_name in tool_input.get("file_path", ""):
  143. triggered = True
  144. return triggered
  145. elif event.get("type") == "result":
  146. return triggered
  147. finally:
  148. # Clean up process on any exit path (return, exception, timeout)
  149. if process.poll() is None:
  150. process.kill()
  151. process.wait()
  152. return triggered
  153. finally:
  154. if command_file.exists():
  155. command_file.unlink()
  156. def run_eval(
  157. eval_set: list[dict],
  158. skill_name: str,
  159. description: str,
  160. num_workers: int,
  161. timeout: int,
  162. project_root: Path,
  163. runs_per_query: int = 1,
  164. trigger_threshold: float = 0.5,
  165. model: str | None = None,
  166. ) -> dict:
  167. """Run the full eval set and return results."""
  168. results = []
  169. with ProcessPoolExecutor(max_workers=num_workers) as executor:
  170. future_to_info = {}
  171. for item in eval_set:
  172. for run_idx in range(runs_per_query):
  173. future = executor.submit(
  174. run_single_query,
  175. item["query"],
  176. skill_name,
  177. description,
  178. timeout,
  179. str(project_root),
  180. model,
  181. )
  182. future_to_info[future] = (item, run_idx)
  183. query_triggers: dict[str, list[bool]] = {}
  184. query_items: dict[str, dict] = {}
  185. for future in as_completed(future_to_info):
  186. item, _ = future_to_info[future]
  187. query = item["query"]
  188. query_items[query] = item
  189. if query not in query_triggers:
  190. query_triggers[query] = []
  191. try:
  192. query_triggers[query].append(future.result())
  193. except Exception as e:
  194. print(f"Warning: query failed: {e}", file=sys.stderr)
  195. query_triggers[query].append(False)
  196. for query, triggers in query_triggers.items():
  197. item = query_items[query]
  198. trigger_rate = sum(triggers) / len(triggers)
  199. should_trigger = item["should_trigger"]
  200. if should_trigger:
  201. did_pass = trigger_rate >= trigger_threshold
  202. else:
  203. did_pass = trigger_rate < trigger_threshold
  204. results.append({
  205. "query": query,
  206. "should_trigger": should_trigger,
  207. "trigger_rate": trigger_rate,
  208. "triggers": sum(triggers),
  209. "runs": len(triggers),
  210. "pass": did_pass,
  211. })
  212. passed = sum(1 for r in results if r["pass"])
  213. total = len(results)
  214. return {
  215. "skill_name": skill_name,
  216. "description": description,
  217. "results": results,
  218. "summary": {
  219. "total": total,
  220. "passed": passed,
  221. "failed": total - passed,
  222. },
  223. }
  224. def main():
  225. parser = argparse.ArgumentParser(description="Run trigger evaluation for a skill description")
  226. parser.add_argument("--eval-set", required=True, help="Path to eval set JSON file")
  227. parser.add_argument("--skill-path", required=True, help="Path to skill directory")
  228. parser.add_argument("--description", default=None, help="Override description to test")
  229. parser.add_argument("--num-workers", type=int, default=10, help="Number of parallel workers")
  230. parser.add_argument("--timeout", type=int, default=30, help="Timeout per query in seconds")
  231. parser.add_argument("--runs-per-query", type=int, default=3, help="Number of runs per query")
  232. parser.add_argument("--trigger-threshold", type=float, default=0.5, help="Trigger rate threshold")
  233. parser.add_argument("--model", default=None, help="Model to use for claude -p (default: user's configured model)")
  234. parser.add_argument("--verbose", action="store_true", help="Print progress to stderr")
  235. args = parser.parse_args()
  236. eval_set = json.loads(Path(args.eval_set).read_text())
  237. skill_path = Path(args.skill_path)
  238. if not (skill_path / "SKILL.md").exists():
  239. print(f"Error: No SKILL.md found at {skill_path}", file=sys.stderr)
  240. sys.exit(1)
  241. name, original_description, content = parse_skill_md(skill_path)
  242. description = args.description or original_description
  243. project_root = find_project_root()
  244. if args.verbose:
  245. print(f"Evaluating: {description}", file=sys.stderr)
  246. output = run_eval(
  247. eval_set=eval_set,
  248. skill_name=name,
  249. description=description,
  250. num_workers=args.num_workers,
  251. timeout=args.timeout,
  252. project_root=project_root,
  253. runs_per_query=args.runs_per_query,
  254. trigger_threshold=args.trigger_threshold,
  255. model=args.model,
  256. )
  257. if args.verbose:
  258. summary = output["summary"]
  259. print(f"Results: {summary['passed']}/{summary['total']} passed", file=sys.stderr)
  260. for r in output["results"]:
  261. status = "PASS" if r["pass"] else "FAIL"
  262. rate_str = f"{r['triggers']}/{r['runs']}"
  263. print(f" [{status}] rate={rate_str} expected={r['should_trigger']}: {r['query'][:70]}", file=sys.stderr)
  264. print(json.dumps(output, indent=2))
  265. if __name__ == "__main__":
  266. main()