test_postgres_cypher_injection.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. """
  2. Unit tests for Cypher injection prevention in PGGraphStorage write paths.
  3. Verifies that upsert_node and upsert_edge keep entity IDs parameterized while
  4. rendering property maps as safely escaped Cypher literals, which is required by
  5. Apache AGE because ``SET ... += $props`` is not supported.
  6. """
  7. import json
  8. import pytest
  9. from unittest.mock import AsyncMock, MagicMock, patch
  10. from lightrag.kg.postgres_impl import PGGraphStorage
  11. # ---------------------------------------------------------------------------
  12. # Helpers
  13. # ---------------------------------------------------------------------------
  14. def make_graph_storage() -> PGGraphStorage:
  15. """Construct a PGGraphStorage instance with a mocked db."""
  16. storage = PGGraphStorage.__new__(PGGraphStorage)
  17. storage.workspace = "test_ws"
  18. storage.namespace = "test_graph"
  19. storage.graph_name = "test_graph"
  20. storage.db = MagicMock()
  21. return storage
  22. class _FakeConnection:
  23. """Captures statements + args passed to a fake asyncpg connection."""
  24. def __init__(self):
  25. self.calls: list[dict] = []
  26. def transaction(self):
  27. return _FakeTransaction()
  28. async def execute(self, sql, *args):
  29. self.calls.append({"sql": sql, "args": args})
  30. return ""
  31. class _FakeTransaction:
  32. async def __aenter__(self):
  33. return self
  34. async def __aexit__(self, exc_type, exc, tb):
  35. return False
  36. async def _capture_upsert_edge(storage: PGGraphStorage, src: str, tgt: str, edge_data):
  37. """Invoke upsert_edge against a fake connection and return the captured calls."""
  38. conn = _FakeConnection()
  39. async def fake_run_with_retry(operation, **_kwargs):
  40. return await operation(conn)
  41. storage.db._run_with_retry = AsyncMock(side_effect=fake_run_with_retry)
  42. await storage.upsert_edge(src, tgt, edge_data)
  43. return conn.calls
  44. # ---------------------------------------------------------------------------
  45. # upsert_node — parameterized Cypher
  46. # ---------------------------------------------------------------------------
  47. @pytest.mark.asyncio
  48. async def test_upsert_node_uses_parameterized_cypher():
  49. """upsert_node must pass entity_id as a Cypher parameter, not interpolate it."""
  50. storage = make_graph_storage()
  51. captured_calls: list[dict] = []
  52. async def fake_query(sql, **kwargs):
  53. captured_calls.append({"sql": sql, **kwargs})
  54. return []
  55. with patch.object(storage, "_query", side_effect=fake_query):
  56. await storage.upsert_node(
  57. "Alice", {"entity_id": "Alice", "description": "A person"}
  58. )
  59. assert len(captured_calls) == 1
  60. call = captured_calls[0]
  61. assert "$1::agtype" in call["sql"]
  62. assert '"Alice"' not in call["sql"].replace("$1::agtype", "")
  63. assert "params" in call
  64. params = json.loads(call["params"]["params"])
  65. assert params["entity_id"] == "Alice"
  66. assert "props" not in params
  67. assert '`description`: "A person"' in call["sql"]
  68. @pytest.mark.asyncio
  69. async def test_upsert_node_injection_payload_in_entity_id():
  70. """A Cypher injection payload in entity_id must be treated as data, not code."""
  71. storage = make_graph_storage()
  72. injection = 'test"}) RETURN n; MATCH (m) DETACH DELETE m; //'
  73. captured_calls: list[dict] = []
  74. async def fake_query(sql, **kwargs):
  75. captured_calls.append({"sql": sql, **kwargs})
  76. return []
  77. with patch.object(storage, "_query", side_effect=fake_query):
  78. await storage.upsert_node(
  79. injection, {"entity_id": injection, "description": "malicious"}
  80. )
  81. call = captured_calls[0]
  82. # The injection payload must NOT appear in the SQL string
  83. assert "DETACH DELETE" not in call["sql"]
  84. assert injection not in call["sql"]
  85. # It must be safely contained in the JSON parameter
  86. params = json.loads(call["params"]["params"])
  87. assert params["entity_id"] == injection
  88. @pytest.mark.asyncio
  89. async def test_upsert_node_special_chars_in_properties():
  90. """Property values with special characters are safely escaped in Cypher."""
  91. storage = make_graph_storage()
  92. captured_calls: list[dict] = []
  93. async def fake_query(sql, **kwargs):
  94. captured_calls.append({"sql": sql, **kwargs})
  95. return []
  96. node_data = {
  97. "entity_id": "test_node",
  98. "description": 'He said "hello" and used a backslash \\',
  99. "notes": "Line1\nLine2\tTabbed",
  100. "formula": "x < 5 && y > 3",
  101. }
  102. with patch.object(storage, "_query", side_effect=fake_query):
  103. await storage.upsert_node("test_node", node_data)
  104. call = captured_calls[0]
  105. assert (
  106. '`description`: "He said \\"hello\\" and used a backslash \\\\"' in call["sql"]
  107. )
  108. assert '`notes`: "Line1\\nLine2\\tTabbed"' in call["sql"]
  109. assert '`formula`: "x < 5 && y > 3"' in call["sql"]
  110. @pytest.mark.asyncio
  111. async def test_upsert_node_unicode_entity_id():
  112. """Unicode entity names are safely parameterized."""
  113. storage = make_graph_storage()
  114. captured_calls: list[dict] = []
  115. async def fake_query(sql, **kwargs):
  116. captured_calls.append({"sql": sql, **kwargs})
  117. return []
  118. unicode_id = "\u4e2d\u6587\u5b9e\u4f53" # Chinese characters
  119. with patch.object(storage, "_query", side_effect=fake_query):
  120. await storage.upsert_node(
  121. unicode_id, {"entity_id": unicode_id, "description": "\u63cf\u8ff0"}
  122. )
  123. call = captured_calls[0]
  124. params = json.loads(call["params"]["params"])
  125. assert params["entity_id"] == unicode_id
  126. assert '`description`: "描述"' in call["sql"]
  127. @pytest.mark.asyncio
  128. async def test_upsert_node_dollar_signs_in_entity_id():
  129. """Dollar signs in entity_id don't break dollar-quoting of the Cypher template."""
  130. storage = make_graph_storage()
  131. captured_calls: list[dict] = []
  132. async def fake_query(sql, **kwargs):
  133. captured_calls.append({"sql": sql, **kwargs})
  134. return []
  135. dollar_id = "price is $100 or $$200$$"
  136. with patch.object(storage, "_query", side_effect=fake_query):
  137. await storage.upsert_node(
  138. dollar_id, {"entity_id": dollar_id, "description": "has dollars"}
  139. )
  140. call = captured_calls[0]
  141. # The dollar signs are in the params, not the SQL template
  142. params = json.loads(call["params"]["params"])
  143. assert params["entity_id"] == dollar_id
  144. @pytest.mark.asyncio
  145. async def test_upsert_node_escapes_backticks_in_property_keys():
  146. """Backticks in property keys must be escaped before inlining the map."""
  147. storage = make_graph_storage()
  148. captured_calls: list[dict] = []
  149. async def fake_query(sql, **kwargs):
  150. captured_calls.append({"sql": sql, **kwargs})
  151. return []
  152. with patch.object(storage, "_query", side_effect=fake_query):
  153. await storage.upsert_node(
  154. "node",
  155. {"entity_id": "node", "danger`key": 'value "quoted"'},
  156. )
  157. assert '`danger``key`: "value \\"quoted\\""' in captured_calls[0]["sql"]
  158. @pytest.mark.asyncio
  159. async def test_upsert_node_requires_entity_id():
  160. """upsert_node still raises ValueError when entity_id is missing."""
  161. storage = make_graph_storage()
  162. with pytest.raises(ValueError, match="entity_id"):
  163. await storage.upsert_node("test", {"description": "no entity_id"})
  164. # ---------------------------------------------------------------------------
  165. # upsert_edge — parameterized Cypher
  166. # ---------------------------------------------------------------------------
  167. @pytest.mark.asyncio
  168. async def test_upsert_edge_uses_parameterized_cypher():
  169. """upsert_edge must pass entity IDs as Cypher parameters."""
  170. storage = make_graph_storage()
  171. calls = await _capture_upsert_edge(
  172. storage, "Alice", "Bob", {"weight": "1.0", "description": "knows"}
  173. )
  174. # Two statements run on the connection: advisory lock first, then cypher.
  175. assert len(calls) == 2
  176. lock_sql = calls[0]["sql"]
  177. # Raw node IDs are positional params on the lock, never interpolated.
  178. assert "Alice" not in lock_sql
  179. assert "Bob" not in lock_sql
  180. # graph_name flows as $1, the endpoint pair as $2/$3.
  181. assert calls[0]["args"] == ("test_graph", "Alice", "Bob")
  182. cypher_call = calls[1]
  183. cypher_sql = cypher_call["sql"]
  184. assert "$1::agtype" in cypher_sql
  185. assert '"Alice"' not in cypher_sql.replace("$1::agtype", "")
  186. assert '"Bob"' not in cypher_sql.replace("$1::agtype", "")
  187. # Cypher params arrive as a single positional agtype JSON arg.
  188. params = json.loads(cypher_call["args"][0])
  189. assert params["src_id"] == "Alice"
  190. assert params["tgt_id"] == "Bob"
  191. assert "props" not in params
  192. assert '`weight`: "1.0"' in cypher_sql
  193. assert '`description`: "knows"' in cypher_sql
  194. @pytest.mark.asyncio
  195. async def test_upsert_edge_injection_payload():
  196. """Injection payloads in edge entity IDs are safely parameterized."""
  197. storage = make_graph_storage()
  198. injection_src = 'src"}) MATCH (x) DETACH DELETE x; //'
  199. injection_tgt = 'tgt"})-[r]-() DELETE r; //'
  200. calls = await _capture_upsert_edge(
  201. storage, injection_src, injection_tgt, {"description": "edge"}
  202. )
  203. # Injection payloads must never appear in either SQL template — they only
  204. # flow through positional params.
  205. for call in calls:
  206. assert "DETACH DELETE" not in call["sql"]
  207. assert "DELETE r" not in call["sql"]
  208. assert injection_src not in call["sql"]
  209. assert injection_tgt not in call["sql"]
  210. # Lock statement passes graph_name + raw IDs as positional params.
  211. assert calls[0]["args"] == ("test_graph", injection_src, injection_tgt)
  212. # Cypher params arrive as a single positional agtype JSON arg.
  213. params = json.loads(calls[1]["args"][0])
  214. assert params["src_id"] == injection_src
  215. assert params["tgt_id"] == injection_tgt
  216. @pytest.mark.asyncio
  217. async def test_upsert_edge_unicode_entity_ids():
  218. """Unicode entity IDs in edges are safely parameterized."""
  219. storage = make_graph_storage()
  220. src = "\u5317\u4eac"
  221. tgt = "\u4e0a\u6d77"
  222. calls = await _capture_upsert_edge(
  223. storage, src, tgt, {"description": "\u8def\u7ebf"}
  224. )
  225. # Lock statement carries graph_name + raw IDs as positional params, not
  226. # interpolated.
  227. assert calls[0]["args"] == ("test_graph", src, tgt)
  228. assert src not in calls[0]["sql"]
  229. assert tgt not in calls[0]["sql"]
  230. # Cypher params parsed from the positional agtype JSON arg.
  231. cypher_sql = calls[1]["sql"]
  232. params = json.loads(calls[1]["args"][0])
  233. assert params["src_id"] == src
  234. assert params["tgt_id"] == tgt
  235. assert '`description`: "路线"' in cypher_sql
  236. # ---------------------------------------------------------------------------
  237. # _normalize_node_id — defence-in-depth for remaining interpolation paths
  238. # ---------------------------------------------------------------------------
  239. def test_normalize_node_id_strips_null_bytes():
  240. """Null bytes are stripped to prevent string truncation."""
  241. assert PGGraphStorage._normalize_node_id("before\x00after") == "beforeafter"
  242. def test_normalize_node_id_escapes_backslash_and_quote():
  243. """Backslashes and double quotes are escaped."""
  244. assert PGGraphStorage._normalize_node_id('a\\"b') == 'a\\\\\\"b'
  245. def test_normalize_node_id_injection_payload():
  246. """Injection payload is escaped so it cannot break out of Cypher string."""
  247. payload = 'test"}) RETURN n; MATCH (m) DETACH DELETE m; //'
  248. normalized = PGGraphStorage._normalize_node_id(payload)
  249. # The double quote must be escaped
  250. assert '\\"' in normalized
  251. # The escaped string must not contain an unescaped double quote
  252. # (remove all escaped quotes and check no raw ones remain)
  253. unescaped = normalized.replace('\\"', "")
  254. assert '"' not in unescaped
  255. # ---------------------------------------------------------------------------
  256. # _query write path passes params to db.execute
  257. # ---------------------------------------------------------------------------
  258. @pytest.mark.asyncio
  259. async def test_query_write_path_passes_params():
  260. """When readonly=False, _query must forward params to db.execute."""
  261. storage = make_graph_storage()
  262. captured_execute_kwargs: list[dict] = []
  263. async def fake_execute(sql, **kwargs):
  264. captured_execute_kwargs.append(kwargs)
  265. return None
  266. storage.db.execute = fake_execute
  267. test_params = {"params": json.dumps({"entity_id": "test"})}
  268. await storage._query(
  269. "SELECT 1",
  270. readonly=False,
  271. upsert=True,
  272. params=test_params,
  273. )
  274. assert len(captured_execute_kwargs) == 1
  275. assert captured_execute_kwargs[0]["data"] == test_params