test_faiss_meta_inconsistency.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """
  2. Regression tests for Faiss meta/index inconsistency handling.
  3. Verifies that FaissVectorDBStorage gracefully handles cases where
  4. meta.json has more rows than the Faiss index (e.g., after a crash
  5. during save), and that delete/upsert operations don't crash.
  6. """
  7. import json
  8. import os
  9. import tempfile
  10. import numpy as np
  11. import pytest
  12. faiss = pytest.importorskip("faiss")
  13. @pytest.mark.offline
  14. class TestFaissMetaInconsistency:
  15. """Test that stale metadata rows are handled gracefully."""
  16. def _create_index_and_meta(self, tmp_dir, dim=4, n_vectors=3, n_extra_meta=2):
  17. """
  18. Helper: create a Faiss index with `n_vectors` vectors and a meta.json
  19. that has `n_vectors + n_extra_meta` entries (simulating a crash where
  20. meta was written but index wasn't fully updated).
  21. """
  22. index_file = os.path.join(tmp_dir, "faiss_index_test.index")
  23. meta_file = index_file + ".meta.json"
  24. # Build real index with n_vectors
  25. index = faiss.IndexFlatIP(dim)
  26. vectors = np.random.rand(n_vectors, dim).astype(np.float32)
  27. # Normalize for cosine similarity
  28. norms = np.linalg.norm(vectors, axis=1, keepdims=True)
  29. vectors = vectors / norms
  30. index.add(vectors)
  31. faiss.write_index(index, index_file)
  32. # Build meta with extra rows beyond index.ntotal
  33. meta = {}
  34. for i in range(n_vectors):
  35. meta[str(i)] = {"__id__": f"id_{i}", "content": f"text_{i}"}
  36. for i in range(n_vectors, n_vectors + n_extra_meta):
  37. meta[str(i)] = {"__id__": f"stale_{i}", "content": f"stale_{i}"}
  38. with open(meta_file, "w", encoding="utf-8") as f:
  39. json.dump(meta, f)
  40. return index_file, meta_file, vectors
  41. def test_load_skips_invalid_metadata_rows(self):
  42. """
  43. Loading an index where meta.json has fids beyond index.ntotal
  44. should skip those rows with a warning, not crash.
  45. """
  46. with tempfile.TemporaryDirectory() as tmp_dir:
  47. dim = 4
  48. n_vectors = 3
  49. n_extra = 2
  50. index_file, meta_file, vectors = self._create_index_and_meta(
  51. tmp_dir, dim=dim, n_vectors=n_vectors, n_extra_meta=n_extra
  52. )
  53. # Manually load and verify behavior
  54. index = faiss.read_index(index_file)
  55. with open(meta_file, "r", encoding="utf-8") as f:
  56. stored_dict = json.load(f)
  57. assert len(stored_dict) == n_vectors + n_extra
  58. # Simulate the load logic from _load_faiss_index
  59. id_to_meta = {}
  60. skipped = 0
  61. for fid_str, meta in stored_dict.items():
  62. fid = int(fid_str)
  63. if fid >= index.ntotal:
  64. skipped += 1
  65. continue
  66. if "__vector__" not in meta:
  67. meta["__vector__"] = index.reconstruct(fid).tolist()
  68. id_to_meta[fid] = meta
  69. assert len(id_to_meta) == n_vectors
  70. assert skipped == n_extra
  71. # Verify reconstructed vectors match originals
  72. for fid in range(n_vectors):
  73. reconstructed = np.array(
  74. id_to_meta[fid]["__vector__"], dtype=np.float32
  75. )
  76. np.testing.assert_allclose(reconstructed, vectors[fid], atol=1e-6)
  77. def test_remove_with_missing_vector_uses_reconstruct(self):
  78. """
  79. _remove_faiss_ids should reconstruct vectors from the index
  80. when __vector__ is not present in metadata.
  81. """
  82. dim = 4
  83. n_vectors = 3
  84. index = faiss.IndexFlatIP(dim)
  85. vectors = np.random.rand(n_vectors, dim).astype(np.float32)
  86. norms = np.linalg.norm(vectors, axis=1, keepdims=True)
  87. vectors = vectors / norms
  88. index.add(vectors)
  89. # Metadata WITHOUT __vector__ (as stored on disk after our PR)
  90. id_to_meta = {}
  91. for i in range(n_vectors):
  92. id_to_meta[i] = {"__id__": f"id_{i}", "content": f"text_{i}"}
  93. # Simulate rebuild logic from _remove_faiss_ids (remove fid=1)
  94. fid_list = [1]
  95. keep_fids = [fid for fid in id_to_meta if fid not in fid_list]
  96. vectors_to_keep = []
  97. new_id_to_meta = {}
  98. for new_fid, old_fid in enumerate(keep_fids):
  99. vec_meta = id_to_meta[old_fid]
  100. if "__vector__" in vec_meta:
  101. vec = vec_meta["__vector__"]
  102. elif old_fid < index.ntotal:
  103. vec = index.reconstruct(old_fid).tolist()
  104. vec_meta["__vector__"] = vec
  105. else:
  106. continue
  107. vectors_to_keep.append(vec)
  108. new_id_to_meta[new_fid] = vec_meta
  109. assert len(vectors_to_keep) == 2
  110. assert len(new_id_to_meta) == 2
  111. # Verify the kept vectors match originals (fid 0 and 2)
  112. np.testing.assert_allclose(
  113. np.array(vectors_to_keep[0], dtype=np.float32), vectors[0], atol=1e-6
  114. )
  115. np.testing.assert_allclose(
  116. np.array(vectors_to_keep[1], dtype=np.float32), vectors[2], atol=1e-6
  117. )
  118. def test_atomic_save_meta(self):
  119. """
  120. _save_faiss_index should write meta.json atomically via temp file + os.replace.
  121. Verify no .tmp file remains after save.
  122. """
  123. with tempfile.TemporaryDirectory() as tmp_dir:
  124. meta_file = os.path.join(tmp_dir, "test.meta.json")
  125. tmp_meta_file = meta_file + ".tmp"
  126. serializable_dict = {"0": {"__id__": "id_0", "content": "text_0"}}
  127. # Simulate atomic write
  128. with open(tmp_meta_file, "w", encoding="utf-8") as f:
  129. json.dump(serializable_dict, f)
  130. os.replace(tmp_meta_file, meta_file)
  131. assert os.path.exists(meta_file)
  132. assert not os.path.exists(tmp_meta_file)
  133. with open(meta_file, "r", encoding="utf-8") as f:
  134. loaded = json.load(f)
  135. assert loaded == serializable_dict