| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374 |
- import pytest
- import numpy as np
- from unittest.mock import patch, AsyncMock
- from lightrag.utils import EmbeddingFunc
- from lightrag.kg.postgres_impl import (
- PGVectorStorage,
- PostgreSQLDB,
- _safe_index_name,
- )
- from lightrag.exceptions import DataMigrationError
- from lightrag.namespace import NameSpace
- # Mock PostgreSQLDB
- @pytest.fixture
- def mock_pg_db():
- """Mock PostgreSQL database connection"""
- db = AsyncMock()
- db.workspace = "test_workspace"
- db.vector_index_type = None
- # Mock query responses: list for search queries (multirows=True), dict for DDL checks
- async def mock_query(sql, params=None, multirows=False, **kwargs):
- if multirows:
- return []
- return {"exists": False, "count": 0}
- # Mock for execute that mimics PostgreSQLDB.execute() behavior
- async def mock_execute(sql, data=None, **kwargs):
- return None
- db.query = AsyncMock(side_effect=mock_query)
- db.execute = AsyncMock(side_effect=mock_execute)
- return db
- # Mock get_data_init_lock to avoid async lock issues in tests
- @pytest.fixture(autouse=True)
- def mock_data_init_lock():
- with patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock:
- mock_lock_ctx = AsyncMock()
- mock_lock.return_value = mock_lock_ctx
- yield mock_lock
- # Mock ClientManager
- @pytest.fixture
- def mock_client_manager(mock_pg_db):
- with patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager:
- mock_manager.get_client = AsyncMock(return_value=mock_pg_db)
- mock_manager.release_client = AsyncMock()
- yield mock_manager
- # Mock Embedding function
- @pytest.fixture
- def mock_embedding_func():
- async def embed_func(texts, **kwargs):
- return np.array([[0.1] * 768 for _ in texts])
- # Note: EmbeddingFunc in this version of lightrag supports model_name
- func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test_model")
- return func
- @pytest.mark.asyncio
- async def test_postgres_halfvec_table_creation(
- mock_client_manager, mock_pg_db, mock_embedding_func
- ):
- """Test if table is created with HALFVEC type when HNSW_HALFVEC is selected"""
- # Set index type to HNSW_HALFVEC
- mock_pg_db.vector_index_type = "HNSW_HALFVEC"
- config = {
- "embedding_batch_num": 10,
- "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
- }
- storage = PGVectorStorage(
- namespace=NameSpace.VECTOR_STORE_CHUNKS,
- global_config=config,
- embedding_func=mock_embedding_func,
- workspace="test_ws",
- )
- # Mock table doesn't exist
- mock_pg_db.check_table_exists = AsyncMock(return_value=False)
- # Initialize storage (should trigger table creation)
- await storage.initialize()
- # Verify table creation SQL contains HALFVEC(768)
- create_table_calls = [
- call
- for call in mock_pg_db.execute.call_args_list
- if "CREATE TABLE" in call[0][0]
- ]
- assert len(create_table_calls) > 0
- create_sql = create_table_calls[0][0][0]
- assert "HALFVEC(768)" in create_sql
- assert "VECTOR(768)" not in create_sql
- @pytest.mark.asyncio
- async def test_postgres_vector_table_creation_default(
- mock_client_manager, mock_pg_db, mock_embedding_func
- ):
- """Test if table is created with default VECTOR type when other index type is selected"""
- # Set index type to HNSW (default)
- mock_pg_db.vector_index_type = "HNSW"
- config = {
- "embedding_batch_num": 10,
- "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
- }
- storage = PGVectorStorage(
- namespace=NameSpace.VECTOR_STORE_CHUNKS,
- global_config=config,
- embedding_func=mock_embedding_func,
- workspace="test_ws",
- )
- # Mock table doesn't exist
- mock_pg_db.check_table_exists = AsyncMock(return_value=False)
- # Initialize storage (should trigger table creation)
- await storage.initialize()
- # Verify table creation SQL contains VECTOR(768)
- create_table_calls = [
- call
- for call in mock_pg_db.execute.call_args_list
- if "CREATE TABLE" in call[0][0]
- ]
- assert len(create_table_calls) > 0
- create_sql = create_table_calls[0][0][0]
- assert "VECTOR(768)" in create_sql
- assert "HALFVEC(768)" not in create_sql
- # Namespaces that use vector search SQL templates (query path)
- QUERY_NAMESPACES = [
- NameSpace.VECTOR_STORE_CHUNKS,
- NameSpace.VECTOR_STORE_ENTITIES,
- NameSpace.VECTOR_STORE_RELATIONSHIPS,
- ]
- @pytest.mark.asyncio
- @pytest.mark.parametrize("namespace", QUERY_NAMESPACES)
- async def test_query_uses_halfvec_cast_when_hnsw_halfvec(
- mock_client_manager, mock_pg_db, mock_embedding_func, namespace
- ):
- """When HNSW_HALFVEC is set, generated search SQL uses ::halfvec (not ::vector)."""
- mock_pg_db.vector_index_type = "HNSW_HALFVEC"
- mock_pg_db.check_table_exists = AsyncMock(return_value=True)
- config = {
- "embedding_batch_num": 10,
- "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
- }
- storage = PGVectorStorage(
- namespace=namespace,
- global_config=config,
- embedding_func=mock_embedding_func,
- workspace="test_ws",
- )
- await storage.initialize()
- query_embedding = [0.1] * 768
- await storage.query("test query", top_k=5, query_embedding=query_embedding)
- assert mock_pg_db.query.called
- call_args = mock_pg_db.query.call_args
- sql = call_args[0][0]
- assert "::halfvec" in sql
- assert "::vector" not in sql
- @pytest.mark.asyncio
- @pytest.mark.parametrize("namespace", QUERY_NAMESPACES)
- async def test_query_uses_vector_cast_when_hnsw_default(
- mock_client_manager, mock_pg_db, mock_embedding_func, namespace
- ):
- """When HNSW (default) is set, generated search SQL uses ::vector (not ::halfvec)."""
- mock_pg_db.vector_index_type = "HNSW"
- mock_pg_db.check_table_exists = AsyncMock(return_value=True)
- config = {
- "embedding_batch_num": 10,
- "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
- }
- storage = PGVectorStorage(
- namespace=namespace,
- global_config=config,
- embedding_func=mock_embedding_func,
- workspace="test_ws",
- )
- await storage.initialize()
- query_embedding = [0.1] * 768
- await storage.query("test query", top_k=5, query_embedding=query_embedding)
- assert mock_pg_db.query.called
- call_args = mock_pg_db.query.call_args
- sql = call_args[0][0]
- assert "::vector" in sql
- assert "::halfvec" not in sql
- # ---------------------------------------------------------------------------
- # Index switching: old conflicting indexes are dropped
- # ---------------------------------------------------------------------------
- @pytest.mark.asyncio
- async def test_create_vector_index_drops_old_indexes_when_switching(mock_pg_db):
- """Switching from HNSW to HNSW_HALFVEC drops the old hnsw_cosine index."""
- mock_pg_db.vector_index_type = "HNSW_HALFVEC"
- mock_pg_db.hnsw_m = 16
- mock_pg_db.hnsw_ef = 64
- mock_pg_db.ivfflat_lists = 100
- mock_pg_db.vchordrq_build_options = ""
- table_name = "lightrag_vdb_chunks_test"
- async def mock_query(sql, params=None, multirows=False, **kwargs):
- if "pg_indexes" in sql:
- return None
- return None
- mock_pg_db.query = AsyncMock(side_effect=mock_query)
- mock_pg_db.execute = AsyncMock()
- # Call the real method with mock_pg_db as self
- await PostgreSQLDB._create_vector_index(mock_pg_db, table_name, 3072)
- execute_calls = [call[0][0] for call in mock_pg_db.execute.call_args_list]
- old_hnsw_name = _safe_index_name(table_name, "hnsw_cosine")
- old_ivfflat_name = _safe_index_name(table_name, "ivfflat_cosine")
- old_vchordrq_name = _safe_index_name(table_name, "vchordrq_cosine")
- drop_calls = [c for c in execute_calls if "DROP INDEX IF EXISTS" in c]
- dropped_names = {c.split("DROP INDEX IF EXISTS ")[1].strip() for c in drop_calls}
- assert old_hnsw_name in dropped_names
- assert old_ivfflat_name in dropped_names
- assert old_vchordrq_name in dropped_names
- new_index_name = _safe_index_name(table_name, "hnsw_halfvec_cosine")
- assert new_index_name not in dropped_names
- alter_calls = [c for c in execute_calls if "ALTER TABLE" in c]
- assert any("HALFVEC(3072)" in c for c in alter_calls)
- create_calls = [c for c in execute_calls if "CREATE INDEX" in c]
- assert any("halfvec_cosine_ops" in c for c in create_calls)
- @pytest.mark.asyncio
- async def test_create_vector_index_no_drop_when_index_exists(mock_pg_db):
- """If the target index already exists, no DROP or CREATE is issued."""
- mock_pg_db.vector_index_type = "HNSW_HALFVEC"
- mock_pg_db.hnsw_m = 16
- mock_pg_db.hnsw_ef = 64
- mock_pg_db.ivfflat_lists = 100
- mock_pg_db.vchordrq_build_options = ""
- table_name = "lightrag_vdb_chunks_test"
- async def mock_query(sql, params=None, multirows=False, **kwargs):
- if "pg_indexes" in sql:
- return {"?column?": 1}
- return None
- mock_pg_db.query = AsyncMock(side_effect=mock_query)
- mock_pg_db.execute = AsyncMock()
- await PostgreSQLDB._create_vector_index(mock_pg_db, table_name, 3072)
- execute_calls = [call[0][0] for call in mock_pg_db.execute.call_args_list]
- assert not any("DROP INDEX" in c for c in execute_calls)
- assert not any("CREATE INDEX" in c for c in execute_calls)
- # ---------------------------------------------------------------------------
- # HalfVector dimension detection in setup_table
- # ---------------------------------------------------------------------------
- class _MockHalfVector:
- """Mimics pgvector.halfvec.HalfVector for testing dimension detection."""
- def __init__(self, dim: int):
- self._dim = dim
- def dimensions(self) -> int:
- return self._dim
- def to_list(self):
- return [0.0] * self._dim
- @pytest.mark.asyncio
- async def test_setup_table_detects_halfvector_dimension_mismatch(mock_pg_db):
- """DataMigrationError is raised when a HalfVector column has a different dimension."""
- table_name = "lightrag_vdb_chunks_new"
- legacy_table = "lightrag_vdb_chunks"
- mock_pg_db.check_table_exists = AsyncMock(
- side_effect=lambda t: t.lower() == legacy_table.lower()
- )
- call_count = 0
- async def mock_query(sql, params=None, multirows=False, **kwargs):
- nonlocal call_count
- call_count += 1
- if "COUNT(*)" in sql:
- return {"count": 5}
- if "content_vector" in sql:
- return {"content_vector": _MockHalfVector(1024)}
- return None
- mock_pg_db.query = AsyncMock(side_effect=mock_query)
- mock_pg_db.execute = AsyncMock()
- with pytest.raises(DataMigrationError, match="Dimension mismatch"):
- await PGVectorStorage.setup_table(
- db=mock_pg_db,
- table_name=table_name,
- workspace="test_ws",
- embedding_dim=768,
- legacy_table_name=legacy_table,
- base_table=legacy_table,
- )
- @pytest.mark.asyncio
- async def test_setup_table_accepts_matching_halfvector_dimension(mock_pg_db):
- """No error when HalfVector dimension matches the expected embedding_dim."""
- table_name = "lightrag_vdb_chunks_new"
- legacy_table = "lightrag_vdb_chunks"
- mock_pg_db.check_table_exists = AsyncMock(
- side_effect=lambda t: t.lower() == legacy_table.lower()
- )
- mock_pg_db.vector_index_type = "HNSW_HALFVEC"
- async def mock_query(sql, params=None, multirows=False, **kwargs):
- if "COUNT(*)" in sql:
- return {"count": 5}
- if "content_vector" in sql:
- return {"content_vector": _MockHalfVector(768)}
- if multirows:
- return []
- return None
- mock_pg_db.query = AsyncMock(side_effect=mock_query)
- mock_pg_db.execute = AsyncMock()
- with patch.object(PGVectorStorage, "_pg_create_table", new_callable=AsyncMock):
- await PGVectorStorage.setup_table(
- db=mock_pg_db,
- table_name=table_name,
- workspace="test_ws",
- embedding_dim=768,
- legacy_table_name=legacy_table,
- base_table=legacy_table,
- )
|