| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361 |
- """
- Unit tests for Cypher injection prevention in PGGraphStorage write paths.
- Verifies that upsert_node and upsert_edge keep entity IDs parameterized while
- rendering property maps as safely escaped Cypher literals, which is required by
- Apache AGE because ``SET ... += $props`` is not supported.
- """
- import json
- import pytest
- from unittest.mock import AsyncMock, MagicMock, patch
- from lightrag.kg.postgres_impl import PGGraphStorage
- # ---------------------------------------------------------------------------
- # Helpers
- # ---------------------------------------------------------------------------
- def make_graph_storage() -> PGGraphStorage:
- """Construct a PGGraphStorage instance with a mocked db."""
- storage = PGGraphStorage.__new__(PGGraphStorage)
- storage.workspace = "test_ws"
- storage.namespace = "test_graph"
- storage.graph_name = "test_graph"
- storage.db = MagicMock()
- return storage
- class _FakeConnection:
- """Captures statements + args passed to a fake asyncpg connection."""
- def __init__(self):
- self.calls: list[dict] = []
- def transaction(self):
- return _FakeTransaction()
- async def execute(self, sql, *args):
- self.calls.append({"sql": sql, "args": args})
- return ""
- class _FakeTransaction:
- async def __aenter__(self):
- return self
- async def __aexit__(self, exc_type, exc, tb):
- return False
- async def _capture_upsert_edge(storage: PGGraphStorage, src: str, tgt: str, edge_data):
- """Invoke upsert_edge against a fake connection and return the captured calls."""
- conn = _FakeConnection()
- async def fake_run_with_retry(operation, **_kwargs):
- return await operation(conn)
- storage.db._run_with_retry = AsyncMock(side_effect=fake_run_with_retry)
- await storage.upsert_edge(src, tgt, edge_data)
- return conn.calls
- # ---------------------------------------------------------------------------
- # upsert_node — parameterized Cypher
- # ---------------------------------------------------------------------------
- @pytest.mark.asyncio
- async def test_upsert_node_uses_parameterized_cypher():
- """upsert_node must pass entity_id as a Cypher parameter, not interpolate it."""
- storage = make_graph_storage()
- captured_calls: list[dict] = []
- async def fake_query(sql, **kwargs):
- captured_calls.append({"sql": sql, **kwargs})
- return []
- with patch.object(storage, "_query", side_effect=fake_query):
- await storage.upsert_node(
- "Alice", {"entity_id": "Alice", "description": "A person"}
- )
- assert len(captured_calls) == 1
- call = captured_calls[0]
- assert "$1::agtype" in call["sql"]
- assert '"Alice"' not in call["sql"].replace("$1::agtype", "")
- assert "params" in call
- params = json.loads(call["params"]["params"])
- assert params["entity_id"] == "Alice"
- assert "props" not in params
- assert '`description`: "A person"' in call["sql"]
- @pytest.mark.asyncio
- async def test_upsert_node_injection_payload_in_entity_id():
- """A Cypher injection payload in entity_id must be treated as data, not code."""
- storage = make_graph_storage()
- injection = 'test"}) RETURN n; MATCH (m) DETACH DELETE m; //'
- captured_calls: list[dict] = []
- async def fake_query(sql, **kwargs):
- captured_calls.append({"sql": sql, **kwargs})
- return []
- with patch.object(storage, "_query", side_effect=fake_query):
- await storage.upsert_node(
- injection, {"entity_id": injection, "description": "malicious"}
- )
- call = captured_calls[0]
- # The injection payload must NOT appear in the SQL string
- assert "DETACH DELETE" not in call["sql"]
- assert injection not in call["sql"]
- # It must be safely contained in the JSON parameter
- params = json.loads(call["params"]["params"])
- assert params["entity_id"] == injection
- @pytest.mark.asyncio
- async def test_upsert_node_special_chars_in_properties():
- """Property values with special characters are safely escaped in Cypher."""
- storage = make_graph_storage()
- captured_calls: list[dict] = []
- async def fake_query(sql, **kwargs):
- captured_calls.append({"sql": sql, **kwargs})
- return []
- node_data = {
- "entity_id": "test_node",
- "description": 'He said "hello" and used a backslash \\',
- "notes": "Line1\nLine2\tTabbed",
- "formula": "x < 5 && y > 3",
- }
- with patch.object(storage, "_query", side_effect=fake_query):
- await storage.upsert_node("test_node", node_data)
- call = captured_calls[0]
- assert (
- '`description`: "He said \\"hello\\" and used a backslash \\\\"' in call["sql"]
- )
- assert '`notes`: "Line1\\nLine2\\tTabbed"' in call["sql"]
- assert '`formula`: "x < 5 && y > 3"' in call["sql"]
- @pytest.mark.asyncio
- async def test_upsert_node_unicode_entity_id():
- """Unicode entity names are safely parameterized."""
- storage = make_graph_storage()
- captured_calls: list[dict] = []
- async def fake_query(sql, **kwargs):
- captured_calls.append({"sql": sql, **kwargs})
- return []
- unicode_id = "\u4e2d\u6587\u5b9e\u4f53" # Chinese characters
- with patch.object(storage, "_query", side_effect=fake_query):
- await storage.upsert_node(
- unicode_id, {"entity_id": unicode_id, "description": "\u63cf\u8ff0"}
- )
- call = captured_calls[0]
- params = json.loads(call["params"]["params"])
- assert params["entity_id"] == unicode_id
- assert '`description`: "描述"' in call["sql"]
- @pytest.mark.asyncio
- async def test_upsert_node_dollar_signs_in_entity_id():
- """Dollar signs in entity_id don't break dollar-quoting of the Cypher template."""
- storage = make_graph_storage()
- captured_calls: list[dict] = []
- async def fake_query(sql, **kwargs):
- captured_calls.append({"sql": sql, **kwargs})
- return []
- dollar_id = "price is $100 or $$200$$"
- with patch.object(storage, "_query", side_effect=fake_query):
- await storage.upsert_node(
- dollar_id, {"entity_id": dollar_id, "description": "has dollars"}
- )
- call = captured_calls[0]
- # The dollar signs are in the params, not the SQL template
- params = json.loads(call["params"]["params"])
- assert params["entity_id"] == dollar_id
- @pytest.mark.asyncio
- async def test_upsert_node_escapes_backticks_in_property_keys():
- """Backticks in property keys must be escaped before inlining the map."""
- storage = make_graph_storage()
- captured_calls: list[dict] = []
- async def fake_query(sql, **kwargs):
- captured_calls.append({"sql": sql, **kwargs})
- return []
- with patch.object(storage, "_query", side_effect=fake_query):
- await storage.upsert_node(
- "node",
- {"entity_id": "node", "danger`key": 'value "quoted"'},
- )
- assert '`danger``key`: "value \\"quoted\\""' in captured_calls[0]["sql"]
- @pytest.mark.asyncio
- async def test_upsert_node_requires_entity_id():
- """upsert_node still raises ValueError when entity_id is missing."""
- storage = make_graph_storage()
- with pytest.raises(ValueError, match="entity_id"):
- await storage.upsert_node("test", {"description": "no entity_id"})
- # ---------------------------------------------------------------------------
- # upsert_edge — parameterized Cypher
- # ---------------------------------------------------------------------------
- @pytest.mark.asyncio
- async def test_upsert_edge_uses_parameterized_cypher():
- """upsert_edge must pass entity IDs as Cypher parameters."""
- storage = make_graph_storage()
- calls = await _capture_upsert_edge(
- storage, "Alice", "Bob", {"weight": "1.0", "description": "knows"}
- )
- # Two statements run on the connection: advisory lock first, then cypher.
- assert len(calls) == 2
- lock_sql = calls[0]["sql"]
- # Raw node IDs are positional params on the lock, never interpolated.
- assert "Alice" not in lock_sql
- assert "Bob" not in lock_sql
- # graph_name flows as $1, the endpoint pair as $2/$3.
- assert calls[0]["args"] == ("test_graph", "Alice", "Bob")
- cypher_call = calls[1]
- cypher_sql = cypher_call["sql"]
- assert "$1::agtype" in cypher_sql
- assert '"Alice"' not in cypher_sql.replace("$1::agtype", "")
- assert '"Bob"' not in cypher_sql.replace("$1::agtype", "")
- # Cypher params arrive as a single positional agtype JSON arg.
- params = json.loads(cypher_call["args"][0])
- assert params["src_id"] == "Alice"
- assert params["tgt_id"] == "Bob"
- assert "props" not in params
- assert '`weight`: "1.0"' in cypher_sql
- assert '`description`: "knows"' in cypher_sql
- @pytest.mark.asyncio
- async def test_upsert_edge_injection_payload():
- """Injection payloads in edge entity IDs are safely parameterized."""
- storage = make_graph_storage()
- injection_src = 'src"}) MATCH (x) DETACH DELETE x; //'
- injection_tgt = 'tgt"})-[r]-() DELETE r; //'
- calls = await _capture_upsert_edge(
- storage, injection_src, injection_tgt, {"description": "edge"}
- )
- # Injection payloads must never appear in either SQL template — they only
- # flow through positional params.
- for call in calls:
- assert "DETACH DELETE" not in call["sql"]
- assert "DELETE r" not in call["sql"]
- assert injection_src not in call["sql"]
- assert injection_tgt not in call["sql"]
- # Lock statement passes graph_name + raw IDs as positional params.
- assert calls[0]["args"] == ("test_graph", injection_src, injection_tgt)
- # Cypher params arrive as a single positional agtype JSON arg.
- params = json.loads(calls[1]["args"][0])
- assert params["src_id"] == injection_src
- assert params["tgt_id"] == injection_tgt
- @pytest.mark.asyncio
- async def test_upsert_edge_unicode_entity_ids():
- """Unicode entity IDs in edges are safely parameterized."""
- storage = make_graph_storage()
- src = "\u5317\u4eac"
- tgt = "\u4e0a\u6d77"
- calls = await _capture_upsert_edge(
- storage, src, tgt, {"description": "\u8def\u7ebf"}
- )
- # Lock statement carries graph_name + raw IDs as positional params, not
- # interpolated.
- assert calls[0]["args"] == ("test_graph", src, tgt)
- assert src not in calls[0]["sql"]
- assert tgt not in calls[0]["sql"]
- # Cypher params parsed from the positional agtype JSON arg.
- cypher_sql = calls[1]["sql"]
- params = json.loads(calls[1]["args"][0])
- assert params["src_id"] == src
- assert params["tgt_id"] == tgt
- assert '`description`: "路线"' in cypher_sql
- # ---------------------------------------------------------------------------
- # _normalize_node_id — defence-in-depth for remaining interpolation paths
- # ---------------------------------------------------------------------------
- def test_normalize_node_id_strips_null_bytes():
- """Null bytes are stripped to prevent string truncation."""
- assert PGGraphStorage._normalize_node_id("before\x00after") == "beforeafter"
- def test_normalize_node_id_escapes_backslash_and_quote():
- """Backslashes and double quotes are escaped."""
- assert PGGraphStorage._normalize_node_id('a\\"b') == 'a\\\\\\"b'
- def test_normalize_node_id_injection_payload():
- """Injection payload is escaped so it cannot break out of Cypher string."""
- payload = 'test"}) RETURN n; MATCH (m) DETACH DELETE m; //'
- normalized = PGGraphStorage._normalize_node_id(payload)
- # The double quote must be escaped
- assert '\\"' in normalized
- # The escaped string must not contain an unescaped double quote
- # (remove all escaped quotes and check no raw ones remain)
- unescaped = normalized.replace('\\"', "")
- assert '"' not in unescaped
- # ---------------------------------------------------------------------------
- # _query write path passes params to db.execute
- # ---------------------------------------------------------------------------
- @pytest.mark.asyncio
- async def test_query_write_path_passes_params():
- """When readonly=False, _query must forward params to db.execute."""
- storage = make_graph_storage()
- captured_execute_kwargs: list[dict] = []
- async def fake_execute(sql, **kwargs):
- captured_execute_kwargs.append(kwargs)
- return None
- storage.db.execute = fake_execute
- test_params = {"params": json.dumps({"entity_id": "test"})}
- await storage._query(
- "SELECT 1",
- readonly=False,
- upsert=True,
- params=test_params,
- )
- assert len(captured_execute_kwargs) == 1
- assert captured_execute_kwargs[0]["data"] == test_params
|