test_workspace_migration_isolation.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. """
  2. Tests for workspace isolation during PostgreSQL migration.
  3. This test module verifies that setup_table() properly filters migration data
  4. by workspace, preventing cross-workspace data leakage during legacy table migration.
  5. Critical Bug: Migration copied ALL records from legacy table regardless of workspace,
  6. causing workspace A to receive workspace B's data, violating multi-tenant isolation.
  7. """
  8. import pytest
  9. from unittest.mock import AsyncMock
  10. from lightrag.kg.postgres_impl import PGVectorStorage
  11. class TestWorkspaceMigrationIsolation:
  12. """Test suite for workspace-scoped migration in PostgreSQL."""
  13. async def test_migration_filters_by_workspace(self):
  14. """
  15. Test that migration only copies data from the specified workspace.
  16. Scenario: Legacy table contains data from multiple workspaces.
  17. Migrate only workspace_a's data to new table.
  18. Expected: New table contains only workspace_a data, workspace_b data excluded.
  19. """
  20. db = AsyncMock()
  21. # Configure mock return values to avoid unawaited coroutine warnings
  22. db._create_vector_index.return_value = None
  23. # Track state for new table count (starts at 0, increases after migration)
  24. new_table_record_count = {"count": 0}
  25. # Mock table existence checks
  26. async def table_exists_side_effect(db_instance, name):
  27. if name.lower() == "lightrag_doc_chunks": # legacy
  28. return True
  29. elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
  30. return False # New table doesn't exist initially
  31. return False
  32. # Mock data for workspace_a
  33. mock_records_a = [
  34. {
  35. "id": "a1",
  36. "workspace": "workspace_a",
  37. "content": "content_a1",
  38. "content_vector": [0.1] * 1536,
  39. },
  40. {
  41. "id": "a2",
  42. "workspace": "workspace_a",
  43. "content": "content_a2",
  44. "content_vector": [0.2] * 1536,
  45. },
  46. ]
  47. # Mock query responses
  48. async def query_side_effect(sql, params, **kwargs):
  49. multirows = kwargs.get("multirows", False)
  50. sql_upper = sql.upper()
  51. # Count query for new table workspace data (verification before migration)
  52. if (
  53. "COUNT(*)" in sql_upper
  54. and "MODEL_1536D" in sql_upper
  55. and "WHERE WORKSPACE" in sql_upper
  56. ):
  57. return new_table_record_count # Initially 0
  58. # Count query with workspace filter (legacy table) - for workspace count
  59. elif "COUNT(*)" in sql_upper and "WHERE WORKSPACE" in sql_upper:
  60. if params and params[0] == "workspace_a":
  61. return {"count": 2} # workspace_a has 2 records
  62. elif params and params[0] == "workspace_b":
  63. return {"count": 3} # workspace_b has 3 records
  64. return {"count": 0}
  65. # Count query for legacy table (total, no workspace filter)
  66. elif (
  67. "COUNT(*)" in sql_upper
  68. and "LIGHTRAG" in sql_upper
  69. and "WHERE WORKSPACE" not in sql_upper
  70. ):
  71. return {"count": 5} # Total records in legacy
  72. # SELECT with workspace filter for migration (multirows)
  73. elif "SELECT" in sql_upper and "FROM" in sql_upper and multirows:
  74. workspace = params[0] if params else None
  75. if workspace == "workspace_a":
  76. # Handle keyset pagination: check for "id >" pattern
  77. if "id >" in sql.lower():
  78. # Keyset pagination: params = [workspace, last_id, limit]
  79. last_id = params[1] if len(params) > 1 else None
  80. # Find records after last_id
  81. found_idx = -1
  82. for i, rec in enumerate(mock_records_a):
  83. if rec["id"] == last_id:
  84. found_idx = i
  85. break
  86. if found_idx >= 0:
  87. return mock_records_a[found_idx + 1 :]
  88. return []
  89. else:
  90. # First batch: params = [workspace, limit]
  91. return mock_records_a
  92. return [] # No data for other workspaces
  93. return {}
  94. db.query.side_effect = query_side_effect
  95. db.execute = AsyncMock()
  96. # Mock check_table_exists on db
  97. async def check_table_exists_side_effect(name):
  98. if name.lower() == "lightrag_doc_chunks": # legacy
  99. return True
  100. elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
  101. return False # New table doesn't exist initially
  102. return False
  103. db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
  104. # Track migration through _run_with_retry calls
  105. migration_executed = []
  106. async def mock_run_with_retry(operation, *args, **kwargs):
  107. migration_executed.append(True)
  108. new_table_record_count["count"] = 2 # Simulate 2 records migrated
  109. return None
  110. db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
  111. # Migrate for workspace_a only - correct parameter order
  112. await PGVectorStorage.setup_table(
  113. db,
  114. "LIGHTRAG_DOC_CHUNKS_model_1536d",
  115. workspace="workspace_a", # CRITICAL: Only migrate workspace_a
  116. embedding_dim=1536,
  117. legacy_table_name="LIGHTRAG_DOC_CHUNKS",
  118. base_table="LIGHTRAG_DOC_CHUNKS",
  119. )
  120. # Verify the migration was triggered
  121. assert (
  122. len(migration_executed) > 0
  123. ), "Migration should have been executed for workspace_a"
  124. async def test_migration_without_workspace_raises_error(self):
  125. """
  126. Test that migration without workspace parameter raises ValueError.
  127. Scenario: setup_table called without workspace parameter.
  128. Expected: ValueError is raised because workspace is required.
  129. """
  130. db = AsyncMock()
  131. # workspace is now a required parameter - calling with None should raise ValueError
  132. with pytest.raises(ValueError, match="workspace must be provided"):
  133. await PGVectorStorage.setup_table(
  134. db,
  135. "lightrag_doc_chunks_model_1536d",
  136. workspace=None, # No workspace - should raise ValueError
  137. embedding_dim=1536,
  138. legacy_table_name="lightrag_doc_chunks",
  139. base_table="lightrag_doc_chunks",
  140. )
  141. async def test_no_cross_workspace_contamination(self):
  142. """
  143. Test that workspace B's migration doesn't include workspace A's data.
  144. Scenario: Migration for workspace_b only.
  145. Expected: Only workspace_b data is queried, workspace_a data excluded.
  146. """
  147. db = AsyncMock()
  148. # Configure mock return values to avoid unawaited coroutine warnings
  149. db._create_vector_index.return_value = None
  150. # Track which workspace is being queried
  151. queried_workspace = None
  152. new_table_count = {"count": 0}
  153. # Mock data for workspace_b
  154. mock_records_b = [
  155. {
  156. "id": "b1",
  157. "workspace": "workspace_b",
  158. "content": "content_b1",
  159. "content_vector": [0.3] * 1536,
  160. },
  161. ]
  162. async def table_exists_side_effect(db_instance, name):
  163. if name.lower() == "lightrag_doc_chunks": # legacy
  164. return True
  165. elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
  166. return False
  167. return False
  168. async def query_side_effect(sql, params, **kwargs):
  169. nonlocal queried_workspace
  170. multirows = kwargs.get("multirows", False)
  171. sql_upper = sql.upper()
  172. # Count query for new table workspace data (should be 0 initially)
  173. if (
  174. "COUNT(*)" in sql_upper
  175. and "MODEL_1536D" in sql_upper
  176. and "WHERE WORKSPACE" in sql_upper
  177. ):
  178. return new_table_count
  179. # Count query with workspace filter (legacy table)
  180. elif "COUNT(*)" in sql_upper and "WHERE WORKSPACE" in sql_upper:
  181. queried_workspace = params[0] if params else None
  182. return {"count": 1} # 1 record for the queried workspace
  183. # Count query for legacy table total (no workspace filter)
  184. elif (
  185. "COUNT(*)" in sql_upper
  186. and "LIGHTRAG" in sql_upper
  187. and "WHERE WORKSPACE" not in sql_upper
  188. ):
  189. return {"count": 3} # 3 total records in legacy
  190. # SELECT with workspace filter for migration (multirows)
  191. elif "SELECT" in sql_upper and "FROM" in sql_upper and multirows:
  192. workspace = params[0] if params else None
  193. if workspace == "workspace_b":
  194. # Handle keyset pagination: check for "id >" pattern
  195. if "id >" in sql.lower():
  196. # Keyset pagination: params = [workspace, last_id, limit]
  197. last_id = params[1] if len(params) > 1 else None
  198. # Find records after last_id
  199. found_idx = -1
  200. for i, rec in enumerate(mock_records_b):
  201. if rec["id"] == last_id:
  202. found_idx = i
  203. break
  204. if found_idx >= 0:
  205. return mock_records_b[found_idx + 1 :]
  206. return []
  207. else:
  208. # First batch: params = [workspace, limit]
  209. return mock_records_b
  210. return [] # No data for other workspaces
  211. return {}
  212. db.query.side_effect = query_side_effect
  213. db.execute = AsyncMock()
  214. # Mock check_table_exists on db
  215. async def check_table_exists_side_effect(name):
  216. if name.lower() == "lightrag_doc_chunks": # legacy
  217. return True
  218. elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
  219. return False
  220. return False
  221. db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
  222. # Track migration through _run_with_retry calls
  223. migration_executed = []
  224. async def mock_run_with_retry(operation, *args, **kwargs):
  225. migration_executed.append(True)
  226. new_table_count["count"] = 1 # Simulate migration
  227. return None
  228. db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
  229. # Migrate workspace_b - correct parameter order
  230. await PGVectorStorage.setup_table(
  231. db,
  232. "LIGHTRAG_DOC_CHUNKS_model_1536d",
  233. workspace="workspace_b", # Only migrate workspace_b
  234. embedding_dim=1536,
  235. legacy_table_name="LIGHTRAG_DOC_CHUNKS",
  236. base_table="LIGHTRAG_DOC_CHUNKS",
  237. )
  238. # Verify only workspace_b was queried
  239. assert queried_workspace == "workspace_b", "Should only query workspace_b"