test_postgres_age_quote_fix.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. """
  2. Unit tests for PGGraphStorage.get_nodes_edges_batch and get_node_edges
  3. with special characters in entity names.
  4. Verifies the fix for KeyError when entity names contain double quotes (PR #2872)
  5. and the follow-up Option C refactor to parameterized Cypher queries.
  6. The root cause: AGE returns the original un-escaped entity_id, but the edges_norm
  7. dict was previously keyed with the normalized (escaped) ID, causing a KeyError on lookup.
  8. The Option C fix: use $node_ids / $entity_id parameters instead of string interpolation,
  9. eliminating the need for _normalize_node_id in these read paths entirely.
  10. """
  11. import json
  12. import pytest
  13. from unittest.mock import MagicMock, patch
  14. from lightrag.kg.postgres_impl import PGGraphStorage
  15. # ---------------------------------------------------------------------------
  16. # Helpers
  17. # ---------------------------------------------------------------------------
  18. def make_graph_storage() -> PGGraphStorage:
  19. """Construct a PGGraphStorage instance with a mocked _query method."""
  20. storage = PGGraphStorage.__new__(PGGraphStorage)
  21. storage.workspace = "test_ws"
  22. storage.namespace = "test_graph"
  23. storage.graph_name = "test_graph"
  24. storage.db = MagicMock()
  25. return storage
  26. # ---------------------------------------------------------------------------
  27. # _normalize_node_id (still used by write paths: remove_nodes, upsert_node, etc.)
  28. # ---------------------------------------------------------------------------
  29. def test_normalize_plain_id():
  30. assert PGGraphStorage._normalize_node_id("Alice") == "Alice"
  31. def test_normalize_double_quote():
  32. assert PGGraphStorage._normalize_node_id('John "Smith"') == 'John \\"Smith\\"'
  33. def test_normalize_backslash():
  34. assert PGGraphStorage._normalize_node_id("C:\\path") == "C:\\\\path"
  35. def test_normalize_both_special_chars():
  36. assert (
  37. PGGraphStorage._normalize_node_id('say \\"hello\\"')
  38. == 'say \\\\\\"hello\\\\\\"'
  39. )
  40. # ---------------------------------------------------------------------------
  41. # get_node_edges — parameterized query (Option C)
  42. # ---------------------------------------------------------------------------
  43. @pytest.mark.asyncio
  44. async def test_get_node_edges_passes_original_id_as_parameter():
  45. """entity_id must be passed as a JSON parameter, not interpolated into Cypher."""
  46. storage = make_graph_storage()
  47. entity = 'John "Smith"'
  48. captured_params: list[dict] = []
  49. async def fake_query(sql, **kwargs):
  50. if kwargs.get("params"):
  51. captured_params.append(json.loads(list(kwargs["params"].values())[0]))
  52. return []
  53. with patch.object(storage, "_query", side_effect=fake_query):
  54. await storage.get_node_edges(entity)
  55. assert len(captured_params) == 1
  56. assert captured_params[0]["entity_id"] == entity
  57. @pytest.mark.asyncio
  58. async def test_get_node_edges_cypher_uses_parameter_syntax():
  59. """The SQL sent to _query must use $1::agtype, not a hardcoded escaped string."""
  60. storage = make_graph_storage()
  61. entity = 'John "Smith"'
  62. captured_sql: list[str] = []
  63. async def fake_query(sql, **kwargs):
  64. captured_sql.append(sql)
  65. return []
  66. with patch.object(storage, "_query", side_effect=fake_query):
  67. await storage.get_node_edges(entity)
  68. assert len(captured_sql) == 1
  69. assert "$1::agtype" in captured_sql[0]
  70. # Entity name must NOT appear literally in the SQL string
  71. assert entity not in captured_sql[0]
  72. assert '\\"' not in captured_sql[0]
  73. @pytest.mark.asyncio
  74. async def test_get_node_edges_returns_edges():
  75. storage = make_graph_storage()
  76. async def fake_query(_sql, **_kwargs):
  77. return [
  78. {"source_id": "Alice", "connected_id": "Bob"},
  79. {"source_id": "Alice", "connected_id": None},
  80. ]
  81. with patch.object(storage, "_query", side_effect=fake_query):
  82. result = await storage.get_node_edges("Alice")
  83. assert result == [("Alice", "Bob")]
  84. # ---------------------------------------------------------------------------
  85. # get_nodes_edges_batch — parameterized query (Option C)
  86. # ---------------------------------------------------------------------------
  87. @pytest.mark.asyncio
  88. async def test_get_nodes_edges_batch_passes_original_ids_as_parameter():
  89. """node_ids batch must be passed as a JSON parameter, not interpolated."""
  90. storage = make_graph_storage()
  91. entities = ['John "Smith"', "Alice", "O\\Brien"]
  92. captured_params: list[dict] = []
  93. async def fake_query(_sql, **kwargs):
  94. if kwargs.get("params"):
  95. captured_params.append(json.loads(list(kwargs["params"].values())[0]))
  96. return []
  97. with patch.object(storage, "_query", side_effect=fake_query):
  98. await storage.get_nodes_edges_batch(entities)
  99. assert len(captured_params) == 2 # outgoing + incoming
  100. assert captured_params[0]["node_ids"] == entities
  101. assert captured_params[1]["node_ids"] == entities
  102. @pytest.mark.asyncio
  103. async def test_get_nodes_edges_batch_cypher_uses_parameter_syntax():
  104. """The SQL must use $1::agtype, not hardcoded escaped entity names."""
  105. storage = make_graph_storage()
  106. entity = 'John "Smith"'
  107. captured_sql: list[str] = []
  108. async def fake_query(sql, **_kwargs):
  109. captured_sql.append(sql)
  110. return []
  111. with patch.object(storage, "_query", side_effect=fake_query):
  112. await storage.get_nodes_edges_batch([entity])
  113. assert len(captured_sql) == 2
  114. for sql in captured_sql:
  115. assert "$1::agtype" in sql
  116. assert entity not in sql
  117. assert '\\"' not in sql
  118. @pytest.mark.asyncio
  119. async def test_get_nodes_edges_batch_with_quoted_entity():
  120. """
  121. AGE returns the original un-escaped node_id in query results.
  122. The result dict must be keyed by the original ID.
  123. """
  124. storage = make_graph_storage()
  125. entity = 'John "Smith"'
  126. async def fake_query(sql, **_kwargs):
  127. if "OPTIONAL MATCH (n:base)-[]->" in sql:
  128. return [{"node_id": entity, "connected_id": "Alice"}]
  129. if "OPTIONAL MATCH (n:base)<-[]-" in sql:
  130. return [{"node_id": entity, "connected_id": "Bob"}]
  131. return []
  132. with patch.object(storage, "_query", side_effect=fake_query):
  133. result = await storage.get_nodes_edges_batch([entity])
  134. assert entity in result
  135. assert (entity, "Alice") in result[entity]
  136. assert ("Bob", entity) in result[entity]
  137. @pytest.mark.asyncio
  138. async def test_get_nodes_edges_batch_plain_entity():
  139. """Entity names without special chars still work correctly."""
  140. storage = make_graph_storage()
  141. entity = "Alice"
  142. async def fake_query(sql, **_kwargs):
  143. if "OPTIONAL MATCH (n:base)-[]->" in sql:
  144. return [{"node_id": entity, "connected_id": "Bob"}]
  145. return []
  146. with patch.object(storage, "_query", side_effect=fake_query):
  147. result = await storage.get_nodes_edges_batch([entity])
  148. assert entity in result
  149. assert (entity, "Bob") in result[entity]
  150. @pytest.mark.asyncio
  151. async def test_get_nodes_edges_batch_no_results():
  152. """Nodes with no edges return an empty list, not a KeyError."""
  153. storage = make_graph_storage()
  154. entity = 'Entity "X"'
  155. async def fake_query(_sql, **_kwargs):
  156. return []
  157. with patch.object(storage, "_query", side_effect=fake_query):
  158. result = await storage.get_nodes_edges_batch([entity])
  159. assert entity in result
  160. assert result[entity] == []
  161. @pytest.mark.asyncio
  162. async def test_get_nodes_edges_batch_deduplication():
  163. """Duplicate input IDs are deduplicated; each maps to the same edge list."""
  164. storage = make_graph_storage()
  165. entity = 'Dup "Entity"'
  166. async def fake_query(sql, **_kwargs):
  167. if "OPTIONAL MATCH (n:base)-[]->" in sql:
  168. return [{"node_id": entity, "connected_id": "Other"}]
  169. return []
  170. with patch.object(storage, "_query", side_effect=fake_query):
  171. result = await storage.get_nodes_edges_batch([entity, entity])
  172. assert result[entity] == [(entity, "Other")]
  173. @pytest.mark.asyncio
  174. async def test_get_nodes_edges_batch_empty_input():
  175. """Empty input returns empty dict without calling _query."""
  176. storage = make_graph_storage()
  177. with patch.object(storage, "_query") as mock_q:
  178. result = await storage.get_nodes_edges_batch([])
  179. assert result == {}
  180. mock_q.assert_not_called()