test_usage_tracking.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. from __future__ import annotations
  2. import threading
  3. import time
  4. import typing
  5. from collections.abc import AsyncIterator
  6. from types import SimpleNamespace
  7. import pytest
  8. from agents import Tool
  9. from agents.agent_output import AgentOutputSchemaBase
  10. from agents.handoffs import Handoff
  11. from agents.items import ModelResponse, TResponseInputItem, TResponseStreamEvent
  12. from agents.model_settings import ModelSettings
  13. from agents.models.interface import Model, ModelTracing
  14. from agents.result import RunResult
  15. from agents.run_context import RunContextWrapper
  16. from agents.usage import Usage
  17. from openai.types.responses import ResponseOutputMessage, ResponseOutputText
  18. from openai.types.responses.response_prompt_param import ResponsePromptParam
  19. from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
  20. from agency_swarm.agent.core import Agent
  21. from agency_swarm.context import MasterContext
  22. from agency_swarm.utils import usage_tracking
  23. from agency_swarm.utils.thread import ThreadManager
  24. from agency_swarm.utils.usage_tracking import (
  25. UsageStats,
  26. calculate_openai_cost,
  27. calculate_usage_with_cost,
  28. extract_usage_from_run_result,
  29. format_usage_for_display,
  30. get_model_pricing,
  31. load_pricing_data,
  32. )
  33. class _HasSubAgentResponsesWithModel(typing.Protocol):
  34. _sub_agent_responses_with_model: list[tuple[str | None, ModelResponse]]
  35. class _HasMainAgentModel(typing.Protocol):
  36. _main_agent_model: str
  37. def _make_run_result(*, usage: Usage, raw_responses: list[ModelResponse] | None = None) -> RunResult:
  38. agent = Agent(name="TestAgent", instructions="Base instructions")
  39. thread_manager = ThreadManager()
  40. master_context = MasterContext(
  41. thread_manager=thread_manager,
  42. agents={agent.name: agent},
  43. user_context={},
  44. agent_runtime_state={},
  45. current_agent_name=agent.name,
  46. shared_instructions=None,
  47. )
  48. wrapper = RunContextWrapper(context=master_context, usage=usage)
  49. return RunResult(
  50. input="Hello",
  51. new_items=[],
  52. raw_responses=list(raw_responses or []),
  53. final_output="ok",
  54. input_guardrail_results=[],
  55. output_guardrail_results=[],
  56. tool_input_guardrail_results=[],
  57. tool_output_guardrail_results=[],
  58. context_wrapper=wrapper,
  59. _last_agent=agent,
  60. )
  61. def test_extract_usage_from_run_result_returns_none_without_run_result() -> None:
  62. assert extract_usage_from_run_result(None) is None
  63. def test_extract_usage_from_run_result_reads_requests_and_tokens() -> None:
  64. usage = Usage(
  65. requests=2,
  66. input_tokens=10,
  67. output_tokens=20,
  68. total_tokens=30,
  69. input_tokens_details=InputTokensDetails(cached_tokens=3),
  70. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  71. )
  72. run_result = _make_run_result(usage=usage)
  73. stats = extract_usage_from_run_result(run_result)
  74. assert stats == UsageStats(
  75. request_count=2,
  76. cached_tokens=3,
  77. input_tokens=10,
  78. output_tokens=20,
  79. total_tokens=30,
  80. total_cost=0.0,
  81. reasoning_tokens=None,
  82. audio_tokens=None,
  83. )
  84. def test_extract_usage_from_run_result_extracts_reasoning_and_sums_subagent_reasoning() -> None:
  85. main_usage = Usage(
  86. requests=1,
  87. input_tokens=10,
  88. output_tokens=20,
  89. total_tokens=30,
  90. input_tokens_details=InputTokensDetails(cached_tokens=0),
  91. output_tokens_details=OutputTokensDetails(reasoning_tokens=5),
  92. )
  93. sub_usage = Usage(
  94. requests=1,
  95. input_tokens=1,
  96. output_tokens=2,
  97. total_tokens=3,
  98. input_tokens_details=InputTokensDetails(cached_tokens=0),
  99. output_tokens_details=OutputTokensDetails(reasoning_tokens=7),
  100. )
  101. run_result = _make_run_result(usage=main_usage)
  102. typing.cast(_HasSubAgentResponsesWithModel, run_result)._sub_agent_responses_with_model = [
  103. ("gpt-5.4-mini", ModelResponse(output=[], usage=sub_usage, response_id=None))
  104. ]
  105. stats = extract_usage_from_run_result(run_result)
  106. assert stats is not None
  107. assert stats.request_count == 2
  108. assert stats.input_tokens == 11
  109. assert stats.output_tokens == 22
  110. assert stats.total_tokens == 33
  111. assert stats.reasoning_tokens == 12 # 5 main + 7 sub
  112. def test_calculate_usage_with_cost_per_response_costs_all_token_types() -> None:
  113. """
  114. Single per-response costing test that verifies:
  115. - input token pricing
  116. - cached input token pricing (via input_tokens_details.cached_tokens)
  117. - output token pricing
  118. - reasoning token pricing (via output_tokens_details.reasoning_tokens)
  119. - dict-based usage (sub-agent) uses that sub-agent's model pricing
  120. """
  121. pricing_data = {
  122. "test/all-tokens-model": {
  123. "input_cost_per_token": 1.0,
  124. "cache_read_input_token_cost": 0.1,
  125. "output_cost_per_token": 2.0,
  126. "output_cost_per_reasoning_token": 0.01,
  127. },
  128. "test/sub-agent-model": {
  129. "input_cost_per_token": 10.0,
  130. "cache_read_input_token_cost": 1.0,
  131. "output_cost_per_token": 20.0,
  132. "output_cost_per_reasoning_token": 0.5,
  133. },
  134. }
  135. response_usage = Usage(
  136. requests=1,
  137. input_tokens=10,
  138. output_tokens=3,
  139. total_tokens=13,
  140. input_tokens_details=InputTokensDetails(cached_tokens=4),
  141. output_tokens_details=OutputTokensDetails(reasoning_tokens=5),
  142. )
  143. response = ModelResponse(output=[], usage=response_usage, response_id=None)
  144. run_result = _make_run_result(usage=Usage(), raw_responses=[response])
  145. typing.cast(_HasMainAgentModel, run_result)._main_agent_model = "test/all-tokens-model"
  146. base = UsageStats(
  147. request_count=1,
  148. cached_tokens=0,
  149. input_tokens=10,
  150. output_tokens=3,
  151. total_tokens=13,
  152. total_cost=0.0,
  153. reasoning_tokens=None,
  154. audio_tokens=None,
  155. )
  156. with_cost = calculate_usage_with_cost(base, pricing_data=pricing_data, run_result=run_result)
  157. # Main response:
  158. # (10 - 4)*1.0 + 4*0.1 + 3*2.0 + 5*0.01 = 6 + 0.4 + 6 + 0.05 = 12.45
  159. assert with_cost.total_cost == pytest.approx(12.45)
  160. @pytest.mark.asyncio
  161. async def test_calculate_usage_with_cost_uses_model_name_from_model_instance() -> None:
  162. """Regression: costing should work when an Agent is configured with a Model instance."""
  163. class FakeModel(Model):
  164. def __init__(self, model: str) -> None:
  165. self.model = model
  166. async def get_response(
  167. self,
  168. system_instructions: str | None,
  169. input: str | list[TResponseInputItem],
  170. model_settings: ModelSettings,
  171. tools: list[Tool],
  172. output_schema: AgentOutputSchemaBase | None,
  173. handoffs: list[Handoff],
  174. tracing: ModelTracing,
  175. *,
  176. previous_response_id: str | None,
  177. conversation_id: str | None,
  178. prompt: ResponsePromptParam | None,
  179. ) -> ModelResponse:
  180. usage = Usage(
  181. requests=1,
  182. input_tokens=2,
  183. output_tokens=1,
  184. total_tokens=3,
  185. input_tokens_details=InputTokensDetails(cached_tokens=0),
  186. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  187. )
  188. msg = ResponseOutputMessage(
  189. id="msg_1",
  190. content=[ResponseOutputText(text="ok", type="output_text", annotations=[])],
  191. role="assistant",
  192. status="completed",
  193. type="message",
  194. )
  195. return ModelResponse(output=[msg], usage=usage, response_id="resp_1")
  196. def stream_response(
  197. self,
  198. system_instructions: str | None,
  199. input: str | list[TResponseInputItem],
  200. model_settings: ModelSettings,
  201. tools: list[Tool],
  202. output_schema: AgentOutputSchemaBase | None,
  203. handoffs: list[Handoff],
  204. tracing: ModelTracing,
  205. *,
  206. previous_response_id: str | None,
  207. conversation_id: str | None,
  208. prompt: ResponsePromptParam | None,
  209. ) -> AsyncIterator[TResponseStreamEvent]:
  210. async def _stream() -> AsyncIterator[TResponseStreamEvent]:
  211. if False:
  212. yield typing.cast(TResponseStreamEvent, {})
  213. return
  214. return _stream()
  215. model_name = "test/model-instance"
  216. agent = Agent(name="ModelInstanceAgent", instructions="Respond with 'ok'.", model=FakeModel(model_name))
  217. result = await agent.get_response("hi")
  218. assert typing.cast(_HasMainAgentModel, result)._main_agent_model == model_name
  219. usage_stats = extract_usage_from_run_result(result)
  220. assert usage_stats is not None
  221. pricing_data = {
  222. model_name: {
  223. "input_cost_per_token": 1.0,
  224. "cache_read_input_token_cost": 0.0,
  225. "output_cost_per_token": 1.0,
  226. "output_cost_per_reasoning_token": 0.0,
  227. }
  228. }
  229. with_cost = calculate_usage_with_cost(usage_stats, pricing_data=pricing_data, run_result=result)
  230. assert with_cost.total_cost == pytest.approx(3.0)
  231. @pytest.mark.asyncio
  232. async def test_calculate_usage_with_cost_prefers_usage_tracking_model_name() -> None:
  233. """Regression: proxy aliases should keep pricing tied to the real upstream model."""
  234. class FakeOpenClawModel(Model):
  235. model = "openclaw:main"
  236. _agency_swarm_usage_model_name = "openai/gpt-5.4"
  237. async def get_response(
  238. self,
  239. system_instructions: str | None,
  240. input: str | list[TResponseInputItem],
  241. model_settings: ModelSettings,
  242. tools: list[Tool],
  243. output_schema: AgentOutputSchemaBase | None,
  244. handoffs: list[Handoff],
  245. tracing: ModelTracing,
  246. *,
  247. previous_response_id: str | None,
  248. conversation_id: str | None,
  249. prompt: ResponsePromptParam | None,
  250. ) -> ModelResponse:
  251. usage = Usage(
  252. requests=1,
  253. input_tokens=2,
  254. output_tokens=1,
  255. total_tokens=3,
  256. input_tokens_details=InputTokensDetails(cached_tokens=0),
  257. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  258. )
  259. msg = ResponseOutputMessage(
  260. id="msg_2",
  261. content=[ResponseOutputText(text="ok", type="output_text", annotations=[])],
  262. role="assistant",
  263. status="completed",
  264. type="message",
  265. )
  266. return ModelResponse(output=[msg], usage=usage, response_id="resp_2")
  267. def stream_response(
  268. self,
  269. system_instructions: str | None,
  270. input: str | list[TResponseInputItem],
  271. model_settings: ModelSettings,
  272. tools: list[Tool],
  273. output_schema: AgentOutputSchemaBase | None,
  274. handoffs: list[Handoff],
  275. tracing: ModelTracing,
  276. *,
  277. previous_response_id: str | None,
  278. conversation_id: str | None,
  279. prompt: ResponsePromptParam | None,
  280. ) -> AsyncIterator[TResponseStreamEvent]:
  281. async def _stream() -> AsyncIterator[TResponseStreamEvent]:
  282. if False:
  283. yield typing.cast(TResponseStreamEvent, {})
  284. return
  285. return _stream()
  286. agent = Agent(name="OpenClawProxyAgent", instructions="Respond with 'ok'.", model=FakeOpenClawModel())
  287. result = await agent.get_response("hi")
  288. assert typing.cast(_HasMainAgentModel, result)._main_agent_model == "openai/gpt-5.4"
  289. usage_stats = extract_usage_from_run_result(result)
  290. assert usage_stats is not None
  291. pricing_data = {
  292. "gpt-5.4": {
  293. "input_cost_per_token": 1.0,
  294. "cache_read_input_token_cost": 0.0,
  295. "output_cost_per_token": 2.0,
  296. "output_cost_per_reasoning_token": 0.0,
  297. }
  298. }
  299. with_cost = calculate_usage_with_cost(usage_stats, pricing_data=pricing_data, run_result=result)
  300. assert with_cost.total_cost == pytest.approx(4.0)
  301. def test_load_pricing_data_is_single_load_under_concurrency(tmp_path, monkeypatch) -> None:
  302. """Regression: cache lock must prevent concurrent duplicate JSON parses."""
  303. pricing_file = tmp_path / "pricing.json"
  304. pricing_file.write_text('{"test/model": {"input_cost_per_token": 1}}', encoding="utf-8")
  305. monkeypatch.setattr(usage_tracking, "PRICING_FILE_PATH", pricing_file)
  306. monkeypatch.setattr(usage_tracking, "_PRICING_DATA_CACHE", None)
  307. original_json_load = usage_tracking.json.load
  308. call_count = 0
  309. first_load_started = threading.Event()
  310. allow_return = threading.Event()
  311. def blocking_load(fp: typing.Any) -> typing.Any:
  312. nonlocal call_count
  313. call_count += 1
  314. first_load_started.set()
  315. assert allow_return.wait(timeout=1.0), "Test timed out waiting to release json.load"
  316. return original_json_load(fp)
  317. monkeypatch.setattr(usage_tracking.json, "load", blocking_load)
  318. results: list[usage_tracking.PricingData] = []
  319. results_lock = threading.Lock()
  320. def worker() -> None:
  321. data = usage_tracking.load_pricing_data()
  322. with results_lock:
  323. results.append(data)
  324. first_thread = threading.Thread(target=worker, daemon=True)
  325. first_thread.start()
  326. assert first_load_started.wait(timeout=1.0)
  327. second_thread = threading.Thread(target=worker, daemon=True)
  328. second_thread.start()
  329. try:
  330. time.sleep(0.05)
  331. finally:
  332. allow_return.set()
  333. first_thread.join(timeout=1.0)
  334. second_thread.join(timeout=1.0)
  335. assert call_count == 1, f"Expected a single JSON load under lock, got {call_count}"
  336. assert results
  337. assert all("test/model" in data for data in results)
  338. def test_load_pricing_data_does_not_cache_invalid_json(tmp_path, monkeypatch) -> None:
  339. """Regression: invalid JSON must not poison the in-process cache."""
  340. pricing_file = tmp_path / "pricing.json"
  341. pricing_file.write_text("{", encoding="utf-8")
  342. monkeypatch.setattr(usage_tracking, "PRICING_FILE_PATH", pricing_file)
  343. monkeypatch.setattr(usage_tracking, "_PRICING_DATA_CACHE", None)
  344. first = usage_tracking.load_pricing_data()
  345. assert first == {}
  346. pricing_file.write_text('{"test/model": {"input_cost_per_token": 1}}', encoding="utf-8")
  347. second = usage_tracking.load_pricing_data()
  348. assert "test/model" in second
  349. def test_load_pricing_data_returns_empty_when_file_missing(tmp_path, monkeypatch) -> None:
  350. missing_file = tmp_path / "missing_pricing.json"
  351. monkeypatch.setattr(usage_tracking, "PRICING_FILE_PATH", missing_file)
  352. monkeypatch.setattr(usage_tracking, "_PRICING_DATA_CACHE", None)
  353. assert usage_tracking.load_pricing_data() == {}
  354. def test_load_pricing_data_handles_non_dict_payload_and_coerces_bool_prices(tmp_path, monkeypatch) -> None:
  355. pricing_file = tmp_path / "pricing.json"
  356. monkeypatch.setattr(usage_tracking, "PRICING_FILE_PATH", pricing_file)
  357. pricing_file.write_text('["unexpected"]', encoding="utf-8")
  358. monkeypatch.setattr(usage_tracking, "_PRICING_DATA_CACHE", None)
  359. assert usage_tracking.load_pricing_data() == {}
  360. pricing_file.write_text(
  361. '{"test/model":{"input_cost_per_token":true,"output_cost_per_token":2,"cache_read_input_token_cost":false}}',
  362. encoding="utf-8",
  363. )
  364. monkeypatch.setattr(usage_tracking, "_PRICING_DATA_CACHE", None)
  365. loaded = usage_tracking.load_pricing_data()
  366. assert loaded["test/model"]["input_cost_per_token"] == 0.0
  367. assert loaded["test/model"]["cache_read_input_token_cost"] == 0.0
  368. assert loaded["test/model"]["output_cost_per_token"] == 2.0
  369. def test_get_model_pricing_resolves_provider_and_version_fallbacks() -> None:
  370. pricing_data = {
  371. "azure/gpt-4o": {"input_cost_per_token": 2.0},
  372. "gpt-4o": {"input_cost_per_token": 1.0},
  373. }
  374. assert get_model_pricing("azure/gpt-4o", pricing_data) == pricing_data["azure/gpt-4o"]
  375. assert get_model_pricing("openai/gpt-4o", pricing_data) == pricing_data["gpt-4o"]
  376. assert get_model_pricing("gpt-4o-2024-05-13", pricing_data) == pricing_data["gpt-4o"]
  377. assert get_model_pricing("gpt-4o-mini", pricing_data) == pricing_data["gpt-4o"]
  378. assert get_model_pricing("missing-model", pricing_data) is None
  379. def test_calculate_openai_cost_handles_cached_and_reasoning_tokens() -> None:
  380. pricing_data = {
  381. "test/model": {
  382. "input_cost_per_token": 1.0,
  383. "output_cost_per_token": 2.0,
  384. "cache_read_input_token_cost": 0.5,
  385. "output_cost_per_reasoning_token": 3.0,
  386. }
  387. }
  388. cost = calculate_openai_cost(
  389. model_name="test/model",
  390. input_tokens=10,
  391. output_tokens=2,
  392. cached_tokens=4,
  393. reasoning_tokens=1,
  394. pricing_data=pricing_data,
  395. )
  396. assert cost == pytest.approx(15.0)
  397. assert calculate_openai_cost("missing", 1, 1, pricing_data=pricing_data) == 0.0
  398. def test_bundled_pricing_supports_gpt_5_4_and_mini_defaults() -> None:
  399. pricing_data = load_pricing_data()
  400. assert get_model_pricing("gpt-5.4", pricing_data) is not None
  401. assert get_model_pricing("openai/gpt-5.4", pricing_data) == pricing_data["gpt-5.4"]
  402. assert calculate_openai_cost("gpt-5.4", 1000, 1000, pricing_data=pricing_data) > 0.0
  403. assert calculate_openai_cost("openai/gpt-5.4", 1000, 1000, pricing_data=pricing_data) > 0.0
  404. assert get_model_pricing("gpt-5.4-mini", pricing_data) is not None
  405. assert get_model_pricing("openai/gpt-5.4-mini", pricing_data) == pricing_data["gpt-5.4-mini"]
  406. assert calculate_openai_cost("gpt-5.4-mini", 1000, 1000, pricing_data=pricing_data) > 0.0
  407. assert calculate_openai_cost("openai/gpt-5.4-mini", 1000, 1000, pricing_data=pricing_data) > 0.0
  408. def test_extract_usage_from_run_result_skips_malformed_subagent_entries() -> None:
  409. usage = Usage(
  410. requests=1,
  411. input_tokens=5,
  412. output_tokens=3,
  413. total_tokens=8,
  414. input_tokens_details=InputTokensDetails(cached_tokens=1),
  415. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  416. )
  417. run_result = _make_run_result(usage=usage)
  418. typing.cast(_HasSubAgentResponsesWithModel, run_result)._sub_agent_responses_with_model = [
  419. ("broken", typing.cast(ModelResponse, object())),
  420. ]
  421. stats = extract_usage_from_run_result(run_result)
  422. assert stats is not None
  423. assert stats.request_count == 1
  424. assert stats.total_tokens == 8
  425. def test_extract_usage_from_run_result_returns_none_for_unusable_context_wrapper() -> None:
  426. run_result = SimpleNamespace(context_wrapper=object())
  427. assert extract_usage_from_run_result(typing.cast(RunResult, run_result)) is None
  428. def test_calculate_usage_with_cost_falls_back_to_full_litellm_path() -> None:
  429. usage_stats = UsageStats(
  430. request_count=1,
  431. cached_tokens=0,
  432. input_tokens=2,
  433. output_tokens=3,
  434. total_tokens=5,
  435. total_cost=0.0,
  436. reasoning_tokens=None,
  437. audio_tokens=None,
  438. )
  439. pricing_data = {
  440. "litellm/anthropic/claude-sonnet-4": {
  441. "input_cost_per_token": 1.0,
  442. "output_cost_per_token": 2.0,
  443. "cache_read_input_token_cost": 0.0,
  444. "output_cost_per_reasoning_token": 0.0,
  445. }
  446. }
  447. result = calculate_usage_with_cost(
  448. usage_stats,
  449. model_name="litellm/anthropic/claude-sonnet-4",
  450. pricing_data=pricing_data,
  451. )
  452. assert result.total_cost == pytest.approx(8.0)
  453. def test_calculate_usage_with_cost_handles_run_result_without_model_name() -> None:
  454. usage = Usage(
  455. requests=1,
  456. input_tokens=2,
  457. output_tokens=1,
  458. total_tokens=3,
  459. input_tokens_details=InputTokensDetails(cached_tokens=0),
  460. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  461. )
  462. response = ModelResponse(output=[], usage=usage, response_id=None)
  463. run_result = SimpleNamespace(raw_responses=[response], _sub_agent_responses_with_model=[])
  464. usage_stats = UsageStats(
  465. request_count=1,
  466. cached_tokens=0,
  467. input_tokens=2,
  468. output_tokens=1,
  469. total_tokens=3,
  470. total_cost=123.0,
  471. reasoning_tokens=None,
  472. audio_tokens=None,
  473. )
  474. with_cost = calculate_usage_with_cost(usage_stats, run_result=typing.cast(RunResult, run_result))
  475. assert with_cost.total_cost == 0.0
  476. def test_format_usage_for_display_includes_optional_fields() -> None:
  477. usage_stats = UsageStats(
  478. request_count=2,
  479. cached_tokens=3,
  480. input_tokens=10,
  481. output_tokens=7,
  482. total_tokens=17,
  483. total_cost=1.234567,
  484. reasoning_tokens=2,
  485. audio_tokens=1,
  486. )
  487. formatted = format_usage_for_display(usage_stats, model_name="gpt-5.4-mini")
  488. assert "Model: gpt-5.4-mini" in formatted
  489. assert "Requests: 2" in formatted
  490. assert "Cached: 3" in formatted
  491. assert "Reasoning: 2" in formatted
  492. assert "Audio: 1" in formatted
  493. assert "Cost: $1.234567" in formatted