test_citation_extractor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. """
  2. Unit tests for citation extraction utilities.
  3. """
  4. from unittest.mock import MagicMock
  5. from agents.items import MessageOutputItem, ToolCallItem
  6. from openai.types.responses import ResponseFunctionWebSearch
  7. from openai.types.responses.response_file_search_tool_call import (
  8. ResponseFileSearchToolCall,
  9. Result as FileSearchResult,
  10. )
  11. from openai.types.responses.response_function_web_search import ActionOpenPage, ActionSearch, ActionSearchSource
  12. from openai.types.responses.response_output_message import ResponseOutputMessage, ResponseOutputText
  13. from openai.types.responses.response_output_text import AnnotationFileCitation, AnnotationURLCitation
  14. from agency_swarm.agent.core import Agent
  15. from agency_swarm.utils.citation_extractor import (
  16. display_citations,
  17. extract_direct_file_annotations,
  18. extract_direct_file_citations_from_history,
  19. extract_vector_store_citations,
  20. extract_web_search_sources,
  21. )
  22. class TestExtractDirectFileAnnotations:
  23. """Test direct file citation extraction from message annotations."""
  24. @staticmethod
  25. def _message_item(content: list[ResponseOutputText], message_id: str) -> MessageOutputItem:
  26. agent = Agent(name="Annotator", instructions="Collect citations")
  27. message = ResponseOutputMessage(
  28. id=message_id,
  29. content=content,
  30. role="assistant",
  31. status="completed",
  32. type="message",
  33. )
  34. return MessageOutputItem(agent=agent, raw_item=message)
  35. def test_extracts_file_citations_from_annotations(self):
  36. """Ensure annotations backed by SDK models are captured."""
  37. annotation = AnnotationFileCitation(
  38. file_id="file-abc123",
  39. filename="test_document.pdf",
  40. index=42,
  41. type="file_citation",
  42. )
  43. content_item = ResponseOutputText(annotations=[annotation], text="Here you go", type="output_text")
  44. msg_item = self._message_item([content_item], message_id="msg_123")
  45. result = extract_direct_file_annotations([msg_item])
  46. assert "msg_123" in result
  47. citation = result["msg_123"][0]
  48. assert citation["file_id"] == "file-abc123"
  49. assert citation["filename"] == "test_document.pdf"
  50. assert citation["index"] == 42
  51. assert citation["type"] == "file_citation"
  52. assert citation["method"] == "direct_file"
  53. def test_handles_multiple_annotations_per_message(self):
  54. """Handle multiple file citations within a single message."""
  55. annotations = [
  56. AnnotationFileCitation(
  57. file_id=f"file-{i}",
  58. filename=f"doc_{i}.pdf",
  59. index=i * 10,
  60. type="file_citation",
  61. )
  62. for i in range(3)
  63. ]
  64. content_item = ResponseOutputText(annotations=annotations, text="Multiple refs", type="output_text")
  65. msg_item = self._message_item([content_item], message_id="msg_multi")
  66. result = extract_direct_file_annotations([msg_item])
  67. citations = result["msg_multi"]
  68. assert len(citations) == 3
  69. assert {c["file_id"] for c in citations} == {"file-0", "file-1", "file-2"}
  70. def test_skips_messages_without_content(self):
  71. """Messages with no content yield no citations."""
  72. msg_item = self._message_item([], message_id="msg_empty")
  73. assert extract_direct_file_annotations([msg_item]) == {}
  74. def test_skips_non_file_citation_annotations(self):
  75. """Annotations of other types are ignored."""
  76. url_annotation = AnnotationURLCitation(
  77. start_index=0,
  78. end_index=5,
  79. title="Example",
  80. type="url_citation",
  81. url="https://example.com",
  82. )
  83. content_item = ResponseOutputText(annotations=[url_annotation], text="See link", type="output_text")
  84. msg_item = self._message_item([content_item], message_id="msg_no_citations")
  85. assert extract_direct_file_annotations([msg_item]) == {}
  86. class TestExtractVectorStoreCitations:
  87. """Test vector store citation extraction from run results."""
  88. def test_extracts_file_search_citations(self):
  89. """Extract citations from a typed FileSearch tool call."""
  90. tool_call = ResponseFileSearchToolCall(
  91. id="call_sdk",
  92. queries=["report"],
  93. status="completed",
  94. type="file_search_call",
  95. results=[FileSearchResult(file_id="file-sdk", text="Findings content")],
  96. )
  97. run_result = MagicMock()
  98. run_result.new_items = [ToolCallItem(agent=MagicMock(), raw_item=tool_call)]
  99. result = extract_vector_store_citations(run_result)
  100. assert result[0]["file_id"] == "file-sdk"
  101. assert result[0]["text"] == "Findings content"
  102. assert result[0]["tool_call_id"] == "call_sdk"
  103. def test_handles_multiple_search_results(self):
  104. """Handle multiple search results within a single tool call."""
  105. tool_call = ResponseFileSearchToolCall(
  106. id="call_multi",
  107. queries=["reports"],
  108. status="completed",
  109. type="file_search_call",
  110. results=[FileSearchResult(file_id=f"vs-file-{i}", text=f"Content from file {i}") for i in range(3)],
  111. )
  112. run_result = MagicMock()
  113. run_result.new_items = [ToolCallItem(agent=MagicMock(), raw_item=tool_call)]
  114. citations = extract_vector_store_citations(run_result)
  115. assert len(citations) == 3
  116. assert {c["file_id"] for c in citations} == {"vs-file-0", "vs-file-1", "vs-file-2"}
  117. def test_missing_file_id_defaults_to_unknown(self):
  118. """Fallback to 'unknown' when file_id is absent in the tool result."""
  119. tool_call = ResponseFileSearchToolCall(
  120. id="call_unknown",
  121. queries=["reports"],
  122. status="completed",
  123. type="file_search_call",
  124. results=[FileSearchResult(file_id=None, text="Content without identifier")],
  125. )
  126. run_result = MagicMock()
  127. run_result.new_items = [ToolCallItem(agent=MagicMock(), raw_item=tool_call)]
  128. citations = extract_vector_store_citations(run_result)
  129. assert len(citations) == 1
  130. assert citations[0]["file_id"] == "unknown"
  131. def test_handles_missing_results(self):
  132. """Gracefully handle file search calls without results."""
  133. tool_call = ResponseFileSearchToolCall(
  134. id="call_empty",
  135. queries=["anything"],
  136. status="completed",
  137. type="file_search_call",
  138. results=None,
  139. )
  140. run_result = MagicMock()
  141. run_result.new_items = [ToolCallItem(agent=MagicMock(), raw_item=tool_call)]
  142. assert extract_vector_store_citations(run_result) == []
  143. def test_skips_non_file_search_items(self):
  144. """Ignore items that are not file-search tool calls."""
  145. tool_call = MagicMock()
  146. tool_call.type = "function_call"
  147. run_result = MagicMock()
  148. run_result.new_items = [ToolCallItem(agent=MagicMock(), raw_item=tool_call)]
  149. assert extract_vector_store_citations(run_result) == []
  150. class TestExtractWebSearchSources:
  151. """Test web search source URL extraction from run results."""
  152. def test_extracts_and_deduplicates_urls(self):
  153. """Extract unique URLs while preserving first-seen order."""
  154. web_search_call = ResponseFunctionWebSearch(
  155. id="web_1",
  156. action=ActionSearch(
  157. query="latest openai updates",
  158. type="search",
  159. sources=[
  160. ActionSearchSource(type="url", url="https://help.openai.com/a"),
  161. ActionSearchSource(type="url", url="https://help.openai.com/b"),
  162. ActionSearchSource(type="url", url="https://help.openai.com/a"),
  163. ],
  164. ),
  165. status="completed",
  166. type="web_search_call",
  167. )
  168. run_result = MagicMock()
  169. run_result.new_items = [ToolCallItem(agent=MagicMock(), raw_item=web_search_call)]
  170. assert extract_web_search_sources(run_result) == [
  171. "https://help.openai.com/a",
  172. "https://help.openai.com/b",
  173. ]
  174. def test_handles_missing_sources(self):
  175. """Return empty list when a web search call has no sources."""
  176. web_search_call = ResponseFunctionWebSearch(
  177. id="web_2",
  178. action=ActionSearch(query="openai docs", type="search", sources=None),
  179. status="completed",
  180. type="web_search_call",
  181. )
  182. run_result = MagicMock()
  183. run_result.new_items = [ToolCallItem(agent=MagicMock(), raw_item=web_search_call)]
  184. assert extract_web_search_sources(run_result) == []
  185. def test_skips_non_search_web_actions(self):
  186. """Ignore web search tool actions that are not search actions."""
  187. web_search_call = ResponseFunctionWebSearch(
  188. id="web_3",
  189. action=ActionOpenPage(type="open_page", url="https://help.openai.com"),
  190. status="completed",
  191. type="web_search_call",
  192. )
  193. run_result = MagicMock()
  194. run_result.new_items = [ToolCallItem(agent=MagicMock(), raw_item=web_search_call)]
  195. assert extract_web_search_sources(run_result) == []
  196. class TestExtractDirectFileCitationsFromHistory:
  197. """Test citation extraction from thread conversation history."""
  198. def test_extracts_new_format_citations(self):
  199. """Test extraction from new format (citations in message metadata)."""
  200. thread_items = [
  201. {
  202. "role": "assistant",
  203. "content": "Here's the information from the file.",
  204. "citations": [
  205. {"file_id": "file-new123", "filename": "new_format.pdf", "index": 15, "method": "direct_file"}
  206. ],
  207. }
  208. ]
  209. result = extract_direct_file_citations_from_history(thread_items)
  210. assert len(result) == 1
  211. citation = result[0]
  212. assert citation["file_id"] == "file-new123"
  213. assert citation["filename"] == "new_format.pdf"
  214. assert citation["index"] == 15
  215. assert citation["method"] == "direct_file"
  216. def test_extracts_legacy_format_citations(self):
  217. """Test extraction from legacy format (synthetic user messages)."""
  218. thread_items = [
  219. {
  220. "role": "user",
  221. "content": """[DIRECT_FILE_CITATIONS]
  222. File ID: file-legacy456
  223. Filename: legacy_document.docx
  224. Text Index: 25
  225. Type: file_citation
  226. File ID: file-legacy789
  227. Filename: another_doc.pdf
  228. Text Index: 50
  229. Type: file_citation
  230. """,
  231. }
  232. ]
  233. result = extract_direct_file_citations_from_history(thread_items)
  234. assert len(result) == 2
  235. first_citation = result[0]
  236. assert first_citation["file_id"] == "file-legacy456"
  237. assert first_citation["filename"] == "legacy_document.docx"
  238. assert first_citation["index"] == 25
  239. assert first_citation["type"] == "file_citation"
  240. assert first_citation["method"] == "direct_file"
  241. second_citation = result[1]
  242. assert second_citation["file_id"] == "file-legacy789"
  243. assert second_citation["filename"] == "another_doc.pdf"
  244. assert second_citation["index"] == 50
  245. def test_handles_mixed_message_types(self):
  246. """Test handling thread with both citation and non-citation messages."""
  247. thread_items = [
  248. {"role": "user", "content": "What's in this file?"},
  249. {
  250. "role": "assistant",
  251. "content": "According to the document...",
  252. "citations": [{"file_id": "file-mixed", "filename": "mixed.pdf"}],
  253. },
  254. {"role": "user", "content": "Thanks for the info!"},
  255. ]
  256. result = extract_direct_file_citations_from_history(thread_items)
  257. assert len(result) == 1
  258. assert result[0]["file_id"] == "file-mixed"
  259. def test_handles_empty_thread(self):
  260. """Test handling of empty thread items."""
  261. result = extract_direct_file_citations_from_history([])
  262. assert result == []
  263. def test_handles_malformed_legacy_format(self):
  264. """Test graceful handling of malformed legacy citation format."""
  265. thread_items = [
  266. {
  267. "role": "user",
  268. "content": """[DIRECT_FILE_CITATIONS]
  269. File ID: incomplete-citation
  270. Filename: test.pdf
  271. # Missing Text Index and Type fields
  272. """,
  273. }
  274. ]
  275. result = extract_direct_file_citations_from_history(thread_items)
  276. # Should not crash, but may not extract complete citation
  277. assert isinstance(result, list)
  278. class TestDisplayCitations:
  279. """Test citation display functionality."""
  280. def test_displays_vector_store_citations(self, capsys):
  281. """Test display of vector store citations."""
  282. citations = [
  283. {
  284. "method": "vector_store",
  285. "file_id": "vs-123",
  286. "text": (
  287. "This is a long piece of text that should be truncated in the preview "
  288. "because it exceeds the 100 character limit for display purposes."
  289. ),
  290. "tool_call_id": "call_123",
  291. }
  292. ]
  293. result = display_citations(citations, "vector store")
  294. captured = capsys.readouterr()
  295. assert result is True
  296. assert "✅ Found 1 citation(s) vector store:" in captured.out
  297. assert "Citation 1 [vector_store]:" in captured.out
  298. assert "File ID: vs-123" in captured.out
  299. assert "Tool Call: call_123" in captured.out
  300. assert (
  301. "Content: This is a long piece of text that should be truncated in the preview "
  302. "because it exceeds the 100 char..."
  303. ) in captured.out
  304. def test_displays_direct_file_citations(self, capsys):
  305. """Test display of direct file citations."""
  306. citations = [
  307. {
  308. "method": "direct_file",
  309. "file_id": "file-456",
  310. "filename": "document.pdf",
  311. "index": 42,
  312. "type": "file_citation",
  313. }
  314. ]
  315. result = display_citations(citations)
  316. captured = capsys.readouterr()
  317. assert result is True
  318. assert "✅ Found 1 citation(s):" in captured.out
  319. assert "Citation 1 [direct_file]:" in captured.out
  320. assert "File ID: file-456" in captured.out
  321. assert "Filename: document.pdf" in captured.out
  322. assert "Text Index: 42" in captured.out
  323. def test_handles_no_citations(self, capsys):
  324. """Test display when no citations are provided."""
  325. result = display_citations([])
  326. captured = capsys.readouterr()
  327. assert result is False
  328. assert "❌ No citations found" in captured.out
  329. def test_handles_no_citations_with_type(self, capsys):
  330. """Test display when no citations with specific type."""
  331. result = display_citations([], "direct file")
  332. captured = capsys.readouterr()
  333. assert result is False
  334. assert "❌ No direct file citations found" in captured.out
  335. def test_displays_multiple_citations(self, capsys):
  336. """Test display of multiple citations."""
  337. citations = [
  338. {"method": "direct_file", "file_id": "file-1", "filename": "doc1.pdf"},
  339. {"method": "vector_store", "file_id": "vs-2", "text": "Short text", "tool_call_id": "call_2"},
  340. ]
  341. result = display_citations(citations)
  342. captured = capsys.readouterr()
  343. assert result is True
  344. assert "✅ Found 2 citation(s):" in captured.out
  345. assert "Citation 1 [direct_file]:" in captured.out
  346. assert "Citation 2 [vector_store]:" in captured.out
  347. assert "Content: Short text" in captured.out # Short text not truncated
  348. def test_handles_missing_citation_fields(self, capsys):
  349. """Test display with citations missing some fields."""
  350. citations = [
  351. {
  352. "method": "unknown",
  353. # Missing most fields
  354. }
  355. ]
  356. result = display_citations(citations)
  357. captured = capsys.readouterr()
  358. assert result is True
  359. assert "Citation 1 [unknown]:" in captured.out
  360. assert "File ID: unknown" in captured.out