| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252 |
- """
- Unit tests for PGGraphStorage.get_nodes_edges_batch and get_node_edges
- with special characters in entity names.
- Verifies the fix for KeyError when entity names contain double quotes (PR #2872)
- and the follow-up Option C refactor to parameterized Cypher queries.
- The root cause: AGE returns the original un-escaped entity_id, but the edges_norm
- dict was previously keyed with the normalized (escaped) ID, causing a KeyError on lookup.
- The Option C fix: use $node_ids / $entity_id parameters instead of string interpolation,
- eliminating the need for _normalize_node_id in these read paths entirely.
- """
- import json
- import pytest
- from unittest.mock import MagicMock, patch
- from lightrag.kg.postgres_impl import PGGraphStorage
- # ---------------------------------------------------------------------------
- # Helpers
- # ---------------------------------------------------------------------------
- def make_graph_storage() -> PGGraphStorage:
- """Construct a PGGraphStorage instance with a mocked _query method."""
- storage = PGGraphStorage.__new__(PGGraphStorage)
- storage.workspace = "test_ws"
- storage.namespace = "test_graph"
- storage.graph_name = "test_graph"
- storage.db = MagicMock()
- return storage
- # ---------------------------------------------------------------------------
- # _normalize_node_id (still used by write paths: remove_nodes, upsert_node, etc.)
- # ---------------------------------------------------------------------------
- def test_normalize_plain_id():
- assert PGGraphStorage._normalize_node_id("Alice") == "Alice"
- def test_normalize_double_quote():
- assert PGGraphStorage._normalize_node_id('John "Smith"') == 'John \\"Smith\\"'
- def test_normalize_backslash():
- assert PGGraphStorage._normalize_node_id("C:\\path") == "C:\\\\path"
- def test_normalize_both_special_chars():
- assert (
- PGGraphStorage._normalize_node_id('say \\"hello\\"')
- == 'say \\\\\\"hello\\\\\\"'
- )
- # ---------------------------------------------------------------------------
- # get_node_edges — parameterized query (Option C)
- # ---------------------------------------------------------------------------
- @pytest.mark.asyncio
- async def test_get_node_edges_passes_original_id_as_parameter():
- """entity_id must be passed as a JSON parameter, not interpolated into Cypher."""
- storage = make_graph_storage()
- entity = 'John "Smith"'
- captured_params: list[dict] = []
- async def fake_query(sql, **kwargs):
- if kwargs.get("params"):
- captured_params.append(json.loads(list(kwargs["params"].values())[0]))
- return []
- with patch.object(storage, "_query", side_effect=fake_query):
- await storage.get_node_edges(entity)
- assert len(captured_params) == 1
- assert captured_params[0]["entity_id"] == entity
- @pytest.mark.asyncio
- async def test_get_node_edges_cypher_uses_parameter_syntax():
- """The SQL sent to _query must use $1::agtype, not a hardcoded escaped string."""
- storage = make_graph_storage()
- entity = 'John "Smith"'
- captured_sql: list[str] = []
- async def fake_query(sql, **kwargs):
- captured_sql.append(sql)
- return []
- with patch.object(storage, "_query", side_effect=fake_query):
- await storage.get_node_edges(entity)
- assert len(captured_sql) == 1
- assert "$1::agtype" in captured_sql[0]
- # Entity name must NOT appear literally in the SQL string
- assert entity not in captured_sql[0]
- assert '\\"' not in captured_sql[0]
- @pytest.mark.asyncio
- async def test_get_node_edges_returns_edges():
- storage = make_graph_storage()
- async def fake_query(_sql, **_kwargs):
- return [
- {"source_id": "Alice", "connected_id": "Bob"},
- {"source_id": "Alice", "connected_id": None},
- ]
- with patch.object(storage, "_query", side_effect=fake_query):
- result = await storage.get_node_edges("Alice")
- assert result == [("Alice", "Bob")]
- # ---------------------------------------------------------------------------
- # get_nodes_edges_batch — parameterized query (Option C)
- # ---------------------------------------------------------------------------
- @pytest.mark.asyncio
- async def test_get_nodes_edges_batch_passes_original_ids_as_parameter():
- """node_ids batch must be passed as a JSON parameter, not interpolated."""
- storage = make_graph_storage()
- entities = ['John "Smith"', "Alice", "O\\Brien"]
- captured_params: list[dict] = []
- async def fake_query(_sql, **kwargs):
- if kwargs.get("params"):
- captured_params.append(json.loads(list(kwargs["params"].values())[0]))
- return []
- with patch.object(storage, "_query", side_effect=fake_query):
- await storage.get_nodes_edges_batch(entities)
- assert len(captured_params) == 2 # outgoing + incoming
- assert captured_params[0]["node_ids"] == entities
- assert captured_params[1]["node_ids"] == entities
- @pytest.mark.asyncio
- async def test_get_nodes_edges_batch_cypher_uses_parameter_syntax():
- """The SQL must use $1::agtype, not hardcoded escaped entity names."""
- storage = make_graph_storage()
- entity = 'John "Smith"'
- captured_sql: list[str] = []
- async def fake_query(sql, **_kwargs):
- captured_sql.append(sql)
- return []
- with patch.object(storage, "_query", side_effect=fake_query):
- await storage.get_nodes_edges_batch([entity])
- assert len(captured_sql) == 2
- for sql in captured_sql:
- assert "$1::agtype" in sql
- assert entity not in sql
- assert '\\"' not in sql
- @pytest.mark.asyncio
- async def test_get_nodes_edges_batch_with_quoted_entity():
- """
- AGE returns the original un-escaped node_id in query results.
- The result dict must be keyed by the original ID.
- """
- storage = make_graph_storage()
- entity = 'John "Smith"'
- async def fake_query(sql, **_kwargs):
- if "OPTIONAL MATCH (n:base)-[]->" in sql:
- return [{"node_id": entity, "connected_id": "Alice"}]
- if "OPTIONAL MATCH (n:base)<-[]-" in sql:
- return [{"node_id": entity, "connected_id": "Bob"}]
- return []
- with patch.object(storage, "_query", side_effect=fake_query):
- result = await storage.get_nodes_edges_batch([entity])
- assert entity in result
- assert (entity, "Alice") in result[entity]
- assert ("Bob", entity) in result[entity]
- @pytest.mark.asyncio
- async def test_get_nodes_edges_batch_plain_entity():
- """Entity names without special chars still work correctly."""
- storage = make_graph_storage()
- entity = "Alice"
- async def fake_query(sql, **_kwargs):
- if "OPTIONAL MATCH (n:base)-[]->" in sql:
- return [{"node_id": entity, "connected_id": "Bob"}]
- return []
- with patch.object(storage, "_query", side_effect=fake_query):
- result = await storage.get_nodes_edges_batch([entity])
- assert entity in result
- assert (entity, "Bob") in result[entity]
- @pytest.mark.asyncio
- async def test_get_nodes_edges_batch_no_results():
- """Nodes with no edges return an empty list, not a KeyError."""
- storage = make_graph_storage()
- entity = 'Entity "X"'
- async def fake_query(_sql, **_kwargs):
- return []
- with patch.object(storage, "_query", side_effect=fake_query):
- result = await storage.get_nodes_edges_batch([entity])
- assert entity in result
- assert result[entity] == []
- @pytest.mark.asyncio
- async def test_get_nodes_edges_batch_deduplication():
- """Duplicate input IDs are deduplicated; each maps to the same edge list."""
- storage = make_graph_storage()
- entity = 'Dup "Entity"'
- async def fake_query(sql, **_kwargs):
- if "OPTIONAL MATCH (n:base)-[]->" in sql:
- return [{"node_id": entity, "connected_id": "Other"}]
- return []
- with patch.object(storage, "_query", side_effect=fake_query):
- result = await storage.get_nodes_edges_batch([entity, entity])
- assert result[entity] == [(entity, "Other")]
- @pytest.mark.asyncio
- async def test_get_nodes_edges_batch_empty_input():
- """Empty input returns empty dict without calling _query."""
- storage = make_graph_storage()
- with patch.object(storage, "_query") as mock_q:
- result = await storage.get_nodes_edges_batch([])
- assert result == {}
- mock_q.assert_not_called()
|