test_paragraph_semantic_merge_and_fallback.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. """Regression tests for paragraph-semantic Stage D merging and the top-level R fallback."""
  2. import pytest
  3. from lightrag.chunker.paragraph_semantic import (
  4. _merge_small_blocks,
  5. chunking_by_paragraph_semantic,
  6. )
  7. from lightrag.utils import Tokenizer, TokenizerInterface
  8. class _CharTokenizer(TokenizerInterface):
  9. """1:1 character-to-token mapping — keeps math obvious in assertions."""
  10. def encode(self, content: str):
  11. return [ord(ch) for ch in content]
  12. def decode(self, tokens):
  13. return "".join(chr(t) for t in tokens)
  14. def _make_tokenizer() -> Tokenizer:
  15. return Tokenizer(model_name="char", tokenizer=_CharTokenizer())
  16. def _make_block(text: str, *, tokenizer: Tokenizer, level: int = 1) -> dict:
  17. return {
  18. "heading": "H",
  19. "parent_headings": [],
  20. "level": level,
  21. "paragraphs": [{"text": text, "is_table": False}],
  22. "content": text,
  23. "tokens": len(tokenizer.encode(text)),
  24. "table_chunk_role": "none",
  25. }
  26. @pytest.mark.offline
  27. def test_tail_absorption_rejects_when_separator_pushes_over_cap():
  28. # Tail absorption joins blocks with ``"\n\n"`` but the original
  29. # predicate only summed per-block tokens. With cur=99 and tail=1
  30. # the raw sum equals target_max=100, but the actual joined
  31. # ``"x"*99 + "\n\n" + "y"*1`` measures 102 tokens — the absorbed
  32. # block silently overflowed before the fix re-measured the joined
  33. # content.
  34. tokenizer = _make_tokenizer()
  35. blocks = [
  36. _make_block("x" * 99, tokenizer=tokenizer),
  37. _make_block("y" * 1, tokenizer=tokenizer),
  38. ]
  39. merged = _merge_small_blocks(
  40. blocks,
  41. tokenizer=tokenizer,
  42. target_max=100,
  43. target_ideal=80,
  44. small_tail_threshold=12,
  45. )
  46. assert all(b["tokens"] <= 100 for b in merged), [b["tokens"] for b in merged]
  47. @pytest.mark.offline
  48. def test_tail_absorption_still_fires_when_joined_size_fits():
  49. # Sanity check: when the joined content (including separators)
  50. # genuinely fits target_max, absorption still happens. cur=80 +
  51. # "\n\n" (2 tokens) + tail=1 = 83 ≤ 100.
  52. tokenizer = _make_tokenizer()
  53. blocks = [
  54. _make_block("x" * 80, tokenizer=tokenizer),
  55. _make_block("y" * 1, tokenizer=tokenizer),
  56. ]
  57. merged = _merge_small_blocks(
  58. blocks,
  59. tokenizer=tokenizer,
  60. target_max=100,
  61. target_ideal=80,
  62. small_tail_threshold=12,
  63. )
  64. assert len(merged) == 1
  65. assert merged[0]["tokens"] == 83
  66. assert merged[0]["content"] == "x" * 80 + "\n\n" + "y" * 1
  67. @pytest.mark.offline
  68. def test_paragraph_semantic_fallback_passes_configured_recursive_overlap(monkeypatch):
  69. # When ``blocks_path`` is missing, paragraph-semantic chunking
  70. # delegates to ``chunking_by_recursive_character``. P now permits
  71. # overlap for long text under one JSONL row, so the fallback must
  72. # pass through the configured overlap rather than forcing zero.
  73. captured: dict[str, object] = {}
  74. def fake_chunker(
  75. tokenizer,
  76. content,
  77. chunk_token_size: int = 1200,
  78. *,
  79. chunk_overlap_token_size: int = 100,
  80. separators=None,
  81. ):
  82. captured["chunk_overlap_token_size"] = chunk_overlap_token_size
  83. captured["chunk_token_size"] = chunk_token_size
  84. return [
  85. {
  86. "tokens": len(tokenizer.encode(content)),
  87. "content": content,
  88. "chunk_order_index": 0,
  89. }
  90. ]
  91. import lightrag.chunker.recursive_character as rc_mod
  92. monkeypatch.setattr(rc_mod, "chunking_by_recursive_character", fake_chunker)
  93. tokenizer = _make_tokenizer()
  94. chunking_by_paragraph_semantic(
  95. tokenizer,
  96. "fallback corpus",
  97. chunk_token_size=500,
  98. blocks_path=None,
  99. chunk_overlap_token_size=37,
  100. )
  101. assert (
  102. captured.get("chunk_overlap_token_size") == 37
  103. ), "P→R fallback must pass the configured chunk_overlap_token_size"
  104. assert captured.get("chunk_token_size") == 500