| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- """Offline regression tests for Ollama API role-specific kwargs."""
- import importlib
- import sys
- from types import SimpleNamespace
- import pytest
- from fastapi import FastAPI
- from fastapi.testclient import TestClient
- pytestmark = pytest.mark.offline
- class _FakeRag:
- def __init__(self):
- self.base_calls = []
- self.query_calls = []
- self.llm_model_kwargs = {"route": "base"}
- self._query_kwargs = {"route": "query"}
- self.ollama_server_infos = SimpleNamespace(
- LIGHTRAG_MODEL="lightrag:latest",
- LIGHTRAG_CREATED_AT="2026-03-14T00:00:00Z",
- LIGHTRAG_SIZE=0,
- )
- async def base_func(*args, **kwargs):
- self.base_calls.append(kwargs)
- return "base"
- async def query_func(*args, **kwargs):
- self.query_calls.append(kwargs)
- return "query"
- self.llm_model_func = base_func
- # ollama_api.py reads `rag.role_llm_funcs[role]` and
- # `rag.role_llm_kwargs[role]`; expose them as plain dict mappings here.
- self.role_llm_funcs = {"query": query_func}
- self.role_llm_kwargs = {"query": self._query_kwargs}
- async def aquery(self, *args, **kwargs):
- return "aquery"
- def _make_client(monkeypatch) -> tuple[TestClient, _FakeRag]:
- monkeypatch.setattr(sys, "argv", [sys.argv[0]])
- for module_name in [
- "lightrag.api.config",
- "lightrag.api.auth",
- "lightrag.api.utils_api",
- "lightrag.api.routers",
- "lightrag.api.routers.ollama_api",
- ]:
- sys.modules.pop(module_name, None)
- ollama_api_module = importlib.import_module("lightrag.api.routers.ollama_api")
- OllamaAPI = ollama_api_module.OllamaAPI
- rag = _FakeRag()
- api = OllamaAPI(rag, top_k=20, api_key=None)
- app = FastAPI()
- app.include_router(api.router, prefix="/api")
- return TestClient(app), rag
- def test_generate_non_stream_uses_query_role_kwargs_without_mutating_base(monkeypatch):
- client, rag = _make_client(monkeypatch)
- response = client.post(
- "/api/generate",
- json={
- "model": "lightrag:latest",
- "prompt": "Summarize this",
- "stream": False,
- "system": "custom system",
- },
- )
- assert response.status_code == 200
- assert rag.base_calls == []
- assert rag.query_calls[-1]["route"] == "query"
- assert rag.query_calls[-1]["system_prompt"] == "custom system"
- assert "system_prompt" not in rag.llm_model_kwargs
- assert "system_prompt" not in rag.role_llm_kwargs["query"]
- def test_chat_bypass_stream_uses_query_role_kwargs_without_mutating_base(monkeypatch):
- client, rag = _make_client(monkeypatch)
- with client.stream(
- "POST",
- "/api/chat",
- json={
- "model": "lightrag:latest",
- "stream": True,
- "system": "chat system",
- "messages": [
- {"role": "assistant", "content": "history"},
- {"role": "user", "content": "/bypass give me a title"},
- ],
- },
- ) as response:
- assert response.status_code == 200
- # Consume the streaming response fully.
- list(response.iter_lines())
- assert rag.base_calls == []
- assert rag.query_calls[-1]["route"] == "query"
- assert rag.query_calls[-1]["system_prompt"] == "chat system"
- assert rag.query_calls[-1]["history_messages"] == [
- {"role": "assistant", "content": "history"}
- ]
- assert "system_prompt" not in rag.llm_model_kwargs
- assert "system_prompt" not in rag.role_llm_kwargs["query"]
|