Step_3_openai_compatible.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import os
  2. import re
  3. import json
  4. from lightrag import LightRAG, QueryParam
  5. from lightrag.llm.openai import openai_complete_if_cache, openai_embed
  6. from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
  7. import numpy as np
  8. ## For Upstage API
  9. # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
  10. async def llm_model_func(
  11. prompt, system_prompt=None, history_messages=[], **kwargs
  12. ) -> str:
  13. return await openai_complete_if_cache(
  14. "solar-mini",
  15. prompt,
  16. system_prompt=system_prompt,
  17. history_messages=history_messages,
  18. api_key=os.getenv("UPSTAGE_API_KEY"),
  19. base_url="https://api.upstage.ai/v1/solar",
  20. **kwargs,
  21. )
  22. async def embedding_func(texts: list[str]) -> np.ndarray:
  23. return await openai_embed(
  24. texts,
  25. model="solar-embedding-1-large-query",
  26. api_key=os.getenv("UPSTAGE_API_KEY"),
  27. base_url="https://api.upstage.ai/v1/solar",
  28. )
  29. ## /For Upstage API
  30. def extract_queries(file_path):
  31. with open(file_path, "r") as f:
  32. data = f.read()
  33. data = data.replace("**", "")
  34. queries = re.findall(r"- Question \d+: (.+)", data)
  35. return queries
  36. async def process_query(query_text, rag_instance, query_param):
  37. try:
  38. result = await rag_instance.aquery(query_text, param=query_param)
  39. return {"query": query_text, "result": result}, None
  40. except Exception as e:
  41. return None, {"query": query_text, "error": str(e)}
  42. def run_queries_and_save_to_json(
  43. queries, rag_instance, query_param, output_file, error_file
  44. ):
  45. loop = always_get_an_event_loop()
  46. with (
  47. open(output_file, "a", encoding="utf-8") as result_file,
  48. open(error_file, "a", encoding="utf-8") as err_file,
  49. ):
  50. result_file.write("[\n")
  51. first_entry = True
  52. for query_text in queries:
  53. result, error = loop.run_until_complete(
  54. process_query(query_text, rag_instance, query_param)
  55. )
  56. if result:
  57. if not first_entry:
  58. result_file.write(",\n")
  59. json.dump(result, result_file, ensure_ascii=False, indent=4)
  60. first_entry = False
  61. elif error:
  62. json.dump(error, err_file, ensure_ascii=False, indent=4)
  63. err_file.write("\n")
  64. result_file.write("\n]")
  65. if __name__ == "__main__":
  66. cls = "mix"
  67. mode = "hybrid"
  68. WORKING_DIR = f"../{cls}"
  69. rag = LightRAG(working_dir=WORKING_DIR)
  70. rag = LightRAG(
  71. working_dir=WORKING_DIR,
  72. llm_model_func=llm_model_func,
  73. embedding_func=EmbeddingFunc(embedding_dim=4096, func=embedding_func),
  74. )
  75. query_param = QueryParam(mode=mode)
  76. base_dir = "../datasets/questions"
  77. queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
  78. run_queries_and_save_to_json(
  79. queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
  80. )