lightrag_cloudflare_demo.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. import asyncio
  2. import os
  3. import inspect
  4. import logging
  5. import logging.config
  6. from lightrag import LightRAG, QueryParam
  7. from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
  8. import requests
  9. import numpy as np
  10. from dotenv import load_dotenv
  11. """This code is a modified version of lightrag_openai_demo.py"""
  12. # ideally, as always, env!
  13. load_dotenv(dotenv_path=".env", override=False)
  14. """ ----========= IMPORTANT CHANGE THIS! =========---- """
  15. cloudflare_api_key = "YOUR_API_KEY"
  16. account_id = "YOUR_ACCOUNT ID" # This is unique to your Cloudflare account
  17. # Authomatically changes
  18. api_base_url = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/"
  19. # choose an embedding model
  20. EMBEDDING_MODEL = "@cf/baai/bge-m3"
  21. # choose a generative model
  22. LLM_MODEL = "@cf/meta/llama-3.2-3b-instruct"
  23. WORKING_DIR = "../dickens" # you can change output as desired
  24. # Cloudflare init
  25. class CloudflareWorker:
  26. def __init__(
  27. self,
  28. cloudflare_api_key: str,
  29. api_base_url: str,
  30. llm_model_name: str,
  31. embedding_model_name: str,
  32. max_tokens: int = 4080,
  33. max_response_tokens: int = 4080,
  34. ):
  35. self.cloudflare_api_key = cloudflare_api_key
  36. self.api_base_url = api_base_url
  37. self.llm_model_name = llm_model_name
  38. self.embedding_model_name = embedding_model_name
  39. self.max_tokens = max_tokens
  40. self.max_response_tokens = max_response_tokens
  41. async def _send_request(self, model_name: str, input_: dict, debug_log: str):
  42. headers = {"Authorization": f"Bearer {self.cloudflare_api_key}"}
  43. print(f"""
  44. data sent to Cloudflare
  45. ~~~~~~~~~~~
  46. {debug_log}
  47. """)
  48. try:
  49. response_raw = requests.post(
  50. f"{self.api_base_url}{model_name}", headers=headers, json=input_
  51. ).json()
  52. print(f"""
  53. Cloudflare worker responded with:
  54. ~~~~~~~~~~~
  55. {str(response_raw)}
  56. """)
  57. result = response_raw.get("result", {})
  58. if "data" in result: # Embedding case
  59. return np.array(result["data"])
  60. if "response" in result: # LLM response
  61. return result["response"]
  62. raise ValueError("Unexpected Cloudflare response format")
  63. except Exception as e:
  64. print(f"""
  65. Cloudflare API returned:
  66. ~~~~~~~~~
  67. Error: {e}
  68. """)
  69. input("Press Enter to continue...")
  70. return None
  71. async def query(self, prompt, system_prompt: str = "", **kwargs) -> str:
  72. # since no caching is used and we don't want to mess with everything lightrag, pop the kwarg it is
  73. kwargs.pop("hashing_kv", None)
  74. message = [
  75. {"role": "system", "content": system_prompt},
  76. {"role": "user", "content": prompt},
  77. ]
  78. input_ = {
  79. "messages": message,
  80. "max_tokens": self.max_tokens,
  81. "response_token_limit": self.max_response_tokens,
  82. }
  83. return await self._send_request(
  84. self.llm_model_name,
  85. input_,
  86. debug_log=f"\n- model used {self.llm_model_name}\n- system prompt: {system_prompt}\n- query: {prompt}",
  87. )
  88. async def embedding_chunk(self, texts: list[str]) -> np.ndarray:
  89. print(f"""
  90. TEXT inputted
  91. ~~~~~
  92. {texts}
  93. """)
  94. input_ = {
  95. "text": texts,
  96. "max_tokens": self.max_tokens,
  97. "response_token_limit": self.max_response_tokens,
  98. }
  99. return await self._send_request(
  100. self.embedding_model_name,
  101. input_,
  102. debug_log=f"\n-llm model name {self.embedding_model_name}\n- texts: {texts}",
  103. )
  104. def configure_logging():
  105. """Configure logging for the application"""
  106. # Reset any existing handlers to ensure clean configuration
  107. for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
  108. logger_instance = logging.getLogger(logger_name)
  109. logger_instance.handlers = []
  110. logger_instance.filters = []
  111. # Get log directory path from environment variable or use current directory
  112. log_dir = os.getenv("LOG_DIR", os.getcwd())
  113. log_file_path = os.path.abspath(
  114. os.path.join(log_dir, "lightrag_cloudflare_worker_demo.log")
  115. )
  116. print(f"\nLightRAG compatible demo log file: {log_file_path}\n")
  117. os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
  118. # Get log file max size and backup count from environment variables
  119. log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
  120. log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
  121. logging.config.dictConfig(
  122. {
  123. "version": 1,
  124. "disable_existing_loggers": False,
  125. "formatters": {
  126. "default": {
  127. "format": "%(levelname)s: %(message)s",
  128. },
  129. "detailed": {
  130. "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  131. },
  132. },
  133. "handlers": {
  134. "console": {
  135. "formatter": "default",
  136. "class": "logging.StreamHandler",
  137. "stream": "ext://sys.stderr",
  138. },
  139. "file": {
  140. "formatter": "detailed",
  141. "class": "logging.handlers.RotatingFileHandler",
  142. "filename": log_file_path,
  143. "maxBytes": log_max_bytes,
  144. "backupCount": log_backup_count,
  145. "encoding": "utf-8",
  146. },
  147. },
  148. "loggers": {
  149. "lightrag": {
  150. "handlers": ["console", "file"],
  151. "level": "INFO",
  152. "propagate": False,
  153. },
  154. },
  155. }
  156. )
  157. # Set the logger level to INFO
  158. logger.setLevel(logging.INFO)
  159. # Enable verbose debug if needed
  160. set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
  161. if not os.path.exists(WORKING_DIR):
  162. os.mkdir(WORKING_DIR)
  163. async def initialize_rag():
  164. cloudflare_worker = CloudflareWorker(
  165. cloudflare_api_key=cloudflare_api_key,
  166. api_base_url=api_base_url,
  167. embedding_model_name=EMBEDDING_MODEL,
  168. llm_model_name=LLM_MODEL,
  169. )
  170. rag = LightRAG(
  171. working_dir=WORKING_DIR,
  172. max_parallel_insert=2,
  173. llm_model_func=cloudflare_worker.query,
  174. llm_model_name=os.getenv("LLM_MODEL", LLM_MODEL),
  175. summary_max_tokens=4080,
  176. embedding_func=EmbeddingFunc(
  177. embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
  178. max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "2048")),
  179. func=lambda texts: cloudflare_worker.embedding_chunk(
  180. texts,
  181. ),
  182. ),
  183. )
  184. await rag.initialize_storages() # Auto-initializes pipeline_status
  185. return rag
  186. async def print_stream(stream):
  187. async for chunk in stream:
  188. print(chunk, end="", flush=True)
  189. async def main():
  190. try:
  191. # Clear old data files
  192. files_to_delete = [
  193. "graph_chunk_entity_relation.graphml",
  194. "kv_store_doc_status.json",
  195. "kv_store_full_docs.json",
  196. "kv_store_text_chunks.json",
  197. "vdb_chunks.json",
  198. "vdb_entities.json",
  199. "vdb_relationships.json",
  200. ]
  201. for file in files_to_delete:
  202. file_path = os.path.join(WORKING_DIR, file)
  203. if os.path.exists(file_path):
  204. os.remove(file_path)
  205. print(f"Deleting old file:: {file_path}")
  206. # Initialize RAG instance
  207. rag = await initialize_rag()
  208. # Test embedding function
  209. test_text = ["This is a test string for embedding."]
  210. embedding = await rag.embedding_func(test_text)
  211. embedding_dim = embedding.shape[1]
  212. print("\n=======================")
  213. print("Test embedding function")
  214. print("========================")
  215. print(f"Test dict: {test_text}")
  216. print(f"Detected embedding dimension: {embedding_dim}\n\n")
  217. # Locate the location of what is needed to be added to the knowledge
  218. # Can add several simultaneously by modifying code
  219. with open("./book.txt", "r", encoding="utf-8") as f:
  220. await rag.ainsert(f.read())
  221. # Perform naive search
  222. print("\n=====================")
  223. print("Query mode: naive")
  224. print("=====================")
  225. resp = await rag.aquery(
  226. "What are the top themes in this story?",
  227. param=QueryParam(mode="naive", stream=True),
  228. )
  229. if inspect.isasyncgen(resp):
  230. await print_stream(resp)
  231. else:
  232. print(resp)
  233. # Perform local search
  234. print("\n=====================")
  235. print("Query mode: local")
  236. print("=====================")
  237. resp = await rag.aquery(
  238. "What are the top themes in this story?",
  239. param=QueryParam(mode="local", stream=True),
  240. )
  241. if inspect.isasyncgen(resp):
  242. await print_stream(resp)
  243. else:
  244. print(resp)
  245. # Perform global search
  246. print("\n=====================")
  247. print("Query mode: global")
  248. print("=====================")
  249. resp = await rag.aquery(
  250. "What are the top themes in this story?",
  251. param=QueryParam(mode="global", stream=True),
  252. )
  253. if inspect.isasyncgen(resp):
  254. await print_stream(resp)
  255. else:
  256. print(resp)
  257. # Perform hybrid search
  258. print("\n=====================")
  259. print("Query mode: hybrid")
  260. print("=====================")
  261. resp = await rag.aquery(
  262. "What are the top themes in this story?",
  263. param=QueryParam(mode="hybrid", stream=True),
  264. )
  265. if inspect.isasyncgen(resp):
  266. await print_stream(resp)
  267. else:
  268. print(resp)
  269. """ FOR TESTING (if you want to test straight away, after building. Uncomment this part"""
  270. """
  271. print("\n" + "=" * 60)
  272. print("AI ASSISTANT READY!")
  273. print("Ask questions about (your uploaded) regulations")
  274. print("Type 'quit' to exit")
  275. print("=" * 60)
  276. while True:
  277. question = input("\n🔥 Your question: ")
  278. if question.lower() in ['quit', 'exit', 'bye']:
  279. break
  280. print("\nThinking...")
  281. response = await rag.aquery(question, param=QueryParam(mode="hybrid"))
  282. print(f"\nAnswer: {response}")
  283. """
  284. except Exception as e:
  285. print(f"An error occurred: {e}")
  286. finally:
  287. if rag:
  288. await rag.llm_response_cache.index_done_callback()
  289. await rag.finalize_storages()
  290. if __name__ == "__main__":
  291. # Configure logging before running the main function
  292. configure_logging()
  293. asyncio.run(main())
  294. print("\nDone!")