| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- """Atomicity coverage for ``lightrag.utils.write_json``.
- The two storages that ride on this function (``JsonDocStatusStorage``,
- ``JsonKVStorage``) inherit crash safety from it, so the contract lives here:
- - A crash during the rename leaves the prior snapshot intact.
- - The sanitize fallback also lands atomically (one tmp, one rename).
- """
- import json
- import os
- import threading
- from unittest.mock import patch
- import pytest
- from lightrag.utils import write_json
- @pytest.mark.offline
- def test_write_json_publishes_clean_payload(tmp_path):
- target = str(tmp_path / "kv.json")
- needs_reload = write_json({"a": 1, "b": "hello"}, target)
- assert needs_reload is False
- assert json.load(open(target)) == {"a": 1, "b": "hello"}
- assert [p for p in os.listdir(tmp_path) if ".tmp." in p] == []
- @pytest.mark.offline
- def test_write_json_replace_crash_preserves_prior_snapshot(tmp_path):
- target = str(tmp_path / "kv.json")
- write_json({"v": 1}, target)
- with patch(
- "lightrag.file_atomic.os.replace",
- side_effect=OSError("simulated crash"),
- ):
- with pytest.raises(OSError, match="simulated crash"):
- write_json({"v": 2}, target)
- assert json.load(open(target)) == {"v": 1}
- leftovers = [p for p in os.listdir(tmp_path) if ".tmp." in p]
- assert leftovers == [], f"write_json must clean tmp on crash, got {leftovers}"
- @pytest.mark.offline
- def test_write_json_concurrent_writers_land_intact(tmp_path):
- """Multiple threads racing on the same destination must each rename
- cleanly. The final file must be valid JSON (one writer's payload)."""
- target = str(tmp_path / "kv.json")
- errors: list[BaseException] = []
- barrier = threading.Barrier(5)
- def writer(tid: int) -> None:
- try:
- barrier.wait()
- write_json({"writer": tid}, target)
- except BaseException as exc:
- errors.append(exc)
- threads = [threading.Thread(target=writer, args=(i,)) for i in range(5)]
- for t in threads:
- t.start()
- for t in threads:
- t.join()
- assert errors == [], f"Concurrent writers raised: {errors}"
- payload = json.load(open(target))
- assert payload.keys() == {"writer"}
- assert payload["writer"] in range(5)
- leftovers = [p for p in os.listdir(tmp_path) if ".tmp." in p]
- assert leftovers == [], f"Unexpected tmp residue: {leftovers}"
|