test_atomic_write_write_json.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. """Atomicity coverage for ``lightrag.utils.write_json``.
  2. The two storages that ride on this function (``JsonDocStatusStorage``,
  3. ``JsonKVStorage``) inherit crash safety from it, so the contract lives here:
  4. - A crash during the rename leaves the prior snapshot intact.
  5. - The sanitize fallback also lands atomically (one tmp, one rename).
  6. """
  7. import json
  8. import os
  9. import threading
  10. from unittest.mock import patch
  11. import pytest
  12. from lightrag.utils import write_json
  13. @pytest.mark.offline
  14. def test_write_json_publishes_clean_payload(tmp_path):
  15. target = str(tmp_path / "kv.json")
  16. needs_reload = write_json({"a": 1, "b": "hello"}, target)
  17. assert needs_reload is False
  18. assert json.load(open(target)) == {"a": 1, "b": "hello"}
  19. assert [p for p in os.listdir(tmp_path) if ".tmp." in p] == []
  20. @pytest.mark.offline
  21. def test_write_json_replace_crash_preserves_prior_snapshot(tmp_path):
  22. target = str(tmp_path / "kv.json")
  23. write_json({"v": 1}, target)
  24. with patch(
  25. "lightrag.file_atomic.os.replace",
  26. side_effect=OSError("simulated crash"),
  27. ):
  28. with pytest.raises(OSError, match="simulated crash"):
  29. write_json({"v": 2}, target)
  30. assert json.load(open(target)) == {"v": 1}
  31. leftovers = [p for p in os.listdir(tmp_path) if ".tmp." in p]
  32. assert leftovers == [], f"write_json must clean tmp on crash, got {leftovers}"
  33. @pytest.mark.offline
  34. def test_write_json_concurrent_writers_land_intact(tmp_path):
  35. """Multiple threads racing on the same destination must each rename
  36. cleanly. The final file must be valid JSON (one writer's payload)."""
  37. target = str(tmp_path / "kv.json")
  38. errors: list[BaseException] = []
  39. barrier = threading.Barrier(5)
  40. def writer(tid: int) -> None:
  41. try:
  42. barrier.wait()
  43. write_json({"writer": tid}, target)
  44. except BaseException as exc:
  45. errors.append(exc)
  46. threads = [threading.Thread(target=writer, args=(i,)) for i in range(5)]
  47. for t in threads:
  48. t.start()
  49. for t in threads:
  50. t.join()
  51. assert errors == [], f"Concurrent writers raised: {errors}"
  52. payload = json.load(open(target))
  53. assert payload.keys() == {"writer"}
  54. assert payload["writer"] in range(5)
  55. leftovers = [p for p in os.listdir(tmp_path) if ".tmp." in p]
  56. assert leftovers == [], f"Unexpected tmp residue: {leftovers}"