test_postgres_halfvec.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. import pytest
  2. import numpy as np
  3. from unittest.mock import patch, AsyncMock
  4. from lightrag.utils import EmbeddingFunc
  5. from lightrag.kg.postgres_impl import (
  6. PGVectorStorage,
  7. PostgreSQLDB,
  8. _safe_index_name,
  9. )
  10. from lightrag.exceptions import DataMigrationError
  11. from lightrag.namespace import NameSpace
  12. # Mock PostgreSQLDB
  13. @pytest.fixture
  14. def mock_pg_db():
  15. """Mock PostgreSQL database connection"""
  16. db = AsyncMock()
  17. db.workspace = "test_workspace"
  18. db.vector_index_type = None
  19. # Mock query responses: list for search queries (multirows=True), dict for DDL checks
  20. async def mock_query(sql, params=None, multirows=False, **kwargs):
  21. if multirows:
  22. return []
  23. return {"exists": False, "count": 0}
  24. # Mock for execute that mimics PostgreSQLDB.execute() behavior
  25. async def mock_execute(sql, data=None, **kwargs):
  26. return None
  27. db.query = AsyncMock(side_effect=mock_query)
  28. db.execute = AsyncMock(side_effect=mock_execute)
  29. return db
  30. # Mock get_data_init_lock to avoid async lock issues in tests
  31. @pytest.fixture(autouse=True)
  32. def mock_data_init_lock():
  33. with patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock:
  34. mock_lock_ctx = AsyncMock()
  35. mock_lock.return_value = mock_lock_ctx
  36. yield mock_lock
  37. # Mock ClientManager
  38. @pytest.fixture
  39. def mock_client_manager(mock_pg_db):
  40. with patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager:
  41. mock_manager.get_client = AsyncMock(return_value=mock_pg_db)
  42. mock_manager.release_client = AsyncMock()
  43. yield mock_manager
  44. # Mock Embedding function
  45. @pytest.fixture
  46. def mock_embedding_func():
  47. async def embed_func(texts, **kwargs):
  48. return np.array([[0.1] * 768 for _ in texts])
  49. # Note: EmbeddingFunc in this version of lightrag supports model_name
  50. func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test_model")
  51. return func
  52. @pytest.mark.asyncio
  53. async def test_postgres_halfvec_table_creation(
  54. mock_client_manager, mock_pg_db, mock_embedding_func
  55. ):
  56. """Test if table is created with HALFVEC type when HNSW_HALFVEC is selected"""
  57. # Set index type to HNSW_HALFVEC
  58. mock_pg_db.vector_index_type = "HNSW_HALFVEC"
  59. config = {
  60. "embedding_batch_num": 10,
  61. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
  62. }
  63. storage = PGVectorStorage(
  64. namespace=NameSpace.VECTOR_STORE_CHUNKS,
  65. global_config=config,
  66. embedding_func=mock_embedding_func,
  67. workspace="test_ws",
  68. )
  69. # Mock table doesn't exist
  70. mock_pg_db.check_table_exists = AsyncMock(return_value=False)
  71. # Initialize storage (should trigger table creation)
  72. await storage.initialize()
  73. # Verify table creation SQL contains HALFVEC(768)
  74. create_table_calls = [
  75. call
  76. for call in mock_pg_db.execute.call_args_list
  77. if "CREATE TABLE" in call[0][0]
  78. ]
  79. assert len(create_table_calls) > 0
  80. create_sql = create_table_calls[0][0][0]
  81. assert "HALFVEC(768)" in create_sql
  82. assert "VECTOR(768)" not in create_sql
  83. @pytest.mark.asyncio
  84. async def test_postgres_vector_table_creation_default(
  85. mock_client_manager, mock_pg_db, mock_embedding_func
  86. ):
  87. """Test if table is created with default VECTOR type when other index type is selected"""
  88. # Set index type to HNSW (default)
  89. mock_pg_db.vector_index_type = "HNSW"
  90. config = {
  91. "embedding_batch_num": 10,
  92. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
  93. }
  94. storage = PGVectorStorage(
  95. namespace=NameSpace.VECTOR_STORE_CHUNKS,
  96. global_config=config,
  97. embedding_func=mock_embedding_func,
  98. workspace="test_ws",
  99. )
  100. # Mock table doesn't exist
  101. mock_pg_db.check_table_exists = AsyncMock(return_value=False)
  102. # Initialize storage (should trigger table creation)
  103. await storage.initialize()
  104. # Verify table creation SQL contains VECTOR(768)
  105. create_table_calls = [
  106. call
  107. for call in mock_pg_db.execute.call_args_list
  108. if "CREATE TABLE" in call[0][0]
  109. ]
  110. assert len(create_table_calls) > 0
  111. create_sql = create_table_calls[0][0][0]
  112. assert "VECTOR(768)" in create_sql
  113. assert "HALFVEC(768)" not in create_sql
  114. # Namespaces that use vector search SQL templates (query path)
  115. QUERY_NAMESPACES = [
  116. NameSpace.VECTOR_STORE_CHUNKS,
  117. NameSpace.VECTOR_STORE_ENTITIES,
  118. NameSpace.VECTOR_STORE_RELATIONSHIPS,
  119. ]
  120. @pytest.mark.asyncio
  121. @pytest.mark.parametrize("namespace", QUERY_NAMESPACES)
  122. async def test_query_uses_halfvec_cast_when_hnsw_halfvec(
  123. mock_client_manager, mock_pg_db, mock_embedding_func, namespace
  124. ):
  125. """When HNSW_HALFVEC is set, generated search SQL uses ::halfvec (not ::vector)."""
  126. mock_pg_db.vector_index_type = "HNSW_HALFVEC"
  127. mock_pg_db.check_table_exists = AsyncMock(return_value=True)
  128. config = {
  129. "embedding_batch_num": 10,
  130. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
  131. }
  132. storage = PGVectorStorage(
  133. namespace=namespace,
  134. global_config=config,
  135. embedding_func=mock_embedding_func,
  136. workspace="test_ws",
  137. )
  138. await storage.initialize()
  139. query_embedding = [0.1] * 768
  140. await storage.query("test query", top_k=5, query_embedding=query_embedding)
  141. assert mock_pg_db.query.called
  142. call_args = mock_pg_db.query.call_args
  143. sql = call_args[0][0]
  144. assert "::halfvec" in sql
  145. assert "::vector" not in sql
  146. @pytest.mark.asyncio
  147. @pytest.mark.parametrize("namespace", QUERY_NAMESPACES)
  148. async def test_query_uses_vector_cast_when_hnsw_default(
  149. mock_client_manager, mock_pg_db, mock_embedding_func, namespace
  150. ):
  151. """When HNSW (default) is set, generated search SQL uses ::vector (not ::halfvec)."""
  152. mock_pg_db.vector_index_type = "HNSW"
  153. mock_pg_db.check_table_exists = AsyncMock(return_value=True)
  154. config = {
  155. "embedding_batch_num": 10,
  156. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
  157. }
  158. storage = PGVectorStorage(
  159. namespace=namespace,
  160. global_config=config,
  161. embedding_func=mock_embedding_func,
  162. workspace="test_ws",
  163. )
  164. await storage.initialize()
  165. query_embedding = [0.1] * 768
  166. await storage.query("test query", top_k=5, query_embedding=query_embedding)
  167. assert mock_pg_db.query.called
  168. call_args = mock_pg_db.query.call_args
  169. sql = call_args[0][0]
  170. assert "::vector" in sql
  171. assert "::halfvec" not in sql
  172. # ---------------------------------------------------------------------------
  173. # Index switching: old conflicting indexes are dropped
  174. # ---------------------------------------------------------------------------
  175. @pytest.mark.asyncio
  176. async def test_create_vector_index_drops_old_indexes_when_switching(mock_pg_db):
  177. """Switching from HNSW to HNSW_HALFVEC drops the old hnsw_cosine index."""
  178. mock_pg_db.vector_index_type = "HNSW_HALFVEC"
  179. mock_pg_db.hnsw_m = 16
  180. mock_pg_db.hnsw_ef = 64
  181. mock_pg_db.ivfflat_lists = 100
  182. mock_pg_db.vchordrq_build_options = ""
  183. table_name = "lightrag_vdb_chunks_test"
  184. async def mock_query(sql, params=None, multirows=False, **kwargs):
  185. if "pg_indexes" in sql:
  186. return None
  187. return None
  188. mock_pg_db.query = AsyncMock(side_effect=mock_query)
  189. mock_pg_db.execute = AsyncMock()
  190. # Call the real method with mock_pg_db as self
  191. await PostgreSQLDB._create_vector_index(mock_pg_db, table_name, 3072)
  192. execute_calls = [call[0][0] for call in mock_pg_db.execute.call_args_list]
  193. old_hnsw_name = _safe_index_name(table_name, "hnsw_cosine")
  194. old_ivfflat_name = _safe_index_name(table_name, "ivfflat_cosine")
  195. old_vchordrq_name = _safe_index_name(table_name, "vchordrq_cosine")
  196. drop_calls = [c for c in execute_calls if "DROP INDEX IF EXISTS" in c]
  197. dropped_names = {c.split("DROP INDEX IF EXISTS ")[1].strip() for c in drop_calls}
  198. assert old_hnsw_name in dropped_names
  199. assert old_ivfflat_name in dropped_names
  200. assert old_vchordrq_name in dropped_names
  201. new_index_name = _safe_index_name(table_name, "hnsw_halfvec_cosine")
  202. assert new_index_name not in dropped_names
  203. alter_calls = [c for c in execute_calls if "ALTER TABLE" in c]
  204. assert any("HALFVEC(3072)" in c for c in alter_calls)
  205. create_calls = [c for c in execute_calls if "CREATE INDEX" in c]
  206. assert any("halfvec_cosine_ops" in c for c in create_calls)
  207. @pytest.mark.asyncio
  208. async def test_create_vector_index_no_drop_when_index_exists(mock_pg_db):
  209. """If the target index already exists, no DROP or CREATE is issued."""
  210. mock_pg_db.vector_index_type = "HNSW_HALFVEC"
  211. mock_pg_db.hnsw_m = 16
  212. mock_pg_db.hnsw_ef = 64
  213. mock_pg_db.ivfflat_lists = 100
  214. mock_pg_db.vchordrq_build_options = ""
  215. table_name = "lightrag_vdb_chunks_test"
  216. async def mock_query(sql, params=None, multirows=False, **kwargs):
  217. if "pg_indexes" in sql:
  218. return {"?column?": 1}
  219. return None
  220. mock_pg_db.query = AsyncMock(side_effect=mock_query)
  221. mock_pg_db.execute = AsyncMock()
  222. await PostgreSQLDB._create_vector_index(mock_pg_db, table_name, 3072)
  223. execute_calls = [call[0][0] for call in mock_pg_db.execute.call_args_list]
  224. assert not any("DROP INDEX" in c for c in execute_calls)
  225. assert not any("CREATE INDEX" in c for c in execute_calls)
  226. # ---------------------------------------------------------------------------
  227. # HalfVector dimension detection in setup_table
  228. # ---------------------------------------------------------------------------
  229. class _MockHalfVector:
  230. """Mimics pgvector.halfvec.HalfVector for testing dimension detection."""
  231. def __init__(self, dim: int):
  232. self._dim = dim
  233. def dimensions(self) -> int:
  234. return self._dim
  235. def to_list(self):
  236. return [0.0] * self._dim
  237. @pytest.mark.asyncio
  238. async def test_setup_table_detects_halfvector_dimension_mismatch(mock_pg_db):
  239. """DataMigrationError is raised when a HalfVector column has a different dimension."""
  240. table_name = "lightrag_vdb_chunks_new"
  241. legacy_table = "lightrag_vdb_chunks"
  242. mock_pg_db.check_table_exists = AsyncMock(
  243. side_effect=lambda t: t.lower() == legacy_table.lower()
  244. )
  245. call_count = 0
  246. async def mock_query(sql, params=None, multirows=False, **kwargs):
  247. nonlocal call_count
  248. call_count += 1
  249. if "COUNT(*)" in sql:
  250. return {"count": 5}
  251. if "content_vector" in sql:
  252. return {"content_vector": _MockHalfVector(1024)}
  253. return None
  254. mock_pg_db.query = AsyncMock(side_effect=mock_query)
  255. mock_pg_db.execute = AsyncMock()
  256. with pytest.raises(DataMigrationError, match="Dimension mismatch"):
  257. await PGVectorStorage.setup_table(
  258. db=mock_pg_db,
  259. table_name=table_name,
  260. workspace="test_ws",
  261. embedding_dim=768,
  262. legacy_table_name=legacy_table,
  263. base_table=legacy_table,
  264. )
  265. @pytest.mark.asyncio
  266. async def test_setup_table_accepts_matching_halfvector_dimension(mock_pg_db):
  267. """No error when HalfVector dimension matches the expected embedding_dim."""
  268. table_name = "lightrag_vdb_chunks_new"
  269. legacy_table = "lightrag_vdb_chunks"
  270. mock_pg_db.check_table_exists = AsyncMock(
  271. side_effect=lambda t: t.lower() == legacy_table.lower()
  272. )
  273. mock_pg_db.vector_index_type = "HNSW_HALFVEC"
  274. async def mock_query(sql, params=None, multirows=False, **kwargs):
  275. if "COUNT(*)" in sql:
  276. return {"count": 5}
  277. if "content_vector" in sql:
  278. return {"content_vector": _MockHalfVector(768)}
  279. if multirows:
  280. return []
  281. return None
  282. mock_pg_db.query = AsyncMock(side_effect=mock_query)
  283. mock_pg_db.execute = AsyncMock()
  284. with patch.object(PGVectorStorage, "_pg_create_table", new_callable=AsyncMock):
  285. await PGVectorStorage.setup_table(
  286. db=mock_pg_db,
  287. table_name=table_name,
  288. workspace="test_ws",
  289. embedding_dim=768,
  290. legacy_table_name=legacy_table,
  291. base_table=legacy_table,
  292. )