eval_rag_quality.py 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016
  1. #!/usr/bin/env python3
  2. """
  3. RAGAS Evaluation Script for LightRAG System
  4. Evaluates RAG response quality using RAGAS metrics:
  5. - Faithfulness: Is the answer factually accurate based on context?
  6. - Answer Relevance: Is the answer relevant to the question?
  7. - Context Recall: Is all relevant information retrieved?
  8. - Context Precision: Is retrieved context clean without noise?
  9. Usage:
  10. # Use defaults (sample_dataset.json, http://localhost:9621)
  11. python lightrag/evaluation/eval_rag_quality.py
  12. # Specify custom dataset
  13. python lightrag/evaluation/eval_rag_quality.py --dataset my_test.json
  14. python lightrag/evaluation/eval_rag_quality.py -d my_test.json
  15. # Specify custom RAG endpoint
  16. python lightrag/evaluation/eval_rag_quality.py --ragendpoint http://my-server.com:9621
  17. python lightrag/evaluation/eval_rag_quality.py -r http://my-server.com:9621
  18. # Specify both
  19. python lightrag/evaluation/eval_rag_quality.py -d my_test.json -r http://localhost:9621
  20. # Get help
  21. python lightrag/evaluation/eval_rag_quality.py --help
  22. Results are saved to: lightrag/evaluation/results/
  23. - results_YYYYMMDD_HHMMSS.csv (CSV export for analysis)
  24. - results_YYYYMMDD_HHMMSS.json (Full results with details)
  25. Technical Notes:
  26. - Uses stable RAGAS API (LangchainLLMWrapper) for maximum compatibility
  27. - Supports custom OpenAI-compatible endpoints via EVAL_LLM_BINDING_HOST
  28. - Enables bypass_n mode for endpoints that don't support 'n' parameter
  29. - Deprecation warnings are suppressed for cleaner output
  30. """
  31. import argparse
  32. import asyncio
  33. import csv
  34. import json
  35. import math
  36. import os
  37. import sys
  38. import time
  39. import warnings
  40. from datetime import datetime
  41. from pathlib import Path
  42. from typing import Any, Dict, List
  43. import httpx
  44. from dotenv import load_dotenv
  45. from lightrag.utils import logger
  46. # Suppress LangchainLLMWrapper deprecation warning
  47. # We use LangchainLLMWrapper for stability and compatibility with all RAGAS versions
  48. warnings.filterwarnings(
  49. "ignore",
  50. message=".*LangchainLLMWrapper is deprecated.*",
  51. category=DeprecationWarning,
  52. )
  53. # Suppress token usage warning for custom OpenAI-compatible endpoints
  54. # Custom endpoints (vLLM, SGLang, etc.) often don't return usage information
  55. # This is non-critical as token tracking is not required for RAGAS evaluation
  56. warnings.filterwarnings(
  57. "ignore",
  58. message=".*Unexpected type for token usage.*",
  59. category=UserWarning,
  60. )
  61. # Add parent directory to path
  62. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  63. # use the .env that is inside the current folder
  64. # allows to use different .env file for each lightrag instance
  65. # the OS environment variables take precedence over the .env file
  66. load_dotenv(dotenv_path=".env", override=False)
  67. # Conditional imports - will raise ImportError if dependencies not installed
  68. try:
  69. from datasets import Dataset
  70. from ragas import evaluate
  71. from ragas.metrics import (
  72. AnswerRelevancy,
  73. ContextPrecision,
  74. ContextRecall,
  75. Faithfulness,
  76. )
  77. from ragas.llms import LangchainLLMWrapper
  78. from langchain_openai import ChatOpenAI, OpenAIEmbeddings
  79. from tqdm.auto import tqdm
  80. RAGAS_AVAILABLE = True
  81. except ImportError:
  82. RAGAS_AVAILABLE = False
  83. Dataset = None
  84. evaluate = None
  85. LangchainLLMWrapper = None
  86. CONNECT_TIMEOUT_SECONDS = 180.0
  87. READ_TIMEOUT_SECONDS = 300.0
  88. TOTAL_TIMEOUT_SECONDS = 180.0
  89. def _is_nan(value: Any) -> bool:
  90. """Return True when value is a float NaN."""
  91. return isinstance(value, float) and math.isnan(value)
  92. class RAGEvaluator:
  93. """Evaluate RAG system quality using RAGAS metrics"""
  94. def __init__(self, test_dataset_path: str = None, rag_api_url: str = None):
  95. """
  96. Initialize evaluator with test dataset
  97. Args:
  98. test_dataset_path: Path to test dataset JSON file
  99. rag_api_url: Base URL of LightRAG API (e.g., http://localhost:9621)
  100. If None, will try to read from environment or use default
  101. Environment Variables:
  102. EVAL_LLM_MODEL: LLM model for evaluation (default: gpt-4o-mini)
  103. EVAL_EMBEDDING_MODEL: Embedding model for evaluation (default: text-embedding-3-small)
  104. EVAL_LLM_BINDING_API_KEY: API key for LLM (fallback to OPENAI_API_KEY)
  105. EVAL_LLM_BINDING_HOST: Custom endpoint URL for LLM (optional)
  106. EVAL_EMBEDDING_BINDING_API_KEY: API key for embeddings (fallback: EVAL_LLM_BINDING_API_KEY -> OPENAI_API_KEY)
  107. EVAL_EMBEDDING_BINDING_HOST: Custom endpoint URL for embeddings (fallback: EVAL_LLM_BINDING_HOST)
  108. Raises:
  109. ImportError: If ragas or datasets packages are not installed
  110. EnvironmentError: If EVAL_LLM_BINDING_API_KEY and OPENAI_API_KEY are both not set
  111. """
  112. # Validate RAGAS dependencies are installed
  113. if not RAGAS_AVAILABLE:
  114. raise ImportError(
  115. "RAGAS dependencies not installed. "
  116. "Install with: pip install ragas datasets"
  117. )
  118. # Configure evaluation LLM (for RAGAS scoring)
  119. eval_llm_api_key = os.getenv("EVAL_LLM_BINDING_API_KEY") or os.getenv(
  120. "OPENAI_API_KEY"
  121. )
  122. if not eval_llm_api_key:
  123. raise EnvironmentError(
  124. "EVAL_LLM_BINDING_API_KEY or OPENAI_API_KEY is required for evaluation. "
  125. "Set EVAL_LLM_BINDING_API_KEY to use a custom API key, "
  126. "or ensure OPENAI_API_KEY is set."
  127. )
  128. eval_model = os.getenv("EVAL_LLM_MODEL", "gpt-4o-mini")
  129. eval_llm_base_url = os.getenv("EVAL_LLM_BINDING_HOST")
  130. # Configure evaluation embeddings (for RAGAS scoring)
  131. # Fallback chain: EVAL_EMBEDDING_BINDING_API_KEY -> EVAL_LLM_BINDING_API_KEY -> OPENAI_API_KEY
  132. eval_embedding_api_key = (
  133. os.getenv("EVAL_EMBEDDING_BINDING_API_KEY")
  134. or os.getenv("EVAL_LLM_BINDING_API_KEY")
  135. or os.getenv("OPENAI_API_KEY")
  136. )
  137. eval_embedding_model = os.getenv(
  138. "EVAL_EMBEDDING_MODEL", "text-embedding-3-large"
  139. )
  140. # Fallback chain: EVAL_EMBEDDING_BINDING_HOST -> EVAL_LLM_BINDING_HOST -> None
  141. eval_embedding_base_url = os.getenv("EVAL_EMBEDDING_BINDING_HOST") or os.getenv(
  142. "EVAL_LLM_BINDING_HOST"
  143. )
  144. # Create LLM and Embeddings instances for RAGAS
  145. llm_kwargs = {
  146. "model": eval_model,
  147. "api_key": eval_llm_api_key,
  148. "max_retries": int(os.getenv("EVAL_LLM_MAX_RETRIES", "5")),
  149. "request_timeout": int(os.getenv("EVAL_LLM_TIMEOUT", "180")),
  150. }
  151. embedding_kwargs = {
  152. "model": eval_embedding_model,
  153. "api_key": eval_embedding_api_key,
  154. }
  155. if eval_llm_base_url:
  156. llm_kwargs["base_url"] = eval_llm_base_url
  157. if eval_embedding_base_url:
  158. embedding_kwargs["base_url"] = eval_embedding_base_url
  159. # Create base LangChain LLM
  160. base_llm = ChatOpenAI(**llm_kwargs)
  161. self.eval_embeddings = OpenAIEmbeddings(**embedding_kwargs)
  162. # Wrap LLM with LangchainLLMWrapper and enable bypass_n mode for custom endpoints
  163. # This ensures compatibility with endpoints that don't support the 'n' parameter
  164. # by generating multiple outputs through repeated prompts instead of using 'n' parameter
  165. try:
  166. self.eval_llm = LangchainLLMWrapper(
  167. langchain_llm=base_llm,
  168. bypass_n=True, # Enable bypass_n to avoid passing 'n' to OpenAI API
  169. )
  170. logger.debug("Successfully configured bypass_n mode for LLM wrapper")
  171. except Exception as e:
  172. logger.warning(
  173. "Could not configure LangchainLLMWrapper with bypass_n: %s. "
  174. "Using base LLM directly, which may cause warnings with custom endpoints.",
  175. e,
  176. )
  177. self.eval_llm = base_llm
  178. if test_dataset_path is None:
  179. test_dataset_path = Path(__file__).parent / "sample_dataset.json"
  180. if rag_api_url is None:
  181. rag_api_url = os.getenv("LIGHTRAG_API_URL", "http://localhost:9621")
  182. self.test_dataset_path = Path(test_dataset_path)
  183. self.rag_api_url = rag_api_url.rstrip("/")
  184. self.results_dir = Path(__file__).parent / "results"
  185. self.results_dir.mkdir(exist_ok=True)
  186. # Load test dataset
  187. self.test_cases = self._load_test_dataset()
  188. # Store configuration values for display
  189. self.eval_model = eval_model
  190. self.eval_embedding_model = eval_embedding_model
  191. self.eval_llm_base_url = eval_llm_base_url
  192. self.eval_embedding_base_url = eval_embedding_base_url
  193. self.eval_max_retries = llm_kwargs["max_retries"]
  194. self.eval_timeout = llm_kwargs["request_timeout"]
  195. # Display configuration
  196. self._display_configuration()
  197. def _display_configuration(self):
  198. """Display all evaluation configuration settings"""
  199. logger.info("Evaluation Models:")
  200. logger.info(" • LLM Model: %s", self.eval_model)
  201. logger.info(" • Embedding Model: %s", self.eval_embedding_model)
  202. # Display LLM endpoint
  203. if self.eval_llm_base_url:
  204. logger.info(" • LLM Endpoint: %s", self.eval_llm_base_url)
  205. logger.info(
  206. " • Bypass N-Parameter: Enabled (use LangchainLLMWrapper for compatibility)"
  207. )
  208. else:
  209. logger.info(" • LLM Endpoint: OpenAI Official API")
  210. # Display Embedding endpoint (only if different from LLM)
  211. if self.eval_embedding_base_url:
  212. if self.eval_embedding_base_url != self.eval_llm_base_url:
  213. logger.info(
  214. " • Embedding Endpoint: %s", self.eval_embedding_base_url
  215. )
  216. # If same as LLM endpoint, no need to display separately
  217. elif not self.eval_llm_base_url:
  218. # Both using OpenAI - already displayed above
  219. pass
  220. else:
  221. # LLM uses custom endpoint, but embeddings use OpenAI
  222. logger.info(" • Embedding Endpoint: OpenAI Official API")
  223. logger.info("Concurrency & Rate Limiting:")
  224. query_top_k = int(os.getenv("EVAL_QUERY_TOP_K", "10"))
  225. logger.info(" • Query Top-K: %s Entities/Relations", query_top_k)
  226. logger.info(" • LLM Max Retries: %s", self.eval_max_retries)
  227. logger.info(" • LLM Timeout: %s seconds", self.eval_timeout)
  228. logger.info("Test Configuration:")
  229. logger.info(" • Total Test Cases: %s", len(self.test_cases))
  230. logger.info(" • Test Dataset: %s", self.test_dataset_path.name)
  231. logger.info(" • LightRAG API: %s", self.rag_api_url)
  232. logger.info(" • Results Directory: %s", self.results_dir.name)
  233. def _load_test_dataset(self) -> List[Dict[str, str]]:
  234. """Load test cases from JSON file"""
  235. if not self.test_dataset_path.exists():
  236. raise FileNotFoundError(f"Test dataset not found: {self.test_dataset_path}")
  237. with open(self.test_dataset_path) as f:
  238. data = json.load(f)
  239. return data.get("test_cases", [])
  240. async def generate_rag_response(
  241. self,
  242. question: str,
  243. client: httpx.AsyncClient,
  244. ) -> Dict[str, Any]:
  245. """
  246. Generate RAG response by calling LightRAG API.
  247. Args:
  248. question: The user query.
  249. client: Shared httpx AsyncClient for connection pooling.
  250. Returns:
  251. Dictionary with 'answer' and 'contexts' keys.
  252. 'contexts' is a list of strings (one per retrieved document).
  253. Raises:
  254. Exception: If LightRAG API is unavailable.
  255. """
  256. try:
  257. payload = {
  258. "query": question,
  259. "mode": "mix",
  260. "include_references": True,
  261. "include_chunk_content": True, # NEW: Request chunk content in references
  262. "response_type": "Multiple Paragraphs",
  263. "top_k": int(os.getenv("EVAL_QUERY_TOP_K", "10")),
  264. }
  265. # Get API key from environment for authentication
  266. api_key = os.getenv("LIGHTRAG_API_KEY")
  267. # Prepare headers with optional authentication
  268. headers = {}
  269. if api_key:
  270. headers["X-API-Key"] = api_key
  271. # Single optimized API call - gets both answer AND chunk content
  272. response = await client.post(
  273. f"{self.rag_api_url}/query",
  274. json=payload,
  275. headers=headers if headers else None,
  276. )
  277. response.raise_for_status()
  278. result = response.json()
  279. answer = result.get("response", "No response generated")
  280. references = result.get("references", [])
  281. # DEBUG: Inspect the API response
  282. logger.debug("🔍 References Count: %s", len(references))
  283. if references:
  284. first_ref = references[0]
  285. logger.debug("🔍 First Reference Keys: %s", list(first_ref.keys()))
  286. if "content" in first_ref:
  287. content_preview = first_ref["content"]
  288. if isinstance(content_preview, list) and content_preview:
  289. logger.debug(
  290. "🔍 Content Preview (first chunk): %s...",
  291. content_preview[0][:100],
  292. )
  293. elif isinstance(content_preview, str):
  294. logger.debug("🔍 Content Preview: %s...", content_preview[:100])
  295. # Extract chunk content from enriched references
  296. # Note: content is now a list of chunks per reference (one file may have multiple chunks)
  297. contexts = []
  298. for ref in references:
  299. content = ref.get("content", [])
  300. if isinstance(content, list):
  301. # Flatten the list: each chunk becomes a separate context
  302. contexts.extend(content)
  303. elif isinstance(content, str):
  304. # Backward compatibility: if content is still a string (shouldn't happen)
  305. contexts.append(content)
  306. return {
  307. "answer": answer,
  308. "contexts": contexts, # List of strings from actual retrieved chunks
  309. }
  310. except httpx.ConnectError as e:
  311. raise Exception(
  312. f"❌ Cannot connect to LightRAG API at {self.rag_api_url}\n"
  313. f" Make sure LightRAG server is running:\n"
  314. f" python -m lightrag.api.lightrag_server\n"
  315. f" Error: {str(e)}"
  316. )
  317. except httpx.HTTPStatusError as e:
  318. raise Exception(
  319. f"LightRAG API error {e.response.status_code}: {e.response.text}"
  320. )
  321. except httpx.ReadTimeout as e:
  322. raise Exception(
  323. f"Request timeout after waiting for response\n"
  324. f" Question: {question[:100]}...\n"
  325. f" Error: {str(e)}"
  326. )
  327. except Exception as e:
  328. raise Exception(f"Error calling LightRAG API: {type(e).__name__}: {str(e)}")
  329. async def evaluate_single_case(
  330. self,
  331. idx: int,
  332. test_case: Dict[str, str],
  333. rag_semaphore: asyncio.Semaphore,
  334. eval_semaphore: asyncio.Semaphore,
  335. client: httpx.AsyncClient,
  336. progress_counter: Dict[str, int],
  337. position_pool: asyncio.Queue,
  338. pbar_creation_lock: asyncio.Lock,
  339. ) -> Dict[str, Any]:
  340. """
  341. Evaluate a single test case with two-stage pipeline concurrency control
  342. Args:
  343. idx: Test case index (1-based)
  344. test_case: Test case dictionary with question and ground_truth
  345. rag_semaphore: Semaphore to control overall concurrency (covers entire function)
  346. eval_semaphore: Semaphore to control RAGAS evaluation concurrency (Stage 2)
  347. client: Shared httpx AsyncClient for connection pooling
  348. progress_counter: Shared dictionary for progress tracking
  349. position_pool: Queue of available tqdm position indices
  350. pbar_creation_lock: Lock to serialize tqdm creation and prevent race conditions
  351. Returns:
  352. Evaluation result dictionary
  353. """
  354. # rag_semaphore controls the entire evaluation process to prevent
  355. # all RAG responses from being generated at once when eval is slow
  356. async with rag_semaphore:
  357. question = test_case["question"]
  358. ground_truth = test_case["ground_truth"]
  359. # Stage 1: Generate RAG response
  360. try:
  361. rag_response = await self.generate_rag_response(
  362. question=question, client=client
  363. )
  364. except Exception as e:
  365. logger.error("Error generating response for test %s: %s", idx, str(e))
  366. progress_counter["completed"] += 1
  367. return {
  368. "test_number": idx,
  369. "question": question,
  370. "error": str(e),
  371. "metrics": {},
  372. "ragas_score": 0,
  373. "timestamp": datetime.now().isoformat(),
  374. }
  375. # *** CRITICAL FIX: Use actual retrieved contexts, NOT ground_truth ***
  376. retrieved_contexts = rag_response["contexts"]
  377. # Prepare dataset for RAGAS evaluation with CORRECT contexts
  378. eval_dataset = Dataset.from_dict(
  379. {
  380. "question": [question],
  381. "answer": [rag_response["answer"]],
  382. "contexts": [retrieved_contexts],
  383. "ground_truth": [ground_truth],
  384. }
  385. )
  386. # Stage 2: Run RAGAS evaluation (controlled by eval_semaphore)
  387. # IMPORTANT: Create fresh metric instances for each evaluation to avoid
  388. # concurrent state conflicts when multiple tasks run in parallel
  389. async with eval_semaphore:
  390. pbar = None
  391. position = None
  392. try:
  393. # Acquire a position from the pool for this tqdm progress bar
  394. position = await position_pool.get()
  395. # Serialize tqdm creation to prevent race conditions
  396. # Multiple tasks creating tqdm simultaneously can cause display conflicts
  397. async with pbar_creation_lock:
  398. # Create tqdm progress bar with assigned position to avoid overlapping
  399. # leave=False ensures the progress bar is cleared after completion,
  400. # preventing accumulation of completed bars and allowing position reuse
  401. pbar = tqdm(
  402. total=4,
  403. desc=f"Eval-{idx:02d}",
  404. position=position,
  405. leave=False,
  406. )
  407. # Give tqdm time to initialize and claim its screen position
  408. await asyncio.sleep(0.05)
  409. eval_results = evaluate(
  410. dataset=eval_dataset,
  411. metrics=[
  412. Faithfulness(),
  413. AnswerRelevancy(),
  414. ContextRecall(),
  415. ContextPrecision(),
  416. ],
  417. llm=self.eval_llm,
  418. embeddings=self.eval_embeddings,
  419. _pbar=pbar,
  420. )
  421. # Convert to DataFrame (RAGAS v0.3+ API)
  422. df = eval_results.to_pandas()
  423. # Extract scores from first row
  424. scores_row = df.iloc[0]
  425. # Extract scores (RAGAS v0.3+ uses .to_pandas())
  426. result = {
  427. "test_number": idx,
  428. "question": question,
  429. "answer": rag_response["answer"][:200] + "..."
  430. if len(rag_response["answer"]) > 200
  431. else rag_response["answer"],
  432. "ground_truth": ground_truth[:200] + "..."
  433. if len(ground_truth) > 200
  434. else ground_truth,
  435. "project": test_case.get("project", "unknown"),
  436. "metrics": {
  437. "faithfulness": float(scores_row.get("faithfulness", 0)),
  438. "answer_relevance": float(
  439. scores_row.get("answer_relevancy", 0)
  440. ),
  441. "context_recall": float(
  442. scores_row.get("context_recall", 0)
  443. ),
  444. "context_precision": float(
  445. scores_row.get("context_precision", 0)
  446. ),
  447. },
  448. "timestamp": datetime.now().isoformat(),
  449. }
  450. # Calculate RAGAS score (average of all metrics, excluding NaN values)
  451. metrics = result["metrics"]
  452. valid_metrics = [v for v in metrics.values() if not _is_nan(v)]
  453. ragas_score = (
  454. sum(valid_metrics) / len(valid_metrics) if valid_metrics else 0
  455. )
  456. result["ragas_score"] = round(ragas_score, 4)
  457. # Update progress counter
  458. progress_counter["completed"] += 1
  459. return result
  460. except Exception as e:
  461. logger.error("Error evaluating test %s: %s", idx, str(e))
  462. progress_counter["completed"] += 1
  463. return {
  464. "test_number": idx,
  465. "question": question,
  466. "error": str(e),
  467. "metrics": {},
  468. "ragas_score": 0,
  469. "timestamp": datetime.now().isoformat(),
  470. }
  471. finally:
  472. # Force close progress bar to ensure completion
  473. if pbar is not None:
  474. pbar.close()
  475. # Release the position back to the pool for reuse
  476. if position is not None:
  477. await position_pool.put(position)
  478. async def evaluate_responses(self) -> List[Dict[str, Any]]:
  479. """
  480. Evaluate all test cases in parallel with two-stage pipeline and return metrics
  481. Returns:
  482. List of evaluation results with metrics
  483. """
  484. # Get evaluation concurrency from environment (default to 2 for parallel evaluation)
  485. max_async = int(os.getenv("EVAL_MAX_CONCURRENT", "2"))
  486. logger.info("%s", "=" * 70)
  487. logger.info("🚀 Starting RAGAS Evaluation of LightRAG System")
  488. logger.info("🔧 RAGAS Evaluation (Stage 2): %s concurrent", max_async)
  489. logger.info("%s", "=" * 70)
  490. # Create two-stage pipeline semaphores
  491. # Stage 1: RAG generation - allow x2 concurrency to keep evaluation fed
  492. rag_semaphore = asyncio.Semaphore(max_async * 2)
  493. # Stage 2: RAGAS evaluation - primary bottleneck
  494. eval_semaphore = asyncio.Semaphore(max_async)
  495. # Create progress counter (shared across all tasks)
  496. progress_counter = {"completed": 0}
  497. # Create position pool for tqdm progress bars
  498. # Positions range from 0 to max_async-1, ensuring no overlapping displays
  499. position_pool = asyncio.Queue()
  500. for i in range(max_async):
  501. await position_pool.put(i)
  502. # Create lock to serialize tqdm creation and prevent race conditions
  503. # This ensures progress bars are created one at a time, avoiding display conflicts
  504. pbar_creation_lock = asyncio.Lock()
  505. # Create shared HTTP client with connection pooling and proper timeouts
  506. # Timeout: 3 minutes for connect, 5 minutes for read (LLM can be slow)
  507. timeout = httpx.Timeout(
  508. TOTAL_TIMEOUT_SECONDS,
  509. connect=CONNECT_TIMEOUT_SECONDS,
  510. read=READ_TIMEOUT_SECONDS,
  511. )
  512. limits = httpx.Limits(
  513. max_connections=(max_async + 1) * 2, # Allow buffer for RAG stage
  514. max_keepalive_connections=max_async + 1,
  515. )
  516. async with httpx.AsyncClient(timeout=timeout, limits=limits) as client:
  517. # Create tasks for all test cases
  518. tasks = [
  519. self.evaluate_single_case(
  520. idx,
  521. test_case,
  522. rag_semaphore,
  523. eval_semaphore,
  524. client,
  525. progress_counter,
  526. position_pool,
  527. pbar_creation_lock,
  528. )
  529. for idx, test_case in enumerate(self.test_cases, 1)
  530. ]
  531. # Run all evaluations in parallel (limited by two-stage semaphores)
  532. results = await asyncio.gather(*tasks)
  533. return list(results)
  534. def _export_to_csv(self, results: List[Dict[str, Any]]) -> Path:
  535. """
  536. Export evaluation results to CSV file
  537. Args:
  538. results: List of evaluation results
  539. Returns:
  540. Path to the CSV file
  541. CSV Format:
  542. - question: The test question
  543. - project: Project context
  544. - faithfulness: Faithfulness score (0-1)
  545. - answer_relevance: Answer relevance score (0-1)
  546. - context_recall: Context recall score (0-1)
  547. - context_precision: Context precision score (0-1)
  548. - ragas_score: Overall RAGAS score (0-1)
  549. - timestamp: When evaluation was run
  550. """
  551. csv_path = (
  552. self.results_dir / f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
  553. )
  554. with open(csv_path, "w", newline="", encoding="utf-8") as f:
  555. fieldnames = [
  556. "test_number",
  557. "question",
  558. "project",
  559. "faithfulness",
  560. "answer_relevance",
  561. "context_recall",
  562. "context_precision",
  563. "ragas_score",
  564. "status",
  565. "timestamp",
  566. ]
  567. writer = csv.DictWriter(f, fieldnames=fieldnames)
  568. writer.writeheader()
  569. for idx, result in enumerate(results, 1):
  570. metrics = result.get("metrics", {})
  571. writer.writerow(
  572. {
  573. "test_number": idx,
  574. "question": result.get("question", ""),
  575. "project": result.get("project", "unknown"),
  576. "faithfulness": f"{metrics.get('faithfulness', 0):.4f}",
  577. "answer_relevance": f"{metrics.get('answer_relevance', 0):.4f}",
  578. "context_recall": f"{metrics.get('context_recall', 0):.4f}",
  579. "context_precision": f"{metrics.get('context_precision', 0):.4f}",
  580. "ragas_score": f"{result.get('ragas_score', 0):.4f}",
  581. "status": "success" if metrics else "error",
  582. "timestamp": result.get("timestamp", ""),
  583. }
  584. )
  585. return csv_path
  586. def _format_metric(self, value: float, width: int = 6) -> str:
  587. """
  588. Format a metric value for display, handling NaN gracefully
  589. Args:
  590. value: The metric value to format
  591. width: The width of the formatted string
  592. Returns:
  593. Formatted string (e.g., "0.8523" or " N/A ")
  594. """
  595. if _is_nan(value):
  596. return "N/A".center(width)
  597. return f"{value:.4f}".rjust(width)
  598. def _display_results_table(self, results: List[Dict[str, Any]]):
  599. """
  600. Display evaluation results in a formatted table
  601. Args:
  602. results: List of evaluation results
  603. """
  604. logger.info("")
  605. logger.info("%s", "=" * 115)
  606. logger.info("📊 EVALUATION RESULTS SUMMARY")
  607. logger.info("%s", "=" * 115)
  608. # Table header
  609. logger.info(
  610. "%-4s | %-50s | %6s | %7s | %6s | %7s | %6s | %6s",
  611. "#",
  612. "Question",
  613. "Faith",
  614. "AnswRel",
  615. "CtxRec",
  616. "CtxPrec",
  617. "RAGAS",
  618. "Status",
  619. )
  620. logger.info("%s", "-" * 115)
  621. # Table rows
  622. for result in results:
  623. test_num = result.get("test_number", 0)
  624. question = result.get("question", "")
  625. # Truncate question to 50 chars
  626. question_display = (
  627. (question[:47] + "...") if len(question) > 50 else question
  628. )
  629. metrics = result.get("metrics", {})
  630. if metrics:
  631. # Success case - format each metric, handling NaN values
  632. faith = metrics.get("faithfulness", 0)
  633. ans_rel = metrics.get("answer_relevance", 0)
  634. ctx_rec = metrics.get("context_recall", 0)
  635. ctx_prec = metrics.get("context_precision", 0)
  636. ragas = result.get("ragas_score", 0)
  637. status = "✓"
  638. logger.info(
  639. "%-4d | %-50s | %s | %s | %s | %s | %s | %6s",
  640. test_num,
  641. question_display,
  642. self._format_metric(faith, 6),
  643. self._format_metric(ans_rel, 7),
  644. self._format_metric(ctx_rec, 6),
  645. self._format_metric(ctx_prec, 7),
  646. self._format_metric(ragas, 6),
  647. status,
  648. )
  649. else:
  650. # Error case
  651. error = result.get("error", "Unknown error")
  652. error_display = (error[:20] + "...") if len(error) > 23 else error
  653. logger.info(
  654. "%-4d | %-50s | %6s | %7s | %6s | %7s | %6s | ✗ %s",
  655. test_num,
  656. question_display,
  657. "N/A",
  658. "N/A",
  659. "N/A",
  660. "N/A",
  661. "N/A",
  662. error_display,
  663. )
  664. logger.info("%s", "=" * 115)
  665. def _calculate_benchmark_stats(
  666. self, results: List[Dict[str, Any]]
  667. ) -> Dict[str, Any]:
  668. """
  669. Calculate benchmark statistics from evaluation results
  670. Args:
  671. results: List of evaluation results
  672. Returns:
  673. Dictionary with benchmark statistics
  674. """
  675. # Filter out results with errors
  676. valid_results = [r for r in results if r.get("metrics")]
  677. total_tests = len(results)
  678. successful_tests = len(valid_results)
  679. failed_tests = total_tests - successful_tests
  680. if not valid_results:
  681. return {
  682. "total_tests": total_tests,
  683. "successful_tests": 0,
  684. "failed_tests": failed_tests,
  685. "success_rate": 0.0,
  686. }
  687. # Calculate averages for each metric (handling NaN values correctly)
  688. # Track both sum and count for each metric to handle NaN values properly
  689. metrics_data = {
  690. "faithfulness": {"sum": 0.0, "count": 0},
  691. "answer_relevance": {"sum": 0.0, "count": 0},
  692. "context_recall": {"sum": 0.0, "count": 0},
  693. "context_precision": {"sum": 0.0, "count": 0},
  694. "ragas_score": {"sum": 0.0, "count": 0},
  695. }
  696. for result in valid_results:
  697. metrics = result.get("metrics", {})
  698. # For each metric, sum non-NaN values and count them
  699. faithfulness = metrics.get("faithfulness", 0)
  700. if not _is_nan(faithfulness):
  701. metrics_data["faithfulness"]["sum"] += faithfulness
  702. metrics_data["faithfulness"]["count"] += 1
  703. answer_relevance = metrics.get("answer_relevance", 0)
  704. if not _is_nan(answer_relevance):
  705. metrics_data["answer_relevance"]["sum"] += answer_relevance
  706. metrics_data["answer_relevance"]["count"] += 1
  707. context_recall = metrics.get("context_recall", 0)
  708. if not _is_nan(context_recall):
  709. metrics_data["context_recall"]["sum"] += context_recall
  710. metrics_data["context_recall"]["count"] += 1
  711. context_precision = metrics.get("context_precision", 0)
  712. if not _is_nan(context_precision):
  713. metrics_data["context_precision"]["sum"] += context_precision
  714. metrics_data["context_precision"]["count"] += 1
  715. ragas_score = result.get("ragas_score", 0)
  716. if not _is_nan(ragas_score):
  717. metrics_data["ragas_score"]["sum"] += ragas_score
  718. metrics_data["ragas_score"]["count"] += 1
  719. # Calculate averages using actual counts for each metric
  720. avg_metrics = {}
  721. for metric_name, data in metrics_data.items():
  722. if data["count"] > 0:
  723. avg_val = data["sum"] / data["count"]
  724. avg_metrics[metric_name] = (
  725. round(avg_val, 4) if not _is_nan(avg_val) else 0.0
  726. )
  727. else:
  728. avg_metrics[metric_name] = 0.0
  729. # Find min and max RAGAS scores (filter out NaN)
  730. ragas_scores = []
  731. for r in valid_results:
  732. score = r.get("ragas_score", 0)
  733. if _is_nan(score):
  734. continue # Skip NaN values
  735. ragas_scores.append(score)
  736. min_score = min(ragas_scores) if ragas_scores else 0
  737. max_score = max(ragas_scores) if ragas_scores else 0
  738. return {
  739. "total_tests": total_tests,
  740. "successful_tests": successful_tests,
  741. "failed_tests": failed_tests,
  742. "success_rate": round(successful_tests / total_tests * 100, 2),
  743. "average_metrics": avg_metrics,
  744. "min_ragas_score": round(min_score, 4),
  745. "max_ragas_score": round(max_score, 4),
  746. }
  747. async def run(self) -> Dict[str, Any]:
  748. """Run complete evaluation pipeline"""
  749. start_time = time.time()
  750. # Evaluate responses
  751. results = await self.evaluate_responses()
  752. elapsed_time = time.time() - start_time
  753. # Calculate benchmark statistics
  754. benchmark_stats = self._calculate_benchmark_stats(results)
  755. # Save results
  756. summary = {
  757. "timestamp": datetime.now().isoformat(),
  758. "total_tests": len(results),
  759. "elapsed_time_seconds": round(elapsed_time, 2),
  760. "benchmark_stats": benchmark_stats,
  761. "results": results,
  762. }
  763. # Display results table
  764. self._display_results_table(results)
  765. # Save JSON results
  766. json_path = (
  767. self.results_dir
  768. / f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
  769. )
  770. with open(json_path, "w") as f:
  771. json.dump(summary, f, indent=2)
  772. # Export to CSV
  773. csv_path = self._export_to_csv(results)
  774. # Print summary
  775. logger.info("")
  776. logger.info("%s", "=" * 70)
  777. logger.info("📊 EVALUATION COMPLETE")
  778. logger.info("%s", "=" * 70)
  779. logger.info("Total Tests: %s", len(results))
  780. logger.info("Successful: %s", benchmark_stats["successful_tests"])
  781. logger.info("Failed: %s", benchmark_stats["failed_tests"])
  782. logger.info("Success Rate: %.2f%%", benchmark_stats["success_rate"])
  783. logger.info("Elapsed Time: %.2f seconds", elapsed_time)
  784. logger.info("Avg Time/Test: %.2f seconds", elapsed_time / len(results))
  785. # Print benchmark metrics
  786. logger.info("")
  787. logger.info("%s", "=" * 70)
  788. logger.info("📈 BENCHMARK RESULTS (Average)")
  789. logger.info("%s", "=" * 70)
  790. avg = benchmark_stats["average_metrics"]
  791. logger.info("Average Faithfulness: %.4f", avg["faithfulness"])
  792. logger.info("Average Answer Relevance: %.4f", avg["answer_relevance"])
  793. logger.info("Average Context Recall: %.4f", avg["context_recall"])
  794. logger.info("Average Context Precision: %.4f", avg["context_precision"])
  795. logger.info("Average RAGAS Score: %.4f", avg["ragas_score"])
  796. logger.info("%s", "-" * 70)
  797. logger.info(
  798. "Min RAGAS Score: %.4f",
  799. benchmark_stats["min_ragas_score"],
  800. )
  801. logger.info(
  802. "Max RAGAS Score: %.4f",
  803. benchmark_stats["max_ragas_score"],
  804. )
  805. logger.info("")
  806. logger.info("%s", "=" * 70)
  807. logger.info("📁 GENERATED FILES")
  808. logger.info("%s", "=" * 70)
  809. logger.info("Results Dir: %s", self.results_dir.absolute())
  810. logger.info(" • CSV: %s", csv_path.name)
  811. logger.info(" • JSON: %s", json_path.name)
  812. logger.info("%s", "=" * 70)
  813. return summary
  814. async def main():
  815. """
  816. Main entry point for RAGAS evaluation
  817. Command-line arguments:
  818. --dataset, -d: Path to test dataset JSON file (default: sample_dataset.json)
  819. --ragendpoint, -r: LightRAG API endpoint URL (default: http://localhost:9621 or $LIGHTRAG_API_URL)
  820. Usage:
  821. python lightrag/evaluation/eval_rag_quality.py
  822. python lightrag/evaluation/eval_rag_quality.py --dataset my_test.json
  823. python lightrag/evaluation/eval_rag_quality.py -d my_test.json -r http://localhost:9621
  824. """
  825. try:
  826. # Parse command-line arguments
  827. parser = argparse.ArgumentParser(
  828. description="RAGAS Evaluation Script for LightRAG System",
  829. formatter_class=argparse.RawDescriptionHelpFormatter,
  830. epilog="""
  831. Examples:
  832. # Use defaults
  833. python lightrag/evaluation/eval_rag_quality.py
  834. # Specify custom dataset
  835. python lightrag/evaluation/eval_rag_quality.py --dataset my_test.json
  836. # Specify custom RAG endpoint
  837. python lightrag/evaluation/eval_rag_quality.py --ragendpoint http://my-server.com:9621
  838. # Specify both
  839. python lightrag/evaluation/eval_rag_quality.py -d my_test.json -r http://localhost:9621
  840. """,
  841. )
  842. parser.add_argument(
  843. "--dataset",
  844. "-d",
  845. type=str,
  846. default=None,
  847. help="Path to test dataset JSON file (default: sample_dataset.json in evaluation directory)",
  848. )
  849. parser.add_argument(
  850. "--ragendpoint",
  851. "-r",
  852. type=str,
  853. default=None,
  854. help="LightRAG API endpoint URL (default: http://localhost:9621 or $LIGHTRAG_API_URL environment variable)",
  855. )
  856. args = parser.parse_args()
  857. logger.info("%s", "=" * 70)
  858. logger.info("🔍 RAGAS Evaluation - Using Real LightRAG API")
  859. logger.info("%s", "=" * 70)
  860. evaluator = RAGEvaluator(
  861. test_dataset_path=args.dataset, rag_api_url=args.ragendpoint
  862. )
  863. await evaluator.run()
  864. except Exception as e:
  865. logger.exception("❌ Error: %s", e)
  866. sys.exit(1)
  867. if __name__ == "__main__":
  868. asyncio.run(main())