test_llm_cache_identity.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import pytest
  2. from lightrag.base import QueryParam
  3. from lightrag.operate import naive_query
  4. class _FakeTokenizer:
  5. def encode(self, content: str) -> list[int]:
  6. return [ord(ch) for ch in content]
  7. def decode(self, tokens: list[int]) -> str:
  8. return "".join(chr(token) for token in tokens)
  9. class _FakeKVStorage:
  10. def __init__(self):
  11. self.global_config = {"enable_llm_cache": True}
  12. self._store = {}
  13. async def get_by_id(self, key):
  14. return self._store.get(key)
  15. async def upsert(self, entries):
  16. self._store.update(entries)
  17. class _FakeChunksVDB:
  18. cosine_better_than_threshold = 0.0
  19. async def query(self, *_args, **_kwargs):
  20. return [
  21. {
  22. "id": "chunk-1",
  23. "content": "LightRAG cache identity test chunk.",
  24. "file_path": "test.md",
  25. }
  26. ]
  27. def _query_global_config(model: str, llm_func) -> dict:
  28. return {
  29. "tokenizer": _FakeTokenizer(),
  30. "role_llm_funcs": {"query": llm_func},
  31. "llm_cache_identities": {
  32. "query": {
  33. "role": "query",
  34. "binding": "openai",
  35. "model": model,
  36. "host": "https://api.example.com/v1",
  37. }
  38. },
  39. "min_rerank_score": 0.0,
  40. "max_total_tokens": 4096,
  41. }
  42. @pytest.mark.offline
  43. @pytest.mark.asyncio
  44. async def test_naive_query_partitions_query_cache_by_llm_identity():
  45. cache = _FakeKVStorage()
  46. chunks_vdb = _FakeChunksVDB()
  47. calls = 0
  48. async def query_model(*_args, **_kwargs):
  49. nonlocal calls
  50. calls += 1
  51. return f"answer-{calls}"
  52. param = QueryParam(mode="naive", enable_rerank=False)
  53. first = await naive_query(
  54. "same query",
  55. chunks_vdb,
  56. param,
  57. _query_global_config("model-a", query_model),
  58. hashing_kv=cache,
  59. )
  60. second = await naive_query(
  61. "same query",
  62. chunks_vdb,
  63. param,
  64. _query_global_config("model-b", query_model),
  65. hashing_kv=cache,
  66. )
  67. assert first.content == "answer-1"
  68. assert second.content == "answer-2"
  69. assert calls == 2
  70. assert len(cache._store) == 2