rerank.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. from __future__ import annotations
  2. import os
  3. import aiohttp
  4. from typing import Any, List, Dict, Optional, Tuple
  5. from tenacity import (
  6. retry,
  7. stop_after_attempt,
  8. wait_exponential,
  9. retry_if_exception_type,
  10. )
  11. from .utils import logger
  12. from dotenv import load_dotenv
  13. # use the .env that is inside the current folder
  14. # allows to use different .env file for each lightrag instance
  15. # the OS environment variables take precedence over the .env file
  16. load_dotenv(dotenv_path=".env", override=False)
  17. def chunk_documents_for_rerank(
  18. documents: List[str],
  19. max_tokens: int = 480,
  20. overlap_tokens: int = 32,
  21. tokenizer_model: str = "gpt-4o-mini",
  22. ) -> Tuple[List[str], List[int]]:
  23. """
  24. Chunk documents that exceed token limit for reranking.
  25. Args:
  26. documents: List of document strings to chunk
  27. max_tokens: Maximum tokens per chunk (default 480 to leave margin for 512 limit)
  28. overlap_tokens: Number of tokens to overlap between chunks
  29. tokenizer_model: Model name for tiktoken tokenizer
  30. Returns:
  31. Tuple of (chunked_documents, original_doc_indices)
  32. - chunked_documents: List of document chunks (may be more than input)
  33. - original_doc_indices: Maps each chunk back to its original document index
  34. """
  35. # Clamp overlap_tokens to ensure the loop always advances
  36. # If overlap_tokens >= max_tokens, the chunking loop would hang
  37. if overlap_tokens >= max_tokens:
  38. original_overlap = overlap_tokens
  39. # Ensure overlap is at least 1 token less than max to guarantee progress
  40. # For very small max_tokens (e.g., 1), set overlap to 0
  41. overlap_tokens = max(0, max_tokens - 1)
  42. logger.warning(
  43. f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). "
  44. f"Clamping to {overlap_tokens} to prevent infinite loop."
  45. )
  46. try:
  47. from .utils import TiktokenTokenizer
  48. tokenizer = TiktokenTokenizer(model_name=tokenizer_model)
  49. except Exception as e:
  50. logger.warning(
  51. f"Failed to initialize tokenizer: {e}. Using character-based approximation."
  52. )
  53. # Fallback: approximate 1 token ≈ 4 characters
  54. max_chars = max_tokens * 4
  55. overlap_chars = overlap_tokens * 4
  56. chunked_docs = []
  57. doc_indices = []
  58. for idx, doc in enumerate(documents):
  59. if len(doc) <= max_chars:
  60. chunked_docs.append(doc)
  61. doc_indices.append(idx)
  62. else:
  63. # Split into overlapping chunks
  64. start = 0
  65. while start < len(doc):
  66. end = min(start + max_chars, len(doc))
  67. chunk = doc[start:end]
  68. chunked_docs.append(chunk)
  69. doc_indices.append(idx)
  70. if end >= len(doc):
  71. break
  72. start = end - overlap_chars
  73. return chunked_docs, doc_indices
  74. # Use tokenizer for accurate chunking
  75. chunked_docs = []
  76. doc_indices = []
  77. for idx, doc in enumerate(documents):
  78. tokens = tokenizer.encode(doc)
  79. if len(tokens) <= max_tokens:
  80. # Document fits in one chunk
  81. chunked_docs.append(doc)
  82. doc_indices.append(idx)
  83. else:
  84. # Split into overlapping chunks
  85. start = 0
  86. while start < len(tokens):
  87. end = min(start + max_tokens, len(tokens))
  88. chunk_tokens = tokens[start:end]
  89. chunk_text = tokenizer.decode(chunk_tokens)
  90. chunked_docs.append(chunk_text)
  91. doc_indices.append(idx)
  92. if end >= len(tokens):
  93. break
  94. start = end - overlap_tokens
  95. return chunked_docs, doc_indices
  96. def aggregate_chunk_scores(
  97. chunk_results: List[Dict[str, Any]],
  98. doc_indices: List[int],
  99. num_original_docs: int,
  100. aggregation: str = "max",
  101. ) -> List[Dict[str, Any]]:
  102. """
  103. Aggregate rerank scores from document chunks back to original documents.
  104. Args:
  105. chunk_results: Rerank results for chunks [{"index": chunk_idx, "relevance_score": score}, ...]
  106. doc_indices: Maps each chunk index to original document index
  107. num_original_docs: Total number of original documents
  108. aggregation: Strategy for aggregating scores ("max", "mean", "first")
  109. Returns:
  110. List of results for original documents [{"index": doc_idx, "relevance_score": score}, ...]
  111. """
  112. # Group scores by original document index
  113. doc_scores: Dict[int, List[float]] = {i: [] for i in range(num_original_docs)}
  114. for result in chunk_results:
  115. chunk_idx = result["index"]
  116. score = result["relevance_score"]
  117. if 0 <= chunk_idx < len(doc_indices):
  118. original_doc_idx = doc_indices[chunk_idx]
  119. doc_scores[original_doc_idx].append(score)
  120. # Aggregate scores
  121. aggregated_results = []
  122. for doc_idx, scores in doc_scores.items():
  123. if not scores:
  124. continue
  125. if aggregation == "max":
  126. final_score = max(scores)
  127. elif aggregation == "mean":
  128. final_score = sum(scores) / len(scores)
  129. elif aggregation == "first":
  130. final_score = scores[0]
  131. else:
  132. logger.warning(f"Unknown aggregation strategy: {aggregation}, using max")
  133. final_score = max(scores)
  134. aggregated_results.append(
  135. {
  136. "index": doc_idx,
  137. "relevance_score": final_score,
  138. }
  139. )
  140. # Sort by relevance score (descending)
  141. aggregated_results.sort(key=lambda x: x["relevance_score"], reverse=True)
  142. return aggregated_results
  143. @retry(
  144. stop=stop_after_attempt(3),
  145. wait=wait_exponential(multiplier=1, min=4, max=60),
  146. retry=(
  147. retry_if_exception_type(aiohttp.ClientError)
  148. | retry_if_exception_type(aiohttp.ClientResponseError)
  149. ),
  150. )
  151. async def generic_rerank_api(
  152. query: str,
  153. documents: List[str],
  154. model: str,
  155. base_url: str,
  156. api_key: Optional[str],
  157. top_n: Optional[int] = None,
  158. return_documents: Optional[bool] = None,
  159. extra_body: Optional[Dict[str, Any]] = None,
  160. response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
  161. request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
  162. enable_chunking: bool = False,
  163. max_tokens_per_doc: int = 480,
  164. ) -> List[Dict[str, Any]]:
  165. """
  166. Generic rerank API call for Jina/Cohere/Aliyun models.
  167. Args:
  168. query: The search query
  169. documents: List of strings to rerank
  170. model: Model name to use
  171. base_url: API endpoint URL
  172. api_key: API key for authentication
  173. top_n: Number of top results to return
  174. return_documents: Whether to return document text (Jina only)
  175. extra_body: Additional body parameters
  176. response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
  177. request_format: Request format type
  178. enable_chunking: Whether to chunk documents exceeding token limit
  179. max_tokens_per_doc: Maximum tokens per document for chunking
  180. Returns:
  181. List of dictionary of ["index": int, "relevance_score": float]
  182. """
  183. if not base_url:
  184. raise ValueError("Base URL is required")
  185. headers = {"Content-Type": "application/json"}
  186. if api_key is not None:
  187. headers["Authorization"] = f"Bearer {api_key}"
  188. # Handle document chunking if enabled
  189. original_documents = documents
  190. doc_indices = None
  191. original_top_n = top_n # Save original top_n for post-aggregation limiting
  192. if enable_chunking:
  193. documents, doc_indices = chunk_documents_for_rerank(
  194. documents, max_tokens=max_tokens_per_doc
  195. )
  196. logger.debug(
  197. f"Chunked {len(original_documents)} documents into {len(documents)} chunks"
  198. )
  199. # When chunking is enabled, disable top_n at API level to get all chunk scores
  200. # This ensures proper document-level coverage after aggregation
  201. # We'll apply top_n to aggregated document results instead
  202. if top_n is not None:
  203. logger.debug(
  204. f"Chunking enabled: disabled API-level top_n={top_n} to ensure complete document coverage"
  205. )
  206. top_n = None
  207. # Build request payload based on request format
  208. if request_format == "aliyun":
  209. # Aliyun format: nested input/parameters structure
  210. payload = {
  211. "model": model,
  212. "input": {
  213. "query": query,
  214. "documents": documents,
  215. },
  216. "parameters": {},
  217. }
  218. # Add optional parameters to parameters object
  219. if top_n is not None:
  220. payload["parameters"]["top_n"] = top_n
  221. if return_documents is not None:
  222. payload["parameters"]["return_documents"] = return_documents
  223. # Add extra parameters to parameters object
  224. if extra_body:
  225. payload["parameters"].update(extra_body)
  226. else:
  227. # Standard format for Jina/Cohere/OpenAI
  228. payload = {
  229. "model": model,
  230. "query": query,
  231. "documents": documents,
  232. }
  233. # Add optional parameters
  234. if top_n is not None:
  235. payload["top_n"] = top_n
  236. # Only Jina API supports return_documents parameter
  237. if return_documents is not None and response_format in ("standard",):
  238. payload["return_documents"] = return_documents
  239. # Add extra parameters
  240. if extra_body:
  241. payload.update(extra_body)
  242. logger.debug(
  243. f"Rerank request: {len(documents)} documents, model: {model}, format: {response_format}"
  244. )
  245. async with aiohttp.ClientSession() as session:
  246. async with session.post(base_url, headers=headers, json=payload) as response:
  247. if response.status != 200:
  248. error_text = await response.text()
  249. content_type = response.headers.get("content-type", "").lower()
  250. is_html_error = (
  251. error_text.strip().startswith("<!DOCTYPE html>")
  252. or "text/html" in content_type
  253. )
  254. if is_html_error:
  255. if response.status == 502:
  256. clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
  257. elif response.status == 503:
  258. clean_error = "Service Unavailable (503) - Rerank service is temporarily overloaded. Please try again later."
  259. elif response.status == 504:
  260. clean_error = "Gateway Timeout (504) - Rerank service request timed out. Please try again."
  261. else:
  262. clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
  263. else:
  264. clean_error = error_text
  265. logger.error(f"Rerank API error {response.status}: {clean_error}")
  266. raise aiohttp.ClientResponseError(
  267. request_info=response.request_info,
  268. history=response.history,
  269. status=response.status,
  270. message=f"Rerank API error: {clean_error}",
  271. )
  272. response_json = await response.json()
  273. if response_format == "aliyun":
  274. # Aliyun format: {"output": {"results": [...]}}
  275. results = response_json.get("output", {}).get("results", [])
  276. if not isinstance(results, list):
  277. logger.warning(
  278. f"Expected 'output.results' to be list, got {type(results)}: {results}"
  279. )
  280. results = []
  281. elif response_format == "standard":
  282. # Standard format: {"results": [...]}
  283. results = response_json.get("results", [])
  284. if not isinstance(results, list):
  285. logger.warning(
  286. f"Expected 'results' to be list, got {type(results)}: {results}"
  287. )
  288. results = []
  289. else:
  290. raise ValueError(f"Unsupported response format: {response_format}")
  291. if not results:
  292. logger.warning("Rerank API returned empty results")
  293. return []
  294. # Standardize return format
  295. standardized_results = [
  296. {"index": result["index"], "relevance_score": result["relevance_score"]}
  297. for result in results
  298. ]
  299. # Aggregate chunk scores back to original documents if chunking was enabled
  300. if enable_chunking and doc_indices:
  301. standardized_results = aggregate_chunk_scores(
  302. standardized_results,
  303. doc_indices,
  304. len(original_documents),
  305. aggregation="max",
  306. )
  307. # Apply original top_n limit at document level (post-aggregation)
  308. # This preserves document-level semantics: top_n limits documents, not chunks
  309. if (
  310. original_top_n is not None
  311. and len(standardized_results) > original_top_n
  312. ):
  313. standardized_results = standardized_results[:original_top_n]
  314. return standardized_results
  315. async def cohere_rerank(
  316. query: str,
  317. documents: List[str],
  318. top_n: Optional[int] = None,
  319. api_key: Optional[str] = None,
  320. model: str = "rerank-v3.5",
  321. base_url: str = "https://api.cohere.com/v2/rerank",
  322. extra_body: Optional[Dict[str, Any]] = None,
  323. enable_chunking: bool = False,
  324. max_tokens_per_doc: int = 4096,
  325. ) -> List[Dict[str, Any]]:
  326. """
  327. Rerank documents using Cohere API.
  328. Supports both standard Cohere API and Cohere-compatible proxies
  329. Args:
  330. query: The search query
  331. documents: List of strings to rerank
  332. top_n: Number of top results to return
  333. api_key: API key for authentication
  334. model: rerank model name (default: rerank-v3.5)
  335. base_url: API endpoint
  336. extra_body: Additional body for http request(reserved for extra params)
  337. enable_chunking: Whether to chunk documents exceeding max_tokens_per_doc
  338. max_tokens_per_doc: Maximum tokens per document (default: 4096 for Cohere v3.5)
  339. Returns:
  340. List of dictionary of ["index": int, "relevance_score": float]
  341. Example:
  342. >>> # Standard Cohere API
  343. >>> results = await cohere_rerank(
  344. ... query="What is the meaning of life?",
  345. ... documents=["Doc1", "Doc2"],
  346. ... api_key="your-cohere-key"
  347. ... )
  348. >>> # LiteLLM proxy with user authentication
  349. >>> results = await cohere_rerank(
  350. ... query="What is vector search?",
  351. ... documents=["Doc1", "Doc2"],
  352. ... model="answerai-colbert-small-v1",
  353. ... base_url="https://llm-proxy.example.com/v2/rerank",
  354. ... api_key="your-proxy-key",
  355. ... enable_chunking=True,
  356. ... max_tokens_per_doc=480
  357. ... )
  358. """
  359. if api_key is None:
  360. api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
  361. return await generic_rerank_api(
  362. query=query,
  363. documents=documents,
  364. model=model,
  365. base_url=base_url,
  366. api_key=api_key,
  367. top_n=top_n,
  368. return_documents=None, # Cohere doesn't support this parameter
  369. extra_body=extra_body,
  370. response_format="standard",
  371. enable_chunking=enable_chunking,
  372. max_tokens_per_doc=max_tokens_per_doc,
  373. )
  374. async def jina_rerank(
  375. query: str,
  376. documents: List[str],
  377. top_n: Optional[int] = None,
  378. api_key: Optional[str] = None,
  379. model: str = "jina-reranker-v2-base-multilingual",
  380. base_url: str = "https://api.jina.ai/v1/rerank",
  381. extra_body: Optional[Dict[str, Any]] = None,
  382. ) -> List[Dict[str, Any]]:
  383. """
  384. Rerank documents using Jina AI API.
  385. Args:
  386. query: The search query
  387. documents: List of strings to rerank
  388. top_n: Number of top results to return
  389. api_key: API key
  390. model: rerank model name
  391. base_url: API endpoint
  392. extra_body: Additional body for http request(reserved for extra params)
  393. Returns:
  394. List of dictionary of ["index": int, "relevance_score": float]
  395. """
  396. if api_key is None:
  397. api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
  398. return await generic_rerank_api(
  399. query=query,
  400. documents=documents,
  401. model=model,
  402. base_url=base_url,
  403. api_key=api_key,
  404. top_n=top_n,
  405. return_documents=False,
  406. extra_body=extra_body,
  407. response_format="standard",
  408. )
  409. async def ali_rerank(
  410. query: str,
  411. documents: List[str],
  412. top_n: Optional[int] = None,
  413. api_key: Optional[str] = None,
  414. model: str = "gte-rerank-v2",
  415. base_url: str = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
  416. extra_body: Optional[Dict[str, Any]] = None,
  417. ) -> List[Dict[str, Any]]:
  418. """
  419. Rerank documents using Aliyun DashScope API.
  420. Args:
  421. query: The search query
  422. documents: List of strings to rerank
  423. top_n: Number of top results to return
  424. api_key: Aliyun API key
  425. model: rerank model name
  426. base_url: API endpoint
  427. extra_body: Additional body for http request(reserved for extra params)
  428. Returns:
  429. List of dictionary of ["index": int, "relevance_score": float]
  430. """
  431. if api_key is None:
  432. api_key = os.getenv("DASHSCOPE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
  433. return await generic_rerank_api(
  434. query=query,
  435. documents=documents,
  436. model=model,
  437. base_url=base_url,
  438. api_key=api_key,
  439. top_n=top_n,
  440. return_documents=False, # Aliyun doesn't need this parameter
  441. extra_body=extra_body,
  442. response_format="aliyun",
  443. request_format="aliyun",
  444. )
  445. """Please run this test as a module:
  446. python -m lightrag.rerank
  447. """
  448. if __name__ == "__main__":
  449. import asyncio
  450. async def main():
  451. # Example usage - documents should be strings, not dictionaries
  452. docs = [
  453. "The capital of France is Paris.",
  454. "Tokyo is the capital of Japan.",
  455. "London is the capital of England.",
  456. ]
  457. query = "What is the capital of France?"
  458. # Test Jina rerank
  459. try:
  460. print("=== Jina Rerank ===")
  461. result = await jina_rerank(
  462. query=query,
  463. documents=docs,
  464. top_n=2,
  465. )
  466. print("Results:")
  467. for item in result:
  468. print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
  469. print(f"Document: {docs[item['index']]}")
  470. except Exception as e:
  471. print(f"Jina Error: {e}")
  472. # Test Cohere rerank
  473. try:
  474. print("\n=== Cohere Rerank ===")
  475. result = await cohere_rerank(
  476. query=query,
  477. documents=docs,
  478. top_n=2,
  479. )
  480. print("Results:")
  481. for item in result:
  482. print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
  483. print(f"Document: {docs[item['index']]}")
  484. except Exception as e:
  485. print(f"Cohere Error: {e}")
  486. # Test Aliyun rerank
  487. try:
  488. print("\n=== Aliyun Rerank ===")
  489. result = await ali_rerank(
  490. query=query,
  491. documents=docs,
  492. top_n=2,
  493. )
  494. print("Results:")
  495. for item in result:
  496. print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
  497. print(f"Document: {docs[item['index']]}")
  498. except Exception as e:
  499. print(f"Aliyun Error: {e}")
  500. asyncio.run(main())