test_lightrag_ollama_chat.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. """
  2. LightRAG Ollama Compatibility Interface Test Script
  3. This script tests the LightRAG's Ollama compatibility interface, including:
  4. 1. Basic functionality tests (streaming and non-streaming responses)
  5. 2. Query mode tests (local, global, naive, hybrid)
  6. 3. Error handling tests (including streaming and non-streaming scenarios)
  7. All responses use the JSON Lines format, complying with the Ollama API specification.
  8. """
  9. import pytest
  10. import requests
  11. import json
  12. import argparse
  13. import time
  14. from typing import Dict, Any, Optional, List, Callable
  15. from dataclasses import dataclass, asdict
  16. from datetime import datetime
  17. from pathlib import Path
  18. from enum import Enum, auto
  19. class ErrorCode(Enum):
  20. """Error codes for MCP errors"""
  21. InvalidRequest = auto()
  22. InternalError = auto()
  23. class McpError(Exception):
  24. """Base exception class for MCP errors"""
  25. def __init__(self, code: ErrorCode, message: str):
  26. self.code = code
  27. self.message = message
  28. super().__init__(message)
  29. DEFAULT_CONFIG = {
  30. "server": {
  31. "host": "localhost",
  32. "port": 9621,
  33. "model": "lightrag:latest",
  34. "timeout": 300,
  35. "max_retries": 1,
  36. "retry_delay": 1,
  37. },
  38. "test_cases": {
  39. "basic": {"query": "唐僧有几个徒弟"},
  40. "generate": {"query": "电视剧西游记导演是谁"},
  41. },
  42. }
  43. # Example conversation history for testing
  44. EXAMPLE_CONVERSATION = [
  45. {"role": "user", "content": "你好"},
  46. {"role": "assistant", "content": "你好!我是一个AI助手,很高兴为你服务。"},
  47. {"role": "user", "content": "Who are you?"},
  48. {"role": "assistant", "content": "I'm a Knowledge base query assistant."},
  49. ]
  50. class OutputControl:
  51. """Output control class, manages the verbosity of test output"""
  52. _verbose: bool = False
  53. @classmethod
  54. def set_verbose(cls, verbose: bool) -> None:
  55. cls._verbose = verbose
  56. @classmethod
  57. def is_verbose(cls) -> bool:
  58. return cls._verbose
  59. @dataclass
  60. class ExecutionResult:
  61. """Test execution result data class"""
  62. name: str
  63. success: bool
  64. duration: float
  65. error: Optional[str] = None
  66. timestamp: str = ""
  67. def __post_init__(self):
  68. if not self.timestamp:
  69. self.timestamp = datetime.now().isoformat()
  70. class ExecutionStats:
  71. """Test execution statistics"""
  72. def __init__(self):
  73. self.results: List[ExecutionResult] = []
  74. self.start_time = datetime.now()
  75. def add_result(self, result: ExecutionResult):
  76. self.results.append(result)
  77. def export_results(self, path: str = "test_results.json"):
  78. """Export test results to a JSON file
  79. Args:
  80. path: Output file path
  81. """
  82. results_data = {
  83. "start_time": self.start_time.isoformat(),
  84. "end_time": datetime.now().isoformat(),
  85. "results": [asdict(r) for r in self.results],
  86. "summary": {
  87. "total": len(self.results),
  88. "passed": sum(1 for r in self.results if r.success),
  89. "failed": sum(1 for r in self.results if not r.success),
  90. "total_duration": sum(r.duration for r in self.results),
  91. },
  92. }
  93. with open(path, "w", encoding="utf-8") as f:
  94. json.dump(results_data, f, ensure_ascii=False, indent=2)
  95. print(f"\nTest results saved to: {path}")
  96. def print_summary(self):
  97. total = len(self.results)
  98. passed = sum(1 for r in self.results if r.success)
  99. failed = total - passed
  100. duration = sum(r.duration for r in self.results)
  101. print("\n=== Test Summary ===")
  102. print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
  103. print(f"Total duration: {duration:.2f} seconds")
  104. print(f"Total tests: {total}")
  105. print(f"Passed: {passed}")
  106. print(f"Failed: {failed}")
  107. if failed > 0:
  108. print("\nFailed tests:")
  109. for result in self.results:
  110. if not result.success:
  111. print(f"- {result.name}: {result.error}")
  112. def make_request(
  113. url: str, data: Dict[str, Any], stream: bool = False, check_status: bool = True
  114. ) -> requests.Response:
  115. """Send an HTTP request with retry mechanism
  116. Args:
  117. url: Request URL
  118. data: Request data
  119. stream: Whether to use streaming response
  120. check_status: Whether to check HTTP status code (default: True)
  121. Returns:
  122. requests.Response: Response object
  123. Raises:
  124. requests.exceptions.RequestException: Request failed after all retries
  125. requests.exceptions.HTTPError: HTTP status code is not 200 (when check_status is True)
  126. """
  127. server_config = CONFIG["server"]
  128. max_retries = server_config["max_retries"]
  129. retry_delay = server_config["retry_delay"]
  130. timeout = server_config["timeout"]
  131. for attempt in range(max_retries):
  132. try:
  133. response = requests.post(url, json=data, stream=stream, timeout=timeout)
  134. if check_status and response.status_code != 200:
  135. response.raise_for_status()
  136. return response
  137. except requests.exceptions.RequestException as e:
  138. if attempt == max_retries - 1: # Last retry
  139. raise
  140. print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
  141. time.sleep(retry_delay)
  142. def load_config() -> Dict[str, Any]:
  143. """Load configuration file
  144. First try to load from config.json in the current directory,
  145. if it doesn't exist, use the default configuration
  146. Returns:
  147. Configuration dictionary
  148. """
  149. config_path = Path("config.json")
  150. if config_path.exists():
  151. with open(config_path, "r", encoding="utf-8") as f:
  152. return json.load(f)
  153. return DEFAULT_CONFIG
  154. def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
  155. """Format and print JSON response data
  156. Args:
  157. data: Data dictionary to print
  158. title: Title to print
  159. indent: Number of spaces for JSON indentation
  160. """
  161. if OutputControl.is_verbose():
  162. if title:
  163. print(f"\n=== {title} ===")
  164. print(json.dumps(data, ensure_ascii=False, indent=indent))
  165. # Global configuration
  166. CONFIG = load_config()
  167. def get_base_url(endpoint: str = "chat") -> str:
  168. """Return the base URL for specified endpoint
  169. Args:
  170. endpoint: API endpoint name (chat or generate)
  171. Returns:
  172. Complete URL for the endpoint
  173. """
  174. server = CONFIG["server"]
  175. return f"http://{server['host']}:{server['port']}/api/{endpoint}"
  176. def create_chat_request_data(
  177. content: str,
  178. stream: bool = False,
  179. model: str = None,
  180. conversation_history: List[Dict[str, str]] = None,
  181. ) -> Dict[str, Any]:
  182. """Create chat request data
  183. Args:
  184. content: User message content
  185. stream: Whether to use streaming response
  186. model: Model name
  187. conversation_history: List of previous conversation messages
  188. Returns:
  189. Dictionary containing complete chat request data
  190. """
  191. messages = conversation_history or []
  192. messages.append({"role": "user", "content": content})
  193. return {
  194. "model": model or CONFIG["server"]["model"],
  195. "messages": messages,
  196. "stream": stream,
  197. }
  198. def create_generate_request_data(
  199. prompt: str,
  200. system: str = None,
  201. stream: bool = False,
  202. model: str = None,
  203. options: Dict[str, Any] = None,
  204. ) -> Dict[str, Any]:
  205. """Create generate request data
  206. Args:
  207. prompt: Generation prompt
  208. system: System prompt
  209. stream: Whether to use streaming response
  210. model: Model name
  211. options: Additional options
  212. Returns:
  213. Dictionary containing complete generate request data
  214. """
  215. data = {
  216. "model": model or CONFIG["server"]["model"],
  217. "prompt": prompt,
  218. "stream": stream,
  219. }
  220. if system:
  221. data["system"] = system
  222. if options:
  223. data["options"] = options
  224. return data
  225. # Global test statistics
  226. STATS = ExecutionStats()
  227. def run_test(func: Callable, name: str) -> None:
  228. """Run a test and record the results
  229. Args:
  230. func: Test function
  231. name: Test name
  232. """
  233. start_time = time.time()
  234. try:
  235. func()
  236. duration = time.time() - start_time
  237. STATS.add_result(ExecutionResult(name, True, duration))
  238. except Exception as e:
  239. duration = time.time() - start_time
  240. STATS.add_result(ExecutionResult(name, False, duration, str(e)))
  241. raise
  242. @pytest.mark.integration
  243. @pytest.mark.requires_api
  244. def test_non_stream_chat() -> None:
  245. """Test non-streaming call to /api/chat endpoint"""
  246. url = get_base_url()
  247. # Send request with conversation history
  248. data = create_chat_request_data(
  249. CONFIG["test_cases"]["basic"]["query"],
  250. stream=False,
  251. conversation_history=EXAMPLE_CONVERSATION,
  252. )
  253. response = make_request(url, data)
  254. # Print response
  255. if OutputControl.is_verbose():
  256. print("\n=== Non-streaming call response ===")
  257. response_json = response.json()
  258. # Print response content
  259. print_json_response(
  260. {"model": response_json["model"], "message": response_json["message"]},
  261. "Response content",
  262. )
  263. @pytest.mark.integration
  264. @pytest.mark.requires_api
  265. def test_stream_chat() -> None:
  266. """Test streaming call to /api/chat endpoint
  267. Use JSON Lines format to process streaming responses, each line is a complete JSON object.
  268. Response format:
  269. {
  270. "model": "lightrag:latest",
  271. "created_at": "2024-01-15T00:00:00Z",
  272. "message": {
  273. "role": "assistant",
  274. "content": "Partial response content",
  275. "images": null
  276. },
  277. "done": false
  278. }
  279. The last message will contain performance statistics, with done set to true.
  280. """
  281. url = get_base_url()
  282. # Send request with conversation history
  283. data = create_chat_request_data(
  284. CONFIG["test_cases"]["basic"]["query"],
  285. stream=True,
  286. conversation_history=EXAMPLE_CONVERSATION,
  287. )
  288. response = make_request(url, data, stream=True)
  289. if OutputControl.is_verbose():
  290. print("\n=== Streaming call response ===")
  291. output_buffer = []
  292. try:
  293. for line in response.iter_lines():
  294. if line: # Skip empty lines
  295. try:
  296. # Decode and parse JSON
  297. data = json.loads(line.decode("utf-8"))
  298. if data.get("done", True): # If it's the completion marker
  299. if (
  300. "total_duration" in data
  301. ): # Final performance statistics message
  302. # print_json_response(data, "Performance statistics")
  303. break
  304. else: # Normal content message
  305. message = data.get("message", {})
  306. content = message.get("content", "")
  307. if content: # Only collect non-empty content
  308. output_buffer.append(content)
  309. print(
  310. content, end="", flush=True
  311. ) # Print content in real-time
  312. except json.JSONDecodeError:
  313. print("Error decoding JSON from response line")
  314. finally:
  315. response.close() # Ensure the response connection is closed
  316. # Print a newline
  317. print()
  318. @pytest.mark.integration
  319. @pytest.mark.requires_api
  320. def test_query_modes() -> None:
  321. """Test different query mode prefixes
  322. Supported query modes:
  323. - /local: Local retrieval mode, searches only in highly relevant documents
  324. - /global: Global retrieval mode, searches across all documents
  325. - /naive: Naive mode, does not use any optimization strategies
  326. - /hybrid: Hybrid mode (default), combines multiple strategies
  327. - /mix: Mix mode
  328. Each mode will return responses in the same format, but with different retrieval strategies.
  329. """
  330. url = get_base_url()
  331. modes = ["local", "global", "naive", "hybrid", "mix"]
  332. for mode in modes:
  333. if OutputControl.is_verbose():
  334. print(f"\n=== Testing /{mode} mode ===")
  335. data = create_chat_request_data(
  336. f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
  337. )
  338. # Send request
  339. response = make_request(url, data)
  340. response_json = response.json()
  341. # Print response content
  342. print_json_response(
  343. {"model": response_json["model"], "message": response_json["message"]}
  344. )
  345. def create_error_test_data(error_type: str) -> Dict[str, Any]:
  346. """Create request data for error testing
  347. Args:
  348. error_type: Error type, supported:
  349. - empty_messages: Empty message list
  350. - invalid_role: Invalid role field
  351. - missing_content: Missing content field
  352. Returns:
  353. Request dictionary containing error data
  354. """
  355. error_data = {
  356. "empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True},
  357. "invalid_role": {
  358. "model": "lightrag:latest",
  359. "messages": [{"invalid_role": "user", "content": "Test message"}],
  360. "stream": True,
  361. },
  362. "missing_content": {
  363. "model": "lightrag:latest",
  364. "messages": [{"role": "user"}],
  365. "stream": True,
  366. },
  367. }
  368. return error_data.get(error_type, error_data["empty_messages"])
  369. @pytest.mark.integration
  370. @pytest.mark.requires_api
  371. def test_stream_error_handling() -> None:
  372. """Test error handling for streaming responses
  373. Test scenarios:
  374. 1. Empty message list
  375. 2. Message format error (missing required fields)
  376. Error responses should be returned immediately without establishing a streaming connection.
  377. The status code should be 4xx, and detailed error information should be returned.
  378. """
  379. url = get_base_url()
  380. if OutputControl.is_verbose():
  381. print("\n=== Testing streaming response error handling ===")
  382. # Test empty message list
  383. if OutputControl.is_verbose():
  384. print("\n--- Testing empty message list (streaming) ---")
  385. data = create_error_test_data("empty_messages")
  386. response = make_request(url, data, stream=True, check_status=False)
  387. print(f"Status code: {response.status_code}")
  388. if response.status_code != 200:
  389. print_json_response(response.json(), "Error message")
  390. response.close()
  391. # Test invalid role field
  392. if OutputControl.is_verbose():
  393. print("\n--- Testing invalid role field (streaming) ---")
  394. data = create_error_test_data("invalid_role")
  395. response = make_request(url, data, stream=True, check_status=False)
  396. print(f"Status code: {response.status_code}")
  397. if response.status_code != 200:
  398. print_json_response(response.json(), "Error message")
  399. response.close()
  400. # Test missing content field
  401. if OutputControl.is_verbose():
  402. print("\n--- Testing missing content field (streaming) ---")
  403. data = create_error_test_data("missing_content")
  404. response = make_request(url, data, stream=True, check_status=False)
  405. print(f"Status code: {response.status_code}")
  406. if response.status_code != 200:
  407. print_json_response(response.json(), "Error message")
  408. response.close()
  409. @pytest.mark.integration
  410. @pytest.mark.requires_api
  411. def test_error_handling() -> None:
  412. """Test error handling for non-streaming responses
  413. Test scenarios:
  414. 1. Empty message list
  415. 2. Message format error (missing required fields)
  416. Error response format:
  417. {
  418. "detail": "Error description"
  419. }
  420. All errors should return appropriate HTTP status codes and clear error messages.
  421. """
  422. url = get_base_url()
  423. if OutputControl.is_verbose():
  424. print("\n=== Testing error handling ===")
  425. # Test empty message list
  426. if OutputControl.is_verbose():
  427. print("\n--- Testing empty message list ---")
  428. data = create_error_test_data("empty_messages")
  429. data["stream"] = False # Change to non-streaming mode
  430. response = make_request(url, data, check_status=False)
  431. print(f"Status code: {response.status_code}")
  432. print_json_response(response.json(), "Error message")
  433. # Test invalid role field
  434. if OutputControl.is_verbose():
  435. print("\n--- Testing invalid role field ---")
  436. data = create_error_test_data("invalid_role")
  437. data["stream"] = False # Change to non-streaming mode
  438. response = make_request(url, data, check_status=False)
  439. print(f"Status code: {response.status_code}")
  440. print_json_response(response.json(), "Error message")
  441. # Test missing content field
  442. if OutputControl.is_verbose():
  443. print("\n--- Testing missing content field ---")
  444. data = create_error_test_data("missing_content")
  445. data["stream"] = False # Change to non-streaming mode
  446. response = make_request(url, data, check_status=False)
  447. print(f"Status code: {response.status_code}")
  448. print_json_response(response.json(), "Error message")
  449. @pytest.mark.integration
  450. @pytest.mark.requires_api
  451. def test_non_stream_generate() -> None:
  452. """Test non-streaming call to /api/generate endpoint"""
  453. url = get_base_url("generate")
  454. data = create_generate_request_data(
  455. CONFIG["test_cases"]["generate"]["query"], stream=False
  456. )
  457. # Send request
  458. response = make_request(url, data)
  459. # Print response
  460. if OutputControl.is_verbose():
  461. print("\n=== Non-streaming generate response ===")
  462. response_json = response.json()
  463. # Print response content
  464. print(json.dumps(response_json, ensure_ascii=False, indent=2))
  465. @pytest.mark.integration
  466. @pytest.mark.requires_api
  467. def test_stream_generate() -> None:
  468. """Test streaming call to /api/generate endpoint"""
  469. url = get_base_url("generate")
  470. data = create_generate_request_data(
  471. CONFIG["test_cases"]["generate"]["query"], stream=True
  472. )
  473. # Send request and get streaming response
  474. response = make_request(url, data, stream=True)
  475. if OutputControl.is_verbose():
  476. print("\n=== Streaming generate response ===")
  477. output_buffer = []
  478. try:
  479. for line in response.iter_lines():
  480. if line: # Skip empty lines
  481. try:
  482. # Decode and parse JSON
  483. data = json.loads(line.decode("utf-8"))
  484. if data.get("done", True): # If it's the completion marker
  485. if (
  486. "total_duration" in data
  487. ): # Final performance statistics message
  488. break
  489. else: # Normal content message
  490. content = data.get("response", "")
  491. if content: # Only collect non-empty content
  492. output_buffer.append(content)
  493. print(
  494. content, end="", flush=True
  495. ) # Print content in real-time
  496. except json.JSONDecodeError:
  497. print("Error decoding JSON from response line")
  498. finally:
  499. response.close() # Ensure the response connection is closed
  500. # Print a newline
  501. print()
  502. @pytest.mark.integration
  503. @pytest.mark.requires_api
  504. def test_generate_with_system() -> None:
  505. """Test generate with system prompt"""
  506. url = get_base_url("generate")
  507. data = create_generate_request_data(
  508. CONFIG["test_cases"]["generate"]["query"],
  509. system="你是一个知识渊博的助手",
  510. stream=False,
  511. )
  512. # Send request
  513. response = make_request(url, data)
  514. # Print response
  515. if OutputControl.is_verbose():
  516. print("\n=== Generate with system prompt response ===")
  517. response_json = response.json()
  518. # Print response content
  519. print_json_response(
  520. {
  521. "model": response_json["model"],
  522. "response": response_json["response"],
  523. "done": response_json["done"],
  524. },
  525. "Response content",
  526. )
  527. @pytest.mark.integration
  528. @pytest.mark.requires_api
  529. def test_generate_error_handling() -> None:
  530. """Test error handling for generate endpoint"""
  531. url = get_base_url("generate")
  532. # Test empty prompt
  533. if OutputControl.is_verbose():
  534. print("\n=== Testing empty prompt ===")
  535. data = create_generate_request_data("", stream=False)
  536. response = make_request(url, data, check_status=False)
  537. print(f"Status code: {response.status_code}")
  538. print_json_response(response.json(), "Error message")
  539. # Test invalid options
  540. if OutputControl.is_verbose():
  541. print("\n=== Testing invalid options ===")
  542. data = create_generate_request_data(
  543. CONFIG["test_cases"]["basic"]["query"],
  544. options={"invalid_option": "value"},
  545. stream=False,
  546. )
  547. response = make_request(url, data, check_status=False)
  548. print(f"Status code: {response.status_code}")
  549. print_json_response(response.json(), "Error message")
  550. @pytest.mark.integration
  551. @pytest.mark.requires_api
  552. def test_generate_concurrent() -> None:
  553. """Test concurrent generate requests"""
  554. import asyncio
  555. import aiohttp
  556. from contextlib import asynccontextmanager
  557. @asynccontextmanager
  558. async def get_session():
  559. async with aiohttp.ClientSession() as session:
  560. yield session
  561. async def make_request(session, prompt: str, request_id: int):
  562. url = get_base_url("generate")
  563. data = create_generate_request_data(prompt, stream=False)
  564. try:
  565. async with session.post(url, json=data) as response:
  566. if response.status != 200:
  567. error_msg = (
  568. f"Request {request_id} failed with status {response.status}"
  569. )
  570. if OutputControl.is_verbose():
  571. print(f"\n{error_msg}")
  572. raise McpError(ErrorCode.InternalError, error_msg)
  573. result = await response.json()
  574. if "error" in result:
  575. error_msg = (
  576. f"Request {request_id} returned error: {result['error']}"
  577. )
  578. if OutputControl.is_verbose():
  579. print(f"\n{error_msg}")
  580. raise McpError(ErrorCode.InternalError, error_msg)
  581. return result
  582. except Exception as e:
  583. error_msg = f"Request {request_id} failed: {str(e)}"
  584. if OutputControl.is_verbose():
  585. print(f"\n{error_msg}")
  586. raise McpError(ErrorCode.InternalError, error_msg)
  587. async def run_concurrent_requests():
  588. prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
  589. async with get_session() as session:
  590. tasks = [
  591. make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts)
  592. ]
  593. results = await asyncio.gather(*tasks, return_exceptions=True)
  594. success_results = []
  595. error_messages = []
  596. for i, result in enumerate(results):
  597. if isinstance(result, Exception):
  598. error_messages.append(f"Request {i + 1} failed: {str(result)}")
  599. else:
  600. success_results.append((i + 1, result))
  601. if error_messages:
  602. for req_id, result in success_results:
  603. if OutputControl.is_verbose():
  604. print(f"\nRequest {req_id} succeeded:")
  605. print_json_response(result)
  606. error_summary = "\n".join(error_messages)
  607. raise McpError(
  608. ErrorCode.InternalError,
  609. f"Some concurrent requests failed:\n{error_summary}",
  610. )
  611. return results
  612. if OutputControl.is_verbose():
  613. print("\n=== Testing concurrent generate requests ===")
  614. # Run concurrent requests
  615. try:
  616. results = asyncio.run(run_concurrent_requests())
  617. # all success, print out results
  618. for i, result in enumerate(results, 1):
  619. print(f"\nRequest {i} result:")
  620. print_json_response(result)
  621. except McpError:
  622. # error message already printed
  623. raise
  624. def get_test_cases() -> Dict[str, Callable]:
  625. """Get all available test cases
  626. Returns:
  627. A dictionary mapping test names to test functions
  628. """
  629. return {
  630. "non_stream": test_non_stream_chat,
  631. "stream": test_stream_chat,
  632. "modes": test_query_modes,
  633. "errors": test_error_handling,
  634. "stream_errors": test_stream_error_handling,
  635. "non_stream_generate": test_non_stream_generate,
  636. "stream_generate": test_stream_generate,
  637. "generate_with_system": test_generate_with_system,
  638. "generate_errors": test_generate_error_handling,
  639. "generate_concurrent": test_generate_concurrent,
  640. }
  641. def create_default_config():
  642. """Create a default configuration file"""
  643. config_path = Path("config.json")
  644. if not config_path.exists():
  645. with open(config_path, "w", encoding="utf-8") as f:
  646. json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
  647. print(f"Default configuration file created: {config_path}")
  648. def parse_args() -> argparse.Namespace:
  649. """Parse command line arguments"""
  650. parser = argparse.ArgumentParser(
  651. description="LightRAG Ollama Compatibility Interface Testing",
  652. formatter_class=argparse.RawDescriptionHelpFormatter,
  653. epilog="""
  654. Configuration file (config.json):
  655. {
  656. "server": {
  657. "host": "localhost", # Server address
  658. "port": 9621, # Server port
  659. "model": "lightrag:latest" # Default model name
  660. },
  661. "test_cases": {
  662. "basic": {
  663. "query": "Test query", # Basic query text
  664. "stream_query": "Stream query" # Stream query text
  665. }
  666. }
  667. }
  668. """,
  669. )
  670. parser.add_argument(
  671. "-q",
  672. "--quiet",
  673. action="store_true",
  674. help="Silent mode, only display test result summary",
  675. )
  676. parser.add_argument(
  677. "-a",
  678. "--ask",
  679. type=str,
  680. help="Specify query content, which will override the query settings in the configuration file",
  681. )
  682. parser.add_argument(
  683. "--init-config", action="store_true", help="Create default configuration file"
  684. )
  685. parser.add_argument(
  686. "--output",
  687. type=str,
  688. default="",
  689. help="Test result output file path, default is not to output to a file",
  690. )
  691. parser.add_argument(
  692. "--tests",
  693. nargs="+",
  694. choices=list(get_test_cases().keys()) + ["all"],
  695. default=["all"],
  696. help="Test cases to run, options: %(choices)s. Use 'all' to run all tests (except error tests)",
  697. )
  698. return parser.parse_args()
  699. if __name__ == "__main__":
  700. args = parse_args()
  701. # Set output mode
  702. OutputControl.set_verbose(not args.quiet)
  703. # If query content is specified, update the configuration
  704. if args.ask:
  705. CONFIG["test_cases"]["basic"]["query"] = args.ask
  706. # If specified to create a configuration file
  707. if args.init_config:
  708. create_default_config()
  709. exit(0)
  710. test_cases = get_test_cases()
  711. try:
  712. if "all" in args.tests:
  713. # Run all tests except error handling tests
  714. if OutputControl.is_verbose():
  715. print("\n【Chat API Tests】")
  716. run_test(test_non_stream_chat, "Non-streaming Chat Test")
  717. run_test(test_stream_chat, "Streaming Chat Test")
  718. run_test(test_query_modes, "Chat Query Mode Test")
  719. if OutputControl.is_verbose():
  720. print("\n【Generate API Tests】")
  721. run_test(test_non_stream_generate, "Non-streaming Generate Test")
  722. run_test(test_stream_generate, "Streaming Generate Test")
  723. run_test(test_generate_with_system, "Generate with System Prompt Test")
  724. run_test(test_generate_concurrent, "Generate Concurrent Test")
  725. else:
  726. # Run specified tests
  727. for test_name in args.tests:
  728. if OutputControl.is_verbose():
  729. print(f"\n【Running Test: {test_name}】")
  730. run_test(test_cases[test_name], test_name)
  731. except Exception as e:
  732. print(f"\nAn error occurred: {str(e)}")
  733. finally:
  734. # Print test statistics
  735. STATS.print_summary()
  736. # If an output file path is specified, export the results
  737. if args.output:
  738. STATS.export_results(args.output)