test_evaluation_offline_retrieval_check.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import tempfile
  2. import unittest
  3. from pathlib import Path
  4. from lightrag.evaluation.offline_retrieval_check import (
  5. audit_samples,
  6. load_cases,
  7. load_documents,
  8. load_oracle,
  9. summarize,
  10. )
  11. class OfflineRetrievalCheckTests(unittest.TestCase):
  12. def test_expected_document_ranks_first(self):
  13. with tempfile.TemporaryDirectory() as temp_dir:
  14. root = Path(temp_dir)
  15. docs_dir = root / "docs"
  16. docs_dir.mkdir()
  17. (docs_dir / "alpha.md").write_text(
  18. "Alpha covers vector search and filtering.",
  19. encoding="utf-8",
  20. )
  21. (docs_dir / "beta.md").write_text(
  22. "Beta covers deployment and monitoring.",
  23. encoding="utf-8",
  24. )
  25. dataset = root / "dataset.json"
  26. dataset.write_text(
  27. '{"test_cases":[{"question":"Which file explains vector search?"}]}',
  28. encoding="utf-8",
  29. )
  30. oracle = root / "oracle.json"
  31. oracle.write_text(
  32. '{"oracle":[{"question":"Which file explains vector search?",'
  33. '"expected_documents":["alpha.md"]}]}',
  34. encoding="utf-8",
  35. )
  36. results = audit_samples(
  37. load_cases(dataset),
  38. load_oracle(oracle),
  39. load_documents(docs_dir),
  40. )
  41. summary = summarize(results, top_k=1)
  42. self.assertEqual(results[0].ranked[0], "alpha.md")
  43. self.assertEqual(summary["queries"], 1)
  44. self.assertEqual(summary["average_recall_at_k"], 1.0)
  45. def test_zero_score_documents_do_not_count_as_hits(self):
  46. with tempfile.TemporaryDirectory() as temp_dir:
  47. root = Path(temp_dir)
  48. docs_dir = root / "docs"
  49. docs_dir.mkdir()
  50. (docs_dir / "alpha.md").write_text(
  51. "Alpha covers deployment pipelines.",
  52. encoding="utf-8",
  53. )
  54. (docs_dir / "beta.md").write_text(
  55. "Beta covers monitoring dashboards.",
  56. encoding="utf-8",
  57. )
  58. dataset = root / "dataset.json"
  59. dataset.write_text(
  60. '{"test_cases":[{"question":"Which file explains vector search?"}]}',
  61. encoding="utf-8",
  62. )
  63. oracle = root / "oracle.json"
  64. oracle.write_text(
  65. '{"oracle":[{"question":"Which file explains vector search?",'
  66. '"expected_documents":["alpha.md"]}]}',
  67. encoding="utf-8",
  68. )
  69. results = audit_samples(
  70. load_cases(dataset),
  71. load_oracle(oracle),
  72. load_documents(docs_dir),
  73. )
  74. summary = summarize(results, top_k=1)
  75. self.assertEqual(results[0].ranked, [])
  76. self.assertEqual(summary["average_recall_at_k"], 0.0)
  77. self.assertEqual(summary["no_hit_queries"], 1)
  78. def test_sample_oracle_has_full_recall_at_two(self):
  79. results = audit_samples(
  80. load_cases(Path("lightrag/evaluation/sample_dataset.json")),
  81. load_oracle(Path("lightrag/evaluation/sample_retrieval_oracle.json")),
  82. load_documents(Path("lightrag/evaluation/sample_documents")),
  83. )
  84. summary = summarize(results, top_k=2)
  85. self.assertEqual(summary["queries"], 6)
  86. self.assertEqual(summary["full_recall_queries"], 6)
  87. self.assertEqual(summary["no_hit_queries"], 0)
  88. if __name__ == "__main__":
  89. unittest.main()