test_pipeline_cancellation.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. """Offline tests for /cancel_pipeline propagation into PARSE and ANALYZE.
  2. Tests target the worker-level cancellation contract added alongside the
  3. existing PROCESS-stage support:
  4. * ``_parse_worker`` and ``_analyze_worker`` check ``cancellation_requested``
  5. at the top of every loop iteration, drain queued items as FAILED with a
  6. ``"User cancelled during {stage}: ..."`` ``error_msg``, and ``task_done()``
  7. each one so ``q.join()`` in ``_run_pipeline_batch`` returns.
  8. * ``analyze_multimodal`` fails fast: the first item that raises (or a
  9. ``cancellation_requested`` flip observed by the poll loop) cancels every
  10. still-running sibling task, preserves already-completed item results in
  11. the sidecar, and re-raises the original exception type.
  12. Tests construct ``_BatchRunContext`` and call worker methods directly to
  13. avoid the cross-task races inherent in driving the full
  14. ``apipeline_process_enqueue_documents`` entry point.
  15. """
  16. from __future__ import annotations
  17. import asyncio
  18. import json
  19. import logging
  20. import time
  21. from datetime import datetime, timezone
  22. from pathlib import Path
  23. from typing import Any
  24. from unittest.mock import AsyncMock
  25. import numpy as np
  26. import pytest
  27. from lightrag import LightRAG, ROLES, RoleLLMConfig
  28. from lightrag.base import DocProcessingStatus, DocStatus
  29. from lightrag.exceptions import MultimodalAnalysisError, PipelineCancelledException
  30. from lightrag.kg.shared_storage import get_namespace_data, get_namespace_lock
  31. from lightrag.pipeline import _BatchRunContext
  32. from lightrag.utils import EmbeddingFunc, Tokenizer
  33. pytestmark = pytest.mark.offline
  34. class _SimpleTokenizerImpl:
  35. def encode(self, content: str) -> list[int]:
  36. return [ord(ch) for ch in content]
  37. def decode(self, tokens: list[int]) -> str:
  38. return "".join(chr(t) for t in tokens)
  39. async def _mock_embedding(texts: list[str]) -> np.ndarray:
  40. return np.random.rand(len(texts), 8)
  41. async def _noop_llm(prompt, **kwargs): # pragma: no cover - never invoked
  42. return ""
  43. def _build_rag(tmp_path: Path, *, vlm_func=None) -> LightRAG:
  44. role_configs = {}
  45. for spec in ROLES:
  46. if spec.name == "vlm" and vlm_func is not None:
  47. role_configs[spec.name] = RoleLLMConfig(func=vlm_func)
  48. else:
  49. role_configs[spec.name] = RoleLLMConfig()
  50. return LightRAG(
  51. working_dir=str(tmp_path),
  52. workspace=f"cancel-{tmp_path.name}",
  53. llm_model_func=vlm_func or _noop_llm,
  54. embedding_func=EmbeddingFunc(
  55. embedding_dim=8,
  56. max_token_size=1024,
  57. func=_mock_embedding,
  58. ),
  59. tokenizer=Tokenizer("mock-tokenizer", _SimpleTokenizerImpl()),
  60. vlm_process_enable=True,
  61. role_llm_configs=role_configs,
  62. )
  63. async def _shutdown_role_workers(rag: LightRAG) -> None:
  64. """Explicitly shut down each role wrapper's priority-queue workers.
  65. finalize_storages() only finalizes storages — it does NOT touch the
  66. per-role priority_limit worker pools. If a test triggered any role
  67. LLM calls whose worker is still in ``await asyncio.sleep(...)`` when
  68. pytest closes the function-scoped event loop, the leaked worker
  69. tasks raise "Task was destroyed but it is pending" / "Event loop is
  70. closed" and (worse, observed on macOS Python 3.12) prevent the
  71. pytest process from exiting cleanly. Call this before
  72. ``finalize_storages()`` to drain workers under a live loop first.
  73. """
  74. for func in rag.role_llm_funcs.values():
  75. try:
  76. await rag._shutdown_llm_wrapper(func)
  77. except Exception as exc:
  78. logging.getLogger("lightrag").warning(
  79. f"role worker shutdown raised during test teardown: {exc}"
  80. )
  81. async def _make_ctx(rag: LightRAG) -> tuple[_BatchRunContext, dict, Any]:
  82. """Build a fresh _BatchRunContext bound to the RAG's workspace.
  83. The pipeline_status dict and lock come from the same shared_storage
  84. keyspace that production code uses, so worker reads of the
  85. cancellation flag observe whatever the test writes.
  86. """
  87. pipeline_status = await get_namespace_data(
  88. "pipeline_status", workspace=rag.workspace
  89. )
  90. pipeline_status_lock = get_namespace_lock(
  91. "pipeline_status", workspace=rag.workspace
  92. )
  93. pipeline_status.clear()
  94. pipeline_status.update(
  95. {
  96. "busy": True,
  97. "history_messages": [],
  98. "latest_message": "",
  99. "cancellation_requested": False,
  100. }
  101. )
  102. ctx = _BatchRunContext(
  103. pipeline_status=pipeline_status,
  104. pipeline_status_lock=pipeline_status_lock,
  105. semaphore=asyncio.Semaphore(2),
  106. total_files=0,
  107. q_native=asyncio.Queue(),
  108. q_mineru=asyncio.Queue(),
  109. q_docling=asyncio.Queue(),
  110. q_analyze=asyncio.Queue(),
  111. q_process=asyncio.Queue(),
  112. )
  113. return ctx, pipeline_status, pipeline_status_lock
  114. def _make_status_doc(doc_id: str) -> DocProcessingStatus:
  115. now = datetime.now(timezone.utc).isoformat()
  116. return DocProcessingStatus(
  117. content_summary=f"summary-{doc_id}",
  118. content_length=10,
  119. file_path=f"{doc_id}.pdf",
  120. status=DocStatus.PENDING,
  121. created_at=now,
  122. updated_at=now,
  123. track_id=None,
  124. content_hash=f"hash-{doc_id}",
  125. )
  126. async def _run_worker_until_drained(
  127. worker_coro_factory,
  128. queue: asyncio.Queue,
  129. *,
  130. timeout: float = 2.0,
  131. ) -> None:
  132. """Spin up the worker, await q.join(), then cancel the worker — same
  133. teardown sequence as ``_run_pipeline_batch``."""
  134. worker = asyncio.create_task(worker_coro_factory())
  135. try:
  136. await asyncio.wait_for(queue.join(), timeout=timeout)
  137. finally:
  138. worker.cancel()
  139. await asyncio.gather(worker, return_exceptions=True)
  140. @pytest.mark.asyncio
  141. async def test_parse_worker_drains_queue_when_cancelled_before_start(tmp_path):
  142. """Cancellation set BEFORE the worker pulls any item: parser must not
  143. run, every queued doc is FAILED with a friendly message, q.join()
  144. returns quickly."""
  145. rag = _build_rag(tmp_path)
  146. await rag.initialize_storages()
  147. try:
  148. ctx, pipeline_status, _ = await _make_ctx(rag)
  149. rag.parse_native = AsyncMock(
  150. side_effect=AssertionError("parse_native must not be called")
  151. )
  152. for i in range(3):
  153. doc_id = f"doc-{i}"
  154. await rag.full_docs.upsert(
  155. {doc_id: {"content": "hello", "file_path": f"{doc_id}.pdf"}}
  156. )
  157. await rag.doc_status.upsert(
  158. {
  159. doc_id: {
  160. "status": DocStatus.PENDING.value,
  161. "content_summary": f"sum-{doc_id}",
  162. "content_length": 5,
  163. "file_path": f"{doc_id}.pdf",
  164. "created_at": datetime.now(timezone.utc).isoformat(),
  165. "updated_at": datetime.now(timezone.utc).isoformat(),
  166. "track_id": "t",
  167. }
  168. }
  169. )
  170. await ctx.q_native.put((doc_id, _make_status_doc(doc_id)))
  171. pipeline_status["cancellation_requested"] = True
  172. start = time.monotonic()
  173. await _run_worker_until_drained(
  174. lambda: rag._parse_worker("native", ctx.q_native, ctx),
  175. ctx.q_native,
  176. )
  177. elapsed = time.monotonic() - start
  178. assert elapsed < 1.0, f"queue drain should be fast, took {elapsed:.2f}s"
  179. assert rag.parse_native.await_count == 0
  180. cancel_messages = [
  181. m
  182. for m in pipeline_status["history_messages"]
  183. if "User cancelled during parse" in m
  184. ]
  185. assert len(cancel_messages) == 3
  186. for i in range(3):
  187. doc_id = f"doc-{i}"
  188. row = await rag.doc_status.get_by_id(doc_id)
  189. assert row is not None
  190. assert row.get("status") == DocStatus.FAILED.value
  191. assert "User cancelled during parse" in (row.get("error_msg") or "")
  192. finally:
  193. await rag.finalize_storages()
  194. @pytest.mark.asyncio
  195. async def test_analyze_worker_drains_queue_when_cancelled_before_start(tmp_path):
  196. """ANALYZE-worker symmetric to the PARSE test above."""
  197. rag = _build_rag(tmp_path)
  198. await rag.initialize_storages()
  199. try:
  200. ctx, pipeline_status, _ = await _make_ctx(rag)
  201. rag.analyze_multimodal = AsyncMock(
  202. side_effect=AssertionError("analyze_multimodal must not be called")
  203. )
  204. for i in range(3):
  205. doc_id = f"doc-{i}"
  206. await rag.doc_status.upsert(
  207. {
  208. doc_id: {
  209. "status": DocStatus.ANALYZING.value,
  210. "content_summary": f"sum-{doc_id}",
  211. "content_length": 5,
  212. "file_path": f"{doc_id}.pdf",
  213. "created_at": datetime.now(timezone.utc).isoformat(),
  214. "updated_at": datetime.now(timezone.utc).isoformat(),
  215. "track_id": "t",
  216. }
  217. }
  218. )
  219. await ctx.q_analyze.put(
  220. (doc_id, _make_status_doc(doc_id), {"content": "x"})
  221. )
  222. pipeline_status["cancellation_requested"] = True
  223. start = time.monotonic()
  224. await _run_worker_until_drained(
  225. lambda: rag._analyze_worker(ctx),
  226. ctx.q_analyze,
  227. )
  228. elapsed = time.monotonic() - start
  229. assert elapsed < 1.0, f"queue drain should be fast, took {elapsed:.2f}s"
  230. assert rag.analyze_multimodal.await_count == 0
  231. cancel_messages = [
  232. m
  233. for m in pipeline_status["history_messages"]
  234. if "User cancelled during analyze" in m
  235. ]
  236. assert len(cancel_messages) == 3
  237. for i in range(3):
  238. row = await rag.doc_status.get_by_id(f"doc-{i}")
  239. assert row is not None
  240. assert row.get("status") == DocStatus.FAILED.value
  241. assert "User cancelled during analyze" in (row.get("error_msg") or "")
  242. finally:
  243. await rag.finalize_storages()
  244. # Drawing sidecar fixture used by both in-flight cancellation and fail-fast
  245. # tests. Three items so we can have one slow / one fast-failing / one slow-
  246. # successful task and observe partial-result preservation.
  247. def _write_three_item_sidecar(tmp_path: Path) -> tuple[str, dict, Path]:
  248. parsed_dir = tmp_path / "parsed"
  249. parsed_dir.mkdir(exist_ok=True)
  250. blocks_path = parsed_dir / "doc.blocks.jsonl"
  251. blocks_path.write_text(
  252. json.dumps({"type": "meta", "doc_id": "doc-1"}) + "\n",
  253. encoding="utf-8",
  254. )
  255. sidecar_path = parsed_dir / "doc.drawings.json"
  256. sidecar_path.write_text(
  257. json.dumps(
  258. {
  259. "drawings": {
  260. "im-A": {"caption": "A", "path": "ignored-A"},
  261. "im-B": {"caption": "B", "path": "ignored-B"},
  262. "im-C": {"caption": "C", "path": "ignored-C"},
  263. }
  264. }
  265. ),
  266. encoding="utf-8",
  267. )
  268. parsed_data = {"blocks_path": str(blocks_path)}
  269. return "doc-1", parsed_data, sidecar_path
  270. @pytest.mark.asyncio
  271. async def test_analyze_multimodal_inflight_cancellation_polls_flag(
  272. tmp_path, monkeypatch
  273. ):
  274. """User sets cancellation_requested while VLM tasks are running.
  275. analyze_multimodal should observe the flag at the next poll boundary
  276. (≤ 0.5s), cancel pending tasks, write the sidecar with partial
  277. results, and raise PipelineCancelledException."""
  278. async def slow_vlm(prompt, **kwargs):
  279. # 1.2s is short enough that even when the priority-queue worker
  280. # finishes the in-flight call after we've already raised (the
  281. # role wrapper does not propagate outer-future cancellation to
  282. # the worker), the post-analyze cleanup is bounded.
  283. await asyncio.sleep(1.2)
  284. return json.dumps(
  285. {"name": "x", "type": "Chart", "description": "should not arrive"}
  286. )
  287. rag = _build_rag(tmp_path, vlm_func=slow_vlm)
  288. await rag.initialize_storages()
  289. try:
  290. doc_id, parsed_data, sidecar_path = _write_three_item_sidecar(tmp_path)
  291. # Bypass image-bytes validation: _analyze_drawing normally reads
  292. # and validates the image file. Replace with a controlled mock so
  293. # the only async work is the (slow_vlm) call we manage above.
  294. async def fake_analyze_drawing(item_id, item, sidecar_dir):
  295. await slow_vlm("dummy") # honors the cancellation timing
  296. return (
  297. {
  298. "name": item_id,
  299. "type": "Chart",
  300. "description": "ok",
  301. "status": "success",
  302. "analyze_time": int(time.time()),
  303. },
  304. f"cache-{item_id}",
  305. )
  306. # analyze_multimodal defines _analyze_drawing as a local closure,
  307. # so we can't monkeypatch it directly. Instead patch the helper
  308. # it relies on (slow_vlm via the role wrapper); we accept the
  309. # closure's image pre-validation and supply a minimal PNG fixture.
  310. from .test_pipeline_analyze_multimodal import PNG_BYTES
  311. for letter in ("A", "B", "C"):
  312. (tmp_path / "parsed" / f"im-{letter}.png").write_bytes(PNG_BYTES)
  313. sidecar_path.write_text(
  314. json.dumps(
  315. {
  316. "drawings": {
  317. f"im-{letter}": {
  318. "caption": letter,
  319. "path": str(tmp_path / "parsed" / f"im-{letter}.png"),
  320. }
  321. for letter in ("A", "B", "C")
  322. }
  323. }
  324. ),
  325. encoding="utf-8",
  326. )
  327. # Use plain dict + asyncio.Lock so the poll loop's lock
  328. # acquisition has no chance of contending with the real
  329. # NamespaceLock used during LightRAG initialization paths.
  330. pipeline_status: dict = {
  331. "busy": True,
  332. "history_messages": [],
  333. "latest_message": "",
  334. "cancellation_requested": False,
  335. }
  336. pipeline_status_lock = asyncio.Lock()
  337. async def flip_after(delay: float):
  338. await asyncio.sleep(delay)
  339. async with pipeline_status_lock:
  340. pipeline_status["cancellation_requested"] = True
  341. flipper = asyncio.create_task(flip_after(0.1))
  342. start = time.monotonic()
  343. with pytest.raises(PipelineCancelledException):
  344. await asyncio.wait_for(
  345. rag.analyze_multimodal(
  346. doc_id=doc_id,
  347. file_path="fixture.pdf",
  348. parsed_data=parsed_data,
  349. process_options="i",
  350. pipeline_status=pipeline_status,
  351. pipeline_status_lock=pipeline_status_lock,
  352. ),
  353. timeout=15.0,
  354. )
  355. elapsed = time.monotonic() - start
  356. await flipper
  357. # Must cancel well before the 1.2s sleep — poll interval is 0.5s.
  358. assert elapsed < 1.0, f"in-flight cancel took {elapsed:.2f}s (>1.0s)"
  359. payload = json.loads(sidecar_path.read_text(encoding="utf-8"))
  360. # Sidecar should have been written even though we raised — every
  361. # item carries a llm_analyze_result entry (cancelled / failure).
  362. for letter in ("A", "B", "C"):
  363. item = payload["drawings"][f"im-{letter}"]
  364. assert "llm_analyze_result" in item
  365. assert item["llm_analyze_result"]["status"] in ("failure", "success")
  366. finally:
  367. await _shutdown_role_workers(rag)
  368. await rag.finalize_storages()
  369. @pytest.mark.asyncio
  370. async def test_analyze_multimodal_fail_fast_preserves_successes(tmp_path):
  371. """One item raises quickly; one already completed; one would have
  372. taken longer. analyze_multimodal must not wait for the slow item,
  373. must preserve the completed item's result in the sidecar, and must
  374. raise MultimodalAnalysisError (not PipelineCancelledException)."""
  375. from .test_pipeline_analyze_multimodal import PNG_BYTES
  376. parsed_dir = tmp_path / "parsed"
  377. parsed_dir.mkdir()
  378. for letter in ("A", "B", "C"):
  379. (parsed_dir / f"im-{letter}.png").write_bytes(PNG_BYTES)
  380. blocks_path = parsed_dir / "doc.blocks.jsonl"
  381. blocks_path.write_text(
  382. json.dumps({"type": "meta", "doc_id": "doc-1"}) + "\n",
  383. encoding="utf-8",
  384. )
  385. sidecar_path = parsed_dir / "doc.drawings.json"
  386. sidecar_path.write_text(
  387. json.dumps(
  388. {
  389. "drawings": {
  390. f"im-{letter}": {
  391. "caption": letter,
  392. "path": str(parsed_dir / f"im-{letter}.png"),
  393. }
  394. for letter in ("A", "B", "C")
  395. }
  396. }
  397. ),
  398. encoding="utf-8",
  399. )
  400. parsed_data = {"blocks_path": str(blocks_path)}
  401. # Per-call behaviour: call 1 succeeds quickly (~0.05s), call 2 fails
  402. # quickly (~0.1s), call 3 would take 5s — we want to prove fail-fast
  403. # cancels call 3 rather than wait. Ordering by call_count rather than
  404. # by item identifier because the VLM role wrapper does not surface
  405. # the item filename in its kwargs (only image_inputs bytes).
  406. call_count = {"n": 0}
  407. call_lock = asyncio.Lock()
  408. async def vlm_func(prompt, **kwargs):
  409. async with call_lock:
  410. call_count["n"] += 1
  411. seq = call_count["n"]
  412. if seq == 1:
  413. await asyncio.sleep(0.05)
  414. return json.dumps({"name": "first", "type": "Chart", "description": "ok"})
  415. if seq == 2:
  416. await asyncio.sleep(0.1)
  417. raise MultimodalAnalysisError("forced failure")
  418. # 1.2s instead of 5s: still proves fail-fast doesn't wait (test
  419. # checks elapsed < 0.8s) but keeps post-analyze cleanup bounded
  420. # since the worker keeps running this sleep until completion.
  421. await asyncio.sleep(1.2)
  422. return json.dumps({"name": "late", "type": "Chart", "description": "late"})
  423. rag = _build_rag(tmp_path, vlm_func=vlm_func)
  424. await rag.initialize_storages()
  425. try:
  426. pipeline_status: dict = {
  427. "busy": True,
  428. "history_messages": [],
  429. "latest_message": "",
  430. "cancellation_requested": False,
  431. }
  432. pipeline_status_lock = asyncio.Lock()
  433. start = time.monotonic()
  434. with pytest.raises(MultimodalAnalysisError):
  435. await asyncio.wait_for(
  436. rag.analyze_multimodal(
  437. doc_id="doc-1",
  438. file_path="fixture.pdf",
  439. parsed_data=parsed_data,
  440. process_options="i",
  441. pipeline_status=pipeline_status,
  442. pipeline_status_lock=pipeline_status_lock,
  443. ),
  444. timeout=15.0,
  445. )
  446. elapsed = time.monotonic() - start
  447. # Without fail-fast we'd have waited for the 1.2s sleep on the
  448. # third call. 0.8s gives the second-call failure path room
  449. # while still catching any regression that waits for call 3.
  450. assert elapsed < 0.8, f"fail-fast still waited {elapsed:.2f}s for slow task"
  451. payload = json.loads(sidecar_path.read_text(encoding="utf-8"))
  452. statuses = sorted(
  453. payload["drawings"][f"im-{letter}"]["llm_analyze_result"]["status"]
  454. for letter in ("A", "B", "C")
  455. )
  456. # Three items → one success (call 1), one failure (call 2), and
  457. # one cancelled (call 3 was killed by fail-fast). All represented
  458. # as failure status_strings except for the success.
  459. assert statuses == ["failure", "failure", "success"]
  460. # Find which item ended up cancelled — its message must say so.
  461. cancelled_items = [
  462. r["message"]
  463. for r in (
  464. payload["drawings"][f"im-{letter}"]["llm_analyze_result"]
  465. for letter in ("A", "B", "C")
  466. )
  467. if r["status"] == "failure" and "cancelled" in r["message"]
  468. ]
  469. assert len(cancelled_items) == 1
  470. forced_items = [
  471. r["message"]
  472. for r in (
  473. payload["drawings"][f"im-{letter}"]["llm_analyze_result"]
  474. for letter in ("A", "B", "C")
  475. )
  476. if r["status"] == "failure" and "forced failure" in r["message"]
  477. ]
  478. assert len(forced_items) == 1
  479. finally:
  480. await _shutdown_role_workers(rag)
  481. await rag.finalize_storages()
  482. @pytest.mark.asyncio
  483. async def test_analyze_multimodal_pre_schedule_cancellation_skips_task_creation(
  484. tmp_path, monkeypatch
  485. ):
  486. """``cancellation_requested`` is already True when analyze_multimodal
  487. enters the sidecar processing loop. The pre-schedule check must
  488. raise immediately, before any per-item VLM task is even constructed
  489. — not merely cancel them before the scheduler yields. Covers the
  490. small window between ``_analyze_worker``'s boundary check and the
  491. per-sidecar task spawn that the polling loop alone would miss.
  492. Asserts both ``vlm_invocations == 0`` (no work executed) AND that
  493. ``asyncio.create_task`` was never called for any
  494. ``_run_with_progress_log`` coroutine — distinguishing the
  495. early-raise implementation from a poll-then-cancel implementation
  496. that would still construct and immediately cancel each task.
  497. """
  498. from .test_pipeline_analyze_multimodal import PNG_BYTES
  499. parsed_dir = tmp_path / "parsed"
  500. parsed_dir.mkdir()
  501. image_path = parsed_dir / "im-X.png"
  502. image_path.write_bytes(PNG_BYTES)
  503. blocks_path = parsed_dir / "doc.blocks.jsonl"
  504. blocks_path.write_text(
  505. json.dumps({"type": "meta", "doc_id": "doc-1"}) + "\n",
  506. encoding="utf-8",
  507. )
  508. sidecar_path = parsed_dir / "doc.drawings.json"
  509. sidecar_path.write_text(
  510. json.dumps({"drawings": {"im-X": {"caption": "X", "path": str(image_path)}}}),
  511. encoding="utf-8",
  512. )
  513. parsed_data = {"blocks_path": str(blocks_path)}
  514. vlm_invocations = 0
  515. async def tripwire_vlm(prompt, **kwargs):
  516. nonlocal vlm_invocations
  517. vlm_invocations += 1
  518. return json.dumps(
  519. {"name": "X", "type": "Chart", "description": "must not be called"}
  520. )
  521. # Spy on asyncio.create_task to count per-item tasks spawned by
  522. # analyze_multimodal. The per-item coroutine is _run_with_progress_log
  523. # (a closure defined inside analyze_multimodal), so filter by qualname.
  524. progress_log_tasks_created = 0
  525. original_create_task = asyncio.create_task
  526. def spy_create_task(coro, *args, **kwargs):
  527. nonlocal progress_log_tasks_created
  528. name = getattr(coro, "__qualname__", "") or getattr(
  529. getattr(coro, "cr_code", None), "co_qualname", ""
  530. )
  531. if "_run_with_progress_log" in name:
  532. progress_log_tasks_created += 1
  533. return original_create_task(coro, *args, **kwargs)
  534. monkeypatch.setattr(asyncio, "create_task", spy_create_task)
  535. rag = _build_rag(tmp_path, vlm_func=tripwire_vlm)
  536. await rag.initialize_storages()
  537. try:
  538. pipeline_status: dict = {
  539. "busy": True,
  540. "history_messages": [],
  541. "latest_message": "",
  542. "cancellation_requested": True, # set BEFORE the call
  543. }
  544. pipeline_status_lock = asyncio.Lock()
  545. with pytest.raises(PipelineCancelledException):
  546. await rag.analyze_multimodal(
  547. doc_id="doc-1",
  548. file_path="fixture.pdf",
  549. parsed_data=parsed_data,
  550. process_options="i",
  551. pipeline_status=pipeline_status,
  552. pipeline_status_lock=pipeline_status_lock,
  553. )
  554. # Stronger than "no work ran": the per-item task object was
  555. # never even constructed. A poll-then-cancel implementation
  556. # would still spawn and cancel — this assertion rules that out.
  557. assert progress_log_tasks_created == 0
  558. assert vlm_invocations == 0
  559. finally:
  560. await _shutdown_role_workers(rag)
  561. await rag.finalize_storages()