test_mongo_storage.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import pytest
  2. from types import SimpleNamespace
  3. from unittest.mock import AsyncMock, Mock
  4. pytest.importorskip(
  5. "pymongo",
  6. reason="pymongo is required for Mongo storage tests",
  7. )
  8. from pymongo.errors import PyMongoError
  9. from lightrag.kg.mongo_impl import MongoDocStatusStorage, MongoGraphStorage
  10. pytestmark = pytest.mark.offline
  11. class _AsyncCursor:
  12. def __init__(self, docs):
  13. self._docs = list(docs)
  14. def limit(self, n: int):
  15. self._docs = self._docs[:n]
  16. return self
  17. def __aiter__(self):
  18. self._iter = iter(self._docs)
  19. return self
  20. async def __anext__(self):
  21. try:
  22. return next(self._iter)
  23. except StopIteration:
  24. raise StopAsyncIteration
  25. class TestMongoGraphStorage:
  26. def _make_storage(self):
  27. storage = MongoGraphStorage.__new__(MongoGraphStorage)
  28. storage.workspace = "test"
  29. storage.global_config = {"max_graph_nodes": 1000}
  30. storage._edge_collection_name = "test_edges"
  31. storage.collection = SimpleNamespace()
  32. storage.edge_collection = SimpleNamespace()
  33. return storage
  34. @pytest.mark.asyncio
  35. async def test_get_knowledge_graph_all_backfills_isolated_nodes_when_truncated(
  36. self,
  37. ):
  38. storage = self._make_storage()
  39. storage.collection.count_documents = AsyncMock(return_value=5)
  40. storage.edge_collection.aggregate = AsyncMock(
  41. return_value=_AsyncCursor(
  42. [{"_id": "A", "degree": 1}, {"_id": "B", "degree": 1}]
  43. )
  44. )
  45. def collection_find_side_effect(query, projection=None):
  46. if query == {"_id": {"$nin": ["A", "B"]}}:
  47. return _AsyncCursor(
  48. [
  49. {"_id": "C", "entity_type": "person"},
  50. {"_id": "D", "entity_type": "person"},
  51. {"_id": "E", "entity_type": "person"},
  52. ]
  53. )
  54. if query == {"_id": {"$in": ["A", "B", "C", "D"]}}:
  55. return _AsyncCursor(
  56. [
  57. {"_id": "B", "entity_type": "person"},
  58. {"_id": "D", "entity_type": "person"},
  59. {"_id": "A", "entity_type": "person"},
  60. {"_id": "C", "entity_type": "person"},
  61. ]
  62. )
  63. raise AssertionError(f"Unexpected node query: {query}")
  64. storage.collection.find = Mock(side_effect=collection_find_side_effect)
  65. storage.edge_collection.find = Mock(
  66. return_value=_AsyncCursor(
  67. [
  68. {
  69. "source_node_id": "A",
  70. "target_node_id": "B",
  71. "relationship": "knows",
  72. }
  73. ]
  74. )
  75. )
  76. result = await storage.get_knowledge_graph_all_by_degree(
  77. max_depth=2, max_nodes=4
  78. )
  79. assert result.is_truncated is True
  80. assert [node.id for node in result.nodes] == ["A", "B", "C", "D"]
  81. assert len(result.edges) == 1
  82. assert result.edges[0].source == "A"
  83. assert result.edges[0].target == "B"
  84. class TestMongoDocStatusLookup:
  85. """Cover the Mongo-native overrides for basename / content_hash lookups."""
  86. def _make_storage(self):
  87. storage = MongoDocStatusStorage.__new__(MongoDocStatusStorage)
  88. storage.workspace = "test"
  89. storage.global_config = {}
  90. storage._collection_name = "test_doc_status"
  91. storage._data = SimpleNamespace()
  92. return storage
  93. @pytest.mark.asyncio
  94. async def test_get_doc_by_file_basename_returns_tuple_on_hit(self):
  95. storage = self._make_storage()
  96. storage._data.find_one = AsyncMock(
  97. return_value={
  98. "_id": "doc-1",
  99. "file_path": "report.pdf",
  100. "status": "processed",
  101. }
  102. )
  103. result = await storage.get_doc_by_file_basename("report.pdf")
  104. assert result is not None
  105. doc_id, doc = result
  106. assert doc_id == "doc-1"
  107. assert doc["file_path"] == "report.pdf"
  108. storage._data.find_one.assert_awaited_once_with({"file_path": "report.pdf"})
  109. @pytest.mark.asyncio
  110. async def test_get_doc_by_file_basename_empty_returns_none_without_query(self):
  111. storage = self._make_storage()
  112. storage._data.find_one = AsyncMock()
  113. assert await storage.get_doc_by_file_basename("") is None
  114. storage._data.find_one.assert_not_called()
  115. @pytest.mark.asyncio
  116. async def test_get_doc_by_file_basename_unknown_source_sentinel(self):
  117. # Lookup for the sentinel must not match real rows that happen to have
  118. # file_path == "unknown_source".
  119. storage = self._make_storage()
  120. storage._data.find_one = AsyncMock()
  121. assert await storage.get_doc_by_file_basename("unknown_source") is None
  122. storage._data.find_one.assert_not_called()
  123. @pytest.mark.asyncio
  124. async def test_get_doc_by_file_basename_miss_returns_none(self):
  125. storage = self._make_storage()
  126. storage._data.find_one = AsyncMock(return_value=None)
  127. assert await storage.get_doc_by_file_basename("missing.pdf") is None
  128. @pytest.mark.asyncio
  129. async def test_get_doc_by_content_hash_returns_tuple_on_hit(self):
  130. storage = self._make_storage()
  131. storage._data.find_one = AsyncMock(
  132. return_value={
  133. "_id": "doc-1",
  134. "file_path": "report.pdf",
  135. "content_hash": "abc123",
  136. "status": "processed",
  137. }
  138. )
  139. result = await storage.get_doc_by_content_hash("abc123")
  140. assert result is not None
  141. doc_id, doc = result
  142. assert doc_id == "doc-1"
  143. assert doc["content_hash"] == "abc123"
  144. storage._data.find_one.assert_awaited_once_with({"content_hash": "abc123"})
  145. @pytest.mark.asyncio
  146. async def test_get_doc_by_content_hash_empty_returns_none_without_query(self):
  147. # Empty hash must short-circuit so it cannot match legacy rows missing
  148. # the field via accidental coercion.
  149. storage = self._make_storage()
  150. storage._data.find_one = AsyncMock()
  151. assert await storage.get_doc_by_content_hash("") is None
  152. storage._data.find_one.assert_not_called()
  153. @pytest.mark.asyncio
  154. async def test_get_doc_by_content_hash_miss_returns_none(self):
  155. storage = self._make_storage()
  156. storage._data.find_one = AsyncMock(return_value=None)
  157. assert await storage.get_doc_by_content_hash("zzz999") is None
  158. @pytest.mark.asyncio
  159. async def test_lookup_swallows_pymongo_error_and_returns_none(self):
  160. # PyMongoError must not propagate to the caller; the dedup path treats
  161. # a storage failure as "no match" and the error is logged instead.
  162. storage = self._make_storage()
  163. storage._data.find_one = AsyncMock(side_effect=PyMongoError("boom"))
  164. assert await storage.get_doc_by_file_basename("report.pdf") is None
  165. assert await storage.get_doc_by_content_hash("abc123") is None