utils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import json
  2. import shutil
  3. import tempfile
  4. from collections.abc import Iterable, Iterator
  5. from contextlib import contextmanager
  6. from pathlib import Path
  7. from typing import Any
  8. _GREEN = "\033[32m"
  9. _RESET = "\033[0m"
  10. def _extract_text(content: object) -> str:
  11. """Return human-readable text from message content.
  12. - If content is a list of dict parts with "text" keys, join them.
  13. - Else fall back to str(content).
  14. """
  15. if isinstance(content, list):
  16. parts: list[str] = []
  17. for part in content:
  18. if isinstance(part, dict) and "text" in part:
  19. parts.append(str(part.get("text")))
  20. if parts:
  21. return " ".join(parts)
  22. return str(content)
  23. def print_history(thread_manager, roles: Iterable[str] = ("assistant", "system")) -> None:
  24. """Print a minimal, chronological history since the last user message.
  25. - Shows only role and content for roles in `roles` (default: assistant/system)
  26. """
  27. messages = thread_manager.get_all_messages()
  28. for m in messages:
  29. if not isinstance(m, dict):
  30. continue
  31. role_obj = m.get("role") or m.get("type")
  32. role = str(role_obj) if role_obj is not None else ""
  33. if role and role not in roles:
  34. continue
  35. if role == "assistant":
  36. role = f"{m.get('agent')}:"
  37. elif role == "user" and m.get("callerAgent") is not None:
  38. role = f"{m.get('callerAgent')}:"
  39. content = _extract_text(m.get("content") or m.get("output") or m.get("arguments"))
  40. print(f" [{role}] {content}")
  41. @contextmanager
  42. def temporary_files_folder(source_subdir: str = "data") -> Iterator[Path]:
  43. """Copy example files into a disposable `files` directory.
  44. The provided directory (relative to the examples folder) is copied into a
  45. temporary location where the folder is named exactly `files`. Vector store
  46. renaming can freely mutate that directory without touching the original
  47. assets. The temporary tree is removed on exit.
  48. """
  49. examples_dir = Path(__file__).parent
  50. source_dir = examples_dir / source_subdir
  51. if not source_dir.exists():
  52. raise FileNotFoundError(f"Example source directory not found: {source_dir}")
  53. temp_root = Path(tempfile.mkdtemp(prefix="agency-swarm-files-"))
  54. destination = temp_root / "files"
  55. shutil.copytree(source_dir, destination)
  56. try:
  57. yield destination
  58. finally:
  59. shutil.rmtree(temp_root, ignore_errors=True)
  60. def iter_agent_messages(agency, *, agent_name: str | None = None) -> Iterator[dict[str, Any]]:
  61. """Yield stored messages filtered by optional caller/agent name."""
  62. for message in agency.thread_manager.get_all_messages():
  63. if not isinstance(message, dict):
  64. continue
  65. if agent_name and not (message.get("callerAgent") == agent_name or message.get("agent") == agent_name):
  66. continue
  67. yield message
  68. def iter_send_message_calls(agency, *, agent_name: str | None = None) -> Iterator[dict[str, Any]]:
  69. """Yield send_message function call records matching optional agent filter."""
  70. for message in iter_agent_messages(agency, agent_name=agent_name):
  71. if message.get("type") == "function_call" and str(message.get("name", "")).startswith("send_message"):
  72. yield message
  73. def format_json_call(arguments: str | dict[str, Any]) -> str:
  74. """Return pretty JSON from arguments string/dict."""
  75. if isinstance(arguments, str):
  76. parsed = json.loads(arguments or "{}")
  77. else:
  78. parsed = arguments
  79. return json.dumps(parsed, indent=2)
  80. def print_highlighted_send_message_args(agency, *, agent_name: str | None = None) -> None:
  81. """Pretty-print send_message call arguments with highlighted keys."""
  82. for message in iter_send_message_calls(agency, agent_name=agent_name):
  83. rendered = format_json_call(message.get("arguments", {}))
  84. rendered = rendered.replace('"key_moments"', f'{_GREEN}"key_moments"{_RESET}').replace(
  85. '"decisions"', f'{_GREEN}"decisions"{_RESET}'
  86. )
  87. print(rendered)
  88. def print_send_message_exchange(agency, *, owner: str) -> None:
  89. """Print send_message requests/responses involving a specific agent."""
  90. call_ids: dict[str, dict[str, Any]] = {}
  91. for message in iter_agent_messages(agency, agent_name=owner):
  92. if message.get("type") == "function_call" and str(message.get("name", "")).startswith("send_message"):
  93. call_ids[str(message.get("parent_run_id"))] = message
  94. payload = format_json_call(message.get("arguments", {}))
  95. print(f"Request {message.get('callerAgent')} -> {message.get('agent')}:\n{payload}\n")
  96. elif message.get("type") == "assistant" and str(message.get("parent_run_id")) in call_ids:
  97. response_text = _extract_text(message.get("content") or message.get("output"))
  98. print(f"Response {message.get('agent')} -> {message.get('callerAgent')}: {response_text}\n")