test_qdrant_upsert_batching.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import asyncio
  2. from unittest.mock import MagicMock
  3. import numpy as np
  4. import pytest
  5. pytest.importorskip(
  6. "qdrant_client",
  7. reason="qdrant-client is required for Qdrant storage tests",
  8. )
  9. from qdrant_client import models # noqa: E402
  10. from lightrag.kg.qdrant_impl import QdrantVectorDBStorage # noqa: E402
  11. def _make_point(point_id: str, content: str) -> models.PointStruct:
  12. return models.PointStruct(
  13. id=point_id,
  14. vector=[0.1, 0.2, 0.3],
  15. payload={"id": point_id, "content": content},
  16. )
  17. def test_build_upsert_batches_respects_point_limit():
  18. points = [_make_point(str(i), "x" * 10) for i in range(5)]
  19. batches = QdrantVectorDBStorage._build_upsert_batches(
  20. points, max_payload_bytes=1024 * 1024, max_points_per_batch=2
  21. )
  22. assert [len(batch_points) for batch_points, _ in batches] == [2, 2, 1]
  23. def test_build_upsert_batches_exact_payload_boundary_no_split():
  24. point_a = _make_point("a", "x" * 32)
  25. point_b = _make_point("b", "y" * 32)
  26. size_a = QdrantVectorDBStorage._estimate_point_payload_bytes(point_a)
  27. size_b = QdrantVectorDBStorage._estimate_point_payload_bytes(point_b)
  28. # JSON array envelope: [] => 2 bytes, and comma between two elements => 1 byte
  29. exact_limit = 2 + size_a + 1 + size_b
  30. batches = QdrantVectorDBStorage._build_upsert_batches(
  31. [point_a, point_b],
  32. max_payload_bytes=exact_limit,
  33. max_points_per_batch=128,
  34. )
  35. assert len(batches) == 1
  36. assert len(batches[0][0]) == 2
  37. assert batches[0][1] == exact_limit
  38. def test_build_upsert_batches_raises_for_single_oversized_point():
  39. point = _make_point("oversized", "x" * 64)
  40. point_size = QdrantVectorDBStorage._estimate_point_payload_bytes(point)
  41. too_small_limit = point_size + 1
  42. with pytest.raises(ValueError, match="Single Qdrant point exceeds payload limit"):
  43. QdrantVectorDBStorage._build_upsert_batches(
  44. [point],
  45. max_payload_bytes=too_small_limit,
  46. max_points_per_batch=128,
  47. )
  48. @pytest.mark.asyncio
  49. async def test_flush_fail_fast_stops_on_first_failed_batch():
  50. """Flush-time fail-fast: once any batch raises, subsequent batches are skipped.
  51. Mirrors the pre-deferred-embedding `upsert()` contract: the failure
  52. bubbles out of `_flush_pending_vector_ops`, and the buffer is preserved
  53. so the next flush can retry.
  54. """
  55. storage = QdrantVectorDBStorage.__new__(QdrantVectorDBStorage)
  56. storage.workspace = "test_ws"
  57. storage.namespace = "chunks"
  58. storage.effective_workspace = "test_ws"
  59. storage.meta_fields = {"content"}
  60. storage._max_batch_size = 16
  61. storage._max_upsert_payload_bytes = 1024 * 1024
  62. storage._max_upsert_points_per_batch = 2
  63. storage.final_namespace = "test_collection"
  64. storage._client = MagicMock()
  65. storage._pending_vector_docs = {}
  66. storage._pending_vector_deletes = set()
  67. storage._flush_lock = asyncio.Lock()
  68. async def fake_embedding_func(texts, **kwargs):
  69. return np.array([[float(len(text)), 0.0] for text in texts], dtype=np.float32)
  70. storage.embedding_func = fake_embedding_func
  71. storage._client.upsert.side_effect = [None, RuntimeError("batch failed"), None]
  72. data = {f"chunk-{i}": {"content": f"content-{i}"} for i in range(5)}
  73. # `upsert` only buffers; the failure surfaces from `_flush_pending_vector_ops`.
  74. await storage.upsert(data)
  75. assert len(storage._pending_vector_docs) == 5
  76. with pytest.raises(RuntimeError, match="batch failed"):
  77. await storage._flush_pending_vector_ops()
  78. # 5 items with max 2 points per batch => expected 3 batches, but stop at batch #2 on error.
  79. assert storage._client.upsert.call_count == 2
  80. first_call = storage._client.upsert.call_args_list[0]
  81. second_call = storage._client.upsert.call_args_list[1]
  82. assert len(first_call.kwargs["points"]) == 2
  83. assert len(second_call.kwargs["points"]) == 2
  84. # Buffer is preserved so the next flush can retry.
  85. assert len(storage._pending_vector_docs) == 5