test_milvus_index_config.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. """
  2. Tests for Milvus index configuration
  3. This test suite validates the MilvusIndexConfig class and its integration
  4. with MilvusVectorDBStorage.
  5. """
  6. import pytest
  7. import os
  8. from unittest.mock import patch, MagicMock
  9. from lightrag.kg.milvus_impl import (
  10. MilvusIndexConfig,
  11. SUPPORTED_INDEX_TYPES,
  12. SUPPORTED_METRIC_TYPES,
  13. SUPPORTED_SQ_TYPES,
  14. SUPPORTED_REFINE_TYPES,
  15. )
  16. class TestMilvusIndexConfig:
  17. """MilvusIndexConfig unit tests"""
  18. def test_default_values(self):
  19. """Test default configuration"""
  20. config = MilvusIndexConfig()
  21. assert config.index_type == "AUTOINDEX"
  22. assert config.metric_type == "COSINE"
  23. assert config.hnsw_m == 16
  24. assert config.hnsw_ef_construction == 360
  25. assert config.hnsw_ef == 200
  26. assert config.sq_type == "SQ8"
  27. assert not config.sq_refine
  28. assert config.sq_refine_type == "FP32"
  29. assert config.sq_refine_k == 10
  30. assert config.ivf_nlist == 1024
  31. assert config.ivf_nprobe == 16
  32. def test_env_override(self):
  33. """Test environment variable override"""
  34. with patch.dict(
  35. os.environ,
  36. {
  37. "MILVUS_INDEX_TYPE": "HNSW",
  38. "MILVUS_METRIC_TYPE": "L2",
  39. "MILVUS_HNSW_M": "64",
  40. },
  41. ):
  42. config = MilvusIndexConfig()
  43. assert config.index_type == "HNSW"
  44. assert config.metric_type == "L2"
  45. assert config.hnsw_m == 64
  46. def test_init_param_priority(self):
  47. """Test initialization parameter priority over environment variables"""
  48. with patch.dict(os.environ, {"MILVUS_INDEX_TYPE": "IVF_FLAT"}):
  49. config = MilvusIndexConfig(index_type="HNSW")
  50. assert config.index_type == "HNSW" # Init param takes precedence
  51. def test_case_insensitive_index_type(self):
  52. """Test that index type is case insensitive"""
  53. config = MilvusIndexConfig(index_type="hnsw")
  54. assert config.index_type == "HNSW"
  55. def test_case_insensitive_metric_type(self):
  56. """Test that metric type is case insensitive"""
  57. config = MilvusIndexConfig(metric_type="cosine")
  58. assert config.metric_type == "COSINE"
  59. def test_invalid_index_type(self):
  60. """Test invalid index type raises exception"""
  61. with pytest.raises(ValueError, match="Unsupported index type"):
  62. MilvusIndexConfig(index_type="INVALID_INDEX")
  63. def test_invalid_metric_type(self):
  64. """Test invalid metric type raises exception"""
  65. with pytest.raises(ValueError, match="Unsupported metric type"):
  66. MilvusIndexConfig(metric_type="INVALID_METRIC")
  67. def test_invalid_hnsw_m_range_low(self):
  68. """Test HNSW M parameter range validation (too low)"""
  69. with pytest.raises(ValueError, match="hnsw_m must be in"):
  70. MilvusIndexConfig(hnsw_m=1) # Less than 2
  71. def test_invalid_hnsw_m_range_high(self):
  72. """Test HNSW M parameter range validation (too high)"""
  73. with pytest.raises(ValueError, match="hnsw_m must be in"):
  74. MilvusIndexConfig(hnsw_m=3000) # Greater than 2048
  75. def test_valid_hnsw_m_boundary(self):
  76. """Test HNSW M parameter boundary values"""
  77. config_low = MilvusIndexConfig(hnsw_m=2)
  78. assert config_low.hnsw_m == 2
  79. config_high = MilvusIndexConfig(hnsw_m=2048)
  80. assert config_high.hnsw_m == 2048
  81. def test_invalid_hnsw_ef_construction(self):
  82. """Test HNSW efConstruction validation"""
  83. with pytest.raises(ValueError, match="hnsw_ef_construction must be"):
  84. MilvusIndexConfig(hnsw_ef_construction=0)
  85. def test_invalid_ivf_nlist_low(self):
  86. """Test IVF nlist parameter range validation (too low)"""
  87. with pytest.raises(ValueError, match="ivf_nlist must be in"):
  88. MilvusIndexConfig(ivf_nlist=0)
  89. def test_invalid_ivf_nlist_high(self):
  90. """Test IVF nlist parameter range validation (too high)"""
  91. with pytest.raises(ValueError, match="ivf_nlist must be in"):
  92. MilvusIndexConfig(ivf_nlist=70000)
  93. def test_invalid_sq_type(self):
  94. """Test invalid sq_type"""
  95. with pytest.raises(ValueError, match="Unsupported sq_type"):
  96. MilvusIndexConfig(index_type="HNSW_SQ", sq_type="INVALID")
  97. def test_invalid_refine_type(self):
  98. """Test invalid refine_type"""
  99. with pytest.raises(ValueError, match="Unsupported refine_type"):
  100. MilvusIndexConfig(
  101. index_type="HNSW_SQ", sq_refine=True, sq_refine_type="INVALID"
  102. )
  103. def test_version_validation_hnsw_sq_pass(self):
  104. """Test HNSW_SQ version validation passes with valid versions"""
  105. config = MilvusIndexConfig(index_type="HNSW_SQ")
  106. # Version meets requirement
  107. config.validate_milvus_version("2.6.8") # Exactly required
  108. config.validate_milvus_version("2.6.9") # Above requirement
  109. config.validate_milvus_version("2.7.0") # Higher version
  110. def test_version_validation_hnsw_sq_fail(self):
  111. """Test HNSW_SQ version validation fails with invalid versions"""
  112. config = MilvusIndexConfig(index_type="HNSW_SQ")
  113. # Version does not meet requirement
  114. with pytest.raises(ValueError, match="HNSW_SQ requires Milvus 2.6.8"):
  115. config.validate_milvus_version("2.6.7") # Below 2.6.8
  116. with pytest.raises(ValueError, match="HNSW_SQ requires Milvus 2.6.8"):
  117. config.validate_milvus_version("2.5.0") # Much lower
  118. def test_version_validation_hnsw_sq_with_sq4u(self):
  119. """Test HNSW_SQ + SQ4U version validation"""
  120. config = MilvusIndexConfig(index_type="HNSW_SQ", sq_type="SQ4U")
  121. # Passes with valid version
  122. config.validate_milvus_version("2.6.9")
  123. # Fails with invalid version
  124. with pytest.raises(ValueError, match="HNSW_SQ requires Milvus 2.6.8"):
  125. config.validate_milvus_version("2.6.0")
  126. def test_version_validation_hnsw_no_requirement(self):
  127. """Test normal HNSW has no version restriction"""
  128. config = MilvusIndexConfig(index_type="HNSW")
  129. # No version restriction
  130. config.validate_milvus_version("2.4.0") # Lower version OK
  131. config.validate_milvus_version("2.6.9") # Higher version OK
  132. def test_version_validation_with_dev_suffix(self):
  133. """Test version validation handles dev suffixes"""
  134. config = MilvusIndexConfig(index_type="HNSW_SQ")
  135. # Should handle "2.6.9-dev" format
  136. config.validate_milvus_version("2.6.9-dev")
  137. def test_build_index_params_autoindex(self):
  138. """Test AUTOINDEX generates explicit index parameters"""
  139. config = MilvusIndexConfig(index_type="AUTOINDEX")
  140. mock_index_params = MagicMock()
  141. result = config.build_index_params(mock_index_params)
  142. assert result is mock_index_params
  143. mock_index_params.add_index.assert_called_once_with(
  144. field_name="vector",
  145. index_type="AUTOINDEX",
  146. metric_type="COSINE",
  147. params={},
  148. )
  149. def test_build_index_params_hnsw(self):
  150. """Test HNSW index parameters construction"""
  151. config = MilvusIndexConfig(
  152. index_type="HNSW",
  153. metric_type="COSINE",
  154. hnsw_m=32,
  155. hnsw_ef_construction=256,
  156. )
  157. mock_index_params = MagicMock()
  158. config.build_index_params(mock_index_params)
  159. mock_index_params.add_index.assert_called_once()
  160. call_kwargs = mock_index_params.add_index.call_args[1]
  161. assert call_kwargs["index_type"] == "HNSW"
  162. assert call_kwargs["metric_type"] == "COSINE"
  163. assert call_kwargs["params"]["M"] == 32
  164. assert call_kwargs["params"]["efConstruction"] == 256
  165. def test_build_index_params_hnsw_sq(self):
  166. """Test HNSW_SQ index parameters construction"""
  167. config = MilvusIndexConfig(
  168. index_type="HNSW_SQ",
  169. sq_type="SQ8",
  170. sq_refine=True,
  171. sq_refine_type="FP32",
  172. )
  173. mock_index_params = MagicMock()
  174. config.build_index_params(mock_index_params)
  175. call_kwargs = mock_index_params.add_index.call_args[1]
  176. assert call_kwargs["index_type"] == "HNSW_SQ"
  177. assert call_kwargs["params"]["sq_type"] == "SQ8"
  178. assert call_kwargs["params"]["refine"] is True
  179. assert call_kwargs["params"]["refine_type"] == "FP32"
  180. def test_build_index_params_hnsw_sq_no_refine(self):
  181. """Test HNSW_SQ without refinement"""
  182. config = MilvusIndexConfig(index_type="HNSW_SQ", sq_type="SQ8", sq_refine=False)
  183. mock_index_params = MagicMock()
  184. config.build_index_params(mock_index_params)
  185. call_kwargs = mock_index_params.add_index.call_args[1]
  186. assert call_kwargs["index_type"] == "HNSW_SQ"
  187. assert call_kwargs["params"]["sq_type"] == "SQ8"
  188. assert "refine" not in call_kwargs["params"]
  189. assert "refine_type" not in call_kwargs["params"]
  190. def test_build_index_params_ivf_flat(self):
  191. """Test IVF_FLAT index parameters construction"""
  192. config = MilvusIndexConfig(index_type="IVF_FLAT", ivf_nlist=2048)
  193. mock_index_params = MagicMock()
  194. config.build_index_params(mock_index_params)
  195. call_kwargs = mock_index_params.add_index.call_args[1]
  196. assert call_kwargs["index_type"] == "IVF_FLAT"
  197. assert call_kwargs["params"]["nlist"] == 2048
  198. def test_build_index_params_with_none(self):
  199. """Test that RuntimeError is raised when index_params is None for custom types"""
  200. config = MilvusIndexConfig(index_type="HNSW")
  201. # Pass None to simulate when compatibility helper returns None
  202. with pytest.raises(RuntimeError, match="IndexParams not available"):
  203. config.build_index_params(None)
  204. def test_build_search_params_hnsw(self):
  205. """Test HNSW search parameters construction"""
  206. config = MilvusIndexConfig(index_type="HNSW", hnsw_ef=150)
  207. params = config.build_search_params()
  208. assert params["params"]["ef"] == 150
  209. def test_build_search_params_hnsw_sq_with_refine(self):
  210. """Test HNSW_SQ with refinement search parameters"""
  211. config = MilvusIndexConfig(
  212. index_type="HNSW_SQ", hnsw_ef=200, sq_refine=True, sq_refine_k=20
  213. )
  214. params = config.build_search_params()
  215. assert params["params"]["ef"] == 200
  216. assert params["params"]["refine_k"] == 20
  217. def test_build_search_params_hnsw_sq_no_refine(self):
  218. """Test HNSW_SQ without refinement search parameters"""
  219. config = MilvusIndexConfig(index_type="HNSW_SQ", hnsw_ef=200, sq_refine=False)
  220. params = config.build_search_params()
  221. assert params["params"]["ef"] == 200
  222. assert "refine_k" not in params["params"]
  223. def test_build_search_params_ivf(self):
  224. """Test IVF search parameters construction"""
  225. config = MilvusIndexConfig(index_type="IVF_FLAT", ivf_nprobe=32)
  226. params = config.build_search_params()
  227. assert params["params"]["nprobe"] == 32
  228. def test_build_search_params_autoindex(self):
  229. """Test AUTOINDEX search parameters (empty)"""
  230. config = MilvusIndexConfig(index_type="AUTOINDEX")
  231. params = config.build_search_params()
  232. assert params == {}
  233. def test_to_dict_hnsw(self):
  234. """Test configuration export for HNSW"""
  235. config = MilvusIndexConfig(index_type="HNSW")
  236. d = config.to_dict()
  237. assert d["index_type"] == "HNSW"
  238. assert d["hnsw_m"] == 16
  239. assert d["sq_type"] is None # Not HNSW_SQ
  240. assert d["ivf_nlist"] is None # Not IVF
  241. def test_to_dict_hnsw_sq(self):
  242. """Test configuration export for HNSW_SQ"""
  243. config = MilvusIndexConfig(index_type="HNSW_SQ", sq_type="SQ8")
  244. d = config.to_dict()
  245. assert d["index_type"] == "HNSW_SQ"
  246. assert d["sq_type"] == "SQ8"
  247. assert d["ivf_nlist"] is None
  248. def test_to_dict_ivf(self):
  249. """Test configuration export for IVF"""
  250. config = MilvusIndexConfig(index_type="IVF_FLAT")
  251. d = config.to_dict()
  252. assert d["index_type"] == "IVF_FLAT"
  253. assert d["ivf_nlist"] == 1024
  254. assert d["sq_type"] is None
  255. def test_env_bool_parsing(self):
  256. """Test boolean environment variable parsing"""
  257. with patch.dict(os.environ, {"MILVUS_HNSW_SQ_REFINE": "true"}):
  258. config = MilvusIndexConfig(index_type="HNSW_SQ")
  259. assert config.sq_refine is True
  260. with patch.dict(os.environ, {"MILVUS_HNSW_SQ_REFINE": "false"}):
  261. config = MilvusIndexConfig(index_type="HNSW_SQ")
  262. assert not config.sq_refine
  263. with patch.dict(os.environ, {"MILVUS_HNSW_SQ_REFINE": "1"}):
  264. config = MilvusIndexConfig(index_type="HNSW_SQ")
  265. assert config.sq_refine is True
  266. with patch.dict(os.environ, {"MILVUS_HNSW_SQ_REFINE": "0"}):
  267. config = MilvusIndexConfig(index_type="HNSW_SQ")
  268. assert not config.sq_refine
  269. def test_env_int_parsing_invalid(self):
  270. """Test integer environment variable parsing with invalid value"""
  271. with patch.dict(os.environ, {"MILVUS_HNSW_M": "invalid"}):
  272. config = MilvusIndexConfig()
  273. assert config.hnsw_m == 16 # Falls back to default (Milvus 2.4+)
  274. def test_all_index_types_supported(self):
  275. """Test all supported index types can be configured"""
  276. for index_type in SUPPORTED_INDEX_TYPES:
  277. if index_type == "HNSW_SQ":
  278. # HNSW_SQ requires special parameters
  279. config = MilvusIndexConfig(index_type=index_type, sq_type="SQ8")
  280. else:
  281. config = MilvusIndexConfig(index_type=index_type)
  282. assert config.index_type == index_type
  283. def test_all_metric_types_supported(self):
  284. """Test all supported metric types can be configured"""
  285. for metric_type in SUPPORTED_METRIC_TYPES:
  286. config = MilvusIndexConfig(metric_type=metric_type)
  287. assert config.metric_type == metric_type
  288. def test_all_sq_types_supported(self):
  289. """Test all supported sq_types can be configured"""
  290. for sq_type in SUPPORTED_SQ_TYPES:
  291. config = MilvusIndexConfig(index_type="HNSW_SQ", sq_type=sq_type)
  292. assert config.sq_type == sq_type
  293. def test_all_refine_types_supported(self):
  294. """Test all supported refine_types can be configured"""
  295. for refine_type in SUPPORTED_REFINE_TYPES:
  296. config = MilvusIndexConfig(
  297. index_type="HNSW_SQ", sq_refine=True, sq_refine_type=refine_type
  298. )
  299. assert config.sq_refine_type == refine_type
  300. def test_get_config_field_names(self):
  301. """Test get_config_field_names() returns all dataclass fields"""
  302. field_names = MilvusIndexConfig.get_config_field_names()
  303. # Check that it's a set
  304. assert isinstance(field_names, set)
  305. # Check that all expected fields are present
  306. expected_fields = {
  307. "index_type",
  308. "metric_type",
  309. "hnsw_m",
  310. "hnsw_ef_construction",
  311. "hnsw_ef",
  312. "sq_type",
  313. "sq_refine",
  314. "sq_refine_type",
  315. "sq_refine_k",
  316. "ivf_nlist",
  317. "ivf_nprobe",
  318. }
  319. assert field_names == expected_fields
  320. def test_get_config_field_names_single_source_of_truth(self):
  321. """Test that get_config_field_names() provides single source of truth for configuration parameters"""
  322. # This test ensures that when we add new fields to MilvusIndexConfig,
  323. # they are automatically included in get_config_field_names()
  324. # without needing to update hardcoded lists elsewhere
  325. from dataclasses import fields as dataclass_fields
  326. # Get fields directly from dataclass
  327. direct_fields = {f.name for f in dataclass_fields(MilvusIndexConfig)}
  328. # Get fields via the method
  329. method_fields = MilvusIndexConfig.get_config_field_names()
  330. # They should be identical
  331. assert direct_fields == method_fields, (
  332. f"Method should return same fields as dataclass. "
  333. f"Difference: {direct_fields.symmetric_difference(method_fields)}"
  334. )
  335. if __name__ == "__main__":
  336. pytest.main([__file__, "-v"])