test_rerank_chunking.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. """
  2. Unit tests for rerank document chunking functionality.
  3. Tests the chunk_documents_for_rerank and aggregate_chunk_scores functions
  4. in lightrag/rerank.py to ensure proper document splitting and score aggregation.
  5. """
  6. import pytest
  7. from unittest.mock import Mock, patch, AsyncMock
  8. from lightrag.rerank import (
  9. chunk_documents_for_rerank,
  10. aggregate_chunk_scores,
  11. cohere_rerank,
  12. )
  13. class TestChunkDocumentsForRerank:
  14. """Test suite for chunk_documents_for_rerank function"""
  15. def test_no_chunking_needed_for_short_docs(self):
  16. """Documents shorter than max_tokens should not be chunked"""
  17. documents = [
  18. "Short doc 1",
  19. "Short doc 2",
  20. "Short doc 3",
  21. ]
  22. chunked_docs, doc_indices = chunk_documents_for_rerank(
  23. documents, max_tokens=100, overlap_tokens=10
  24. )
  25. # No chunking should occur
  26. assert len(chunked_docs) == 3
  27. assert chunked_docs == documents
  28. assert doc_indices == [0, 1, 2]
  29. def test_chunking_with_character_fallback(self):
  30. """Test chunking falls back to character-based when tokenizer unavailable"""
  31. # Create a very long document that exceeds character limit
  32. long_doc = "a" * 2000 # 2000 characters
  33. documents = [long_doc, "short doc"]
  34. with patch("lightrag.utils.TiktokenTokenizer", side_effect=ImportError):
  35. chunked_docs, doc_indices = chunk_documents_for_rerank(
  36. documents,
  37. max_tokens=100, # 100 tokens = ~400 chars
  38. overlap_tokens=10, # 10 tokens = ~40 chars
  39. )
  40. # First doc should be split into chunks, second doc stays whole
  41. assert len(chunked_docs) > 2 # At least one chunk from first doc + second doc
  42. assert chunked_docs[-1] == "short doc" # Last chunk is the short doc
  43. # Verify doc_indices maps chunks to correct original document
  44. assert doc_indices[-1] == 1 # Last chunk maps to document 1
  45. def test_chunking_with_tiktoken_tokenizer(self):
  46. """Test chunking with actual tokenizer"""
  47. # Create document with known token count
  48. # Approximate: "word " = ~1 token, so 200 words ~ 200 tokens
  49. long_doc = " ".join([f"word{i}" for i in range(200)])
  50. documents = [long_doc, "short"]
  51. chunked_docs, doc_indices = chunk_documents_for_rerank(
  52. documents, max_tokens=50, overlap_tokens=10
  53. )
  54. # Long doc should be split, short doc should remain
  55. assert len(chunked_docs) > 2
  56. assert doc_indices[-1] == 1 # Last chunk is from second document
  57. # Verify overlapping chunks contain overlapping content
  58. if len(chunked_docs) > 2:
  59. # Check that consecutive chunks from same doc have some overlap
  60. for i in range(len(doc_indices) - 1):
  61. if doc_indices[i] == doc_indices[i + 1] == 0:
  62. # Both chunks from first doc, should have overlap
  63. chunk1_words = chunked_docs[i].split()
  64. chunk2_words = chunked_docs[i + 1].split()
  65. # At least one word should be common due to overlap
  66. assert any(word in chunk2_words for word in chunk1_words[-5:])
  67. def test_empty_documents(self):
  68. """Test handling of empty document list"""
  69. documents = []
  70. chunked_docs, doc_indices = chunk_documents_for_rerank(documents)
  71. assert chunked_docs == []
  72. assert doc_indices == []
  73. def test_single_document_chunking(self):
  74. """Test chunking of a single long document"""
  75. # Create document with ~100 tokens
  76. long_doc = " ".join([f"token{i}" for i in range(100)])
  77. documents = [long_doc]
  78. chunked_docs, doc_indices = chunk_documents_for_rerank(
  79. documents, max_tokens=30, overlap_tokens=5
  80. )
  81. # Should create multiple chunks
  82. assert len(chunked_docs) > 1
  83. # All chunks should map to document 0
  84. assert all(idx == 0 for idx in doc_indices)
  85. class TestAggregateChunkScores:
  86. """Test suite for aggregate_chunk_scores function"""
  87. def test_no_chunking_simple_aggregation(self):
  88. """Test aggregation when no chunking occurred (1:1 mapping)"""
  89. chunk_results = [
  90. {"index": 0, "relevance_score": 0.9},
  91. {"index": 1, "relevance_score": 0.7},
  92. {"index": 2, "relevance_score": 0.5},
  93. ]
  94. doc_indices = [0, 1, 2] # 1:1 mapping
  95. num_original_docs = 3
  96. aggregated = aggregate_chunk_scores(
  97. chunk_results, doc_indices, num_original_docs, aggregation="max"
  98. )
  99. # Results should be sorted by score
  100. assert len(aggregated) == 3
  101. assert aggregated[0]["index"] == 0
  102. assert aggregated[0]["relevance_score"] == 0.9
  103. assert aggregated[1]["index"] == 1
  104. assert aggregated[1]["relevance_score"] == 0.7
  105. assert aggregated[2]["index"] == 2
  106. assert aggregated[2]["relevance_score"] == 0.5
  107. def test_max_aggregation_with_chunks(self):
  108. """Test max aggregation strategy with multiple chunks per document"""
  109. # 5 chunks: first 3 from doc 0, last 2 from doc 1
  110. chunk_results = [
  111. {"index": 0, "relevance_score": 0.5},
  112. {"index": 1, "relevance_score": 0.8},
  113. {"index": 2, "relevance_score": 0.6},
  114. {"index": 3, "relevance_score": 0.7},
  115. {"index": 4, "relevance_score": 0.4},
  116. ]
  117. doc_indices = [0, 0, 0, 1, 1]
  118. num_original_docs = 2
  119. aggregated = aggregate_chunk_scores(
  120. chunk_results, doc_indices, num_original_docs, aggregation="max"
  121. )
  122. # Should take max score for each document
  123. assert len(aggregated) == 2
  124. assert aggregated[0]["index"] == 0
  125. assert aggregated[0]["relevance_score"] == 0.8 # max of 0.5, 0.8, 0.6
  126. assert aggregated[1]["index"] == 1
  127. assert aggregated[1]["relevance_score"] == 0.7 # max of 0.7, 0.4
  128. def test_mean_aggregation_with_chunks(self):
  129. """Test mean aggregation strategy"""
  130. chunk_results = [
  131. {"index": 0, "relevance_score": 0.6},
  132. {"index": 1, "relevance_score": 0.8},
  133. {"index": 2, "relevance_score": 0.4},
  134. ]
  135. doc_indices = [0, 0, 1] # First two chunks from doc 0, last from doc 1
  136. num_original_docs = 2
  137. aggregated = aggregate_chunk_scores(
  138. chunk_results, doc_indices, num_original_docs, aggregation="mean"
  139. )
  140. assert len(aggregated) == 2
  141. assert aggregated[0]["index"] == 0
  142. assert aggregated[0]["relevance_score"] == pytest.approx(0.7) # (0.6 + 0.8) / 2
  143. assert aggregated[1]["index"] == 1
  144. assert aggregated[1]["relevance_score"] == 0.4
  145. def test_first_aggregation_with_chunks(self):
  146. """Test first aggregation strategy"""
  147. chunk_results = [
  148. {"index": 0, "relevance_score": 0.6},
  149. {"index": 1, "relevance_score": 0.8},
  150. {"index": 2, "relevance_score": 0.4},
  151. ]
  152. doc_indices = [0, 0, 1]
  153. num_original_docs = 2
  154. aggregated = aggregate_chunk_scores(
  155. chunk_results, doc_indices, num_original_docs, aggregation="first"
  156. )
  157. assert len(aggregated) == 2
  158. # First should use first score seen for each doc
  159. assert aggregated[0]["index"] == 0
  160. assert aggregated[0]["relevance_score"] == 0.6 # First score for doc 0
  161. assert aggregated[1]["index"] == 1
  162. assert aggregated[1]["relevance_score"] == 0.4
  163. def test_empty_chunk_results(self):
  164. """Test handling of empty results"""
  165. aggregated = aggregate_chunk_scores([], [], 3, aggregation="max")
  166. assert aggregated == []
  167. def test_documents_with_no_scores(self):
  168. """Test when some documents have no chunks/scores"""
  169. chunk_results = [
  170. {"index": 0, "relevance_score": 0.9},
  171. {"index": 1, "relevance_score": 0.7},
  172. ]
  173. doc_indices = [0, 0] # Both chunks from document 0
  174. num_original_docs = 3 # But we have 3 documents total
  175. aggregated = aggregate_chunk_scores(
  176. chunk_results, doc_indices, num_original_docs, aggregation="max"
  177. )
  178. # Only doc 0 should appear in results
  179. assert len(aggregated) == 1
  180. assert aggregated[0]["index"] == 0
  181. def test_unknown_aggregation_strategy(self):
  182. """Test that unknown strategy falls back to max"""
  183. chunk_results = [
  184. {"index": 0, "relevance_score": 0.6},
  185. {"index": 1, "relevance_score": 0.8},
  186. ]
  187. doc_indices = [0, 0]
  188. num_original_docs = 1
  189. # Use invalid strategy
  190. aggregated = aggregate_chunk_scores(
  191. chunk_results, doc_indices, num_original_docs, aggregation="invalid"
  192. )
  193. # Should fall back to max
  194. assert aggregated[0]["relevance_score"] == 0.8
  195. @pytest.mark.offline
  196. class TestTopNWithChunking:
  197. """Tests for top_n behavior when chunking is enabled (Bug fix verification)"""
  198. @pytest.mark.asyncio
  199. async def test_top_n_limits_documents_not_chunks(self):
  200. """
  201. Test that top_n correctly limits documents (not chunks) when chunking is enabled.
  202. Bug scenario: 10 docs expand to 50 chunks. With old behavior, top_n=5 would
  203. return scores for only 5 chunks (possibly all from 1-2 docs). After aggregation,
  204. fewer than 5 documents would be returned.
  205. Fixed behavior: top_n=5 should return exactly 5 documents after aggregation.
  206. """
  207. # Setup: 5 documents, each producing multiple chunks when chunked
  208. # Using small max_tokens to force chunking
  209. long_docs = [" ".join([f"doc{i}_word{j}" for j in range(50)]) for i in range(5)]
  210. query = "test query"
  211. # First, determine how many chunks will be created by actual chunking
  212. _, doc_indices = chunk_documents_for_rerank(
  213. long_docs, max_tokens=50, overlap_tokens=10
  214. )
  215. num_chunks = len(doc_indices)
  216. # Mock API returns scores for ALL chunks (simulating disabled API-level top_n)
  217. # Give different scores to ensure doc 0 gets highest, doc 1 second, etc.
  218. # Assign scores based on original document index (lower doc index = higher score)
  219. mock_chunk_scores = []
  220. for i in range(num_chunks):
  221. original_doc = doc_indices[i]
  222. # Higher score for lower doc index, with small variation per chunk
  223. base_score = 0.9 - (original_doc * 0.1)
  224. mock_chunk_scores.append({"index": i, "relevance_score": base_score})
  225. mock_response = Mock()
  226. mock_response.status = 200
  227. mock_response.json = AsyncMock(return_value={"results": mock_chunk_scores})
  228. mock_response.request_info = None
  229. mock_response.history = None
  230. mock_response.headers = {}
  231. mock_response.__aenter__ = AsyncMock(return_value=mock_response)
  232. mock_response.__aexit__ = AsyncMock(return_value=None)
  233. mock_session = Mock()
  234. mock_session.post = Mock(return_value=mock_response)
  235. mock_session.__aenter__ = AsyncMock(return_value=mock_session)
  236. mock_session.__aexit__ = AsyncMock(return_value=None)
  237. with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
  238. result = await cohere_rerank(
  239. query=query,
  240. documents=long_docs,
  241. api_key="test-key",
  242. base_url="http://test.com/rerank",
  243. enable_chunking=True,
  244. max_tokens_per_doc=50, # Match chunking above
  245. top_n=3, # Request top 3 documents
  246. )
  247. # Verify: should get exactly 3 documents (not unlimited chunks)
  248. assert len(result) == 3
  249. # All results should have valid document indices (0-4)
  250. assert all(0 <= r["index"] < 5 for r in result)
  251. # Results should be sorted by score (descending)
  252. assert all(
  253. result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
  254. for i in range(len(result) - 1)
  255. )
  256. # The top 3 docs should be 0, 1, 2 (highest scores)
  257. result_indices = [r["index"] for r in result]
  258. assert set(result_indices) == {0, 1, 2}
  259. @pytest.mark.asyncio
  260. async def test_api_receives_no_top_n_when_chunking_enabled(self):
  261. """
  262. Test that the API request does NOT include top_n when chunking is enabled.
  263. This ensures all chunk scores are retrieved for proper aggregation.
  264. """
  265. documents = [" ".join([f"word{i}" for i in range(100)]), "short doc"]
  266. query = "test query"
  267. captured_payload = {}
  268. mock_response = Mock()
  269. mock_response.status = 200
  270. mock_response.json = AsyncMock(
  271. return_value={
  272. "results": [
  273. {"index": 0, "relevance_score": 0.9},
  274. {"index": 1, "relevance_score": 0.8},
  275. {"index": 2, "relevance_score": 0.7},
  276. ]
  277. }
  278. )
  279. mock_response.request_info = None
  280. mock_response.history = None
  281. mock_response.headers = {}
  282. mock_response.__aenter__ = AsyncMock(return_value=mock_response)
  283. mock_response.__aexit__ = AsyncMock(return_value=None)
  284. def capture_post(*args, **kwargs):
  285. captured_payload.update(kwargs.get("json", {}))
  286. return mock_response
  287. mock_session = Mock()
  288. mock_session.post = Mock(side_effect=capture_post)
  289. mock_session.__aenter__ = AsyncMock(return_value=mock_session)
  290. mock_session.__aexit__ = AsyncMock(return_value=None)
  291. with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
  292. await cohere_rerank(
  293. query=query,
  294. documents=documents,
  295. api_key="test-key",
  296. base_url="http://test.com/rerank",
  297. enable_chunking=True,
  298. max_tokens_per_doc=30,
  299. top_n=1, # User wants top 1 document
  300. )
  301. # Verify: API payload should NOT have top_n (disabled for chunking)
  302. assert "top_n" not in captured_payload
  303. @pytest.mark.asyncio
  304. async def test_top_n_not_modified_when_chunking_disabled(self):
  305. """
  306. Test that top_n is passed through to API when chunking is disabled.
  307. """
  308. documents = ["doc1", "doc2"]
  309. query = "test query"
  310. captured_payload = {}
  311. mock_response = Mock()
  312. mock_response.status = 200
  313. mock_response.json = AsyncMock(
  314. return_value={
  315. "results": [
  316. {"index": 0, "relevance_score": 0.9},
  317. ]
  318. }
  319. )
  320. mock_response.request_info = None
  321. mock_response.history = None
  322. mock_response.headers = {}
  323. mock_response.__aenter__ = AsyncMock(return_value=mock_response)
  324. mock_response.__aexit__ = AsyncMock(return_value=None)
  325. def capture_post(*args, **kwargs):
  326. captured_payload.update(kwargs.get("json", {}))
  327. return mock_response
  328. mock_session = Mock()
  329. mock_session.post = Mock(side_effect=capture_post)
  330. mock_session.__aenter__ = AsyncMock(return_value=mock_session)
  331. mock_session.__aexit__ = AsyncMock(return_value=None)
  332. with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
  333. await cohere_rerank(
  334. query=query,
  335. documents=documents,
  336. api_key="test-key",
  337. base_url="http://test.com/rerank",
  338. enable_chunking=False, # Chunking disabled
  339. top_n=1,
  340. )
  341. # Verify: API payload should have top_n when chunking is disabled
  342. assert captured_payload.get("top_n") == 1
  343. @pytest.mark.offline
  344. class TestCohereRerankChunking:
  345. """Integration tests for cohere_rerank with chunking enabled"""
  346. @pytest.mark.asyncio
  347. async def test_cohere_rerank_with_chunking_disabled(self):
  348. """Test that chunking can be disabled"""
  349. documents = ["doc1", "doc2"]
  350. query = "test query"
  351. # Mock the generic_rerank_api
  352. with patch(
  353. "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
  354. ) as mock_api:
  355. mock_api.return_value = [
  356. {"index": 0, "relevance_score": 0.9},
  357. {"index": 1, "relevance_score": 0.7},
  358. ]
  359. result = await cohere_rerank(
  360. query=query,
  361. documents=documents,
  362. api_key="test-key",
  363. enable_chunking=False,
  364. max_tokens_per_doc=100,
  365. )
  366. # Verify generic_rerank_api was called with correct parameters
  367. mock_api.assert_called_once()
  368. call_kwargs = mock_api.call_args[1]
  369. assert call_kwargs["enable_chunking"] is False
  370. assert call_kwargs["max_tokens_per_doc"] == 100
  371. # Result should mirror mocked scores
  372. assert len(result) == 2
  373. assert result[0]["index"] == 0
  374. assert result[0]["relevance_score"] == 0.9
  375. assert result[1]["index"] == 1
  376. assert result[1]["relevance_score"] == 0.7
  377. @pytest.mark.asyncio
  378. async def test_cohere_rerank_with_chunking_enabled(self):
  379. """Test that chunking parameters are passed through"""
  380. documents = ["doc1", "doc2"]
  381. query = "test query"
  382. with patch(
  383. "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
  384. ) as mock_api:
  385. mock_api.return_value = [
  386. {"index": 0, "relevance_score": 0.9},
  387. {"index": 1, "relevance_score": 0.7},
  388. ]
  389. result = await cohere_rerank(
  390. query=query,
  391. documents=documents,
  392. api_key="test-key",
  393. enable_chunking=True,
  394. max_tokens_per_doc=480,
  395. )
  396. # Verify parameters were passed
  397. call_kwargs = mock_api.call_args[1]
  398. assert call_kwargs["enable_chunking"] is True
  399. assert call_kwargs["max_tokens_per_doc"] == 480
  400. # Result should mirror mocked scores
  401. assert len(result) == 2
  402. assert result[0]["index"] == 0
  403. assert result[0]["relevance_score"] == 0.9
  404. assert result[1]["index"] == 1
  405. assert result[1]["relevance_score"] == 0.7
  406. @pytest.mark.asyncio
  407. async def test_cohere_rerank_default_parameters(self):
  408. """Test default parameter values for cohere_rerank"""
  409. documents = ["doc1"]
  410. query = "test"
  411. with patch(
  412. "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
  413. ) as mock_api:
  414. mock_api.return_value = [{"index": 0, "relevance_score": 0.9}]
  415. result = await cohere_rerank(
  416. query=query, documents=documents, api_key="test-key"
  417. )
  418. # Verify default values
  419. call_kwargs = mock_api.call_args[1]
  420. assert call_kwargs["enable_chunking"] is False
  421. assert call_kwargs["max_tokens_per_doc"] == 4096
  422. assert call_kwargs["model"] == "rerank-v3.5"
  423. # Result should mirror mocked scores
  424. assert len(result) == 1
  425. assert result[0]["index"] == 0
  426. assert result[0]["relevance_score"] == 0.9
  427. @pytest.mark.offline
  428. class TestEndToEndChunking:
  429. """End-to-end tests for chunking workflow"""
  430. @pytest.mark.asyncio
  431. async def test_end_to_end_chunking_workflow(self):
  432. """Test complete chunking workflow from documents to aggregated results"""
  433. # Create documents where first one needs chunking
  434. long_doc = " ".join([f"word{i}" for i in range(100)])
  435. documents = [long_doc, "short doc"]
  436. query = "test query"
  437. # Mock the HTTP call inside generic_rerank_api
  438. mock_response = Mock()
  439. mock_response.status = 200
  440. mock_response.json = AsyncMock(
  441. return_value={
  442. "results": [
  443. {"index": 0, "relevance_score": 0.5}, # chunk 0 from doc 0
  444. {"index": 1, "relevance_score": 0.8}, # chunk 1 from doc 0
  445. {"index": 2, "relevance_score": 0.6}, # chunk 2 from doc 0
  446. {"index": 3, "relevance_score": 0.7}, # doc 1 (short)
  447. ]
  448. }
  449. )
  450. mock_response.request_info = None
  451. mock_response.history = None
  452. mock_response.headers = {}
  453. # Make mock_response an async context manager (for `async with session.post() as response`)
  454. mock_response.__aenter__ = AsyncMock(return_value=mock_response)
  455. mock_response.__aexit__ = AsyncMock(return_value=None)
  456. mock_session = Mock()
  457. # session.post() returns an async context manager, so return mock_response which is now one
  458. mock_session.post = Mock(return_value=mock_response)
  459. mock_session.__aenter__ = AsyncMock(return_value=mock_session)
  460. mock_session.__aexit__ = AsyncMock(return_value=None)
  461. with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
  462. result = await cohere_rerank(
  463. query=query,
  464. documents=documents,
  465. api_key="test-key",
  466. base_url="http://test.com/rerank",
  467. enable_chunking=True,
  468. max_tokens_per_doc=30, # Force chunking of long doc
  469. )
  470. # Should get 2 results (one per original document)
  471. # The long doc's chunks should be aggregated
  472. assert len(result) <= len(documents)
  473. # Results should be sorted by score
  474. assert all(
  475. result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
  476. for i in range(len(result) - 1)
  477. )