Skip to content

Commit

Permalink
Bump torchtune pin (#1157)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Khuu authored Sep 17, 2024
1 parent b0c933c commit f730056
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ PYTORCH_NIGHTLY_VERSION=dev20240814
VISION_NIGHTLY_VERSION=dev20240814

# Nightly version for torchtune
TUNE_NIGHTLY_VERSION=dev20240910
TUNE_NIGHTLY_VERSION=dev20240916


# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
Expand Down
8 changes: 4 additions & 4 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,22 +727,22 @@ def chat(
if generator_args.image_prompts is not None:
print("Image prompts", generator_args.image_prompts)

# Support for just the first image prompt for now
images = [Image.open(generator_args.image_prompts[0])]
messages = [
Message(
role="user",
content=[
{"type": "image"},
{"type": "image", "content": images[0]},
{"type": "text", "content": generator_args.prompt},
],
eot=True,
),
Message(role="assistant", content=""),
]

images = [Image.open(generator_args.image_prompts[0])]
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))

data = transform({"images": images, "messages": messages}, inference=True)
data = transform({"messages": messages}, inference=True)
batch = padded_collate([data], self.builder_args.device)
batch.pop("mask")
encoded = batch["tokens"]
Expand Down

0 comments on commit f730056

Please sign in to comment.