test_ollama_role_kwargs.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. """Offline regression tests for Ollama API role-specific kwargs."""
  2. import importlib
  3. import sys
  4. from types import SimpleNamespace
  5. import pytest
  6. from fastapi import FastAPI
  7. from fastapi.testclient import TestClient
  8. pytestmark = pytest.mark.offline
  9. class _FakeRag:
  10. def __init__(self):
  11. self.base_calls = []
  12. self.query_calls = []
  13. self.llm_model_kwargs = {"route": "base"}
  14. self._query_kwargs = {"route": "query"}
  15. self.ollama_server_infos = SimpleNamespace(
  16. LIGHTRAG_MODEL="lightrag:latest",
  17. LIGHTRAG_CREATED_AT="2026-03-14T00:00:00Z",
  18. LIGHTRAG_SIZE=0,
  19. )
  20. async def base_func(*args, **kwargs):
  21. self.base_calls.append(kwargs)
  22. return "base"
  23. async def query_func(*args, **kwargs):
  24. self.query_calls.append(kwargs)
  25. return "query"
  26. self.llm_model_func = base_func
  27. # ollama_api.py reads `rag.role_llm_funcs[role]` and
  28. # `rag.role_llm_kwargs[role]`; expose them as plain dict mappings here.
  29. self.role_llm_funcs = {"query": query_func}
  30. self.role_llm_kwargs = {"query": self._query_kwargs}
  31. async def aquery(self, *args, **kwargs):
  32. return "aquery"
  33. def _make_client(monkeypatch) -> tuple[TestClient, _FakeRag]:
  34. monkeypatch.setattr(sys, "argv", [sys.argv[0]])
  35. for module_name in [
  36. "lightrag.api.config",
  37. "lightrag.api.auth",
  38. "lightrag.api.utils_api",
  39. "lightrag.api.routers",
  40. "lightrag.api.routers.ollama_api",
  41. ]:
  42. sys.modules.pop(module_name, None)
  43. ollama_api_module = importlib.import_module("lightrag.api.routers.ollama_api")
  44. OllamaAPI = ollama_api_module.OllamaAPI
  45. rag = _FakeRag()
  46. api = OllamaAPI(rag, top_k=20, api_key=None)
  47. app = FastAPI()
  48. app.include_router(api.router, prefix="/api")
  49. return TestClient(app), rag
  50. def test_generate_non_stream_uses_query_role_kwargs_without_mutating_base(monkeypatch):
  51. client, rag = _make_client(monkeypatch)
  52. response = client.post(
  53. "/api/generate",
  54. json={
  55. "model": "lightrag:latest",
  56. "prompt": "Summarize this",
  57. "stream": False,
  58. "system": "custom system",
  59. },
  60. )
  61. assert response.status_code == 200
  62. assert rag.base_calls == []
  63. assert rag.query_calls[-1]["route"] == "query"
  64. assert rag.query_calls[-1]["system_prompt"] == "custom system"
  65. assert "system_prompt" not in rag.llm_model_kwargs
  66. assert "system_prompt" not in rag.role_llm_kwargs["query"]
  67. def test_chat_bypass_stream_uses_query_role_kwargs_without_mutating_base(monkeypatch):
  68. client, rag = _make_client(monkeypatch)
  69. with client.stream(
  70. "POST",
  71. "/api/chat",
  72. json={
  73. "model": "lightrag:latest",
  74. "stream": True,
  75. "system": "chat system",
  76. "messages": [
  77. {"role": "assistant", "content": "history"},
  78. {"role": "user", "content": "/bypass give me a title"},
  79. ],
  80. },
  81. ) as response:
  82. assert response.status_code == 200
  83. # Consume the streaming response fully.
  84. list(response.iter_lines())
  85. assert rag.base_calls == []
  86. assert rag.query_calls[-1]["route"] == "query"
  87. assert rag.query_calls[-1]["system_prompt"] == "chat system"
  88. assert rag.query_calls[-1]["history_messages"] == [
  89. {"role": "assistant", "content": "history"}
  90. ]
  91. assert "system_prompt" not in rag.llm_model_kwargs
  92. assert "system_prompt" not in rag.role_llm_kwargs["query"]