diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 740f344a8..a8a2c7da8 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -5,20 +5,16 @@ # LICENSE file in the root directory of this source tree. import argparse +import importlib.metadata import json import logging import os import sys from pathlib import Path -import torch - -from torchchat.cli.download import download_and_convert, is_model_downloaded - from torchchat.utils.build_utils import ( allowable_dtype_names, allowable_params_table, - get_device_str, ) logging.basicConfig(level=logging.INFO, format="%(message)s") @@ -42,6 +38,9 @@ # Handle CLI arguments that are common to a majority of subcommands. def check_args(args, verb: str) -> None: + # Local import to avoid unnecessary expensive imports + from torchchat.cli.download import download_and_convert, is_model_downloaded + # Handle model download. Skip this for download, since it has slightly # different semantics. if ( @@ -498,9 +497,10 @@ def _add_speculative_execution_args(parser) -> None: def arg_init(args): - if not (torch.__version__ > "2.3"): + torch_version = importlib.metadata.version("torch") + if not torch_version or (torch_version <= "2.3"): raise RuntimeError( - f"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release" + f"You are using PyTorch {torch_version}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release" ) if sys.version_info.major != 3 or sys.version_info.minor < 10: @@ -521,6 +521,9 @@ def arg_init(args): raise RuntimeError("Device not supported by ExecuTorch") args.device = "cpu" else: + # Localized import to minimize expensive imports + from torchchat.utils.build_utils import get_device_str + args.device = get_device_str( args.quantize.get("executor", {}).get("accelerator", args.device) ) @@ -534,5 +537,8 @@ def arg_init(args): vars(args)["compile_prefill"] = False if hasattr(args, "seed") and args.seed: + # Localized import to minimize expensive imports + import torch + torch.manual_seed(args.seed) return args diff --git a/torchchat/cli/convert_hf_checkpoint.py b/torchchat/cli/convert_hf_checkpoint.py index f95cbdaef..f428e4cc6 100644 --- a/torchchat/cli/convert_hf_checkpoint.py +++ b/torchchat/cli/convert_hf_checkpoint.py @@ -11,25 +11,23 @@ from pathlib import Path from typing import Optional -import torch - -from torchchat.model import TransformerArgs - # support running without installing as a package wd = Path(__file__).parent.parent sys.path.append(str(wd.resolve())) sys.path.append(str((wd / "build").resolve())) -from torchchat.model import ModelArgs - -@torch.inference_mode() def convert_hf_checkpoint( *, model_dir: Optional[Path] = None, model_name: Optional[str] = None, remove_bin_files: bool = False, ) -> None: + + # Local imports to avoid expensive imports + from torchchat.model import ModelArgs, TransformerArgs + import torch + if model_dir is None: model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf") if model_name is None: @@ -58,10 +56,11 @@ def convert_hf_checkpoint( tokenizer_pth = model_dir / "original" / "tokenizer.model" if consolidated_pth.is_file() and tokenizer_pth.is_file(): # Confirm we can load it - loaded_result = torch.load( - str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True - ) - del loaded_result # No longer needed + with torch.inference_mode(): + loaded_result = torch.load( + str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True + ) + del loaded_result # No longer needed print(f"Moving checkpoint to {model_dir / 'model.pth'}.") os.rename(consolidated_pth, model_dir / "model.pth") os.rename(tokenizer_pth, model_dir / "tokenizer.model") @@ -130,7 +129,8 @@ def load_safetensors(): state_dict = None for loader in loaders: try: - state_dict = loader() + with torch.inference_mode(): + state_dict = loader() break except Exception: continue @@ -173,7 +173,6 @@ def load_safetensors(): os.remove(file) -@torch.inference_mode() def convert_hf_checkpoint_to_tune( *, model_dir: Optional[Path] = None, diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index 005bb6ef2..2685ec2f3 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -13,18 +13,31 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple -import torch ########################################################################## ### unpack packed weights ### +class _LazyImportTorch: + """This is a wrapper around the import of torch that only performs the + import when an actual attribute is needed off of torch. + """ + @staticmethod + def __getattribute__(name: str) -> Any: + import torch + return getattr(torch, name) + + +# Alias torch to the lazy import +torch = _LazyImportTorch() + + def unpack_packed_weights( packed_weights: Dict[str, Any], packed_linear: Callable, - input_dtype: torch.dtype, + input_dtype: "torch.dtype", unpacked_dims: Tuple, -) -> torch.Tensor: +) -> "torch.Tensor": """Given a packed weight matrix `packed_weights`, a Callable implementing a packed linear function for the packed format, and the unpacked dimensions of the weights, recreate the unpacked weight @@ -169,26 +182,27 @@ def name_to_dtype(name, device): return torch.bfloat16 try: - return name_to_dtype_dict[name] + return _name_to_dtype_dict[name]() except KeyError: raise RuntimeError(f"unsupported dtype name {name} specified") def allowable_dtype_names() -> List[str]: - return name_to_dtype_dict.keys() - - -name_to_dtype_dict = { - "fp32": torch.float, - "fp16": torch.float16, - "bf16": torch.bfloat16, - "float": torch.float, - "half": torch.float16, - "float32": torch.float, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "fast": None, - "fast16": None, + return _name_to_dtype_dict.keys() + + +# NOTE: values are wrapped in lambdas to avoid proactive imports for torch +_name_to_dtype_dict = { + "fp32": lambda: torch.float, + "fp16": lambda: torch.float16, + "bf16": lambda: torch.bfloat16, + "float": lambda: torch.float, + "half": lambda: torch.float16, + "float32": lambda: torch.float, + "float16": lambda: torch.float16, + "bfloat16": lambda: torch.bfloat16, + "fast": lambda: None, + "fast16": lambda: None, }