download_cache.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. """
  2. Download all necessary cache files for offline deployment.
  3. This module provides a CLI command to download tiktoken model cache files
  4. for offline environments where internet access is not available.
  5. """
  6. import os
  7. import sys
  8. from pathlib import Path
  9. # Known tiktoken encoding names (not model names)
  10. # These need to be loaded with tiktoken.get_encoding() instead of tiktoken.encoding_for_model()
  11. TIKTOKEN_ENCODING_NAMES = {"cl100k_base", "p50k_base", "r50k_base", "o200k_base"}
  12. def download_tiktoken_cache(cache_dir: str = None, models: list = None):
  13. """Download tiktoken models to local cache
  14. Args:
  15. cache_dir: Directory to store the cache files. If None, uses tiktoken's default location.
  16. models: List of model names or encoding names to download. If None, downloads common ones.
  17. Returns:
  18. Tuple of (success_count, failed_models, actual_cache_dir)
  19. """
  20. # If user specified a cache directory, set it BEFORE importing tiktoken
  21. # tiktoken reads TIKTOKEN_CACHE_DIR at import time
  22. user_specified_cache = cache_dir is not None
  23. if user_specified_cache:
  24. cache_dir = os.path.abspath(cache_dir)
  25. os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
  26. cache_path = Path(cache_dir)
  27. cache_path.mkdir(parents=True, exist_ok=True)
  28. print(f"Using specified cache directory: {cache_dir}")
  29. else:
  30. # Check if TIKTOKEN_CACHE_DIR is already set in environment
  31. env_cache_dir = os.environ.get("TIKTOKEN_CACHE_DIR")
  32. if env_cache_dir:
  33. cache_dir = env_cache_dir
  34. print(f"Using TIKTOKEN_CACHE_DIR from environment: {cache_dir}")
  35. else:
  36. # Use tiktoken's default location (tempdir/data-gym-cache)
  37. import tempfile
  38. cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache")
  39. print(f"Using tiktoken default cache directory: {cache_dir}")
  40. # Now import tiktoken (it will use the cache directory we determined)
  41. try:
  42. import tiktoken
  43. except ImportError:
  44. print("Error: tiktoken is not installed.")
  45. print("Install with: pip install tiktoken")
  46. sys.exit(1)
  47. # Common models used by LightRAG and OpenAI
  48. if models is None:
  49. models = [
  50. "gpt-4o-mini", # Default model for LightRAG
  51. "gpt-4o", # GPT-4 Omni
  52. "gpt-4", # GPT-4
  53. "gpt-3.5-turbo", # GPT-3.5 Turbo
  54. "text-embedding-ada-002", # Legacy embedding model
  55. "text-embedding-3-small", # Small embedding model
  56. "text-embedding-3-large", # Large embedding model
  57. "cl100k_base", # Default encoding for LightRAG
  58. ]
  59. print(f"\nDownloading {len(models)} tiktoken models...")
  60. print("=" * 70)
  61. success_count = 0
  62. failed_models = []
  63. for i, model in enumerate(models, 1):
  64. try:
  65. print(f"[{i}/{len(models)}] Downloading {model}...", end=" ", flush=True)
  66. # Use get_encoding for encoding names, encoding_for_model for model names
  67. if model in TIKTOKEN_ENCODING_NAMES:
  68. encoding = tiktoken.get_encoding(model)
  69. else:
  70. encoding = tiktoken.encoding_for_model(model)
  71. # Trigger download by encoding a test string
  72. encoding.encode("test")
  73. print("✓ Done")
  74. success_count += 1
  75. except KeyError as e:
  76. print(f"✗ Failed: Unknown model or encoding '{model}'")
  77. failed_models.append((model, str(e)))
  78. except Exception as e:
  79. print(f"✗ Failed: {e}")
  80. failed_models.append((model, str(e)))
  81. print("=" * 70)
  82. print(f"\n✓ Successfully cached {success_count}/{len(models)} models")
  83. if failed_models:
  84. print(f"\n✗ Failed to download {len(failed_models)} models:")
  85. for model, error in failed_models:
  86. print(f" - {model}: {error}")
  87. print(f"\nCache location: {cache_dir}")
  88. print("\nFor offline deployment:")
  89. print(" 1. Copy directory to offline server:")
  90. print(f" tar -czf tiktoken_cache.tar.gz {cache_dir}")
  91. print(" scp tiktoken_cache.tar.gz user@offline-server:/path/to/")
  92. print("")
  93. print(" 2. On offline server, extract and set environment variable:")
  94. print(" tar -xzf tiktoken_cache.tar.gz")
  95. print(" export TIKTOKEN_CACHE_DIR=/path/to/tiktoken_cache")
  96. print("")
  97. print(" 3. Or copy to default location:")
  98. print(f" cp -r {cache_dir} ~/.tiktoken_cache/")
  99. return success_count, failed_models
  100. def main():
  101. """Main entry point for the CLI command"""
  102. import argparse
  103. parser = argparse.ArgumentParser(
  104. prog="lightrag-download-cache",
  105. description="Download cache files for LightRAG offline deployment",
  106. formatter_class=argparse.RawDescriptionHelpFormatter,
  107. epilog="""
  108. Examples:
  109. # Download to default location (~/.tiktoken_cache)
  110. lightrag-download-cache
  111. # Download to specific directory
  112. lightrag-download-cache --cache-dir ./offline_cache/tiktoken
  113. # Download specific models only
  114. lightrag-download-cache --models gpt-4o-mini gpt-4
  115. For more information, visit: https://github.com/HKUDS/LightRAG
  116. """,
  117. )
  118. parser.add_argument(
  119. "--cache-dir",
  120. help="Cache directory path (default: ~/.tiktoken_cache)",
  121. default=None,
  122. )
  123. parser.add_argument(
  124. "--models",
  125. nargs="+",
  126. help="Specific models to download (default: common models)",
  127. default=None,
  128. )
  129. parser.add_argument(
  130. "--version", action="version", version="%(prog)s (LightRAG cache downloader)"
  131. )
  132. args = parser.parse_args()
  133. print("=" * 70)
  134. print("LightRAG Offline Cache Downloader")
  135. print("=" * 70)
  136. try:
  137. success_count, failed_models = download_tiktoken_cache(
  138. args.cache_dir, args.models
  139. )
  140. print("\n" + "=" * 70)
  141. print("Download Complete")
  142. print("=" * 70)
  143. # Exit with error code if all downloads failed
  144. if success_count == 0:
  145. print("\n✗ All downloads failed. Please check your internet connection.")
  146. sys.exit(1)
  147. # Exit with warning code if some downloads failed
  148. elif failed_models:
  149. print(
  150. f"\n⚠ Some downloads failed ({len(failed_models)}/{success_count + len(failed_models)})"
  151. )
  152. sys.exit(2)
  153. else:
  154. print("\n✓ All cache files downloaded successfully!")
  155. sys.exit(0)
  156. except KeyboardInterrupt:
  157. print("\n\n✗ Download interrupted by user")
  158. sys.exit(130)
  159. except Exception as e:
  160. print(f"\n\n✗ Error: {e}")
  161. import traceback
  162. traceback.print_exc()
  163. sys.exit(1)
  164. if __name__ == "__main__":
  165. main()