| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- import pytest
- from lightrag.kg.memgraph_impl import MemgraphStorage
- pytestmark = pytest.mark.offline
- class _FakeNode(dict):
- def __init__(self, node_id: int, entity_id: str, **properties):
- super().__init__(entity_id=entity_id, **properties)
- self.id = node_id
- class _FakeResult:
- def __init__(self, record):
- self._record = record
- async def single(self):
- return self._record
- async def consume(self):
- return None
- class _FakeSession:
- def __init__(self, record, calls):
- self._record = record
- self._calls = calls
- async def __aenter__(self):
- return self
- async def __aexit__(self, exc_type, exc, tb):
- return False
- async def run(self, query, parameters=None, **kwargs):
- if parameters is None:
- parameters = kwargs
- self._calls.append((query, parameters))
- return _FakeResult(self._record)
- class _FakeDriver:
- def __init__(self, record, calls):
- self._record = record
- self._calls = calls
- def session(self, **kwargs):
- return _FakeSession(self._record, self._calls)
- def _make_storage(record):
- calls = []
- storage = MemgraphStorage(
- namespace="chunk_entity_relation",
- global_config={"max_graph_nodes": 1000},
- embedding_func=None,
- workspace="test",
- )
- storage._driver = _FakeDriver(record, calls)
- storage._DATABASE = "memgraph"
- return storage, calls
- @pytest.mark.asyncio
- async def test_get_knowledge_graph_preserves_isolated_start_node():
- start_node = _FakeNode(1, "Start", description="isolated")
- storage, calls = _make_storage(
- {
- "node_info": [{"node": start_node}],
- "relationships": [],
- "is_truncated": False,
- }
- )
- result = await storage.get_knowledge_graph("Start", max_depth=0, max_nodes=1)
- # Verify result data: isolated node must appear with correct labels and properties
- assert len(result.nodes) == 1
- assert result.nodes[0].labels == ["Start"]
- assert result.nodes[0].properties["entity_id"] == "Start"
- assert result.edges == []
- assert result.is_truncated is False
- # Verify query parameters: max_other_nodes must reserve a slot for the start node
- assert len(calls) == 1
- _, params = calls[0]
- assert params["entity_id"] == "Start"
- assert params["max_nodes"] == 1
- assert (
- params["max_other_nodes"] == 0
- ) # max_nodes - 1 = 0, start node occupies the only slot
- @pytest.mark.asyncio
- async def test_get_knowledge_graph_reserves_capacity_for_start_node_when_truncating():
- start_node = _FakeNode(1, "Start")
- storage, calls = _make_storage(
- {
- "node_info": [{"node": start_node}],
- "relationships": [],
- "is_truncated": True,
- }
- )
- result = await storage.get_knowledge_graph("Start", max_depth=2, max_nodes=2)
- # Verify truncation is reflected in result
- assert result.is_truncated is True
- assert len(result.nodes) == 1
- assert result.edges == []
- # Verify max_other_nodes leaves exactly one slot for the start node
- assert len(calls) == 1
- _, params = calls[0]
- assert params["max_nodes"] == 2
- assert (
- params["max_other_nodes"] == 1
- ) # max_nodes - 1 = 1, start node always included
- @pytest.mark.asyncio
- async def test_get_knowledge_graph_max_nodes_zero_does_not_underflow():
- """max_other_nodes must not go negative when max_nodes=0."""
- storage, calls = _make_storage(
- {
- "node_info": [],
- "relationships": [],
- "is_truncated": False,
- }
- )
- await storage.get_knowledge_graph("Start", max_depth=1, max_nodes=0)
- _, params = calls[0]
- assert params["max_other_nodes"] == 0 # max(0 - 1, 0) = 0, no underflow
|