binding_options.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833
  1. """
  2. Module that implements containers for specific LLM bindings.
  3. This module provides container implementations for various Large Language Model
  4. bindings and integrations.
  5. """
  6. from argparse import ArgumentParser, Namespace
  7. import argparse
  8. import json
  9. from dataclasses import asdict, dataclass, field
  10. from typing import Any, ClassVar, List, get_args, get_origin
  11. from lightrag.utils import get_env_value
  12. from lightrag.constants import DEFAULT_TEMPERATURE
  13. def _resolve_optional_type(field_type: Any) -> Any:
  14. """Return the concrete type for Optional/Union annotations."""
  15. origin = get_origin(field_type)
  16. if origin in (list, dict, tuple):
  17. return field_type
  18. args = get_args(field_type)
  19. if args:
  20. non_none_args = [arg for arg in args if arg is not type(None)]
  21. if len(non_none_args) == 1:
  22. return non_none_args[0]
  23. return field_type
  24. # =============================================================================
  25. # BindingOptions Base Class
  26. # =============================================================================
  27. #
  28. # The BindingOptions class serves as the foundation for all LLM provider bindings
  29. # in LightRAG. It provides a standardized framework for:
  30. #
  31. # 1. Configuration Management:
  32. # - Defines how each LLM provider's configuration parameters are structured
  33. # - Handles default values and type information for each parameter
  34. # - Maps configuration options to command-line arguments and environment variables
  35. #
  36. # 2. Environment Integration:
  37. # - Automatically generates environment variable names from binding parameters
  38. # - Provides methods to create sample .env files for easy configuration
  39. # - Supports configuration via environment variables with fallback to defaults
  40. #
  41. # 3. Command-Line Interface:
  42. # - Dynamically generates command-line arguments for all registered bindings
  43. # - Maintains consistent naming conventions across different LLM providers
  44. # - Provides help text and type validation for each configuration option
  45. #
  46. # 4. Extensibility:
  47. # - Uses class introspection to automatically discover all binding subclasses
  48. # - Requires minimal boilerplate code when adding new LLM provider bindings
  49. # - Maintains separation of concerns between different provider configurations
  50. #
  51. # This design pattern ensures that adding support for a new LLM provider requires
  52. # only defining the provider-specific parameters and help text, while the base
  53. # class handles all the common functionality for argument parsing, environment
  54. # variable handling, and configuration management.
  55. #
  56. # Instances of a derived class of BindingOptions can be used to store multiple
  57. # runtime configurations of options for a single LLM provider. using the
  58. # asdict() method to convert the options to a dictionary.
  59. #
  60. # =============================================================================
  61. @dataclass
  62. class BindingOptions:
  63. """Base class for binding options."""
  64. # mandatory name of binding
  65. _binding_name: ClassVar[str]
  66. # optional help message for each option
  67. _help: ClassVar[dict[str, str]]
  68. @staticmethod
  69. def _all_class_vars(klass: type, include_inherited=True) -> dict[str, Any]:
  70. """Print class variables, optionally including inherited ones"""
  71. if include_inherited:
  72. # Get all class variables from MRO
  73. vars_dict = {}
  74. for base in reversed(klass.__mro__[:-1]): # Exclude 'object'
  75. vars_dict.update(
  76. {
  77. k: v
  78. for k, v in base.__dict__.items()
  79. if (
  80. not k.startswith("_")
  81. and not callable(v)
  82. and not isinstance(v, classmethod)
  83. )
  84. }
  85. )
  86. else:
  87. # Only direct class variables
  88. vars_dict = {
  89. k: v
  90. for k, v in klass.__dict__.items()
  91. if (
  92. not k.startswith("_")
  93. and not callable(v)
  94. and not isinstance(v, classmethod)
  95. )
  96. }
  97. return vars_dict
  98. @classmethod
  99. def add_args(cls, parser: ArgumentParser):
  100. group = parser.add_argument_group(f"{cls._binding_name} binding options")
  101. for arg_item in cls.args_env_name_type_value():
  102. # Handle JSON parsing for list types
  103. if arg_item["type"] is List[str]:
  104. def json_list_parser(value):
  105. try:
  106. parsed = json.loads(value)
  107. if not isinstance(parsed, list):
  108. raise argparse.ArgumentTypeError(
  109. f"Expected JSON array, got {type(parsed).__name__}"
  110. )
  111. return parsed
  112. except json.JSONDecodeError as e:
  113. raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
  114. # Get environment variable with JSON parsing
  115. env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
  116. if env_value is not argparse.SUPPRESS:
  117. try:
  118. env_value = json_list_parser(env_value)
  119. except argparse.ArgumentTypeError:
  120. env_value = argparse.SUPPRESS
  121. group.add_argument(
  122. f"--{arg_item['argname']}",
  123. type=json_list_parser,
  124. default=env_value,
  125. help=arg_item["help"],
  126. )
  127. # Handle JSON parsing for dict types
  128. elif arg_item["type"] is dict:
  129. def json_dict_parser(value):
  130. try:
  131. parsed = json.loads(value)
  132. if not isinstance(parsed, dict):
  133. raise argparse.ArgumentTypeError(
  134. f"Expected JSON object, got {type(parsed).__name__}"
  135. )
  136. return parsed
  137. except json.JSONDecodeError as e:
  138. raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
  139. # Get environment variable with JSON parsing
  140. env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
  141. if env_value is not argparse.SUPPRESS:
  142. try:
  143. env_value = json_dict_parser(env_value)
  144. except argparse.ArgumentTypeError:
  145. env_value = argparse.SUPPRESS
  146. group.add_argument(
  147. f"--{arg_item['argname']}",
  148. type=json_dict_parser,
  149. default=env_value,
  150. help=arg_item["help"],
  151. )
  152. # Handle boolean types specially to avoid argparse bool() constructor issues
  153. elif arg_item["type"] is bool:
  154. def bool_parser(value):
  155. """Custom boolean parser that handles string representations correctly"""
  156. if isinstance(value, bool):
  157. return value
  158. if isinstance(value, str):
  159. return value.lower() in ("true", "1", "yes", "t", "on")
  160. return bool(value)
  161. # Get environment variable with proper type conversion
  162. env_value = get_env_value(
  163. f"{arg_item['env_name']}", argparse.SUPPRESS, bool
  164. )
  165. group.add_argument(
  166. f"--{arg_item['argname']}",
  167. type=bool_parser,
  168. default=env_value,
  169. help=arg_item["help"],
  170. )
  171. else:
  172. resolved_type = arg_item["type"]
  173. if resolved_type is not None:
  174. resolved_type = _resolve_optional_type(resolved_type)
  175. group.add_argument(
  176. f"--{arg_item['argname']}",
  177. type=resolved_type,
  178. default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
  179. help=arg_item["help"],
  180. )
  181. @classmethod
  182. def args_env_name_type_value(cls):
  183. import dataclasses
  184. args_prefix = f"{cls._binding_name}".replace("_", "-")
  185. env_var_prefix = f"{cls._binding_name}_".upper()
  186. help = cls._help
  187. # Check if this is a dataclass and use dataclass fields
  188. if dataclasses.is_dataclass(cls):
  189. for field in dataclasses.fields(cls):
  190. # Skip private fields
  191. if field.name.startswith("_"):
  192. continue
  193. # Get default value
  194. if field.default is not dataclasses.MISSING:
  195. default_value = field.default
  196. elif field.default_factory is not dataclasses.MISSING:
  197. default_value = field.default_factory()
  198. else:
  199. default_value = None
  200. argdef = {
  201. "argname": f"{args_prefix}-{field.name}",
  202. "env_name": f"{env_var_prefix}{field.name.upper()}",
  203. "type": _resolve_optional_type(field.type),
  204. "default": default_value,
  205. "help": f"{cls._binding_name} -- " + help.get(field.name, ""),
  206. }
  207. yield argdef
  208. else:
  209. # Fallback to old method for non-dataclass classes
  210. class_vars = {
  211. key: value
  212. for key, value in cls._all_class_vars(cls).items()
  213. if not callable(value) and not key.startswith("_")
  214. }
  215. # Get type hints to properly detect List[str] types
  216. type_hints = {}
  217. for base in cls.__mro__:
  218. if hasattr(base, "__annotations__"):
  219. type_hints.update(base.__annotations__)
  220. for class_var in class_vars:
  221. # Use type hint if available, otherwise fall back to type of value
  222. var_type = type_hints.get(class_var, type(class_vars[class_var]))
  223. argdef = {
  224. "argname": f"{args_prefix}-{class_var}",
  225. "env_name": f"{env_var_prefix}{class_var.upper()}",
  226. "type": var_type,
  227. "default": class_vars[class_var],
  228. "help": f"{cls._binding_name} -- " + help.get(class_var, ""),
  229. }
  230. yield argdef
  231. @classmethod
  232. def generate_dot_env_sample(cls):
  233. """
  234. Generate a sample .env file for all LightRAG binding options.
  235. This method creates a .env file that includes all the binding options
  236. defined by the subclasses of BindingOptions. It uses the args_env_name_type_value()
  237. method to get the list of all options and their default values.
  238. Returns:
  239. str: A string containing the contents of the sample .env file.
  240. """
  241. from io import StringIO
  242. sample_top = (
  243. "#" * 80
  244. + "\n"
  245. + (
  246. "# Autogenerated .env entries list for LightRAG binding options\n"
  247. "#\n"
  248. "# To generate run:\n"
  249. "# $ python -m lightrag.llm.binding_options\n"
  250. )
  251. + "#" * 80
  252. + "\n"
  253. )
  254. sample_bottom = (
  255. ("#\n# End of .env entries for LightRAG binding options\n")
  256. + "#" * 80
  257. + "\n"
  258. )
  259. sample_stream = StringIO()
  260. sample_stream.write(sample_top)
  261. for klass in cls.__subclasses__():
  262. for arg_item in klass.args_env_name_type_value():
  263. if arg_item["help"]:
  264. sample_stream.write(f"# {arg_item['help']}\n")
  265. # Handle JSON formatting for list and dict types
  266. if arg_item["type"] is List[str] or arg_item["type"] is dict:
  267. default_value = json.dumps(arg_item["default"])
  268. else:
  269. default_value = arg_item["default"]
  270. sample_stream.write(f"# {arg_item['env_name']}={default_value}\n\n")
  271. sample_stream.write(sample_bottom)
  272. return sample_stream.getvalue()
  273. @classmethod
  274. def options_dict(cls, args: Namespace) -> dict[str, Any]:
  275. """
  276. Extract options dictionary for a specific binding from parsed arguments.
  277. This method filters the parsed command-line arguments to return only those
  278. that belong to the specific binding class. It removes the binding prefix
  279. from argument names to create a clean options dictionary.
  280. Args:
  281. args (Namespace): Parsed command-line arguments containing all binding options
  282. Returns:
  283. dict[str, Any]: Dictionary mapping option names (without prefix) to their values
  284. Example:
  285. If args contains {'ollama_num_ctx': 512, 'other_option': 'value'}
  286. and this is called on OllamaOptions, it returns {'num_ctx': 512}
  287. """
  288. prefix = cls._binding_name + "_"
  289. skipchars = len(prefix)
  290. options = {
  291. key[skipchars:]: value
  292. for key, value in vars(args).items()
  293. if key.startswith(prefix)
  294. }
  295. return options
  296. @classmethod
  297. def options_dict_for_role(
  298. cls, args: Namespace, role: str, is_cross_provider: bool = False
  299. ) -> dict[str, Any]:
  300. """
  301. Extract role-specific provider options with proper inheritance.
  302. Same provider:
  303. - inherit the base binding options from parsed args
  304. - overlay any role-specific environment variable overrides
  305. Cross provider:
  306. - start from empty provider options
  307. - overlay any role-specific environment variable overrides
  308. Role env vars follow the pattern:
  309. `{ROLE}_{BINDING_PREFIX}_{FIELD}`
  310. e.g. `EXTRACT_OPENAI_LLM_TEMPERATURE`
  311. """
  312. import os
  313. if is_cross_provider:
  314. base: dict[str, Any] = {}
  315. else:
  316. base = cls.options_dict(args)
  317. role_upper = role.upper()
  318. env_prefix = cls._binding_name.upper() + "_"
  319. for arg_item in cls.args_env_name_type_value():
  320. original_env = arg_item["env_name"]
  321. role_env = f"{role_upper}_{original_env}"
  322. field_name = original_env[len(env_prefix) :].lower()
  323. env_raw = os.getenv(role_env)
  324. if env_raw is None:
  325. continue
  326. field_type = _resolve_optional_type(arg_item["type"])
  327. try:
  328. if field_type is bool:
  329. base[field_name] = env_raw.lower() in (
  330. "true",
  331. "1",
  332. "yes",
  333. "t",
  334. "on",
  335. )
  336. elif field_type in (list, List[str]):
  337. base[field_name] = json.loads(env_raw)
  338. elif field_type is dict:
  339. base[field_name] = json.loads(env_raw)
  340. elif field_type is int:
  341. base[field_name] = int(env_raw)
  342. elif field_type is float:
  343. base[field_name] = float(env_raw)
  344. else:
  345. base[field_name] = env_raw
  346. except (ValueError, json.JSONDecodeError):
  347. base[field_name] = env_raw
  348. return base
  349. def asdict(self) -> dict[str, Any]:
  350. """
  351. Convert an instance of binding options to a dictionary.
  352. This method uses dataclasses.asdict() to convert the dataclass instance
  353. into a dictionary representation, including all its fields and values.
  354. Returns:
  355. dict[str, Any]: Dictionary representation of the binding options instance
  356. """
  357. return asdict(self)
  358. # =============================================================================
  359. # Binding Options for Ollama
  360. # =============================================================================
  361. #
  362. # Ollama binding options provide configuration for the Ollama local LLM server.
  363. # These options control model behavior, sampling parameters, hardware utilization,
  364. # and performance settings. The parameters are based on Ollama's API specification
  365. # and provide fine-grained control over model inference and generation.
  366. #
  367. # The _OllamaOptionsMixin defines the complete set of available options, while
  368. # OllamaEmbeddingOptions and OllamaLLMOptions provide specialized configurations
  369. # for embedding and language model tasks respectively.
  370. # =============================================================================
  371. @dataclass
  372. class _OllamaOptionsMixin:
  373. """Options for Ollama bindings."""
  374. # Core context and generation parameters
  375. num_ctx: int = 32768 # Context window size (number of tokens)
  376. num_predict: int = 128 # Maximum number of tokens to predict
  377. num_keep: int = 0 # Number of tokens to keep from the initial prompt
  378. seed: int = -1 # Random seed for generation (-1 for random)
  379. # Sampling parameters
  380. temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0-2.0)
  381. top_k: int = 40 # Top-k sampling parameter
  382. top_p: float = 0.9 # Top-p (nucleus) sampling parameter
  383. tfs_z: float = 1.0 # Tail free sampling parameter
  384. typical_p: float = 1.0 # Typical probability mass
  385. min_p: float = 0.0 # Minimum probability threshold
  386. # Repetition control
  387. repeat_last_n: int = 64 # Number of tokens to consider for repetition penalty
  388. repeat_penalty: float = 1.1 # Penalty for repetition
  389. presence_penalty: float = 0.0 # Penalty for token presence
  390. frequency_penalty: float = 0.0 # Penalty for token frequency
  391. # Mirostat sampling
  392. mirostat: int = (
  393. # Mirostat sampling algorithm (0=disabled, 1=Mirostat 1.0, 2=Mirostat 2.0)
  394. 0
  395. )
  396. mirostat_tau: float = 5.0 # Mirostat target entropy
  397. mirostat_eta: float = 0.1 # Mirostat learning rate
  398. # Hardware and performance parameters
  399. numa: bool = False # Enable NUMA optimization
  400. num_batch: int = 512 # Batch size for processing
  401. num_gpu: int = -1 # Number of GPUs to use (-1 for auto)
  402. main_gpu: int = 0 # Main GPU index
  403. low_vram: bool = False # Optimize for low VRAM
  404. num_thread: int = 0 # Number of CPU threads (0 for auto)
  405. # Memory and model parameters
  406. f16_kv: bool = True # Use half-precision for key/value cache
  407. logits_all: bool = False # Return logits for all tokens
  408. vocab_only: bool = False # Only load vocabulary
  409. use_mmap: bool = True # Use memory mapping for model files
  410. use_mlock: bool = False # Lock model in memory
  411. embedding_only: bool = False # Only use for embeddings
  412. # Output control
  413. penalize_newline: bool = True # Penalize newline tokens
  414. stop: List[str] = field(default_factory=list) # Stop sequences
  415. # optional help strings
  416. _help: ClassVar[dict[str, str]] = {
  417. "num_ctx": "Context window size (number of tokens)",
  418. "num_predict": "Maximum number of tokens to predict",
  419. "num_keep": "Number of tokens to keep from the initial prompt",
  420. "seed": "Random seed for generation (-1 for random)",
  421. "temperature": "Controls randomness (0.0-2.0, higher = more creative)",
  422. "top_k": "Top-k sampling parameter (0 = disabled)",
  423. "top_p": "Top-p (nucleus) sampling parameter (0.0-1.0)",
  424. "tfs_z": "Tail free sampling parameter (1.0 = disabled)",
  425. "typical_p": "Typical probability mass (1.0 = disabled)",
  426. "min_p": "Minimum probability threshold (0.0 = disabled)",
  427. "repeat_last_n": "Number of tokens to consider for repetition penalty",
  428. "repeat_penalty": "Penalty for repetition (1.0 = no penalty)",
  429. "presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
  430. "frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
  431. "mirostat": "Mirostat sampling algorithm (0=disabled, 1=Mirostat 1.0, 2=Mirostat 2.0)",
  432. "mirostat_tau": "Mirostat target entropy",
  433. "mirostat_eta": "Mirostat learning rate",
  434. "numa": "Enable NUMA optimization",
  435. "num_batch": "Batch size for processing",
  436. "num_gpu": "Number of GPUs to use (-1 for auto)",
  437. "main_gpu": "Main GPU index",
  438. "low_vram": "Optimize for low VRAM",
  439. "num_thread": "Number of CPU threads (0 for auto)",
  440. "f16_kv": "Use half-precision for key/value cache",
  441. "logits_all": "Return logits for all tokens",
  442. "vocab_only": "Only load vocabulary",
  443. "use_mmap": "Use memory mapping for model files",
  444. "use_mlock": "Lock model in memory",
  445. "embedding_only": "Only use for embeddings",
  446. "penalize_newline": "Penalize newline tokens",
  447. "stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')',
  448. }
  449. @dataclass
  450. class OllamaEmbeddingOptions(_OllamaOptionsMixin, BindingOptions):
  451. """Options for Ollama embeddings with specialized configuration for embedding tasks."""
  452. # mandatory name of binding
  453. _binding_name: ClassVar[str] = "ollama_embedding"
  454. @dataclass
  455. class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
  456. """Options for Ollama LLM with specialized configuration for LLM tasks."""
  457. # mandatory name of binding
  458. _binding_name: ClassVar[str] = "ollama_llm"
  459. # =============================================================================
  460. # Binding Options for Gemini
  461. # =============================================================================
  462. @dataclass
  463. class GeminiLLMOptions(BindingOptions):
  464. """Options for Google Gemini models."""
  465. _binding_name: ClassVar[str] = "gemini_llm"
  466. temperature: float = DEFAULT_TEMPERATURE
  467. top_p: float = 0.95
  468. top_k: int = 40
  469. max_output_tokens: int | None = None
  470. candidate_count: int = 1
  471. presence_penalty: float = 0.0
  472. frequency_penalty: float = 0.0
  473. stop_sequences: List[str] = field(default_factory=list)
  474. seed: int | None = None
  475. thinking_config: dict | None = None
  476. safety_settings: dict | None = None
  477. _help: ClassVar[dict[str, str]] = {
  478. "temperature": "Controls randomness (0.0-2.0, higher = more creative)",
  479. "top_p": "Nucleus sampling parameter (0.0-1.0)",
  480. "top_k": "Limits sampling to the top K tokens (1 disables the limit)",
  481. "max_output_tokens": "Maximum tokens generated in the response",
  482. "candidate_count": "Number of candidates returned per request",
  483. "presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
  484. "frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
  485. "stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')",
  486. "seed": "Random seed for reproducible generation (leave empty for random)",
  487. "thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')",
  488. "safety_settings": "JSON object with Gemini safety settings overrides",
  489. }
  490. @dataclass
  491. class GeminiEmbeddingOptions(BindingOptions):
  492. """Options for Google Gemini embedding models."""
  493. _binding_name: ClassVar[str] = "gemini_embedding"
  494. task_type: str | None = None
  495. _help: ClassVar[dict[str, str]] = {
  496. "task_type": "Task type for embedding optimization. If not specified, automatically determined from context (RETRIEVAL_QUERY for queries, RETRIEVAL_DOCUMENT for documents). Supported types: RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, CODE_RETRIEVAL_QUERY, QUESTION_ANSWERING, FACT_VERIFICATION",
  497. }
  498. # =============================================================================
  499. # Binding Options for OpenAI
  500. # =============================================================================
  501. #
  502. # OpenAI binding options provide configuration for OpenAI's API and Azure OpenAI.
  503. # These options control model behavior, sampling parameters, and generation settings.
  504. # The parameters are based on OpenAI's API specification and provide fine-grained
  505. # control over model inference and generation.
  506. #
  507. # =============================================================================
  508. @dataclass
  509. class OpenAILLMOptions(BindingOptions):
  510. """Options for OpenAI LLM with configuration for OpenAI and Azure OpenAI API calls."""
  511. # mandatory name of binding
  512. _binding_name: ClassVar[str] = "openai_llm"
  513. # Sampling and generation parameters
  514. frequency_penalty: float = 0.0 # Penalty for token frequency (-2.0 to 2.0)
  515. max_completion_tokens: int = None # Maximum number of tokens to generate
  516. presence_penalty: float = 0.0 # Penalty for token presence (-2.0 to 2.0)
  517. reasoning_effort: str = "medium" # Reasoning effort level (low, medium, high)
  518. safety_identifier: str = "" # Safety identifier for content filtering
  519. service_tier: str = "" # Service tier for API usage
  520. stop: List[str] = field(default_factory=list) # Stop sequences
  521. temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0 to 2.0)
  522. top_p: float = 1.0 # Nucleus sampling parameter (0.0 to 1.0)
  523. max_tokens: int = None # Maximum number of tokens to generate(deprecated, use max_completion_tokens instead)
  524. extra_body: dict = None # Extra body parameters for OpenRouter of vLLM
  525. # Help descriptions
  526. _help: ClassVar[dict[str, str]] = {
  527. "frequency_penalty": "Penalty for token frequency (-2.0 to 2.0, positive values discourage repetition)",
  528. "max_completion_tokens": "Maximum number of tokens to generate (optional, leave empty for model default)",
  529. "presence_penalty": "Penalty for token presence (-2.0 to 2.0, positive values encourage new topics)",
  530. "reasoning_effort": "Reasoning effort level for o1 models (low, medium, high)",
  531. "safety_identifier": "Safety identifier for content filtering (optional)",
  532. "service_tier": "Service tier for API usage (optional)",
  533. "stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')',
  534. "temperature": "Controls randomness (0.0-2.0, higher = more creative)",
  535. "top_p": "Nucleus sampling parameter (0.0-1.0, lower = more focused)",
  536. "max_tokens": "Maximum number of tokens to generate (deprecated, use max_completion_tokens instead)",
  537. "extra_body": 'Extra body parameters for OpenRouter of vLLM (JSON dict, e.g., \'"reasoning": {"reasoning": {"enabled": false}}\')',
  538. }
  539. # =============================================================================
  540. # Binding Options for AWS Bedrock
  541. # =============================================================================
  542. #
  543. # Bedrock binding options map to the subset of the Bedrock Converse API
  544. # inferenceConfig that LightRAG's bedrock driver actually forwards. See
  545. # ``lightrag/llm/bedrock.py`` for the whitelist — any field added here that is
  546. # not in that whitelist will be silently dropped by the driver.
  547. # =============================================================================
  548. @dataclass
  549. class BedrockLLMOptions(BindingOptions):
  550. """Options for AWS Bedrock LLM (Converse API inferenceConfig)."""
  551. _binding_name: ClassVar[str] = "bedrock_llm"
  552. temperature: float = DEFAULT_TEMPERATURE
  553. max_tokens: int | None = None
  554. top_p: float = 1.0
  555. stop_sequences: List[str] = field(default_factory=list)
  556. extra_fields: dict = None # Converse API additionalModelRequestFields
  557. _help: ClassVar[dict[str, str]] = {
  558. "temperature": "Controls randomness (0.0-1.0 for most Bedrock models)",
  559. "max_tokens": "Maximum tokens generated in the response (leave empty for model default)",
  560. "top_p": "Nucleus sampling parameter (0.0-1.0)",
  561. "stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"</s>\"]')",
  562. "extra_fields": 'Model-specific request fields forwarded as Converse API additionalModelRequestFields (JSON dict, e.g., \'{"reasoning_config": {"type": "enabled"}}\')',
  563. }
  564. # =============================================================================
  565. # Main Section - For Testing and Sample Generation
  566. # =============================================================================
  567. #
  568. # When run as a script, this module:
  569. # 1. Generates and prints a sample .env file with all binding options
  570. # 2. If "test" argument is provided, demonstrates argument parsing with Ollama binding
  571. #
  572. # Usage:
  573. # python -m lightrag.llm.binding_options # Generate .env sample
  574. # python -m lightrag.llm.binding_options test # Test argument parsing
  575. #
  576. # =============================================================================
  577. if __name__ == "__main__":
  578. import sys
  579. import dotenv
  580. # from io import StringIO
  581. dotenv.load_dotenv(dotenv_path=".env", override=False)
  582. # env_strstream = StringIO(
  583. # ("OLLAMA_LLM_TEMPERATURE=0.1\nOLLAMA_EMBEDDING_TEMPERATURE=0.2\n")
  584. # )
  585. # # Load environment variables from .env file
  586. # dotenv.load_dotenv(stream=env_strstream)
  587. if len(sys.argv) > 1 and sys.argv[1] == "test":
  588. # Add arguments for OllamaEmbeddingOptions, OllamaLLMOptions, and OpenAILLMOptions
  589. parser = ArgumentParser(description="Test binding options")
  590. OllamaEmbeddingOptions.add_args(parser)
  591. OllamaLLMOptions.add_args(parser)
  592. OpenAILLMOptions.add_args(parser)
  593. # Parse arguments test
  594. args = parser.parse_args(
  595. [
  596. "--ollama-embedding-num_ctx",
  597. "1024",
  598. "--ollama-llm-num_ctx",
  599. "2048",
  600. "--openai-llm-temperature",
  601. "0.7",
  602. "--openai-llm-max_completion_tokens",
  603. "1000",
  604. "--openai-llm-stop",
  605. '["</s>", "\\n\\n"]',
  606. "--openai-llm-reasoning",
  607. '{"effort": "high", "max_tokens": 2000, "exclude": false, "enabled": true}',
  608. ]
  609. )
  610. print("Final args for LLM and Embedding:")
  611. print(f"{args}\n")
  612. print("Ollama LLM options:")
  613. print(OllamaLLMOptions.options_dict(args))
  614. print("\nOllama Embedding options:")
  615. print(OllamaEmbeddingOptions.options_dict(args))
  616. print("\nOpenAI LLM options:")
  617. print(OpenAILLMOptions.options_dict(args))
  618. # Test creating OpenAI options instance
  619. openai_options = OpenAILLMOptions(
  620. temperature=0.8,
  621. max_completion_tokens=1500,
  622. frequency_penalty=0.1,
  623. presence_penalty=0.2,
  624. stop=["<|end|>", "\n\n"],
  625. )
  626. print("\nOpenAI LLM options instance:")
  627. print(openai_options.asdict())
  628. # Test creating OpenAI options instance with reasoning parameter
  629. openai_options_with_reasoning = OpenAILLMOptions(
  630. temperature=0.9,
  631. max_completion_tokens=2000,
  632. reasoning={
  633. "effort": "medium",
  634. "max_tokens": 1500,
  635. "exclude": True,
  636. "enabled": True,
  637. },
  638. )
  639. print("\nOpenAI LLM options instance with reasoning:")
  640. print(openai_options_with_reasoning.asdict())
  641. # Test dict parsing functionality
  642. print("\n" + "=" * 50)
  643. print("TESTING DICT PARSING FUNCTIONALITY")
  644. print("=" * 50)
  645. # Test valid JSON dict parsing
  646. test_parser = ArgumentParser(description="Test dict parsing")
  647. OpenAILLMOptions.add_args(test_parser)
  648. try:
  649. test_args = test_parser.parse_args(
  650. ["--openai-llm-reasoning", '{"effort": "low", "max_tokens": 1000}']
  651. )
  652. print("✓ Valid JSON dict parsing successful:")
  653. print(
  654. f" Parsed reasoning: {OpenAILLMOptions.options_dict(test_args)['reasoning']}"
  655. )
  656. except Exception as e:
  657. print(f"✗ Valid JSON dict parsing failed: {e}")
  658. # Test invalid JSON dict parsing
  659. try:
  660. test_args = test_parser.parse_args(
  661. [
  662. "--openai-llm-reasoning",
  663. '{"effort": "low", "max_tokens": 1000', # Missing closing brace
  664. ]
  665. )
  666. print("✗ Invalid JSON should have failed but didn't")
  667. except SystemExit:
  668. print("✓ Invalid JSON dict parsing correctly rejected")
  669. except Exception as e:
  670. print(f"✓ Invalid JSON dict parsing correctly rejected: {e}")
  671. # Test non-dict JSON parsing
  672. try:
  673. test_args = test_parser.parse_args(
  674. [
  675. "--openai-llm-reasoning",
  676. '["not", "a", "dict"]', # Array instead of dict
  677. ]
  678. )
  679. print("✗ Non-dict JSON should have failed but didn't")
  680. except SystemExit:
  681. print("✓ Non-dict JSON parsing correctly rejected")
  682. except Exception as e:
  683. print(f"✓ Non-dict JSON parsing correctly rejected: {e}")
  684. print("\n" + "=" * 50)
  685. print("TESTING ENVIRONMENT VARIABLE SUPPORT")
  686. print("=" * 50)
  687. # Test environment variable support for dict
  688. import os
  689. os.environ["OPENAI_LLM_REASONING"] = (
  690. '{"effort": "high", "max_tokens": 3000, "exclude": false}'
  691. )
  692. env_parser = ArgumentParser(description="Test env var dict parsing")
  693. OpenAILLMOptions.add_args(env_parser)
  694. try:
  695. env_args = env_parser.parse_args(
  696. []
  697. ) # No command line args, should use env var
  698. reasoning_from_env = OpenAILLMOptions.options_dict(env_args).get(
  699. "reasoning"
  700. )
  701. if reasoning_from_env:
  702. print("✓ Environment variable dict parsing successful:")
  703. print(f" Parsed reasoning from env: {reasoning_from_env}")
  704. else:
  705. print("✗ Environment variable dict parsing failed: No reasoning found")
  706. except Exception as e:
  707. print(f"✗ Environment variable dict parsing failed: {e}")
  708. finally:
  709. # Clean up environment variable
  710. if "OPENAI_LLM_REASONING" in os.environ:
  711. del os.environ["OPENAI_LLM_REASONING"]
  712. else:
  713. print(BindingOptions.generate_dot_env_sample())