Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Initial PR for generating and loading state dict #1329

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class BuilderArgs:
dynamic_shapes: bool = False
max_seq_length: Optional[int] = None

state_dict_path: Optional[Union[Path, str]] = None

def __post_init__(self):
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -185,6 +187,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
is_chat_model=is_chat_model,
dynamic_shapes=getattr(args, "dynamic_shapes", False),
max_seq_length=getattr(args, "max_seq_length", None),
state_dict_path=args.state_dict_path,
)

@classmethod
Expand Down Expand Up @@ -579,26 +582,47 @@ def _initialize_model(
model = _load_model(builder_args)
device_sync(device=builder_args.device)

if quantize:
print(f"Quantizing the model with: {quantize}")
with measure_time("Time to quantize model: {time:.02f} seconds"):
quantize_model(
model,
builder_args.device,
quantize,
tokenizer,
support_tensor_subclass,
)
device_sync(device=builder_args.device)
state_dict_path = builder_args.state_dict_path
state_dict_exists: bool = state_dict_path and os.path.isfile(state_dict_path)
if quantize or state_dict_exists:

if builder_args.setup_caches:
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1,
max_seq_length=max_seq_length
or model.text_transformer_args.max_seq_length,
if quantize and state_dict_exists:
print(
"WARNING: Both a state_dict and quantize arg were provided; Ignoring quantize arg"
)

if state_dict_exists:
with measure_time("Time to load quantized state: {time:.02f} seconds"):
print(f"Loading the model_state in: {state_dict_path}")
model.load_state_dict(state_dict_path)
device_sync(device=builder_args.device)
else:
with measure_time("Time to quantize model: {time:.02f} seconds"):
print(f"Quantizing the model with: {quantize}")
quantize_model(
model,
builder_args.device,
quantize,
tokenizer,
support_tensor_subclass,
)
device_sync(device=builder_args.device)

if state_dict_path:
with measure_time(
"Time to save quantized state: {time:.02f} seconds"
):
print(f"Saving the quantized state dict")
torch.save(model.state_dict(), state_dict_path)

if builder_args.setup_caches:
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1,
max_seq_length=max_seq_length
or model.text_transformer_args.max_seq_length,
)

model.to(dtype=builder_args.precision)

print("-----------------------------------------------------------")
Expand Down
6 changes: 6 additions & 0 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ def _add_model_config_args(parser, verb: str) -> None:
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
)

model_config_parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after migration we shouldn't need anything special for quantized model right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow? The things I'm testing out should work for tensor subclass right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I meant that with tensor subclass API, quantized checkpoint should be able to be loaded the same way as normal checkpoint.

the code path of loading a quantized model v.s. quantizing model on the fly might still make sense though, maybe just need to change the naming or something

"--state-dict-path",
type=str,
default=None,
help="Model state dict to load (if path exists) or write out to (if path doesn't exist). Supercedes --quantize arg.",
)
model_config_parser.add_argument(
"--dtype",
default="fast",
Expand Down
Loading