test_postgres_upsert_edge_cypher.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. """
  2. Unit tests for PGGraphStorage.upsert_edge Cypher query generation.
  3. Verifies the Cypher query sent to AGE uses the OPTIONAL MATCH + DELETE +
  4. CREATE pattern with inline edge properties — the only reliable way to write
  5. edge properties in Apache AGE (SET r += {...}, ON CREATE/ON MATCH SET, and
  6. SET r.key = value all silently fail for DIRECTED edges).
  7. """
  8. import json
  9. import pytest
  10. from unittest.mock import AsyncMock, MagicMock
  11. import asyncpg
  12. from tenacity import wait_none
  13. from lightrag.kg.postgres_impl import (
  14. PGGraphQueryException,
  15. PGGraphStorage,
  16. _is_transient_graph_write_error,
  17. )
  18. # ---------------------------------------------------------------------------
  19. # Helpers
  20. # ---------------------------------------------------------------------------
  21. def make_graph_storage() -> PGGraphStorage:
  22. """Construct a PGGraphStorage instance with a mocked db."""
  23. storage = PGGraphStorage.__new__(PGGraphStorage)
  24. storage.workspace = "test_ws"
  25. storage.namespace = "test_graph"
  26. storage.graph_name = "test_graph"
  27. storage.db = MagicMock()
  28. return storage
  29. class _FakeConnection:
  30. """Captures statements + args passed to a fake asyncpg connection."""
  31. def __init__(self):
  32. self.calls: list[dict] = []
  33. def transaction(self):
  34. return _FakeTransaction()
  35. async def execute(self, sql, *args):
  36. self.calls.append({"sql": sql, "args": args})
  37. return ""
  38. class _FakeTransaction:
  39. async def __aenter__(self):
  40. return self
  41. async def __aexit__(self, exc_type, exc, tb):
  42. return False
  43. async def _capture_upsert_edge(storage: PGGraphStorage, src: str, tgt: str, edge_data):
  44. """Invoke upsert_edge against a fake connection and return the captured calls."""
  45. conn = _FakeConnection()
  46. async def fake_run_with_retry(operation, **_kwargs):
  47. return await operation(conn)
  48. storage.db._run_with_retry = AsyncMock(side_effect=fake_run_with_retry)
  49. await storage.upsert_edge(src, tgt, edge_data)
  50. return conn.calls
  51. # ---------------------------------------------------------------------------
  52. # upsert_edge: Cypher query correctness
  53. # ---------------------------------------------------------------------------
  54. @pytest.mark.asyncio
  55. async def test_upsert_edge_uses_delete_create_not_set():
  56. """Cypher must use OPTIONAL MATCH + DELETE + CREATE — not SET-based update.
  57. Apache AGE silently drops edge properties written via SET r += {...},
  58. SET r.key = value, and ON CREATE/ON MATCH SET. The only reliable pattern
  59. is to delete any existing edge and CREATE a new one with inline props.
  60. """
  61. storage = make_graph_storage()
  62. calls = await _capture_upsert_edge(
  63. storage, "NodeA", "NodeB", {"weight": "1.0", "description": "test edge"}
  64. )
  65. # The cypher statement is the second one (after the lock acquisition).
  66. cypher_sql = calls[1]["sql"]
  67. # The new query must not contain any SET-based edge update — those silently
  68. # fail against AGE. All edge props live inline in the CREATE clause.
  69. assert (
  70. "SET r" not in cypher_sql
  71. ), f"Edge SET clauses are silently dropped by AGE: {cypher_sql}"
  72. assert "ON CREATE SET" not in cypher_sql
  73. assert "ON MATCH SET" not in cypher_sql
  74. @pytest.mark.asyncio
  75. async def test_upsert_edge_contains_optional_match_delete_create():
  76. """The Cypher query must use OPTIONAL MATCH + DELETE + CREATE with inline props."""
  77. storage = make_graph_storage()
  78. calls = await _capture_upsert_edge(storage, "Alice", "Bob", {"weight": "0.5"})
  79. cypher_sql = calls[1]["sql"]
  80. assert "OPTIONAL MATCH (source)-[old:DIRECTED]-(target)" in cypher_sql
  81. assert "DELETE old" in cypher_sql
  82. assert "CREATE (source)-[r:DIRECTED" in cypher_sql
  83. assert "]->(target)" in cypher_sql
  84. # Edge properties must be inlined into the CREATE clause as a literal map.
  85. assert '`weight`: "0.5"' in cypher_sql
  86. assert "RETURN r" in cypher_sql
  87. @pytest.mark.asyncio
  88. async def test_upsert_edge_handles_empty_props():
  89. """Empty edge_data must inline an empty literal map, not crash."""
  90. storage = make_graph_storage()
  91. calls = await _capture_upsert_edge(storage, "Alice", "Bob", {})
  92. cypher_sql = calls[1]["sql"]
  93. assert "CREATE (source)-[r:DIRECTED {}]->(target)" in cypher_sql
  94. @pytest.mark.asyncio
  95. async def test_upsert_edge_uses_parameterized_match_ids():
  96. """Source and target IDs must flow through Cypher parameters as agtype JSON."""
  97. storage = make_graph_storage()
  98. calls = await _capture_upsert_edge(storage, "Node A", "Node B", {"weight": "1.0"})
  99. cypher_call = calls[1]
  100. cypher_sql = cypher_call["sql"]
  101. assert "entity_id: $src_id" in cypher_sql
  102. assert "entity_id: $tgt_id" in cypher_sql
  103. # Cypher params arrive as a single positional agtype JSON arg.
  104. params_json = cypher_call["args"][0]
  105. params = json.loads(params_json)
  106. assert params["src_id"] == "Node A"
  107. assert params["tgt_id"] == "Node B"
  108. @pytest.mark.asyncio
  109. async def test_upsert_edge_serialises_with_advisory_lock():
  110. """Concurrent upserts on the same edge must be serialised via pg_advisory_xact_lock.
  111. OPTIONAL MATCH + DELETE + CREATE is not atomic on its own: two transactions
  112. could both pass the OPTIONAL MATCH and both run CREATE, leaving duplicate
  113. DIRECTED rows. We open a transaction and acquire a transaction-scoped
  114. advisory lock keyed on (graph_name, ordered (src_id, tgt_id)) before running
  115. the cypher upsert, so concurrent upserts of the same logical edge run
  116. serially without serialising independent graphs.
  117. AGE refuses to plan a join against a cypher() call that contains a CREATE
  118. clause, so the lock cannot live in a CTE — it must be a separate statement
  119. on the same connection inside an explicit transaction.
  120. """
  121. storage = make_graph_storage()
  122. calls = await _capture_upsert_edge(storage, "Alice", "Bob", {"weight": "0.5"})
  123. # Two statements: lock first, then cypher upsert.
  124. assert len(calls) == 2
  125. lock_sql = calls[0]["sql"]
  126. assert "pg_advisory_xact_lock" in lock_sql
  127. # graph_name flows as $1 so independent AGE graphs in the same DB do not
  128. # serialise each other's edges.
  129. assert "$1::text || E'\\x01' ||" in lock_sql
  130. # Key must be order-independent for (src, tgt) so {A, B} and {B, A} collide
  131. # on the same lock (the OPTIONAL MATCH is undirected).
  132. assert "LEAST($2::text, $3::text)" in lock_sql
  133. assert "GREATEST($2::text, $3::text)" in lock_sql
  134. # Raw graph_name + node IDs flow as positional params — never interpolated.
  135. assert "test_graph" not in lock_sql
  136. assert "Alice" not in lock_sql and "Bob" not in lock_sql
  137. assert calls[0]["args"] == ("test_graph", "Alice", "Bob")
  138. # The cypher statement must not contain the lock — that would cause AGE to
  139. # reject the plan with "cypher create clause cannot be rescanned".
  140. cypher_sql = calls[1]["sql"]
  141. assert "pg_advisory_xact_lock" not in cypher_sql
  142. @pytest.mark.asyncio
  143. async def test_upsert_edge_lock_key_includes_graph_name():
  144. """Advisory lock key must include graph_name so independent AGE graphs in
  145. the same PostgreSQL database don't serialise each other's edges.
  146. Regression for the codex review on PR #3056: pg_advisory_xact_lock is
  147. database-wide, so hashing only (src, tgt) would make {Alice, Bob} in
  148. `graph_a` block {Alice, Bob} in `graph_b` even though they touch different
  149. AGE graph tables.
  150. """
  151. storage_a = make_graph_storage()
  152. storage_a.graph_name = "graph_a"
  153. storage_b = make_graph_storage()
  154. storage_b.graph_name = "graph_b"
  155. calls_a = await _capture_upsert_edge(storage_a, "Alice", "Bob", {})
  156. calls_b = await _capture_upsert_edge(storage_b, "Alice", "Bob", {})
  157. # graph_name flows as the first positional arg into the lock SQL so the
  158. # hashed lock key differs between graphs even when (src, tgt) match.
  159. assert calls_a[0]["args"] == ("graph_a", "Alice", "Bob")
  160. assert calls_b[0]["args"] == ("graph_b", "Alice", "Bob")
  161. # And the lock template references graph_name as $1, with the ID pair as
  162. # $2/$3 — keep the param order pinned so future refactors don't silently
  163. # swap them.
  164. lock_sql = calls_a[0]["sql"]
  165. assert "$1::text" in lock_sql
  166. assert "LEAST($2::text, $3::text)" in lock_sql
  167. assert "GREATEST($2::text, $3::text)" in lock_sql
  168. @pytest.mark.asyncio
  169. async def test_upsert_edge_lock_key_is_endpoint_order_independent():
  170. """{A, B} and {B, A} must produce the same advisory lock key.
  171. The lock SQL itself is identical across both call directions; only the
  172. positional args differ. LEAST/GREATEST inside the template then normalises
  173. them to the same hash input, so concurrent {A,B} and {B,A} writes collide
  174. on a single lock (matching the undirected OPTIONAL MATCH).
  175. """
  176. storage = make_graph_storage()
  177. forward = await _capture_upsert_edge(storage, "Alice", "Bob", {})
  178. reverse = await _capture_upsert_edge(storage, "Bob", "Alice", {})
  179. # Same lock SQL template for both directions.
  180. assert forward[0]["sql"] == reverse[0]["sql"]
  181. # graph_name first, then the endpoint pair in whatever order the caller
  182. # passed — LEAST/GREATEST canonicalises inside the SQL.
  183. assert forward[0]["args"][0] == reverse[0]["args"][0] == "test_graph"
  184. assert (
  185. set(forward[0]["args"][1:])
  186. == set(reverse[0]["args"][1:])
  187. == {
  188. "Alice",
  189. "Bob",
  190. }
  191. )
  192. @pytest.mark.asyncio
  193. async def test_upsert_edge_wraps_transient_errors_for_retry(monkeypatch):
  194. """Query-level transient errors must be wrapped in PGGraphQueryException so
  195. the outer @retry predicate can identify them and retry.
  196. Regression: when upsert_edge was moved off self._query onto
  197. self.db._run_with_retry (to run the advisory lock + cypher in one
  198. transaction), the _query exception-wrapping path was bypassed. A raw
  199. asyncpg.DeadlockDetectedError surfacing from connection.execute would
  200. therefore fail _is_transient_graph_write_error's first guard
  201. (isinstance(exc, PGGraphQueryException)) and skip the retry loop, silently
  202. degrading concurrent ingestion under contention. This test pins the
  203. wrapping back in place and asserts the retry loop actually fires.
  204. """
  205. # Make the retry loop fire with zero backoff so the test stays fast.
  206. monkeypatch.setattr(PGGraphStorage.upsert_edge.retry, "wait", wait_none())
  207. storage = make_graph_storage()
  208. deadlock = asyncpg.exceptions.DeadlockDetectedError("simulated deadlock")
  209. call_count = 0
  210. async def fake_run_with_retry(_operation, **_kwargs):
  211. nonlocal call_count
  212. call_count += 1
  213. raise deadlock
  214. storage.db._run_with_retry = AsyncMock(side_effect=fake_run_with_retry)
  215. with pytest.raises(PGGraphQueryException) as excinfo:
  216. await storage.upsert_edge("Alice", "Bob", {"weight": "1.0"})
  217. # The original asyncpg exception is preserved as __cause__ so the predicate
  218. # can introspect it via exc.__cause__.
  219. assert excinfo.value.__cause__ is deadlock
  220. # And the predicate now recognises this exception as retryable.
  221. assert _is_transient_graph_write_error(excinfo.value) is True
  222. # Retried up to stop_after_attempt(3) — proves the wrapping actually
  223. # engages the @retry loop rather than failing fast on the first attempt.
  224. assert call_count == 3
  225. @pytest.mark.asyncio
  226. async def test_upsert_edge_does_not_retry_non_transient_errors(monkeypatch):
  227. """Non-transient errors must not be retried by the @retry loop.
  228. The wrapping in upsert_edge unconditionally re-raises as
  229. PGGraphQueryException, but _is_transient_graph_write_error only returns
  230. True for a small set of asyncpg transient causes. A plain ValueError
  231. bubbling out of _run_with_retry should fail fast, not loop 3 times.
  232. """
  233. monkeypatch.setattr(PGGraphStorage.upsert_edge.retry, "wait", wait_none())
  234. storage = make_graph_storage()
  235. boom = ValueError("not a transient db error")
  236. call_count = 0
  237. async def fake_run_with_retry(_operation, **_kwargs):
  238. nonlocal call_count
  239. call_count += 1
  240. raise boom
  241. storage.db._run_with_retry = AsyncMock(side_effect=fake_run_with_retry)
  242. with pytest.raises(PGGraphQueryException) as excinfo:
  243. await storage.upsert_edge("Alice", "Bob", {"weight": "1.0"})
  244. assert excinfo.value.__cause__ is boom
  245. assert _is_transient_graph_write_error(excinfo.value) is False
  246. assert call_count == 1
  247. @pytest.mark.asyncio
  248. async def test_upsert_edges_batch_iterates_in_sorted_order():
  249. """upsert_edges_batch calls upsert_edge in canonical (LEAST, GREATEST)
  250. order regardless of insertion order.
  251. upsert_edge opens an independent transaction per call and releases the
  252. advisory lock on commit, so this iteration order is not a deadlock fix
  253. — but a deterministic order matches the dedup key already used above
  254. and keeps logs / replays reproducible across callers.
  255. """
  256. storage = make_graph_storage()
  257. captured: list[tuple[str, str]] = []
  258. async def fake_upsert_edge(src, tgt, edge_data): # noqa: ARG001
  259. captured.append((src, tgt))
  260. storage.upsert_edge = AsyncMock(side_effect=fake_upsert_edge)
  261. # Insertion order intentionally non-canonical: B-A, C-A, D-A.
  262. await storage.upsert_edges_batch(
  263. [
  264. ("B", "A", {"weight": "1"}),
  265. ("C", "A", {"weight": "2"}),
  266. ("D", "A", {"weight": "3"}),
  267. ]
  268. )
  269. # Edge keys after canonicalisation: (A,B), (A,C), (A,D). The values
  270. # preserve the caller's original orientation per pair, but the iteration
  271. # visits them in sorted-key order.
  272. canonical_keys = [tuple(sorted(pair)) for pair in captured]
  273. assert canonical_keys == sorted(canonical_keys)
  274. assert canonical_keys == [("A", "B"), ("A", "C"), ("A", "D")]
  275. @pytest.mark.asyncio
  276. async def test_upsert_edges_batch_dedupes_last_write_wins():
  277. """Reciprocal duplicates collapse to a single upsert with the LATEST
  278. edge_data, regardless of which orientation arrives last — preserves the
  279. historical serial-fallback semantics documented on the method."""
  280. storage = make_graph_storage()
  281. captured: list[tuple[str, str, dict]] = []
  282. async def fake_upsert_edge(src, tgt, edge_data):
  283. captured.append((src, tgt, edge_data))
  284. storage.upsert_edge = AsyncMock(side_effect=fake_upsert_edge)
  285. await storage.upsert_edges_batch(
  286. [
  287. ("A", "B", {"weight": "first"}),
  288. ("B", "A", {"weight": "second"}), # reciprocal, wins
  289. ]
  290. )
  291. assert len(captured) == 1
  292. # Orientation = last write's caller order; edge_data = last write's payload.
  293. assert captured[0] == ("B", "A", {"weight": "second"})