Skip to content

Commit

Permalink
Add mistral-7b-instruct alias, fix lints
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer committed Apr 17, 2024
1 parent 0a37583 commit 3f6eb29
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,4 @@ jobs:
echo "::group::Run inference"
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cuda" "aoti"
echo "::endgroup::"
echo "::endgroup::"
7 changes: 1 addition & 6 deletions .github/workflows/test_torchchat_commands.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,4 @@ jobs:
cat ./output_eager1
cat ./output_eager2
echo "Tests complete."
- name: Test download
run: |
python torchchat.py download mistral-7b-instruct
python torchchat.py generate mistral-7b-instruct
8 changes: 4 additions & 4 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class BuilderArgs:
setup_caches: bool = False
use_tp: bool = False
is_chat_model: bool = False

def __post_init__(self):
if not (
(self.checkpoint_path and self.checkpoint_path.is_file())
Expand Down Expand Up @@ -82,15 +82,15 @@ def from_args(cls, args): # -> BuilderArgs:
args.checkpoint_dir,
args.dso_path,
args.pte_path,
args.gguf_path
args.gguf_path,
]:
path = str(path)
if path.endswith('/'):
if path.endswith("/"):
path = path[:-1]
path_basename = os.path.basename(path)
if "chat" in path_basename:
is_chat_model = True

return cls(
checkpoint_path=checkpoint_path,
checkpoint_dir=args.checkpoint_dir,
Expand Down
6 changes: 5 additions & 1 deletion build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def from_name(cls, name: str):

# Aliases for well-known models. Maps a short name to a HuggingFace path. These
# can be used from the CLI in-place of the full model path.
model_aliases = {"llama2": "meta-llama/Llama-2-7b-chat-hf"}
model_aliases = {
"llama2": "meta-llama/Llama-2-7b-chat-hf",
"llama2-7": "meta-llama/Llama-2-7b-chat-hf",
"mistral-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.2",
}

transformer_configs = {
"CodeLlama-7b-Python-hf": {
Expand Down
6 changes: 0 additions & 6 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,6 @@ def _add_arguments_common(parser):
default=".model-artifacts",
help="The directory to store downloaded model artifacts",
)
parser.add_argument(
"--chat",
action="store_true",
help="Use torchchat to for an interactive chat session.",
)


def arg_init(args):

Expand Down
7 changes: 5 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

B_INST, E_INST = "[INST]", "[/INST]"


@dataclass
class GeneratorArgs:
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
Expand Down Expand Up @@ -346,14 +347,16 @@ def _main(
is_speculative = speculative_builder_args.checkpoint_path is not None

if generator_args.chat_mode and not builder_args.is_chat_model:
print("""
print(
"""
*******************************************************
This model is not known to support the chat function.
We will enable chat mode based on your instructions.
If the model is not trained to support chat, it will
produce nonsensical or false output.
*******************************************************
""")
"""
)
# raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")

tokenizer = _initialize_tokenizer(tokenizer_args)
Expand Down

0 comments on commit 3f6eb29

Please sign in to comment.