test_memgraph_storage.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import pytest
  2. from lightrag.kg.memgraph_impl import MemgraphStorage
  3. pytestmark = pytest.mark.offline
  4. class _FakeNode(dict):
  5. def __init__(self, node_id: int, entity_id: str, **properties):
  6. super().__init__(entity_id=entity_id, **properties)
  7. self.id = node_id
  8. class _FakeResult:
  9. def __init__(self, record):
  10. self._record = record
  11. async def single(self):
  12. return self._record
  13. async def consume(self):
  14. return None
  15. class _FakeSession:
  16. def __init__(self, record, calls):
  17. self._record = record
  18. self._calls = calls
  19. async def __aenter__(self):
  20. return self
  21. async def __aexit__(self, exc_type, exc, tb):
  22. return False
  23. async def run(self, query, parameters=None, **kwargs):
  24. if parameters is None:
  25. parameters = kwargs
  26. self._calls.append((query, parameters))
  27. return _FakeResult(self._record)
  28. class _FakeDriver:
  29. def __init__(self, record, calls):
  30. self._record = record
  31. self._calls = calls
  32. def session(self, **kwargs):
  33. return _FakeSession(self._record, self._calls)
  34. def _make_storage(record):
  35. calls = []
  36. storage = MemgraphStorage(
  37. namespace="chunk_entity_relation",
  38. global_config={"max_graph_nodes": 1000},
  39. embedding_func=None,
  40. workspace="test",
  41. )
  42. storage._driver = _FakeDriver(record, calls)
  43. storage._DATABASE = "memgraph"
  44. return storage, calls
  45. @pytest.mark.asyncio
  46. async def test_get_knowledge_graph_preserves_isolated_start_node():
  47. start_node = _FakeNode(1, "Start", description="isolated")
  48. storage, calls = _make_storage(
  49. {
  50. "node_info": [{"node": start_node}],
  51. "relationships": [],
  52. "is_truncated": False,
  53. }
  54. )
  55. result = await storage.get_knowledge_graph("Start", max_depth=0, max_nodes=1)
  56. # Verify result data: isolated node must appear with correct labels and properties
  57. assert len(result.nodes) == 1
  58. assert result.nodes[0].labels == ["Start"]
  59. assert result.nodes[0].properties["entity_id"] == "Start"
  60. assert result.edges == []
  61. assert result.is_truncated is False
  62. # Verify query parameters: max_other_nodes must reserve a slot for the start node
  63. assert len(calls) == 1
  64. _, params = calls[0]
  65. assert params["entity_id"] == "Start"
  66. assert params["max_nodes"] == 1
  67. assert (
  68. params["max_other_nodes"] == 0
  69. ) # max_nodes - 1 = 0, start node occupies the only slot
  70. @pytest.mark.asyncio
  71. async def test_get_knowledge_graph_reserves_capacity_for_start_node_when_truncating():
  72. start_node = _FakeNode(1, "Start")
  73. storage, calls = _make_storage(
  74. {
  75. "node_info": [{"node": start_node}],
  76. "relationships": [],
  77. "is_truncated": True,
  78. }
  79. )
  80. result = await storage.get_knowledge_graph("Start", max_depth=2, max_nodes=2)
  81. # Verify truncation is reflected in result
  82. assert result.is_truncated is True
  83. assert len(result.nodes) == 1
  84. assert result.edges == []
  85. # Verify max_other_nodes leaves exactly one slot for the start node
  86. assert len(calls) == 1
  87. _, params = calls[0]
  88. assert params["max_nodes"] == 2
  89. assert (
  90. params["max_other_nodes"] == 1
  91. ) # max_nodes - 1 = 1, start node always included
  92. @pytest.mark.asyncio
  93. async def test_get_knowledge_graph_max_nodes_zero_does_not_underflow():
  94. """max_other_nodes must not go negative when max_nodes=0."""
  95. storage, calls = _make_storage(
  96. {
  97. "node_info": [],
  98. "relationships": [],
  99. "is_truncated": False,
  100. }
  101. )
  102. await storage.get_knowledge_graph("Start", max_depth=1, max_nodes=0)
  103. _, params = calls[0]
  104. assert params["max_other_nodes"] == 0 # max(0 - 1, 0) = 0, no underflow