test_utils_llm_cache.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from unittest.mock import AsyncMock
  2. import pytest
  3. from lightrag.utils import use_llm_func_with_cache
  4. class _FakeKVStorage:
  5. def __init__(self):
  6. self.global_config = {"enable_llm_cache_for_entity_extract": True}
  7. self._store = {}
  8. async def get_by_id(self, key):
  9. return self._store.get(key)
  10. async def upsert(self, entries):
  11. self._store.update(entries)
  12. @pytest.mark.offline
  13. @pytest.mark.asyncio
  14. async def test_use_llm_func_with_cache_partitions_cache_by_response_format():
  15. cache = _FakeKVStorage()
  16. llm_func = AsyncMock(side_effect=["plain-text", '{"answer":"json"}'])
  17. plain_result, _ = await use_llm_func_with_cache(
  18. "same prompt",
  19. llm_func,
  20. llm_response_cache=cache,
  21. )
  22. json_result, _ = await use_llm_func_with_cache(
  23. "same prompt",
  24. llm_func,
  25. llm_response_cache=cache,
  26. response_format={"type": "json_object"},
  27. )
  28. assert plain_result == "plain-text"
  29. assert json_result == '{"answer":"json"}'
  30. assert llm_func.await_count == 2
  31. assert len(cache._store) == 2
  32. @pytest.mark.offline
  33. @pytest.mark.asyncio
  34. async def test_use_llm_func_with_cache_partitions_cache_by_llm_identity():
  35. cache = _FakeKVStorage()
  36. llm_func = AsyncMock(side_effect=["model-a", "model-b"])
  37. first_result, _ = await use_llm_func_with_cache(
  38. "same prompt",
  39. llm_func,
  40. llm_response_cache=cache,
  41. llm_cache_identity={
  42. "role": "query",
  43. "binding": "openai",
  44. "model": "model-a",
  45. "host": "https://api.example.com/v1",
  46. },
  47. )
  48. second_result, _ = await use_llm_func_with_cache(
  49. "same prompt",
  50. llm_func,
  51. llm_response_cache=cache,
  52. llm_cache_identity={
  53. "role": "query",
  54. "binding": "openai",
  55. "model": "model-b",
  56. "host": "https://api.example.com/v1",
  57. },
  58. )
  59. assert first_result == "model-a"
  60. assert second_result == "model-b"
  61. assert llm_func.await_count == 2
  62. assert len(cache._store) == 2
  63. @pytest.mark.offline
  64. @pytest.mark.asyncio
  65. async def test_use_llm_func_with_cache_rejects_json_schema_response_format():
  66. llm_func = AsyncMock()
  67. with pytest.raises(ValueError, match="json_schema"):
  68. await use_llm_func_with_cache(
  69. "same prompt",
  70. llm_func,
  71. response_format={
  72. "type": "json_schema",
  73. "json_schema": {
  74. "name": "answer_payload",
  75. "schema": {"type": "object"},
  76. },
  77. },
  78. )
  79. llm_func.assert_not_awaited()