test_milvus_index_creation.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. """
  2. Tests for Milvus index creation behavior
  3. This test suite validates:
  4. 1. P1: build_index_params uses compatibility helper
  5. 2. P2: Vector index creation failures are surfaced to callers
  6. """
  7. import asyncio
  8. import pytest
  9. from unittest.mock import MagicMock, patch
  10. from lightrag.kg.milvus_impl import MilvusVectorDBStorage, MilvusIndexConfig
  11. @pytest.mark.offline
  12. class TestMilvusIndexCreation:
  13. """Test index creation behavior and error handling"""
  14. def test_vector_index_creation_failure_is_raised(self):
  15. """Test that vector index creation failures are raised to the caller (P2 fix)"""
  16. # Setup storage instance
  17. mock_embedding_func = MagicMock()
  18. mock_embedding_func.embedding_dim = 128
  19. storage = MilvusVectorDBStorage(
  20. namespace="test_entities",
  21. workspace="test_workspace",
  22. global_config={
  23. "embedding_batch_num": 100,
  24. "vector_db_storage_cls_kwargs": {
  25. "cosine_better_than_threshold": 0.3,
  26. "index_type": "HNSW",
  27. },
  28. },
  29. embedding_func=mock_embedding_func,
  30. meta_fields=set(),
  31. )
  32. # Mock the client and _get_index_params
  33. mock_client = MagicMock()
  34. mock_index_params = MagicMock()
  35. storage._client = mock_client
  36. storage.final_namespace = "test_entities"
  37. # Mock _get_index_params to return a valid IndexParams
  38. with patch.object(storage, "_get_index_params", return_value=mock_index_params):
  39. # Mock build_index_params to return the mock_index_params
  40. with patch.object(
  41. storage.index_config,
  42. "build_index_params",
  43. return_value=mock_index_params,
  44. ):
  45. # Mock create_index to raise an exception (simulating index creation failure)
  46. mock_client.create_index.side_effect = Exception(
  47. "Index creation failed"
  48. )
  49. # Verify that the exception is raised (not caught and logged)
  50. with pytest.raises(Exception, match="Index creation failed"):
  51. storage._create_indexes_after_collection()
  52. def test_scalar_index_creation_failure_is_logged_not_raised(self):
  53. """Test that scalar index creation failures are logged but not raised (existing behavior)"""
  54. # Setup storage instance
  55. mock_embedding_func = MagicMock()
  56. mock_embedding_func.embedding_dim = 128
  57. storage = MilvusVectorDBStorage(
  58. namespace="test_entities",
  59. workspace="test_workspace",
  60. global_config={
  61. "embedding_batch_num": 100,
  62. "vector_db_storage_cls_kwargs": {
  63. "cosine_better_than_threshold": 0.3,
  64. "index_type": "AUTOINDEX", # No custom vector index
  65. },
  66. },
  67. embedding_func=mock_embedding_func,
  68. meta_fields=set(),
  69. )
  70. # Mock the client and _get_index_params
  71. mock_client = MagicMock()
  72. mock_index_params = MagicMock()
  73. storage._client = mock_client
  74. storage.final_namespace = "test_entities"
  75. # Mock _get_index_params to return a valid IndexParams for scalar indexes
  76. with patch.object(storage, "_get_index_params", return_value=mock_index_params):
  77. # Let vector AUTOINDEX creation succeed, then fail on scalar index creation
  78. mock_client.create_index.side_effect = [
  79. None,
  80. Exception("Scalar index creation failed"),
  81. ]
  82. # Verify that the function completes without raising (scalar index failures are logged)
  83. # This should not raise an exception
  84. storage._create_indexes_after_collection()
  85. # The function should complete successfully even though scalar index creation failed
  86. def test_build_index_params_uses_passed_index_params(self):
  87. """Test that build_index_params uses the passed index_params parameter (P1 fix)"""
  88. config = MilvusIndexConfig(
  89. index_type="HNSW",
  90. metric_type="COSINE",
  91. hnsw_m=32,
  92. hnsw_ef_construction=256,
  93. )
  94. mock_index_params = MagicMock()
  95. # Call build_index_params with the mock_index_params
  96. result = config.build_index_params(mock_index_params)
  97. # Verify that it used the passed index_params
  98. assert result == mock_index_params
  99. mock_index_params.add_index.assert_called_once()
  100. def test_build_index_params_raises_when_index_params_is_none_for_custom_type(self):
  101. """Test that build_index_params raises RuntimeError when index_params is None for custom types (P1 fix)"""
  102. config = MilvusIndexConfig(
  103. index_type="HNSW",
  104. metric_type="COSINE",
  105. )
  106. # Call with None (simulating compatibility helper returning None)
  107. # Should raise RuntimeError for non-AUTOINDEX types
  108. with pytest.raises(RuntimeError, match="IndexParams not available"):
  109. config.build_index_params(None)
  110. def test_build_index_params_returns_none_for_autoindex_when_index_params_is_none(
  111. self,
  112. ):
  113. """Test AUTOINDEX falls back to direct API parameters when IndexParams is unavailable."""
  114. config = MilvusIndexConfig(
  115. index_type="AUTOINDEX",
  116. metric_type="COSINE",
  117. )
  118. # AUTOINDEX should still produce direct API parameters
  119. result = config.build_index_params(None)
  120. assert result == {
  121. "field_name": "vector",
  122. "index_type": "AUTOINDEX",
  123. "metric_type": "COSINE",
  124. "params": {},
  125. }
  126. def test_build_index_params_autoindex_uses_index_params_object(self):
  127. """Test AUTOINDEX still creates an explicit vector index when IndexParams is available."""
  128. config = MilvusIndexConfig(
  129. index_type="AUTOINDEX",
  130. metric_type="COSINE",
  131. )
  132. mock_index_params = MagicMock()
  133. result = config.build_index_params(mock_index_params)
  134. assert result == mock_index_params
  135. mock_index_params.add_index.assert_called_once_with(
  136. field_name="vector",
  137. index_type="AUTOINDEX",
  138. metric_type="COSINE",
  139. params={},
  140. )
  141. def test_create_indexes_uses_compatibility_helper(self):
  142. """Test that _create_indexes_after_collection uses _get_index_params (P1 fix)"""
  143. # Setup storage instance
  144. mock_embedding_func = MagicMock()
  145. mock_embedding_func.embedding_dim = 128
  146. storage = MilvusVectorDBStorage(
  147. namespace="test_entities",
  148. workspace="test_workspace",
  149. global_config={
  150. "embedding_batch_num": 100,
  151. "vector_db_storage_cls_kwargs": {
  152. "cosine_better_than_threshold": 0.3,
  153. "index_type": "HNSW",
  154. },
  155. },
  156. embedding_func=mock_embedding_func,
  157. meta_fields=set(),
  158. )
  159. # Mock the client
  160. mock_client = MagicMock()
  161. mock_index_params = MagicMock()
  162. storage._client = mock_client
  163. storage.final_namespace = "test_entities"
  164. # Spy on _get_index_params to verify it's called
  165. with patch.object(
  166. storage, "_get_index_params", return_value=mock_index_params
  167. ) as mock_get_index_params:
  168. # Call the method
  169. storage._create_indexes_after_collection()
  170. # Verify that _get_index_params was called at least once
  171. assert mock_get_index_params.call_count >= 1
  172. def test_version_probing_only_for_hnsw_sq(self):
  173. """Test that get_server_version is only called when index type requires it (P2 fix)"""
  174. from unittest.mock import AsyncMock
  175. mock_embedding_func = MagicMock()
  176. mock_embedding_func.embedding_dim = 128
  177. # Test with HNSW (no version requirement) - should NOT call get_server_version
  178. storage = MilvusVectorDBStorage(
  179. namespace="test_entities",
  180. workspace="test_workspace",
  181. global_config={
  182. "embedding_batch_num": 100,
  183. "vector_db_storage_cls_kwargs": {
  184. "cosine_better_than_threshold": 0.3,
  185. "index_type": "HNSW",
  186. },
  187. },
  188. embedding_func=mock_embedding_func,
  189. meta_fields=set(),
  190. )
  191. mock_client = MagicMock()
  192. storage._client = mock_client
  193. # Mock the init lock as an async context manager
  194. mock_lock = AsyncMock()
  195. with patch(
  196. "lightrag.kg.milvus_impl.get_data_init_lock", return_value=mock_lock
  197. ):
  198. with patch.object(storage, "_create_collection_if_not_exist"):
  199. asyncio.run(storage.initialize())
  200. # get_server_version should NOT be called for HNSW
  201. mock_client.get_server_version.assert_not_called()
  202. def test_version_probing_called_for_hnsw_sq(self):
  203. """Test that get_server_version IS called when HNSW_SQ is configured (P2 fix)"""
  204. from unittest.mock import AsyncMock
  205. mock_embedding_func = MagicMock()
  206. mock_embedding_func.embedding_dim = 128
  207. storage = MilvusVectorDBStorage(
  208. namespace="test_entities",
  209. workspace="test_workspace",
  210. global_config={
  211. "embedding_batch_num": 100,
  212. "vector_db_storage_cls_kwargs": {
  213. "cosine_better_than_threshold": 0.3,
  214. "index_type": "HNSW_SQ",
  215. },
  216. },
  217. embedding_func=mock_embedding_func,
  218. meta_fields=set(),
  219. )
  220. mock_client = MagicMock()
  221. mock_client.get_server_version.return_value = "2.6.9"
  222. storage._client = mock_client
  223. # Mock the init lock as an async context manager
  224. mock_lock = AsyncMock()
  225. with patch(
  226. "lightrag.kg.milvus_impl.get_data_init_lock", return_value=mock_lock
  227. ):
  228. with patch.object(storage, "_create_collection_if_not_exist"):
  229. asyncio.run(storage.initialize())
  230. # get_server_version SHOULD be called for HNSW_SQ
  231. mock_client.get_server_version.assert_called_once()
  232. def test_initialize_creates_missing_database_before_collection_setup(self):
  233. """Test that initialize bootstraps a missing configured Milvus database."""
  234. from unittest.mock import AsyncMock
  235. mock_embedding_func = MagicMock()
  236. mock_embedding_func.embedding_dim = 128
  237. storage = MilvusVectorDBStorage(
  238. namespace="test_entities",
  239. workspace="space1",
  240. global_config={
  241. "embedding_batch_num": 100,
  242. "working_dir": "/tmp/lightrag",
  243. "vector_db_storage_cls_kwargs": {
  244. "cosine_better_than_threshold": 0.3,
  245. },
  246. },
  247. embedding_func=mock_embedding_func,
  248. meta_fields=set(),
  249. )
  250. bootstrap_client = MagicMock()
  251. bootstrap_client.list_databases.return_value = ["default"]
  252. mock_lock = AsyncMock()
  253. with patch.dict(
  254. "os.environ",
  255. {
  256. "MILVUS_URI": "http://milvus:19530",
  257. "MILVUS_DB_NAME": "lightrag",
  258. },
  259. clear=False,
  260. ):
  261. with patch(
  262. "lightrag.kg.milvus_impl.MilvusClient", return_value=bootstrap_client
  263. ) as mock_client_cls:
  264. with patch(
  265. "lightrag.kg.milvus_impl.get_data_init_lock",
  266. return_value=mock_lock,
  267. ):
  268. with patch.object(storage, "_create_collection_if_not_exist"):
  269. asyncio.run(storage.initialize())
  270. mock_client_cls.assert_called_once_with(
  271. uri="http://milvus:19530",
  272. user=None,
  273. password=None,
  274. token=None,
  275. )
  276. bootstrap_client.list_databases.assert_called_once_with()
  277. bootstrap_client.create_database.assert_called_once_with("lightrag")
  278. bootstrap_client.use_database.assert_called_once_with("lightrag")
  279. def test_initialize_uses_existing_database_without_recreating_it(self):
  280. """Test that initialize switches to an existing configured Milvus database."""
  281. from unittest.mock import AsyncMock
  282. mock_embedding_func = MagicMock()
  283. mock_embedding_func.embedding_dim = 128
  284. storage = MilvusVectorDBStorage(
  285. namespace="test_entities",
  286. workspace="space1",
  287. global_config={
  288. "embedding_batch_num": 100,
  289. "working_dir": "/tmp/lightrag",
  290. "vector_db_storage_cls_kwargs": {
  291. "cosine_better_than_threshold": 0.3,
  292. },
  293. },
  294. embedding_func=mock_embedding_func,
  295. meta_fields=set(),
  296. )
  297. bootstrap_client = MagicMock()
  298. bootstrap_client.list_databases.return_value = ["default", "lightrag"]
  299. mock_lock = AsyncMock()
  300. with patch.dict(
  301. "os.environ",
  302. {
  303. "MILVUS_URI": "http://milvus:19530",
  304. "MILVUS_DB_NAME": "lightrag",
  305. },
  306. clear=False,
  307. ):
  308. with patch(
  309. "lightrag.kg.milvus_impl.MilvusClient", return_value=bootstrap_client
  310. ):
  311. with patch(
  312. "lightrag.kg.milvus_impl.get_data_init_lock",
  313. return_value=mock_lock,
  314. ):
  315. with patch.object(storage, "_create_collection_if_not_exist"):
  316. asyncio.run(storage.initialize())
  317. bootstrap_client.list_databases.assert_called_once_with()
  318. bootstrap_client.create_database.assert_not_called()
  319. bootstrap_client.use_database.assert_called_once_with("lightrag")
  320. def test_existing_collection_missing_vector_index_is_repaired(self):
  321. """Existing collections missing vector indexes should be repaired automatically."""
  322. mock_embedding_func = MagicMock()
  323. mock_embedding_func.embedding_dim = 128
  324. storage = MilvusVectorDBStorage(
  325. namespace="entities",
  326. workspace="space1",
  327. global_config={
  328. "embedding_batch_num": 100,
  329. "working_dir": "/tmp/lightrag",
  330. "vector_db_storage_cls_kwargs": {
  331. "cosine_better_than_threshold": 0.3,
  332. },
  333. },
  334. embedding_func=mock_embedding_func,
  335. meta_fields=set(),
  336. )
  337. storage.final_namespace = "space1_entities"
  338. storage._client = MagicMock()
  339. storage._client.has_collection.return_value = True
  340. load_error = RuntimeError(
  341. "there is no vector index on field: [vector], please create index firstly"
  342. )
  343. with patch.object(storage._client, "describe_collection", return_value={}):
  344. with patch.object(storage, "_validate_collection_compatibility"):
  345. with patch.object(
  346. storage,
  347. "_ensure_collection_loaded",
  348. side_effect=[load_error, None],
  349. ) as mock_load:
  350. with patch.object(
  351. storage, "_repair_missing_vector_index"
  352. ) as mock_repair:
  353. storage._create_collection_if_not_exist()
  354. assert mock_load.call_count == 2
  355. mock_repair.assert_called_once_with()
  356. def test_existing_collection_index_repair_failure_has_precise_error(self):
  357. """Index repair failures should not be reported as collection validation failures."""
  358. mock_embedding_func = MagicMock()
  359. mock_embedding_func.embedding_dim = 128
  360. storage = MilvusVectorDBStorage(
  361. namespace="entities",
  362. workspace="space1",
  363. global_config={
  364. "embedding_batch_num": 100,
  365. "working_dir": "/tmp/lightrag",
  366. "vector_db_storage_cls_kwargs": {
  367. "cosine_better_than_threshold": 0.3,
  368. },
  369. },
  370. embedding_func=mock_embedding_func,
  371. meta_fields=set(),
  372. )
  373. storage.final_namespace = "space1_entities"
  374. storage._client = MagicMock()
  375. storage._client.has_collection.return_value = True
  376. load_error = RuntimeError(
  377. "there is no vector index on field: [vector], please create index firstly"
  378. )
  379. with patch.object(storage._client, "describe_collection", return_value={}):
  380. with patch.object(storage, "_validate_collection_compatibility"):
  381. with patch.object(
  382. storage, "_ensure_collection_loaded", side_effect=load_error
  383. ):
  384. with patch.object(
  385. storage,
  386. "_repair_missing_vector_index",
  387. side_effect=RuntimeError("create index failed"),
  388. ):
  389. with pytest.raises(
  390. RuntimeError,
  391. match="Index repair failed for collection 'space1_entities'",
  392. ):
  393. storage._create_collection_if_not_exist()
  394. def test_existing_collection_non_index_validation_failure_still_raises(self):
  395. """Non-index validation failures should still stop initialization."""
  396. mock_embedding_func = MagicMock()
  397. mock_embedding_func.embedding_dim = 128
  398. storage = MilvusVectorDBStorage(
  399. namespace="entities",
  400. workspace="space1",
  401. global_config={
  402. "embedding_batch_num": 100,
  403. "working_dir": "/tmp/lightrag",
  404. "vector_db_storage_cls_kwargs": {
  405. "cosine_better_than_threshold": 0.3,
  406. },
  407. },
  408. embedding_func=mock_embedding_func,
  409. meta_fields=set(),
  410. )
  411. storage.final_namespace = "space1_entities"
  412. storage._client = MagicMock()
  413. storage._client.has_collection.return_value = True
  414. with patch.object(storage._client, "describe_collection", return_value={}):
  415. with patch.object(
  416. storage,
  417. "_validate_collection_compatibility",
  418. side_effect=RuntimeError("dimension mismatch"),
  419. ):
  420. with pytest.raises(
  421. RuntimeError,
  422. match="Collection validation failed for 'space1_entities'",
  423. ):
  424. storage._create_collection_if_not_exist()
  425. if __name__ == "__main__":
  426. pytest.main([__file__, "-v"])