test_dimension_mismatch.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. """
  2. Tests for dimension mismatch handling during migration.
  3. This test module verifies that both PostgreSQL and Qdrant storage backends
  4. properly detect and handle vector dimension mismatches when migrating from
  5. legacy collections/tables to new ones with different embedding models.
  6. """
  7. import json
  8. import pytest
  9. from unittest.mock import MagicMock, AsyncMock, patch
  10. from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
  11. from lightrag.kg.postgres_impl import PGVectorStorage
  12. from lightrag.exceptions import DataMigrationError
  13. # Note: Tests should use proper table names that have DDL templates
  14. # Valid base tables: LIGHTRAG_VDB_CHUNKS, LIGHTRAG_VDB_ENTITIES, LIGHTRAG_VDB_RELATIONSHIPS,
  15. # LIGHTRAG_DOC_CHUNKS, LIGHTRAG_DOC_FULL_DOCS, LIGHTRAG_DOC_TEXT_CHUNKS
  16. class TestQdrantDimensionMismatch:
  17. """Test suite for Qdrant dimension mismatch handling."""
  18. def test_qdrant_dimension_mismatch_raises_error(self):
  19. """
  20. Test that Qdrant raises DataMigrationError when dimensions don't match.
  21. Scenario: Legacy collection has 1536d vectors, new model expects 3072d.
  22. Expected: DataMigrationError is raised to prevent data corruption.
  23. """
  24. from qdrant_client import models
  25. # Setup mock client
  26. client = MagicMock()
  27. # Mock legacy collection with 1536d vectors
  28. legacy_collection_info = MagicMock()
  29. legacy_collection_info.config.params.vectors.size = 1536
  30. # Setup collection existence checks
  31. def collection_exists_side_effect(name):
  32. if (
  33. name == "lightrag_vdb_chunks"
  34. ): # legacy (matches _find_legacy_collection pattern)
  35. return True
  36. elif name == "lightrag_chunks_model_3072d": # new
  37. return False
  38. return False
  39. client.collection_exists.side_effect = collection_exists_side_effect
  40. client.get_collection.return_value = legacy_collection_info
  41. client.count.return_value.count = 100 # Legacy has data
  42. # Patch _find_legacy_collection to return the legacy collection name
  43. with patch(
  44. "lightrag.kg.qdrant_impl._find_legacy_collection",
  45. return_value="lightrag_vdb_chunks",
  46. ):
  47. # Call setup_collection with 3072d (different from legacy 1536d)
  48. # Should raise DataMigrationError due to dimension mismatch
  49. with pytest.raises(DataMigrationError) as exc_info:
  50. QdrantVectorDBStorage.setup_collection(
  51. client,
  52. "lightrag_chunks_model_3072d",
  53. namespace="chunks",
  54. workspace="test",
  55. vectors_config=models.VectorParams(
  56. size=3072, distance=models.Distance.COSINE
  57. ),
  58. hnsw_config=models.HnswConfigDiff(
  59. payload_m=16,
  60. m=0,
  61. ),
  62. model_suffix="model_3072d",
  63. )
  64. # Verify error message contains dimension information
  65. assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
  66. # Verify new collection was NOT created (error raised before creation)
  67. client.create_collection.assert_not_called()
  68. # Verify migration was NOT attempted
  69. client.scroll.assert_not_called()
  70. client.upsert.assert_not_called()
  71. def test_qdrant_dimension_match_proceed_migration(self):
  72. """
  73. Test that Qdrant proceeds with migration when dimensions match.
  74. Scenario: Legacy collection has 1536d vectors, new model also expects 1536d.
  75. Expected: Migration proceeds normally.
  76. """
  77. from qdrant_client import models
  78. client = MagicMock()
  79. # Mock legacy collection with 1536d vectors (matching new)
  80. legacy_collection_info = MagicMock()
  81. legacy_collection_info.config.params.vectors.size = 1536
  82. def collection_exists_side_effect(name):
  83. if name == "lightrag_chunks": # legacy
  84. return True
  85. elif name == "lightrag_chunks_model_1536d": # new
  86. return False
  87. return False
  88. client.collection_exists.side_effect = collection_exists_side_effect
  89. client.get_collection.return_value = legacy_collection_info
  90. # Track whether upsert has been called (migration occurred)
  91. migration_done = {"value": False}
  92. def upsert_side_effect(*args, **kwargs):
  93. migration_done["value"] = True
  94. return MagicMock()
  95. client.upsert.side_effect = upsert_side_effect
  96. # Mock count to return different values based on collection name and migration state
  97. # Before migration: new collection has 0 records
  98. # After migration: new collection has 1 record (matching migrated data)
  99. def count_side_effect(collection_name, **kwargs):
  100. result = MagicMock()
  101. if collection_name == "lightrag_chunks": # legacy
  102. result.count = 1 # Legacy has 1 record
  103. elif collection_name == "lightrag_chunks_model_1536d": # new
  104. # Return 0 before migration, 1 after migration
  105. result.count = 1 if migration_done["value"] else 0
  106. else:
  107. result.count = 0
  108. return result
  109. client.count.side_effect = count_side_effect
  110. # Mock scroll to return sample data (1 record for easier verification)
  111. sample_point = MagicMock()
  112. sample_point.id = "test_id"
  113. sample_point.vector = [0.1] * 1536
  114. sample_point.payload = {"id": "test"}
  115. client.scroll.return_value = ([sample_point], None)
  116. # Mock _find_legacy_collection to return the legacy collection name
  117. with patch(
  118. "lightrag.kg.qdrant_impl._find_legacy_collection",
  119. return_value="lightrag_chunks",
  120. ):
  121. # Call setup_collection with matching 1536d
  122. QdrantVectorDBStorage.setup_collection(
  123. client,
  124. "lightrag_chunks_model_1536d",
  125. namespace="chunks",
  126. workspace="test",
  127. vectors_config=models.VectorParams(
  128. size=1536, distance=models.Distance.COSINE
  129. ),
  130. hnsw_config=models.HnswConfigDiff(
  131. payload_m=16,
  132. m=0,
  133. ),
  134. model_suffix="model_1536d",
  135. )
  136. # Verify migration WAS attempted
  137. client.create_collection.assert_called_once()
  138. client.scroll.assert_called()
  139. client.upsert.assert_called()
  140. class TestPostgresDimensionMismatch:
  141. """Test suite for PostgreSQL dimension mismatch handling."""
  142. async def test_postgres_dimension_mismatch_raises_error_metadata(self):
  143. """
  144. Test that PostgreSQL raises DataMigrationError when dimensions don't match.
  145. Scenario: Legacy table has 1536d vectors, new model expects 3072d.
  146. Expected: DataMigrationError is raised to prevent data corruption.
  147. """
  148. # Setup mock database
  149. db = AsyncMock()
  150. # Mock check_table_exists
  151. async def mock_check_table_exists(table_name):
  152. if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy
  153. return True
  154. elif table_name == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
  155. return False
  156. return False
  157. db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
  158. # Mock table existence and dimension checks
  159. async def query_side_effect(query, params, **kwargs):
  160. if "COUNT(*)" in query:
  161. return {"count": 100} # Legacy has data
  162. elif "SELECT content_vector FROM" in query:
  163. # Return sample vector with 1536 dimensions
  164. return {"content_vector": [0.1] * 1536}
  165. return {}
  166. db.query.side_effect = query_side_effect
  167. db.execute = AsyncMock()
  168. db._create_vector_index = AsyncMock()
  169. # Call setup_table with 3072d (different from legacy 1536d)
  170. # Should raise DataMigrationError due to dimension mismatch
  171. with pytest.raises(DataMigrationError) as exc_info:
  172. await PGVectorStorage.setup_table(
  173. db,
  174. "LIGHTRAG_DOC_CHUNKS_model_3072d",
  175. legacy_table_name="LIGHTRAG_DOC_CHUNKS",
  176. base_table="LIGHTRAG_DOC_CHUNKS",
  177. embedding_dim=3072,
  178. workspace="test",
  179. )
  180. # Verify error message contains dimension information
  181. assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
  182. async def test_postgres_dimension_mismatch_raises_error_sampling(self):
  183. """
  184. Test that PostgreSQL raises error when dimensions don't match (via sampling).
  185. Scenario: Legacy table vector sampling detects 1536d vs expected 3072d.
  186. Expected: DataMigrationError is raised to prevent data corruption.
  187. """
  188. db = AsyncMock()
  189. # Mock check_table_exists
  190. async def mock_check_table_exists(table_name):
  191. if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy
  192. return True
  193. elif table_name == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
  194. return False
  195. return False
  196. db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
  197. # Mock table existence and dimension checks
  198. async def query_side_effect(query, params, **kwargs):
  199. if "information_schema.tables" in query:
  200. if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
  201. return {"exists": True}
  202. elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
  203. return {"exists": False}
  204. elif "COUNT(*)" in query:
  205. return {"count": 100} # Legacy has data
  206. elif "SELECT content_vector FROM" in query:
  207. # Return sample vector with 1536 dimensions as a JSON string
  208. return {"content_vector": json.dumps([0.1] * 1536)}
  209. return {}
  210. db.query.side_effect = query_side_effect
  211. db.execute = AsyncMock()
  212. db._create_vector_index = AsyncMock()
  213. # Call setup_table with 3072d (different from legacy 1536d)
  214. # Should raise DataMigrationError due to dimension mismatch
  215. with pytest.raises(DataMigrationError) as exc_info:
  216. await PGVectorStorage.setup_table(
  217. db,
  218. "LIGHTRAG_DOC_CHUNKS_model_3072d",
  219. legacy_table_name="LIGHTRAG_DOC_CHUNKS",
  220. base_table="LIGHTRAG_DOC_CHUNKS",
  221. embedding_dim=3072,
  222. workspace="test",
  223. )
  224. # Verify error message contains dimension information
  225. assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
  226. async def test_postgres_dimension_match_proceed_migration(self):
  227. """
  228. Test that PostgreSQL proceeds with migration when dimensions match.
  229. Scenario: Legacy table has 1536d vectors, new model also expects 1536d.
  230. Expected: Migration proceeds normally.
  231. """
  232. db = AsyncMock()
  233. # Track migration state
  234. migration_done = {"value": False}
  235. # Define exactly 2 records for consistency
  236. mock_records = [
  237. {
  238. "id": "test1",
  239. "content_vector": [0.1] * 1536,
  240. "workspace": "test",
  241. },
  242. {
  243. "id": "test2",
  244. "content_vector": [0.2] * 1536,
  245. "workspace": "test",
  246. },
  247. ]
  248. # Mock check_table_exists
  249. async def mock_check_table_exists(table_name):
  250. if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy exists
  251. return True
  252. elif table_name == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new doesn't exist
  253. return False
  254. return False
  255. db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
  256. async def query_side_effect(query, params, **kwargs):
  257. multirows = kwargs.get("multirows", False)
  258. query_upper = query.upper()
  259. if "information_schema.tables" in query:
  260. if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
  261. return {"exists": True}
  262. elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new
  263. return {"exists": False}
  264. elif "COUNT(*)" in query_upper:
  265. # Return different counts based on table name in query and migration state
  266. if "LIGHTRAG_DOC_CHUNKS_MODEL_1536D" in query_upper:
  267. # After migration: return migrated count, before: return 0
  268. return {
  269. "count": len(mock_records) if migration_done["value"] else 0
  270. }
  271. # Legacy table always has 2 records (matching mock_records)
  272. return {"count": len(mock_records)}
  273. elif "PG_ATTRIBUTE" in query_upper:
  274. return {"vector_dim": 1536} # Legacy has matching 1536d
  275. elif "SELECT" in query_upper and "FROM" in query_upper and multirows:
  276. # Return sample data for migration using keyset pagination
  277. # Handle keyset pagination: params = [workspace, limit] or [workspace, last_id, limit]
  278. if "id >" in query.lower():
  279. # Keyset pagination: params = [workspace, last_id, limit]
  280. last_id = params[1] if len(params) > 1 else None
  281. # Find records after last_id
  282. found_idx = -1
  283. for i, rec in enumerate(mock_records):
  284. if rec["id"] == last_id:
  285. found_idx = i
  286. break
  287. if found_idx >= 0:
  288. return mock_records[found_idx + 1 :]
  289. return []
  290. else:
  291. # First batch: params = [workspace, limit]
  292. return mock_records
  293. return {}
  294. db.query.side_effect = query_side_effect
  295. # Mock _run_with_retry to track when migration happens
  296. migration_executed = []
  297. async def mock_run_with_retry(operation, *args, **kwargs):
  298. migration_executed.append(True)
  299. migration_done["value"] = True
  300. return None
  301. db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
  302. db.execute = AsyncMock()
  303. db._create_vector_index = AsyncMock()
  304. # Call setup_table with matching 1536d
  305. await PGVectorStorage.setup_table(
  306. db,
  307. "LIGHTRAG_DOC_CHUNKS_model_1536d",
  308. legacy_table_name="LIGHTRAG_DOC_CHUNKS",
  309. base_table="LIGHTRAG_DOC_CHUNKS",
  310. embedding_dim=1536,
  311. workspace="test",
  312. )
  313. # Verify migration WAS called (via _run_with_retry for batch operations)
  314. assert len(migration_executed) > 0, "Migration should have been executed"