test_extract_entities.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. """Tests for entity extraction gleaning token limit guard."""
  2. import logging
  3. from unittest.mock import AsyncMock
  4. import pytest
  5. from lightrag.utils import Tokenizer, TokenizerInterface
  6. @pytest.fixture
  7. def _propagate_lightrag_logger(monkeypatch):
  8. """``lightrag.utils.logger`` sets ``propagate = False`` to avoid noisy
  9. test output; restore propagation locally so ``caplog`` can capture
  10. WARNING records emitted from inside ``lightrag.operate``."""
  11. monkeypatch.setattr(logging.getLogger("lightrag"), "propagate", True)
  12. class DummyTokenizer(TokenizerInterface):
  13. """Simple 1:1 character-to-token mapping for testing."""
  14. def encode(self, content: str):
  15. return [ord(ch) for ch in content]
  16. def decode(self, tokens):
  17. return "".join(chr(token) for token in tokens)
  18. def _make_global_config(
  19. entity_extract_max_gleaning: int = 1,
  20. ) -> dict:
  21. """Build a minimal global_config dict for extract_entities."""
  22. tokenizer = Tokenizer("dummy", DummyTokenizer())
  23. extract_func = AsyncMock(return_value="")
  24. return {
  25. "llm_model_func": extract_func,
  26. "role_llm_funcs": {
  27. "extract": extract_func,
  28. "keyword": extract_func,
  29. "query": extract_func,
  30. "vlm": extract_func,
  31. },
  32. "entity_extract_max_gleaning": entity_extract_max_gleaning,
  33. "entity_extract_max_records": 100,
  34. "entity_extract_max_entities": 40,
  35. "addon_params": {},
  36. "tokenizer": tokenizer,
  37. "llm_model_max_async": 1,
  38. }
  39. # Minimal valid extraction result that _process_extraction_result can parse
  40. _EXTRACTION_RESULT = (
  41. "(entity<|#|>TEST_ENTITY<|#|>CONCEPT<|#|>A test entity)<|COMPLETE|>"
  42. )
  43. def _make_chunks(content: str = "Test content.") -> dict[str, dict]:
  44. return {
  45. "chunk-001": {
  46. "tokens": len(content),
  47. "content": content,
  48. "full_doc_id": "doc-001",
  49. "chunk_order_index": 0,
  50. }
  51. }
  52. @pytest.mark.offline
  53. @pytest.mark.asyncio
  54. async def test_gleaning_skipped_when_tokens_exceed_limit(
  55. monkeypatch, caplog, _propagate_lightrag_logger
  56. ):
  57. """Gleaning must be skipped (with a WARNING) when the projected
  58. gleaning input — system + history(user+assistant) + continue prompt —
  59. exceeds ``MAX_EXTRACT_INPUT_TOKENS``. This prevents
  60. ``context_length_exceeded`` errors from the LLM provider on the second
  61. round when the initial response was long.
  62. """
  63. from lightrag.operate import extract_entities
  64. # 10 tokens cannot fit any realistic prompt — guard must trip.
  65. monkeypatch.setenv("MAX_EXTRACT_INPUT_TOKENS", "10")
  66. global_config = _make_global_config(entity_extract_max_gleaning=1)
  67. llm_func = global_config["llm_model_func"]
  68. llm_func.return_value = _EXTRACTION_RESULT
  69. with caplog.at_level("WARNING", logger="lightrag"):
  70. await extract_entities(
  71. chunks=_make_chunks(),
  72. global_config=global_config,
  73. )
  74. # Only the initial extraction round ran; gleaning was skipped.
  75. assert llm_func.await_count == 1
  76. warnings_emitted = [
  77. rec.getMessage()
  78. for rec in caplog.records
  79. if rec.levelname == "WARNING"
  80. and rec.getMessage().startswith("Gleaning stopped for chunk chunk-001:")
  81. ]
  82. assert warnings_emitted, (
  83. "expected a WARNING log explaining gleaning was skipped due to "
  84. "token limit; got: "
  85. f"{[r.getMessage() for r in caplog.records]}"
  86. )
  87. # Message must surface both the measured token count and the limit so
  88. # operators can size MAX_EXTRACT_INPUT_TOKENS appropriately.
  89. msg = warnings_emitted[0]
  90. assert "exceeded limit (10)" in msg
  91. assert "Input tokens (" in msg
  92. @pytest.mark.offline
  93. @pytest.mark.asyncio
  94. async def test_gleaning_proceeds_when_tokens_within_limit(monkeypatch):
  95. """Gleaning runs normally when the projected input fits the cap."""
  96. from lightrag.operate import extract_entities
  97. monkeypatch.setenv("MAX_EXTRACT_INPUT_TOKENS", "999999")
  98. global_config = _make_global_config(entity_extract_max_gleaning=1)
  99. llm_func = global_config["llm_model_func"]
  100. llm_func.return_value = _EXTRACTION_RESULT
  101. await extract_entities(
  102. chunks=_make_chunks(),
  103. global_config=global_config,
  104. )
  105. # Both rounds run: initial extraction + one gleaning pass.
  106. assert llm_func.await_count == 2
  107. @pytest.mark.offline
  108. @pytest.mark.asyncio
  109. async def test_no_gleaning_when_max_gleaning_zero(monkeypatch):
  110. """``entity_extract_max_gleaning=0`` disables gleaning regardless of
  111. token budget — the guard is downstream of the feature flag."""
  112. from lightrag.operate import extract_entities
  113. monkeypatch.setenv("MAX_EXTRACT_INPUT_TOKENS", "999999")
  114. global_config = _make_global_config(entity_extract_max_gleaning=0)
  115. llm_func = global_config["llm_model_func"]
  116. llm_func.return_value = _EXTRACTION_RESULT
  117. await extract_entities(
  118. chunks=_make_chunks(),
  119. global_config=global_config,
  120. )
  121. assert llm_func.await_count == 1
  122. @pytest.mark.offline
  123. @pytest.mark.asyncio
  124. async def test_gleaning_guard_disabled_when_max_tokens_zero(monkeypatch):
  125. """Setting ``MAX_EXTRACT_INPUT_TOKENS=0`` opts out of the guard so
  126. gleaning always runs regardless of input size — useful for callers
  127. whose provider has no hard input ceiling."""
  128. from lightrag.operate import extract_entities
  129. monkeypatch.setenv("MAX_EXTRACT_INPUT_TOKENS", "0")
  130. global_config = _make_global_config(entity_extract_max_gleaning=1)
  131. llm_func = global_config["llm_model_func"]
  132. llm_func.return_value = _EXTRACTION_RESULT
  133. await extract_entities(
  134. chunks=_make_chunks(),
  135. global_config=global_config,
  136. )
  137. # Guard disabled → gleaning still runs even with tight projected input.
  138. assert llm_func.await_count == 2