test_keyword_parsing.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import pytest
  2. from unittest.mock import patch
  3. from lightrag.base import QueryParam
  4. from lightrag.operate import _parse_keywords_payload, extract_keywords_only
  5. class _FakeKeywordModel:
  6. def model_dump(self):
  7. return {
  8. "high_level_keywords": ["AI"],
  9. "low_level_keywords": ["RAG", "Graph"],
  10. }
  11. class _FakeTokenizer:
  12. def encode(self, content: str) -> list[int]:
  13. return [ord(ch) for ch in content]
  14. class _FakeKVStorage:
  15. def __init__(self):
  16. self.global_config = {"enable_llm_cache": True}
  17. self._store = {}
  18. async def get_by_id(self, key):
  19. return self._store.get(key)
  20. async def upsert(self, entries):
  21. self._store.update(entries)
  22. def _keyword_global_config(
  23. model: str, binding: str = "openai", keyword_func=None
  24. ) -> dict:
  25. return {
  26. "addon_params": {"language": "en"},
  27. "tokenizer": _FakeTokenizer(),
  28. "role_llm_funcs": {"keyword": keyword_func} if keyword_func else {},
  29. "llm_cache_identities": {
  30. "keyword": {
  31. "role": "keyword",
  32. "binding": binding,
  33. "model": model,
  34. "host": "https://api.example.com/v1",
  35. }
  36. },
  37. }
  38. @pytest.mark.offline
  39. def test_parse_keywords_payload_accepts_model_like_objects():
  40. is_valid, hl_keywords, ll_keywords = _parse_keywords_payload(_FakeKeywordModel())
  41. assert is_valid is True
  42. assert hl_keywords == ["AI"]
  43. assert ll_keywords == ["RAG", "Graph"]
  44. @pytest.mark.offline
  45. def test_parse_keywords_payload_extracts_json_from_wrapped_text():
  46. result = """
  47. analysis first
  48. {"high_level_keywords":"AI, Agents","low_level_keywords":["RAG","LightRAG"]}
  49. trailing note
  50. """
  51. is_valid, hl_keywords, ll_keywords = _parse_keywords_payload(result)
  52. assert is_valid is True
  53. assert hl_keywords == ["AI", "Agents"]
  54. assert ll_keywords == ["RAG", "LightRAG"]
  55. @pytest.mark.offline
  56. def test_parse_keywords_payload_warns_when_json_repair_is_used():
  57. broken_result = (
  58. '{"high_level_keywords":"AI, Agents","low_level_keywords":["RAG","LightRAG"]'
  59. )
  60. with patch("lightrag.operate.logger.warning") as mocked_warning:
  61. is_valid, hl_keywords, ll_keywords = _parse_keywords_payload(broken_result)
  62. assert is_valid is True
  63. assert hl_keywords == ["AI", "Agents"]
  64. assert ll_keywords == ["RAG", "LightRAG"]
  65. mocked_warning.assert_called_once()
  66. assert (
  67. "Keyword extraction response required JSON repair"
  68. in mocked_warning.call_args[0][0]
  69. )
  70. @pytest.mark.offline
  71. @pytest.mark.asyncio
  72. async def test_extract_keywords_only_accepts_empty_keyword_cache_without_requery():
  73. async def should_not_run(*_args, **_kwargs):
  74. raise AssertionError(
  75. "keyword LLM should not be called on a valid empty cache hit"
  76. )
  77. param = QueryParam()
  78. global_config = _keyword_global_config("model-a", keyword_func=should_not_run)
  79. with patch(
  80. "lightrag.operate.handle_cache",
  81. return_value=('{"high_level_keywords":[],"low_level_keywords":[]}', None),
  82. ):
  83. hl_keywords, ll_keywords = await extract_keywords_only(
  84. "hello",
  85. param,
  86. global_config,
  87. hashing_kv=None,
  88. )
  89. assert hl_keywords == []
  90. assert ll_keywords == []
  91. @pytest.mark.offline
  92. @pytest.mark.asyncio
  93. async def test_extract_keywords_only_partitions_cache_by_keyword_llm_identity():
  94. cache = _FakeKVStorage()
  95. calls = 0
  96. async def keyword_model(*_args, **_kwargs):
  97. nonlocal calls
  98. calls += 1
  99. return (
  100. '{"high_level_keywords":["model-'
  101. + str(calls)
  102. + '"],"low_level_keywords":["rag"]}'
  103. )
  104. param = QueryParam()
  105. first_hl, first_ll = await extract_keywords_only(
  106. "same query",
  107. param,
  108. _keyword_global_config("model-a", keyword_func=keyword_model),
  109. hashing_kv=cache,
  110. )
  111. second_hl, second_ll = await extract_keywords_only(
  112. "same query",
  113. param,
  114. _keyword_global_config("model-b", keyword_func=keyword_model),
  115. hashing_kv=cache,
  116. )
  117. assert first_hl == ["model-1"]
  118. assert first_ll == ["rag"]
  119. assert second_hl == ["model-2"]
  120. assert second_ll == ["rag"]
  121. assert calls == 2
  122. assert len(cache._store) == 2