chroma_impl.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. import asyncio
  2. import os
  3. from dataclasses import dataclass
  4. from typing import Any, final
  5. import numpy as np
  6. from lightrag.base import BaseVectorStorage
  7. from lightrag.utils import logger
  8. import pipmaster as pm
  9. if not pm.is_installed("chromadb"):
  10. pm.install("chromadb")
  11. from chromadb import HttpClient, PersistentClient # type: ignore
  12. from chromadb.config import Settings # type: ignore
  13. @final
  14. @dataclass
  15. class ChromaVectorDBStorage(BaseVectorStorage):
  16. """ChromaDB vector storage implementation."""
  17. def __post_init__(self):
  18. self._validate_embedding_func()
  19. try:
  20. config = self.global_config.get("vector_db_storage_cls_kwargs", {})
  21. cosine_threshold = config.get("cosine_better_than_threshold")
  22. if cosine_threshold is None:
  23. raise ValueError(
  24. "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
  25. )
  26. self.cosine_better_than_threshold = cosine_threshold
  27. user_collection_settings = config.get("collection_settings", {})
  28. # Default HNSW index settings for ChromaDB
  29. default_collection_settings = {
  30. # Distance metric used for similarity search (cosine similarity)
  31. "hnsw:space": "cosine",
  32. # Number of nearest neighbors to explore during index construction
  33. # Higher values = better recall but slower indexing
  34. "hnsw:construction_ef": 128,
  35. # Number of nearest neighbors to explore during search
  36. # Higher values = better recall but slower search
  37. "hnsw:search_ef": 128,
  38. # Number of connections per node in the HNSW graph
  39. # Higher values = better recall but more memory usage
  40. "hnsw:M": 16,
  41. # Number of vectors to process in one batch during indexing
  42. "hnsw:batch_size": 100,
  43. # Number of updates before forcing index synchronization
  44. # Lower values = more frequent syncs but slower indexing
  45. "hnsw:sync_threshold": 1000,
  46. }
  47. collection_settings = {
  48. **default_collection_settings,
  49. **user_collection_settings,
  50. }
  51. local_path = config.get("local_path", None)
  52. if local_path:
  53. self._client = PersistentClient(
  54. path=local_path,
  55. settings=Settings(
  56. allow_reset=True,
  57. anonymized_telemetry=False,
  58. ),
  59. )
  60. else:
  61. auth_provider = config.get(
  62. "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
  63. )
  64. auth_credentials = config.get("auth_token", "secret-token")
  65. headers = {}
  66. if "token_authn" in auth_provider:
  67. headers = {
  68. config.get(
  69. "auth_header_name", "X-Chroma-Token"
  70. ): auth_credentials
  71. }
  72. elif "basic_authn" in auth_provider:
  73. auth_credentials = config.get("auth_credentials", "admin:admin")
  74. self._client = HttpClient(
  75. host=config.get("host", "localhost"),
  76. port=config.get("port", 8000),
  77. headers=headers,
  78. settings=Settings(
  79. chroma_api_impl="rest",
  80. chroma_client_auth_provider=auth_provider,
  81. chroma_client_auth_credentials=auth_credentials,
  82. allow_reset=True,
  83. anonymized_telemetry=False,
  84. ),
  85. )
  86. self._collection = self._client.get_or_create_collection(
  87. name=self.namespace,
  88. metadata={
  89. **collection_settings,
  90. "dimension": self.embedding_func.embedding_dim,
  91. },
  92. )
  93. # Use batch size from collection settings if specified
  94. self._max_batch_size = self.global_config.get(
  95. "embedding_batch_num", collection_settings.get("hnsw:batch_size", 32)
  96. )
  97. except Exception as e:
  98. logger.error(f"ChromaDB initialization failed: {str(e)}")
  99. raise
  100. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  101. logger.debug(f"Inserting {len(data)} to {self.namespace}")
  102. if not data:
  103. return
  104. try:
  105. import time
  106. current_time = int(time.time())
  107. ids = list(data.keys())
  108. documents = [v["content"] for v in data.values()]
  109. metadatas = [
  110. {
  111. **{k: v for k, v in item.items() if k in self.meta_fields},
  112. "created_at": current_time,
  113. }
  114. or {"_default": "true", "created_at": current_time}
  115. for item in data.values()
  116. ]
  117. # Process in batches
  118. batches = [
  119. documents[i : i + self._max_batch_size]
  120. for i in range(0, len(documents), self._max_batch_size)
  121. ]
  122. embedding_tasks = [self.embedding_func(batch) for batch in batches]
  123. embeddings_list = []
  124. # Pre-allocate embeddings_list with known size
  125. embeddings_list = [None] * len(embedding_tasks)
  126. # Use asyncio.gather instead of as_completed if order doesn't matter
  127. embeddings_results = await asyncio.gather(*embedding_tasks)
  128. embeddings_list = list(embeddings_results)
  129. embeddings = np.concatenate(embeddings_list)
  130. # Upsert in batches
  131. for i in range(0, len(ids), self._max_batch_size):
  132. batch_slice = slice(i, i + self._max_batch_size)
  133. self._collection.upsert(
  134. ids=ids[batch_slice],
  135. embeddings=embeddings[batch_slice].tolist(),
  136. documents=documents[batch_slice],
  137. metadatas=metadatas[batch_slice],
  138. )
  139. return ids
  140. except Exception as e:
  141. logger.error(f"Error during ChromaDB upsert: {str(e)}")
  142. raise
  143. async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
  144. try:
  145. embedding = await self.embedding_func(
  146. [query], _priority=5
  147. ) # higher priority for query
  148. results = self._collection.query(
  149. query_embeddings=embedding.tolist()
  150. if not isinstance(embedding, list)
  151. else embedding,
  152. n_results=top_k * 2, # Request more results to allow for filtering
  153. include=["metadatas", "distances", "documents"],
  154. )
  155. # Filter results by cosine similarity threshold and take top k
  156. # We request 2x results initially to have enough after filtering
  157. # ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal)
  158. # We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity)
  159. # Only keep results with distance below threshold, then take top k
  160. return [
  161. {
  162. "id": results["ids"][0][i],
  163. "distance": 1 - results["distances"][0][i],
  164. "content": results["documents"][0][i],
  165. "created_at": results["metadatas"][0][i].get("created_at"),
  166. **results["metadatas"][0][i],
  167. }
  168. for i in range(len(results["ids"][0]))
  169. if (1 - results["distances"][0][i]) >= self.cosine_better_than_threshold
  170. ][:top_k]
  171. except Exception as e:
  172. logger.error(f"Error during ChromaDB query: {str(e)}")
  173. raise
  174. async def index_done_callback(self) -> None:
  175. # ChromaDB handles persistence automatically
  176. pass
  177. async def delete_entity(self, entity_name: str) -> None:
  178. """Delete an entity by its ID.
  179. Args:
  180. entity_name: The ID of the entity to delete
  181. """
  182. try:
  183. logger.info(f"Deleting entity with ID {entity_name} from {self.namespace}")
  184. self._collection.delete(ids=[entity_name])
  185. except Exception as e:
  186. logger.error(f"Error during entity deletion: {str(e)}")
  187. raise
  188. async def delete_entity_relation(self, entity_name: str) -> None:
  189. """Delete an entity and its relations by ID.
  190. In vector DB context, this is equivalent to delete_entity.
  191. Args:
  192. entity_name: The ID of the entity to delete
  193. """
  194. await self.delete_entity(entity_name)
  195. async def delete(self, ids: list[str]) -> None:
  196. """Delete vectors with specified IDs
  197. Args:
  198. ids: List of vector IDs to be deleted
  199. """
  200. try:
  201. self._collection.delete(ids=ids)
  202. logger.debug(
  203. f"Successfully deleted {len(ids)} vectors from {self.namespace}"
  204. )
  205. except Exception as e:
  206. logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
  207. raise
  208. except Exception as e:
  209. logger.error(f"Error during prefix search in ChromaDB: {str(e)}")
  210. raise
  211. async def get_by_id(self, id: str) -> dict[str, Any] | None:
  212. """Get vector data by its ID
  213. Args:
  214. id: The unique identifier of the vector
  215. Returns:
  216. The vector data if found, or None if not found
  217. """
  218. try:
  219. # Query the collection for a single vector by ID
  220. result = self._collection.get(
  221. ids=[id], include=["metadatas", "embeddings", "documents"]
  222. )
  223. if not result or not result["ids"] or len(result["ids"]) == 0:
  224. return None
  225. # Format the result to match the expected structure
  226. return {
  227. "id": result["ids"][0],
  228. "vector": result["embeddings"][0],
  229. "content": result["documents"][0],
  230. "created_at": result["metadatas"][0].get("created_at"),
  231. **result["metadatas"][0],
  232. }
  233. except Exception as e:
  234. logger.error(f"Error retrieving vector data for ID {id}: {e}")
  235. return None
  236. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  237. """Get multiple vector data by their IDs
  238. Args:
  239. ids: List of unique identifiers
  240. Returns:
  241. List of vector data objects that were found
  242. """
  243. if not ids:
  244. return []
  245. try:
  246. # Query the collection for multiple vectors by IDs
  247. result = self._collection.get(
  248. ids=ids, include=["metadatas", "embeddings", "documents"]
  249. )
  250. if not result or not result["ids"] or len(result["ids"]) == 0:
  251. return []
  252. # Format the results to match the expected structure and preserve ordering
  253. formatted_map: dict[str, dict[str, Any]] = {}
  254. for i, result_id in enumerate(result["ids"]):
  255. record = {
  256. "id": result_id,
  257. "vector": result["embeddings"][i],
  258. "content": result["documents"][i],
  259. "created_at": result["metadatas"][i].get("created_at"),
  260. **result["metadatas"][i],
  261. }
  262. formatted_map[str(result_id)] = record
  263. ordered_results: list[dict[str, Any] | None] = []
  264. for requested_id in ids:
  265. ordered_results.append(formatted_map.get(str(requested_id)))
  266. return ordered_results
  267. except Exception as e:
  268. logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
  269. return []
  270. async def drop(self) -> dict[str, str]:
  271. """Drop all vector data from storage and clean up resources
  272. This method will delete all documents from the ChromaDB collection.
  273. Returns:
  274. dict[str, str]: Operation status and message
  275. - On success: {"status": "success", "message": "data dropped"}
  276. - On failure: {"status": "error", "message": "<error details>"}
  277. """
  278. try:
  279. # Get all IDs in the collection
  280. result = self._collection.get(include=[])
  281. if result and result["ids"] and len(result["ids"]) > 0:
  282. # Delete all documents
  283. self._collection.delete(ids=result["ids"])
  284. logger.info(
  285. f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
  286. )
  287. return {"status": "success", "message": "data dropped"}
  288. except Exception as e:
  289. logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
  290. return {"status": "error", "message": str(e)}