diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 9c7e10d03..193e5c7bd 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -173,4 +173,4 @@ jobs: echo "::group::Run inference" bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cuda" "aoti" - echo "::endgroup::" \ No newline at end of file + echo "::endgroup::" diff --git a/.github/workflows/test_torchchat_commands.yml b/.github/workflows/test_torchchat_commands.yml index d6bc44ddc..dc0a17588 100644 --- a/.github/workflows/test_torchchat_commands.yml +++ b/.github/workflows/test_torchchat_commands.yml @@ -54,9 +54,4 @@ jobs: cat ./output_eager1 cat ./output_eager2 echo "Tests complete." - - - name: Test download - run: | - python torchchat.py download mistral-7b-instruct - python torchchat.py generate mistral-7b-instruct - \ No newline at end of file + \ No newline at end of file diff --git a/build/builder.py b/build/builder.py index 5201beea1..1d603c624 100644 --- a/build/builder.py +++ b/build/builder.py @@ -35,7 +35,7 @@ class BuilderArgs: setup_caches: bool = False use_tp: bool = False is_chat_model: bool = False - + def __post_init__(self): if not ( (self.checkpoint_path and self.checkpoint_path.is_file()) @@ -82,15 +82,15 @@ def from_args(cls, args): # -> BuilderArgs: args.checkpoint_dir, args.dso_path, args.pte_path, - args.gguf_path + args.gguf_path, ]: path = str(path) - if path.endswith('/'): + if path.endswith("/"): path = path[:-1] path_basename = os.path.basename(path) if "chat" in path_basename: is_chat_model = True - + return cls( checkpoint_path=checkpoint_path, checkpoint_dir=args.checkpoint_dir, diff --git a/build/model.py b/build/model.py index d233ad150..4edecc229 100644 --- a/build/model.py +++ b/build/model.py @@ -96,7 +96,11 @@ def from_name(cls, name: str): # Aliases for well-known models. Maps a short name to a HuggingFace path. These # can be used from the CLI in-place of the full model path. -model_aliases = {"llama2": "meta-llama/Llama-2-7b-chat-hf"} +model_aliases = { + "llama2": "meta-llama/Llama-2-7b-chat-hf", + "llama2-7": "meta-llama/Llama-2-7b-chat-hf", + "mistral-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.2", +} transformer_configs = { "CodeLlama-7b-Python-hf": { diff --git a/cli.py b/cli.py index 6c8b92544..2f0f49a56 100644 --- a/cli.py +++ b/cli.py @@ -222,12 +222,6 @@ def _add_arguments_common(parser): default=".model-artifacts", help="The directory to store downloaded model artifacts", ) - parser.add_argument( - "--chat", - action="store_true", - help="Use torchchat to for an interactive chat session.", - ) - def arg_init(args): diff --git a/generate.py b/generate.py index bbcc5179f..709060b5d 100644 --- a/generate.py +++ b/generate.py @@ -30,6 +30,7 @@ B_INST, E_INST = "[INST]", "[/INST]" + @dataclass class GeneratorArgs: prompt: str = "torchchat is pronounced torch-chat and is so cool because" @@ -346,14 +347,16 @@ def _main( is_speculative = speculative_builder_args.checkpoint_path is not None if generator_args.chat_mode and not builder_args.is_chat_model: - print(""" + print( + """ ******************************************************* This model is not known to support the chat function. We will enable chat mode based on your instructions. If the model is not trained to support chat, it will produce nonsensical or false output. ******************************************************* - """) + """ + ) # raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.") tokenizer = _initialize_tokenizer(tokenizer_args)