opensearch_storage_demo.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. """
  2. Integration test for OpenSearch Storage in LightRAG.
  3. Tests all 4 storage types against a live OpenSearch cluster:
  4. - KV Storage: CRUD, filter_keys
  5. - DocStatus Storage: CRUD, pagination (PIT + search_after), status counts
  6. - Graph Storage: nodes, edges, BFS traversal, search_labels
  7. - Vector Storage: k-NN upsert, query, get/delete
  8. Prerequisites:
  9. OpenSearch cluster running with k-NN plugin enabled.
  10. Set env vars: OPENSEARCH_HOSTS, OPENSEARCH_USER, OPENSEARCH_PASSWORD,
  11. OPENSEARCH_USE_SSL, OPENSEARCH_VERIFY_CERTS
  12. Usage:
  13. OPENSEARCH_HOSTS=localhost:9200 OPENSEARCH_USER=admin \
  14. OPENSEARCH_PASSWORD=<password> OPENSEARCH_USE_SSL=true \
  15. OPENSEARCH_VERIFY_CERTS=false python examples/opensearch_storage_demo.py
  16. """
  17. import asyncio
  18. import numpy as np
  19. from lightrag.kg.opensearch_impl import (
  20. OpenSearchKVStorage,
  21. OpenSearchDocStatusStorage,
  22. OpenSearchGraphStorage,
  23. OpenSearchVectorDBStorage,
  24. ClientManager,
  25. )
  26. from lightrag.kg.shared_storage import initialize_share_data
  27. from lightrag.base import DocStatus
  28. class MockEmbeddingFunc:
  29. """Mock embedding function for testing."""
  30. def __init__(self, dim=128):
  31. self.embedding_dim = dim
  32. self.max_token_size = 512
  33. self.model_name = "mock-embedding"
  34. async def __call__(self, texts, **kwargs):
  35. return np.random.rand(len(texts), self.embedding_dim).astype(np.float32)
  36. CONFIG = {
  37. "embedding_batch_num": 10,
  38. "max_graph_nodes": 1000,
  39. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
  40. }
  41. EMBED = MockEmbeddingFunc()
  42. PASSED = 0
  43. FAILED = 0
  44. def check(condition, msg):
  45. global PASSED, FAILED
  46. if condition:
  47. print(f" ✓ {msg}")
  48. PASSED += 1
  49. else:
  50. print(f" ✗ {msg}")
  51. FAILED += 1
  52. async def test_connection_manager():
  53. print("\n=== Connection Manager ===")
  54. client1 = await ClientManager.get_client()
  55. client2 = await ClientManager.get_client()
  56. check(client1 is client2, "Singleton pattern (same instance)")
  57. await ClientManager.release_client(client1)
  58. await ClientManager.release_client(client2)
  59. check(True, "Released clients")
  60. async def test_kv_storage():
  61. print("\n=== KV Storage ===")
  62. s = OpenSearchKVStorage(
  63. namespace="integ_kv",
  64. global_config=CONFIG,
  65. embedding_func=EMBED,
  66. workspace="integ",
  67. )
  68. await s.initialize()
  69. try:
  70. await s.upsert({"k1": {"content": "hello"}, "k2": {"content": "world"}})
  71. await s.index_done_callback()
  72. doc = await s.get_by_id("k1")
  73. check(doc is not None and doc.get("content") == "hello", "get_by_id")
  74. docs = await s.get_by_ids(["k1", "k2", "missing"])
  75. check(docs[0] is not None and docs[2] is None, "get_by_ids preserves order")
  76. missing = await s.filter_keys({"k1", "k99"})
  77. check(missing == {"k99"}, f"filter_keys: {missing}")
  78. check(not await s.is_empty(), "is_empty=False")
  79. await s.delete(["k2"])
  80. await s.index_done_callback()
  81. check(await s.get_by_id("k2") is None, "delete + verify")
  82. finally:
  83. await s.drop()
  84. await s.finalize()
  85. async def test_doc_status_storage():
  86. print("\n=== DocStatus Storage ===")
  87. s = OpenSearchDocStatusStorage(
  88. namespace="integ_ds",
  89. global_config=CONFIG,
  90. embedding_func=EMBED,
  91. workspace="integ",
  92. )
  93. await s.initialize()
  94. try:
  95. # Insert docs
  96. await s.upsert(
  97. {
  98. f"d{i}": {
  99. "status": "processed" if i % 2 == 0 else "pending",
  100. "file_path": f"/file{i}.txt",
  101. "content_summary": f"summary {i}",
  102. "content_length": i * 10,
  103. "chunks_count": i,
  104. "created_at": 1000 + i,
  105. "updated_at": 2000 + i,
  106. }
  107. for i in range(20)
  108. }
  109. )
  110. await s.index_done_callback()
  111. # Status counts
  112. counts = await s.get_all_status_counts()
  113. check(counts.get("all") == 20, f"all_status_counts: {counts}")
  114. check(
  115. counts.get("processed") == 10, f"processed count: {counts.get('processed')}"
  116. )
  117. # get_docs_by_status (uses PIT + search_after)
  118. processed = await s.get_docs_by_status(DocStatus.PROCESSED)
  119. check(len(processed) == 10, f"get_docs_by_status(processed): {len(processed)}")
  120. # get_docs_by_track_id (uses PIT + search_after)
  121. await s.upsert(
  122. {
  123. "tracked1": {
  124. "status": "processed",
  125. "file_path": "/t.txt",
  126. "content_summary": "s",
  127. "content_length": 1,
  128. "chunks_count": 1,
  129. "created_at": 100,
  130. "updated_at": 200,
  131. "track_id": "batch-42",
  132. }
  133. }
  134. )
  135. await s.index_done_callback()
  136. tracked = await s.get_docs_by_track_id("batch-42")
  137. check(len(tracked) == 1, f"get_docs_by_track_id: {len(tracked)}")
  138. # Paginated (uses PIT + search_after)
  139. page1, total = await s.get_docs_paginated(page=1, page_size=10)
  140. check(total == 21, f"paginated total: {total}")
  141. check(len(page1) == 10, f"page1 size: {len(page1)}")
  142. page2, _ = await s.get_docs_paginated(page=2, page_size=10)
  143. check(len(page2) == 10, f"page2 size: {len(page2)}")
  144. page3, _ = await s.get_docs_paginated(page=3, page_size=10)
  145. check(len(page3) == 1, f"page3 size: {len(page3)}")
  146. # With status filter
  147. filtered, ftotal = await s.get_docs_paginated(
  148. status_filter=DocStatus.PENDING, page=1, page_size=50
  149. )
  150. check(ftotal == 10, f"filtered total: {ftotal}")
  151. # get_doc_by_file_path
  152. doc = await s.get_doc_by_file_path("/file0.txt")
  153. check(doc is not None and doc["_id"] == "d0", "get_doc_by_file_path")
  154. finally:
  155. await s.drop()
  156. await s.finalize()
  157. async def test_graph_storage():
  158. print("\n=== Graph Storage ===")
  159. s = OpenSearchGraphStorage(
  160. namespace="integ_graph",
  161. global_config=CONFIG,
  162. embedding_func=EMBED,
  163. workspace="integ",
  164. )
  165. await s.initialize()
  166. try:
  167. # Upsert nodes and edges
  168. await s.upsert_node(
  169. "Alice", {"entity_type": "person", "description": "A researcher"}
  170. )
  171. await s.upsert_node(
  172. "Bob", {"entity_type": "person", "description": "A developer"}
  173. )
  174. await s.upsert_node(
  175. "Quantum", {"entity_type": "topic", "description": "Quantum computing"}
  176. )
  177. await s.upsert_edge(
  178. "Alice",
  179. "Bob",
  180. {"relationship": "knows", "weight": "1.0", "keywords": "collab"},
  181. )
  182. await s.upsert_edge(
  183. "Alice",
  184. "Quantum",
  185. {"relationship": "researches", "weight": "2.0", "keywords": "research"},
  186. )
  187. await s.upsert_edge(
  188. "Bob",
  189. "Quantum",
  190. {"relationship": "studies", "weight": "0.5", "keywords": "learning"},
  191. )
  192. await s.index_done_callback()
  193. check(await s.has_node("Alice"), "has_node(Alice)")
  194. check(not await s.has_node("Nobody"), "has_node(Nobody)=False")
  195. check(await s.has_edge("Alice", "Bob"), "has_edge(Alice,Bob)")
  196. node = await s.get_node("Alice")
  197. check(node is not None and node.get("entity_type") == "person", "get_node")
  198. check(node.get("entity_id") == "Alice", "entity_id field present")
  199. check(
  200. await s.node_degree("Alice") == 2,
  201. f"node_degree(Alice)={await s.node_degree('Alice')}",
  202. )
  203. edges = await s.get_node_edges("Alice")
  204. check(len(edges) == 2, f"get_node_edges: {len(edges)}")
  205. # Batch ops
  206. batch = await s.get_nodes_batch(["Alice", "Bob", "Missing"])
  207. check("Alice" in batch and "Missing" not in batch, "get_nodes_batch")
  208. degrees = await s.node_degrees_batch(["Alice", "Bob", "Quantum"])
  209. check(degrees.get("Alice") == 2, f"node_degrees_batch: {degrees}")
  210. # Knowledge graph (BFS)
  211. kg = await s.get_knowledge_graph("Alice", max_depth=2)
  212. check(len(kg.nodes) == 3, f"BFS nodes: {len(kg.nodes)}")
  213. check(len(kg.edges) == 3, f"BFS edges: {len(kg.edges)}")
  214. # get_all_labels (uses PIT)
  215. labels = await s.get_all_labels()
  216. check("Alice" in labels and "Bob" in labels, f"get_all_labels: {labels}")
  217. # get_all_nodes (uses PIT)
  218. all_nodes = await s.get_all_nodes()
  219. check(len(all_nodes) == 3, f"get_all_nodes: {len(all_nodes)}")
  220. # get_all_edges (uses PIT)
  221. all_edges = await s.get_all_edges()
  222. check(len(all_edges) == 3, f"get_all_edges: {len(all_edges)}")
  223. # search_labels
  224. found = await s.search_labels("ali", limit=10)
  225. check("Alice" in found, f"search_labels('ali'): {found}")
  226. # popular_labels
  227. popular = await s.get_popular_labels(limit=10)
  228. check(len(popular) > 0, f"get_popular_labels: {popular}")
  229. # Delete node (cascading)
  230. await s.delete_node("Bob")
  231. await s.index_done_callback()
  232. check(not await s.has_node("Bob"), "delete_node cascade")
  233. check(not await s.has_edge("Alice", "Bob"), "edges removed after delete_node")
  234. print(f" (PPL graphlookup: {s._ppl_graphlookup_available})")
  235. finally:
  236. await s.drop()
  237. await s.finalize()
  238. async def test_vector_storage():
  239. print("\n=== Vector Storage ===")
  240. s = OpenSearchVectorDBStorage(
  241. namespace="integ_vec",
  242. global_config=CONFIG,
  243. embedding_func=EMBED,
  244. workspace="integ",
  245. meta_fields={"content", "entity_name"},
  246. )
  247. await s.initialize()
  248. try:
  249. await s.upsert(
  250. {
  251. "v1": {"content": "apple fruit"},
  252. "v2": {"content": "banana fruit"},
  253. "v3": {"content": "quantum physics"},
  254. }
  255. )
  256. await s.index_done_callback()
  257. results = await s.query("apple", top_k=3)
  258. check(len(results) > 0, f"query returned {len(results)} results")
  259. check(all("distance" in r for r in results), "results have distance")
  260. doc = await s.get_by_id("v1")
  261. check(doc is not None and doc["id"] == "v1", "get_by_id")
  262. docs = await s.get_by_ids(["v1", "v2", "missing"])
  263. check(docs[0] is not None and docs[2] is None, "get_by_ids")
  264. vecs = await s.get_vectors_by_ids(["v1"])
  265. check("v1" in vecs and len(vecs["v1"]) == 128, "get_vectors_by_ids")
  266. await s.delete(["v3"])
  267. await s.index_done_callback()
  268. check(await s.get_by_id("v3") is None, "delete + verify")
  269. finally:
  270. await s.drop()
  271. await s.finalize()
  272. async def main():
  273. print("=" * 60)
  274. print("OpenSearch Storage Integration Tests")
  275. print("=" * 60)
  276. initialize_share_data(workers=1)
  277. try:
  278. await test_connection_manager()
  279. await test_kv_storage()
  280. await test_doc_status_storage()
  281. await test_graph_storage()
  282. await test_vector_storage()
  283. except Exception as e:
  284. print(f"\n✗ Fatal error: {e}")
  285. import traceback
  286. traceback.print_exc()
  287. print(f"\n{'=' * 60}")
  288. print(f"Results: {PASSED} passed, {FAILED} failed")
  289. print(f"{'=' * 60}")
  290. if FAILED > 0:
  291. exit(1)
  292. if __name__ == "__main__":
  293. asyncio.run(main())