test_additional_instructions.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from collections.abc import AsyncIterator
  2. import pytest
  3. from agents import Tool
  4. from agents.agent_output import AgentOutputSchemaBase
  5. from agents.handoffs import Handoff as SDKHandoff
  6. from agents.items import ModelResponse, TResponseInputItem, TResponseStreamEvent
  7. from agents.model_settings import ModelSettings
  8. from agents.models.interface import Model, ModelTracing
  9. from openai.types.responses.response_prompt_param import ResponsePromptParam
  10. from agency_swarm import Agency, Agent
  11. from tests.deterministic_model import _build_message_response, _stream_text_events
  12. class SystemInstructionsEchoModel(Model):
  13. def __init__(self, model: str = "test-system-instructions") -> None:
  14. self.model = model
  15. async def get_response(
  16. self,
  17. system_instructions: str | None,
  18. input: str | list[TResponseInputItem],
  19. model_settings: ModelSettings,
  20. tools: list[Tool],
  21. output_schema: AgentOutputSchemaBase | None,
  22. handoffs: list[SDKHandoff],
  23. tracing: ModelTracing,
  24. *,
  25. previous_response_id: str | None,
  26. conversation_id: str | None,
  27. prompt: ResponsePromptParam | None,
  28. ) -> ModelResponse:
  29. text = system_instructions or ""
  30. return _build_message_response(text, self.model)
  31. def stream_response(
  32. self,
  33. system_instructions: str | None,
  34. input: str | list[TResponseInputItem],
  35. model_settings: ModelSettings,
  36. tools: list[Tool],
  37. output_schema: AgentOutputSchemaBase | None,
  38. handoffs: list[SDKHandoff],
  39. tracing: ModelTracing,
  40. *,
  41. previous_response_id: str | None,
  42. conversation_id: str | None,
  43. prompt: ResponsePromptParam | None,
  44. ) -> AsyncIterator[TResponseStreamEvent]:
  45. text = system_instructions or ""
  46. return _stream_text_events(text, self.model)
  47. @pytest.mark.asyncio
  48. async def test_agent_get_response_applies_additional_instructions_and_restores_original() -> None:
  49. original_instructions = "Base agent instructions."
  50. additional_instructions = "Additional run instructions."
  51. agent = Agent(
  52. name="TestAgent",
  53. instructions=original_instructions,
  54. model=SystemInstructionsEchoModel(),
  55. )
  56. result = await agent.get_response("hello", additional_instructions=additional_instructions)
  57. assert isinstance(result.final_output, str)
  58. assert result.final_output == f"{original_instructions}\n\n{additional_instructions}"
  59. assert agent.instructions == original_instructions
  60. @pytest.mark.asyncio
  61. async def test_agent_get_response_stream_applies_additional_instructions_and_restores_original() -> None:
  62. original_instructions = "Base agent instructions."
  63. additional_instructions = "Streaming run instructions."
  64. agent = Agent(
  65. name="TestAgent",
  66. instructions=original_instructions,
  67. model=SystemInstructionsEchoModel(),
  68. )
  69. stream = agent.get_response_stream("hello", additional_instructions=additional_instructions)
  70. async for _event in stream:
  71. pass
  72. assert stream.final_output == f"{original_instructions}\n\n{additional_instructions}"
  73. assert agent.instructions == original_instructions
  74. @pytest.mark.asyncio
  75. async def test_agency_shared_instructions_precede_base_and_additional() -> None:
  76. shared_instructions = "Shared agency instructions."
  77. base_instructions = "Base agent instructions."
  78. additional_instructions = "Additional run instructions."
  79. agent = Agent(
  80. name="TestAgent",
  81. instructions=base_instructions,
  82. model=SystemInstructionsEchoModel(),
  83. )
  84. agency = Agency(agent, shared_instructions=shared_instructions)
  85. result = await agency.get_response("hello", additional_instructions=additional_instructions)
  86. assert isinstance(result.final_output, str)
  87. assert result.final_output == (f"{shared_instructions}\n\n{base_instructions}\n\n---\n\n{additional_instructions}")
  88. assert agent.instructions == base_instructions
  89. @pytest.mark.asyncio
  90. async def test_agency_uses_latest_shared_instructions_between_runs() -> None:
  91. base_instructions = "Base agent instructions."
  92. additional_instructions = "Additional run instructions."
  93. initial_shared_instructions = "Initial shared instructions."
  94. updated_shared_instructions = "Updated shared instructions."
  95. agent = Agent(
  96. name="TestAgent",
  97. instructions=base_instructions,
  98. model=SystemInstructionsEchoModel(),
  99. )
  100. agency = Agency(agent, shared_instructions=initial_shared_instructions)
  101. first = await agency.get_response("hello", additional_instructions=additional_instructions)
  102. assert isinstance(first.final_output, str)
  103. assert first.final_output == (
  104. f"{initial_shared_instructions}\n\n{base_instructions}\n\n---\n\n{additional_instructions}"
  105. )
  106. agency.shared_instructions = updated_shared_instructions
  107. second = await agency.get_response("hello", additional_instructions=additional_instructions)
  108. assert isinstance(second.final_output, str)
  109. assert second.final_output == (
  110. f"{updated_shared_instructions}\n\n{base_instructions}\n\n---\n\n{additional_instructions}"
  111. )
  112. @pytest.mark.asyncio
  113. @pytest.mark.parametrize("additional_instructions", ["", None])
  114. async def test_empty_or_none_additional_instructions_do_not_add_separator(
  115. additional_instructions: str | None,
  116. ) -> None:
  117. original_instructions = "Base agent instructions."
  118. agent = Agent(
  119. name="TestAgent",
  120. instructions=original_instructions,
  121. model=SystemInstructionsEchoModel(),
  122. )
  123. result = await agent.get_response("hello", additional_instructions=additional_instructions)
  124. assert isinstance(result.final_output, str)
  125. assert result.final_output == original_instructions
  126. assert "---" not in result.final_output
  127. assert agent.instructions == original_instructions
  128. @pytest.mark.asyncio
  129. async def test_additional_instructions_with_none_base_instructions() -> None:
  130. additional_instructions = "Additional run instructions only."
  131. agent = Agent(
  132. name="NoBaseInstructionsAgent",
  133. instructions=None,
  134. model=SystemInstructionsEchoModel(),
  135. )
  136. result = await agent.get_response("hello", additional_instructions=additional_instructions)
  137. assert isinstance(result.final_output, str)
  138. assert result.final_output == additional_instructions
  139. assert agent.instructions is None