test_redis_doc_status_lookup.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. """Unit tests for RedisDocStatusStorage basename / content_hash lookups.
  2. These tests do NOT require a live Redis instance — the Redis client is
  3. substituted with an in-memory fake that mirrors just enough of the
  4. ``redis.asyncio`` surface used by ``RedisDocStatusStorage`` (``scan``,
  5. ``pipeline().get/set/exists/delete`` and ``execute``). This keeps the suite
  6. offline-safe and fast.
  7. """
  8. from __future__ import annotations
  9. import json
  10. from unittest.mock import MagicMock
  11. import pytest
  12. from lightrag.base import DocStatus
  13. from lightrag.namespace import NameSpace
  14. pytestmark = pytest.mark.offline
  15. class _DummyEmbeddingFunc:
  16. embedding_dim = 1
  17. max_token_size = 1
  18. async def __call__(self, texts, **kwargs):
  19. return [[0.0] for _ in texts]
  20. def _doc(status: str, file_path: str, content_hash: str | None = None) -> dict:
  21. payload = {
  22. "content_summary": f"{status} summary",
  23. "content_length": 10,
  24. "file_path": file_path,
  25. "status": status,
  26. "created_at": "2024-01-01T00:00:00+00:00",
  27. "updated_at": "2024-01-01T00:00:00+00:00",
  28. "metadata": {},
  29. "error_msg": None,
  30. }
  31. if content_hash is not None:
  32. payload["content_hash"] = content_hash
  33. return payload
  34. class _FakePipeline:
  35. """Mimics redis.asyncio pipeline: commands are queued synchronously and
  36. executed in batch via ``await execute()``."""
  37. def __init__(self, store: dict[str, str]):
  38. self._store = store
  39. self._ops: list[tuple] = []
  40. def get(self, key: str) -> None:
  41. self._ops.append(("get", key))
  42. def set(self, key: str, value: str) -> None:
  43. self._ops.append(("set", key, value))
  44. def exists(self, key: str) -> None:
  45. self._ops.append(("exists", key))
  46. def delete(self, key: str) -> None:
  47. self._ops.append(("delete", key))
  48. async def execute(self) -> list:
  49. results = []
  50. for op in self._ops:
  51. kind = op[0]
  52. if kind == "get":
  53. results.append(self._store.get(op[1]))
  54. elif kind == "set":
  55. self._store[op[1]] = op[2]
  56. results.append(True)
  57. elif kind == "exists":
  58. results.append(1 if op[1] in self._store else 0)
  59. elif kind == "delete":
  60. existed = op[1] in self._store
  61. self._store.pop(op[1], None)
  62. results.append(1 if existed else 0)
  63. self._ops.clear()
  64. return results
  65. class _FakeRedis:
  66. """Tiny in-memory stand-in for the bits of ``redis.asyncio.Redis`` that
  67. ``RedisDocStatusStorage`` actually calls."""
  68. def __init__(self):
  69. self.store: dict[str, str] = {}
  70. async def ping(self):
  71. return True
  72. async def scan(self, *args, **kwargs):
  73. # Signature: scan(cursor, match=..., count=...). args holds the cursor
  74. # positional; we ignore it and return single-shot results (cursor=0)
  75. # so callers stop looping.
  76. _ = args
  77. match = kwargs.get("match", "")
  78. if match.endswith("*"):
  79. prefix = match[:-1]
  80. keys = [k for k in self.store if k.startswith(prefix)]
  81. else:
  82. keys = [k for k in self.store if k == match]
  83. return 0, keys
  84. def scan_iter(self, **kwargs):
  85. # Used by is_empty(); returns an async iterator.
  86. match = kwargs.get("match", "")
  87. prefix = match[:-1] if match.endswith("*") else match
  88. keys = [k for k in self.store if k.startswith(prefix)]
  89. async def _aiter():
  90. for k in keys:
  91. yield k
  92. return _aiter()
  93. def pipeline(self):
  94. return _FakePipeline(self.store)
  95. async def get(self, key: str):
  96. return self.store.get(key)
  97. async def set(self, key: str, value: str):
  98. self.store[key] = value
  99. return True
  100. async def delete(self, *keys: str) -> int:
  101. count = 0
  102. for k in keys:
  103. if k in self.store:
  104. self.store.pop(k)
  105. count += 1
  106. return count
  107. @pytest.fixture
  108. def redis_doc_status(monkeypatch):
  109. """Construct RedisDocStatusStorage with its Redis client replaced by a
  110. fake in-memory store. No network I/O occurs."""
  111. fake = _FakeRedis()
  112. # Stub out the connection pool factory so __post_init__ does not invoke
  113. # the real redis-py ConnectionPool.from_url (which is lazy but still
  114. # parses URLs and caches state we don't want).
  115. monkeypatch.setattr(
  116. "lightrag.kg.redis_impl.RedisConnectionManager.get_pool",
  117. lambda redis_url: MagicMock(name="fake_pool"),
  118. )
  119. monkeypatch.setattr(
  120. "lightrag.kg.redis_impl.RedisConnectionManager.release_pool",
  121. lambda redis_url: None,
  122. )
  123. # Swap the Redis client class used in __post_init__ so any call site that
  124. # reaches self._redis hits the fake.
  125. monkeypatch.setattr(
  126. "lightrag.kg.redis_impl.Redis", lambda connection_pool=None, **_: fake
  127. )
  128. from lightrag.kg.redis_impl import RedisDocStatusStorage
  129. storage = RedisDocStatusStorage(
  130. namespace=NameSpace.DOC_STATUS,
  131. global_config={},
  132. embedding_func=_DummyEmbeddingFunc(),
  133. workspace="test",
  134. )
  135. storage._initialized = True # skip the real ping in initialize()
  136. return storage
  137. def _store_raw(storage, doc_id: str, payload: dict) -> None:
  138. """Write a record directly into the fake redis backing store, bypassing
  139. ``upsert`` so we control the serialized shape (e.g. legacy rows without
  140. a content_hash field)."""
  141. key = f"{storage.final_namespace}:{doc_id}"
  142. storage._redis.store[key] = json.dumps(payload)
  143. async def test_get_doc_by_file_basename_returns_tuple_on_hit(redis_doc_status):
  144. _store_raw(redis_doc_status, "doc-1", _doc(DocStatus.PROCESSED.value, "report.pdf"))
  145. result = await redis_doc_status.get_doc_by_file_basename("report.pdf")
  146. assert result is not None
  147. doc_id, doc_data = result
  148. assert doc_id == "doc-1"
  149. assert doc_data["file_path"] == "report.pdf"
  150. async def test_get_doc_by_file_basename_misses_when_not_present(redis_doc_status):
  151. _store_raw(redis_doc_status, "doc-1", _doc(DocStatus.PROCESSED.value, "report.pdf"))
  152. assert await redis_doc_status.get_doc_by_file_basename("other.pdf") is None
  153. async def test_get_doc_by_file_basename_empty_returns_none(redis_doc_status):
  154. _store_raw(redis_doc_status, "doc-1", _doc(DocStatus.PROCESSED.value, "report.pdf"))
  155. assert await redis_doc_status.get_doc_by_file_basename("") is None
  156. async def test_get_doc_by_file_basename_unknown_source_sentinel(redis_doc_status):
  157. # A record whose file_path itself is the sentinel must not be returned by
  158. # a basename lookup for "unknown_source" — otherwise every unsourced doc
  159. # would collide.
  160. _store_raw(
  161. redis_doc_status, "doc-1", _doc(DocStatus.PROCESSED.value, "unknown_source")
  162. )
  163. assert await redis_doc_status.get_doc_by_file_basename("unknown_source") is None
  164. async def test_get_doc_by_content_hash_returns_tuple_on_hit(redis_doc_status):
  165. _store_raw(
  166. redis_doc_status,
  167. "doc-1",
  168. _doc(DocStatus.PROCESSED.value, "report.pdf", content_hash="abc123"),
  169. )
  170. result = await redis_doc_status.get_doc_by_content_hash("abc123")
  171. assert result is not None
  172. doc_id, doc_data = result
  173. assert doc_id == "doc-1"
  174. assert doc_data["content_hash"] == "abc123"
  175. async def test_get_doc_by_content_hash_misses_when_not_present(redis_doc_status):
  176. _store_raw(
  177. redis_doc_status,
  178. "doc-1",
  179. _doc(DocStatus.PROCESSED.value, "report.pdf", content_hash="abc123"),
  180. )
  181. assert await redis_doc_status.get_doc_by_content_hash("zzz999") is None
  182. async def test_get_doc_by_content_hash_empty_returns_none_even_with_legacy_rows(
  183. redis_doc_status,
  184. ):
  185. # Legacy row written before the content_hash field existed; an empty-string
  186. # query must not match it. The early-return guard protects against this.
  187. _store_raw(
  188. redis_doc_status, "doc-legacy", _doc(DocStatus.PROCESSED.value, "old.pdf")
  189. )
  190. assert await redis_doc_status.get_doc_by_content_hash("") is None
  191. async def test_get_doc_by_content_hash_ignores_legacy_rows(redis_doc_status):
  192. # A legacy row (no content_hash field) must not be returned when querying
  193. # any non-empty hash, because doc_data.get("content_hash") is None and
  194. # None != "abc123".
  195. _store_raw(
  196. redis_doc_status, "doc-legacy", _doc(DocStatus.PROCESSED.value, "old.pdf")
  197. )
  198. _store_raw(
  199. redis_doc_status,
  200. "doc-new",
  201. _doc(DocStatus.PROCESSED.value, "new.pdf", content_hash="abc123"),
  202. )
  203. result = await redis_doc_status.get_doc_by_content_hash("abc123")
  204. assert result is not None
  205. doc_id, _ = result
  206. assert doc_id == "doc-new"