lightrag_openai_mongodb_graph_demo.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import os
  2. import asyncio
  3. from lightrag import LightRAG, QueryParam
  4. from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
  5. from lightrag.utils import EmbeddingFunc
  6. import numpy as np
  7. #########
  8. # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
  9. # import nest_asyncio
  10. # nest_asyncio.apply()
  11. #########
  12. WORKING_DIR = "./mongodb_test_dir"
  13. if not os.path.exists(WORKING_DIR):
  14. os.mkdir(WORKING_DIR)
  15. os.environ["OPENAI_API_KEY"] = "sk-"
  16. os.environ["MONGO_URI"] = "mongodb://0.0.0.0:27017/?directConnection=true"
  17. os.environ["MONGO_DATABASE"] = "LightRAG"
  18. os.environ["MONGO_KG_COLLECTION"] = "MDB_KG"
  19. # Embedding Configuration and Functions
  20. EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
  21. EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
  22. async def embedding_func(texts: list[str]) -> np.ndarray:
  23. # Note: openai_embed is decorated with @wrap_embedding_func_with_attrs,
  24. # which wraps it in an EmbeddingFunc. Using .func accesses the original
  25. # unwrapped function to avoid double wrapping when we create our own
  26. # EmbeddingFunc with custom configuration in create_embedding_function_instance().
  27. return await openai_embed.func(
  28. texts,
  29. model=EMBEDDING_MODEL,
  30. )
  31. async def get_embedding_dimension():
  32. test_text = ["This is a test sentence."]
  33. embedding = await embedding_func(test_text)
  34. return embedding.shape[1]
  35. async def create_embedding_function_instance():
  36. # Get embedding dimension
  37. embedding_dimension = await get_embedding_dimension()
  38. # Create embedding function instance
  39. return EmbeddingFunc(
  40. embedding_dim=embedding_dimension,
  41. max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
  42. func=embedding_func,
  43. )
  44. async def initialize_rag():
  45. embedding_func_instance = await create_embedding_function_instance()
  46. rag = LightRAG(
  47. working_dir=WORKING_DIR,
  48. llm_model_func=gpt_4o_mini_complete,
  49. embedding_func=embedding_func_instance,
  50. graph_storage="MongoGraphStorage",
  51. log_level="DEBUG",
  52. )
  53. await rag.initialize_storages() # Auto-initializes pipeline_status
  54. return rag
  55. def main():
  56. # Initialize RAG instance
  57. rag = asyncio.run(initialize_rag())
  58. with open("./book.txt", "r", encoding="utf-8") as f:
  59. rag.insert(f.read())
  60. # Perform naive search
  61. print(
  62. rag.query(
  63. "What are the top themes in this story?", param=QueryParam(mode="naive")
  64. )
  65. )
  66. # Perform local search
  67. print(
  68. rag.query(
  69. "What are the top themes in this story?", param=QueryParam(mode="local")
  70. )
  71. )
  72. # Perform global search
  73. print(
  74. rag.query(
  75. "What are the top themes in this story?", param=QueryParam(mode="global")
  76. )
  77. )
  78. # Perform hybrid search
  79. print(
  80. rag.query(
  81. "What are the top themes in this story?", param=QueryParam(mode="hybrid")
  82. )
  83. )
  84. if __name__ == "__main__":
  85. main()