hf.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import copy
  2. import os
  3. import warnings
  4. from functools import lru_cache
  5. import pipmaster as pm # Pipmaster for dynamic library install
  6. # install specific modules
  7. if not pm.is_installed("transformers"):
  8. pm.install("transformers")
  9. if not pm.is_installed("torch"):
  10. pm.install("torch")
  11. if not pm.is_installed("numpy"):
  12. pm.install("numpy")
  13. from transformers import AutoTokenizer, AutoModelForCausalLM
  14. from tenacity import (
  15. retry,
  16. stop_after_attempt,
  17. wait_exponential,
  18. retry_if_exception_type,
  19. )
  20. from lightrag.exceptions import (
  21. APIConnectionError,
  22. RateLimitError,
  23. APITimeoutError,
  24. )
  25. import torch
  26. import numpy as np
  27. from lightrag.utils import wrap_embedding_func_with_attrs
  28. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  29. @lru_cache(maxsize=1)
  30. def initialize_hf_model(model_name):
  31. hf_tokenizer = AutoTokenizer.from_pretrained(
  32. model_name, device_map="auto", trust_remote_code=True
  33. )
  34. hf_model = AutoModelForCausalLM.from_pretrained(
  35. model_name, device_map="auto", trust_remote_code=True
  36. )
  37. if hf_tokenizer.pad_token is None:
  38. hf_tokenizer.pad_token = hf_tokenizer.eos_token
  39. return hf_model, hf_tokenizer
  40. @retry(
  41. stop=stop_after_attempt(3),
  42. wait=wait_exponential(multiplier=1, min=4, max=10),
  43. retry=retry_if_exception_type(
  44. (RateLimitError, APIConnectionError, APITimeoutError)
  45. ),
  46. )
  47. async def hf_model_if_cache(
  48. model,
  49. prompt,
  50. system_prompt=None,
  51. history_messages=[],
  52. enable_cot: bool = False,
  53. **kwargs,
  54. ) -> str:
  55. if enable_cot:
  56. from lightrag.utils import logger
  57. logger.debug(
  58. "enable_cot=True is not supported for Hugging Face local models and will be ignored."
  59. )
  60. model_name = model
  61. hf_model, hf_tokenizer = initialize_hf_model(model_name)
  62. messages = []
  63. if system_prompt:
  64. messages.append({"role": "system", "content": system_prompt})
  65. messages.extend(history_messages)
  66. messages.append({"role": "user", "content": prompt})
  67. kwargs.pop("hashing_kv", None)
  68. input_prompt = ""
  69. try:
  70. input_prompt = hf_tokenizer.apply_chat_template(
  71. messages, tokenize=False, add_generation_prompt=True
  72. )
  73. except Exception:
  74. try:
  75. ori_message = copy.deepcopy(messages)
  76. if messages[0]["role"] == "system":
  77. messages[1]["content"] = (
  78. "<system>"
  79. + messages[0]["content"]
  80. + "</system>\n"
  81. + messages[1]["content"]
  82. )
  83. messages = messages[1:]
  84. input_prompt = hf_tokenizer.apply_chat_template(
  85. messages, tokenize=False, add_generation_prompt=True
  86. )
  87. except Exception:
  88. len_message = len(ori_message)
  89. for msgid in range(len_message):
  90. input_prompt = (
  91. input_prompt
  92. + "<"
  93. + ori_message[msgid]["role"]
  94. + ">"
  95. + ori_message[msgid]["content"]
  96. + "</"
  97. + ori_message[msgid]["role"]
  98. + ">\n"
  99. )
  100. input_ids = hf_tokenizer(
  101. input_prompt, return_tensors="pt", padding=True, truncation=True
  102. ).to("cuda")
  103. inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
  104. output = hf_model.generate(
  105. **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
  106. )
  107. response_text = hf_tokenizer.decode(
  108. output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
  109. )
  110. return response_text
  111. async def hf_model_complete(
  112. prompt,
  113. system_prompt=None,
  114. history_messages=[],
  115. keyword_extraction=False,
  116. entity_extraction=False,
  117. enable_cot: bool = False,
  118. **kwargs,
  119. ) -> str:
  120. """Run local Hugging Face inference with LightRAG-compatible shims.
  121. Structured output note:
  122. - This adapter does not support OpenAI-style ``response_format`` JSON mode.
  123. - If callers pass ``response_format``, it is stripped before generation.
  124. - Deprecated ``keyword_extraction`` and ``entity_extraction`` booleans are
  125. accepted only as compatibility shims; they emit warnings and are ignored.
  126. """
  127. # HuggingFace local inference has no JSON mode; drop response_format and
  128. # warn when legacy shim flags are set.
  129. if kwargs.pop("keyword_extraction", False) or keyword_extraction:
  130. warnings.warn(
  131. "hf_model_complete(keyword_extraction=True) is deprecated; "
  132. "pass response_format={'type': 'json_object'} instead.",
  133. DeprecationWarning,
  134. stacklevel=2,
  135. )
  136. if kwargs.pop("entity_extraction", False) or entity_extraction:
  137. warnings.warn(
  138. "hf_model_complete(entity_extraction=True) is deprecated; "
  139. "pass response_format={'type': 'json_object'} instead.",
  140. DeprecationWarning,
  141. stacklevel=2,
  142. )
  143. kwargs.pop("response_format", None)
  144. model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
  145. result = await hf_model_if_cache(
  146. model_name,
  147. prompt,
  148. system_prompt=system_prompt,
  149. history_messages=history_messages,
  150. enable_cot=enable_cot,
  151. **kwargs,
  152. )
  153. return result
  154. @wrap_embedding_func_with_attrs(
  155. embedding_dim=1024,
  156. max_token_size=8192,
  157. model_name="hf_embedding_model",
  158. supports_asymmetric=True,
  159. )
  160. async def hf_embed(
  161. texts: list[str],
  162. tokenizer,
  163. embed_model,
  164. context: str = "document",
  165. query_prefix: str | None = None,
  166. document_prefix: str | None = None,
  167. ) -> np.ndarray:
  168. """Generate embeddings for a list of texts using a Hugging Face model.
  169. Args:
  170. texts (list[str]): List of input texts to embed.
  171. tokenizer: Hugging Face tokenizer.
  172. embed_model: Hugging Face model for generating embeddings.
  173. context (str): Context indicating whether the texts are "query" or "document".
  174. query_prefix (str | None): Optional prefix to add to query texts.
  175. document_prefix (str | None): Optional prefix to add to document texts.
  176. Returns:
  177. np.ndarray: Array of embeddings.
  178. """
  179. # Detect the appropriate device
  180. if torch.cuda.is_available():
  181. device = next(embed_model.parameters()).device # Use CUDA if available
  182. elif torch.backends.mps.is_available():
  183. device = torch.device("mps") # Use MPS for Apple Silicon
  184. else:
  185. device = torch.device("cpu") # Fallback to CPU
  186. # Move the model to the detected device
  187. embed_model = embed_model.to(device)
  188. # Apply context-based prefixes if provided
  189. if context == "query" and query_prefix:
  190. texts = [query_prefix + text for text in texts]
  191. elif context == "document" and document_prefix:
  192. texts = [document_prefix + text for text in texts]
  193. # Tokenize the input texts and move them to the same device
  194. encoded_texts = tokenizer(
  195. texts, return_tensors="pt", padding=True, truncation=True
  196. ).to(device)
  197. # Perform inference
  198. with torch.no_grad():
  199. outputs = embed_model(
  200. input_ids=encoded_texts["input_ids"],
  201. attention_mask=encoded_texts["attention_mask"],
  202. )
  203. embeddings = outputs.last_hidden_state.mean(dim=1)
  204. # Convert embeddings to NumPy
  205. if embeddings.dtype == torch.bfloat16:
  206. return embeddings.detach().to(torch.float32).cpu().numpy()
  207. else:
  208. return embeddings.detach().cpu().numpy()