offline_retrieval_check.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. #!/usr/bin/env python3
  2. """Offline sanity check for the bundled LightRAG evaluation samples.
  3. The check uses a small deterministic lexical ranker. It does not start
  4. LightRAG, call the API server, compute embeddings, or call LLM/RAGAS services.
  5. """
  6. from __future__ import annotations
  7. import argparse
  8. import json
  9. import math
  10. import re
  11. import sys
  12. from collections import Counter
  13. from dataclasses import dataclass
  14. from pathlib import Path
  15. from typing import Any
  16. EVAL_DIR = Path(__file__).resolve().parent
  17. DEFAULT_DATASET = EVAL_DIR / "sample_dataset.json"
  18. DEFAULT_DOCS_DIR = EVAL_DIR / "sample_documents"
  19. DEFAULT_ORACLE = EVAL_DIR / "sample_retrieval_oracle.json"
  20. STOPWORDS = {
  21. "a",
  22. "an",
  23. "and",
  24. "are",
  25. "as",
  26. "at",
  27. "be",
  28. "by",
  29. "for",
  30. "from",
  31. "how",
  32. "in",
  33. "into",
  34. "is",
  35. "it",
  36. "its",
  37. "of",
  38. "on",
  39. "or",
  40. "that",
  41. "the",
  42. "their",
  43. "to",
  44. "what",
  45. "with",
  46. }
  47. @dataclass
  48. class Document:
  49. name: str
  50. tokens: Counter[str]
  51. @dataclass
  52. class QueryResult:
  53. question: str
  54. expected: list[str]
  55. ranked: list[str]
  56. def recall_at(self, top_k: int) -> float:
  57. hits = set(self.expected) & set(self.ranked[:top_k])
  58. return len(hits) / len(self.expected)
  59. def reciprocal_rank(self) -> float:
  60. for rank, document in enumerate(self.ranked, start=1):
  61. if document in self.expected:
  62. return 1 / rank
  63. return 0.0
  64. def tokenize(text: str) -> list[str]:
  65. tokens = re.findall(r"[a-z0-9]+", text.lower())
  66. return [token for token in tokens if token not in STOPWORDS and len(token) > 1]
  67. def load_cases(dataset_path: Path) -> list[dict[str, Any]]:
  68. payload = json.loads(dataset_path.read_text(encoding="utf-8"))
  69. cases = payload.get("test_cases")
  70. if not isinstance(cases, list):
  71. raise ValueError(f"{dataset_path} must contain a test_cases list")
  72. return cases
  73. def load_oracle(oracle_path: Path) -> dict[str, list[str]]:
  74. payload = json.loads(oracle_path.read_text(encoding="utf-8"))
  75. entries = payload.get("oracle")
  76. if not isinstance(entries, list):
  77. raise ValueError(f"{oracle_path} must contain an oracle list")
  78. oracle: dict[str, list[str]] = {}
  79. for entry in entries:
  80. question = str(entry.get("question", "")).strip()
  81. expected = entry.get("expected_documents")
  82. if not question or not isinstance(expected, list) or not expected:
  83. raise ValueError("Each oracle entry needs question and expected_documents")
  84. oracle[question] = [str(document) for document in expected]
  85. return oracle
  86. def load_documents(docs_dir: Path) -> list[Document]:
  87. documents: list[Document] = []
  88. for path in sorted(docs_dir.glob("*.md")):
  89. if path.name.lower() == "readme.md":
  90. continue
  91. documents.append(
  92. Document(
  93. name=path.name,
  94. tokens=Counter(tokenize(path.read_text(encoding="utf-8"))),
  95. )
  96. )
  97. if not documents:
  98. raise ValueError(f"No markdown sample documents found in {docs_dir}")
  99. return documents
  100. def inverse_document_frequency(documents: list[Document]) -> dict[str, float]:
  101. document_frequency: Counter[str] = Counter()
  102. for document in documents:
  103. document_frequency.update(document.tokens.keys())
  104. doc_count = len(documents)
  105. return {
  106. token: math.log((doc_count + 1) / (frequency + 1)) + 1
  107. for token, frequency in document_frequency.items()
  108. }
  109. def score_query(
  110. query_tokens: list[str],
  111. document: Document,
  112. idf: dict[str, float],
  113. ) -> float:
  114. score = 0.0
  115. for token in query_tokens:
  116. if token in document.tokens:
  117. score += (1 + math.log(document.tokens[token])) * idf.get(token, 0.0)
  118. return score
  119. def audit_samples(
  120. cases: list[dict[str, Any]],
  121. oracle: dict[str, list[str]],
  122. documents: list[Document],
  123. ) -> list[QueryResult]:
  124. idf = inverse_document_frequency(documents)
  125. results: list[QueryResult] = []
  126. for case in cases:
  127. question = str(case.get("question", "")).strip()
  128. if question not in oracle:
  129. raise ValueError(f"No oracle entry for question: {question}")
  130. query_tokens = tokenize(question)
  131. scored_documents = [
  132. (score_query(query_tokens, document, idf), document)
  133. for document in documents
  134. ]
  135. ranked = [
  136. document
  137. for score, document in sorted(
  138. scored_documents,
  139. key=lambda item: (-item[0], item[1].name),
  140. )
  141. if score > 0
  142. ]
  143. results.append(
  144. QueryResult(
  145. question=question,
  146. expected=oracle[question],
  147. ranked=[document.name for document in ranked],
  148. )
  149. )
  150. return results
  151. def summarize(results: list[QueryResult], top_k: int) -> dict[str, Any]:
  152. if not results:
  153. raise ValueError("No query results to summarize")
  154. recalls = [result.recall_at(top_k) for result in results]
  155. reciprocal_ranks = [result.reciprocal_rank() for result in results]
  156. return {
  157. "queries": len(results),
  158. "top_k": top_k,
  159. "average_recall_at_k": sum(recalls) / len(recalls),
  160. "mean_reciprocal_rank": sum(reciprocal_ranks) / len(reciprocal_ranks),
  161. "full_recall_queries": sum(recall == 1.0 for recall in recalls),
  162. "no_hit_queries": sum(recall == 0.0 for recall in recalls),
  163. }
  164. def print_report(results: list[QueryResult], top_k: int) -> None:
  165. summary = summarize(results, top_k)
  166. print("LightRAG sample retrieval check")
  167. print(f"Queries: {summary['queries']}")
  168. print(f"Top-k: {summary['top_k']}")
  169. print(f"Average recall@k: {summary['average_recall_at_k']:.3f}")
  170. print(f"Mean reciprocal rank: {summary['mean_reciprocal_rank']:.3f}")
  171. print(f"Full-recall queries: {summary['full_recall_queries']}/{summary['queries']}")
  172. print(f"No-hit queries: {summary['no_hit_queries']}")
  173. print()
  174. for index, result in enumerate(results, start=1):
  175. top_docs = ", ".join(result.ranked[:top_k])
  176. expected = ", ".join(result.expected)
  177. print(f"{index}. recall@{top_k}={result.recall_at(top_k):.3f}")
  178. print(f" expected: {expected}")
  179. print(f" top docs: {top_docs}")
  180. def parse_args(argv: list[str]) -> argparse.Namespace:
  181. parser = argparse.ArgumentParser(
  182. description="Run an offline retrieval check for LightRAG evaluation samples."
  183. )
  184. parser.add_argument("--dataset", default=str(DEFAULT_DATASET))
  185. parser.add_argument("--docs-dir", default=str(DEFAULT_DOCS_DIR))
  186. parser.add_argument("--oracle", default=str(DEFAULT_ORACLE))
  187. parser.add_argument("--top-k", type=int, default=2)
  188. parser.add_argument(
  189. "--strict",
  190. action="store_true",
  191. help="Exit non-zero unless every sample query has full recall@k.",
  192. )
  193. return parser.parse_args(argv)
  194. def main(argv: list[str] | None = None) -> int:
  195. args = parse_args(argv or sys.argv[1:])
  196. if args.top_k <= 0:
  197. print("--top-k must be positive", file=sys.stderr)
  198. return 2
  199. try:
  200. cases = load_cases(Path(args.dataset).expanduser())
  201. oracle = load_oracle(Path(args.oracle).expanduser())
  202. documents = load_documents(Path(args.docs_dir).expanduser())
  203. results = audit_samples(cases, oracle, documents)
  204. print_report(results, args.top_k)
  205. summary = summarize(results, args.top_k)
  206. except (OSError, ValueError, json.JSONDecodeError) as exc:
  207. print(f"Sample retrieval check failed: {exc}", file=sys.stderr)
  208. return 2
  209. if args.strict and summary["full_recall_queries"] != summary["queries"]:
  210. return 1
  211. return 0
  212. if __name__ == "__main__":
  213. raise SystemExit(main())