test_persistent_shell_tool.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. """Integration tests for PersistentShellTool."""
  2. import os
  3. import sys
  4. import tempfile
  5. from pathlib import Path
  6. import pytest
  7. from agents.run_context import RunContextWrapper
  8. from agency_swarm import Agent
  9. from agency_swarm.context import MasterContext
  10. from agency_swarm.tools.built_in import PersistentShellTool
  11. from agency_swarm.utils.thread import ThreadManager
  12. @pytest.fixture
  13. def shared_context():
  14. """Create a shared context wrapped for tools to persist state."""
  15. thread_manager = ThreadManager()
  16. master_context = MasterContext(
  17. thread_manager=thread_manager,
  18. agents={},
  19. user_context={},
  20. )
  21. return RunContextWrapper(context=master_context)
  22. @pytest.fixture
  23. def agent_with_shell():
  24. """Create an agent with PersistentShellTool."""
  25. return Agent(
  26. name="ShellAgent",
  27. description="Test agent with shell access",
  28. instructions="Execute shell commands",
  29. tools=[PersistentShellTool],
  30. )
  31. @pytest.fixture
  32. def temp_test_dir():
  33. """Create a temporary directory for tests."""
  34. with tempfile.TemporaryDirectory() as tmpdir:
  35. yield tmpdir
  36. class TestPersistentShellToolBasics:
  37. """Test basic shell command execution."""
  38. @pytest.mark.asyncio
  39. async def test_simple_command_execution(self, agent_with_shell):
  40. """Test executing a simple command."""
  41. if sys.platform == "win32":
  42. tool = PersistentShellTool(command="echo 'test'")
  43. else:
  44. tool = PersistentShellTool(command="echo test")
  45. tool._caller_agent = agent_with_shell
  46. result = await tool.run()
  47. assert "test" in result
  48. assert "Working Directory:" in result
  49. @pytest.mark.asyncio
  50. async def test_command_with_no_output(self, agent_with_shell, temp_test_dir):
  51. """Test that commands with no output show success message."""
  52. test_file = os.path.join(temp_test_dir, "test.txt")
  53. if sys.platform == "win32":
  54. tool = PersistentShellTool(command=f"New-Item -Path '{test_file}' -ItemType File -Force")
  55. else:
  56. tool = PersistentShellTool(command=f"touch '{test_file}'")
  57. tool._caller_agent = agent_with_shell
  58. result = await tool.run()
  59. assert "executed successfully" in result.lower() or os.path.exists(test_file)
  60. @pytest.mark.asyncio
  61. async def test_command_error_handling(self, agent_with_shell):
  62. """Test that command errors are properly caught."""
  63. tool = PersistentShellTool(command="nonexistent_command_12345")
  64. tool._caller_agent = agent_with_shell
  65. result = await tool.run()
  66. # Should indicate error (either in stderr or error message)
  67. assert "Error" in result or "Stderr:" in result or "Exit Code:" in result
  68. class TestWorkingDirectoryPersistence:
  69. """Test that working directory persists within same agent."""
  70. @pytest.mark.asyncio
  71. async def test_cd_persistence(self, agent_with_shell, shared_context, temp_test_dir):
  72. """Test that cd command persists working directory."""
  73. # Change to temp directory
  74. tool1 = PersistentShellTool(command=f"cd '{temp_test_dir}'")
  75. tool1._caller_agent = agent_with_shell
  76. tool1._context = shared_context
  77. result1 = await tool1.run()
  78. assert "Error" not in result1
  79. # Check working directory - should be the temp directory
  80. if sys.platform == "win32":
  81. tool2 = PersistentShellTool(command="(Get-Location).Path")
  82. else:
  83. tool2 = PersistentShellTool(command="pwd")
  84. tool2._caller_agent = agent_with_shell
  85. tool2._context = shared_context
  86. result2 = await tool2.run()
  87. # Extract the path from the output (between ``` marks)
  88. output_lines = result2.split("```")
  89. if len(output_lines) >= 2:
  90. output_path = output_lines[1].strip()
  91. else:
  92. output_path = result2
  93. # Normalize paths for comparison
  94. assert Path(temp_test_dir).resolve() == Path(output_path).resolve()
  95. @pytest.mark.asyncio
  96. async def test_relative_paths_work_after_cd(self, agent_with_shell, shared_context, temp_test_dir):
  97. """Test that relative paths work correctly after changing directory."""
  98. # Change to temp directory
  99. tool1 = PersistentShellTool(command=f"cd '{temp_test_dir}'")
  100. tool1._caller_agent = agent_with_shell
  101. tool1._context = shared_context
  102. await tool1.run()
  103. # Create file in current (temp) directory using relative path
  104. if sys.platform == "win32":
  105. tool2 = PersistentShellTool(command="New-Item -Path './test_file.txt' -ItemType File -Force")
  106. else:
  107. tool2 = PersistentShellTool(command="touch ./test_file.txt")
  108. tool2._caller_agent = agent_with_shell
  109. tool2._context = shared_context
  110. await tool2.run()
  111. # Verify file was created in temp directory
  112. assert os.path.exists(os.path.join(temp_test_dir, "test_file.txt"))
  113. @pytest.mark.asyncio
  114. async def test_cd_with_tilde_expansion(self, agent_with_shell):
  115. """Test that ~ is properly expanded to home directory."""
  116. tool = PersistentShellTool(command="cd ~")
  117. tool._caller_agent = agent_with_shell
  118. result = await tool.run()
  119. assert "Error" not in result
  120. # Working directory should be home directory
  121. home_dir = Path.home()
  122. assert str(home_dir) in result or home_dir.name in result
  123. class TestWorkingDirectoryIsolation:
  124. """Test that working directories are isolated between agents."""
  125. @pytest.mark.asyncio
  126. async def test_cd_isolation_between_agents(self, shared_context, temp_test_dir):
  127. """Test that cd in one agent doesn't affect another."""
  128. agent_a = Agent(name="AgentA", description="", instructions="", tools=[PersistentShellTool])
  129. agent_b = Agent(name="AgentB", description="", instructions="", tools=[PersistentShellTool])
  130. # Agent A changes to temp directory
  131. tool_a = PersistentShellTool(command=f"cd '{temp_test_dir}'")
  132. tool_a._caller_agent = agent_a
  133. tool_a._context = shared_context
  134. result_a = await tool_a.run()
  135. assert temp_test_dir in result_a
  136. # Agent B checks its working directory - should NOT be temp directory
  137. if sys.platform == "win32":
  138. tool_b = PersistentShellTool(command="(Get-Location).Path")
  139. else:
  140. tool_b = PersistentShellTool(command="pwd")
  141. tool_b._caller_agent = agent_b
  142. tool_b._context = shared_context
  143. result_b = await tool_b.run()
  144. # Extract the path from the output
  145. output_lines = result_b.split("```")
  146. if len(output_lines) >= 2:
  147. output_path = output_lines[1].strip()
  148. else:
  149. output_path = result_b
  150. # Agent B should be in original directory, not temp_test_dir
  151. assert Path(temp_test_dir).resolve() != Path(output_path).resolve()
  152. @pytest.mark.asyncio
  153. async def test_concurrent_commands_different_agents(self, temp_test_dir):
  154. """Test that commands in different agents run independently."""
  155. import asyncio
  156. agent_a = Agent(name="AgentA", description="", instructions="", tools=[PersistentShellTool])
  157. agent_b = Agent(name="AgentB", description="", instructions="", tools=[PersistentShellTool])
  158. # Both agents run commands concurrently
  159. if sys.platform == "win32":
  160. cmd = "Get-Date"
  161. else:
  162. cmd = "date"
  163. tool_a = PersistentShellTool(command=cmd)
  164. tool_a._caller_agent = agent_a
  165. tool_b = PersistentShellTool(command=cmd)
  166. tool_b._caller_agent = agent_b
  167. results = await asyncio.gather(tool_a.run(), tool_b.run())
  168. # Both should succeed
  169. assert "Error" not in results[0]
  170. assert "Error" not in results[1]
  171. class TestChainedCommandsAndEdgeCases:
  172. """Test chained commands and edge cases."""
  173. @pytest.mark.asyncio
  174. async def test_cd_in_chained_command_warning(self, agent_with_shell, temp_test_dir):
  175. """Test that cd in chained command shows warning."""
  176. if sys.platform == "win32":
  177. # PowerShell uses semicolon for command chaining
  178. tool = PersistentShellTool(command=f"cd '{temp_test_dir}'; Get-Date")
  179. else:
  180. tool = PersistentShellTool(command=f"cd '{temp_test_dir}' && date")
  181. tool._caller_agent = agent_with_shell
  182. result = await tool.run()
  183. # Should either show warning or fail with an error
  184. assert "Warning" in result or "not persisted" in result or "separate" in result
  185. @pytest.mark.asyncio
  186. async def test_stderr_capture(self, agent_with_shell):
  187. """Test that stderr is captured separately."""
  188. if sys.platform == "win32":
  189. # Write to stderr in PowerShell
  190. tool = PersistentShellTool(command="Write-Error 'test error' 2>&1")
  191. else:
  192. tool = PersistentShellTool(command="echo 'test error' >&2")
  193. tool._caller_agent = agent_with_shell
  194. result = await tool.run()
  195. assert "test error" in result.lower()
  196. @pytest.mark.asyncio
  197. async def test_no_agent_context(self):
  198. """Test that tool works without agent context."""
  199. if sys.platform == "win32":
  200. tool = PersistentShellTool(command="echo test")
  201. else:
  202. tool = PersistentShellTool(command="echo test")
  203. # Don't set _caller_agent
  204. result = await tool.run()
  205. assert "test" in result
  206. assert "Working Directory:" in result