Skip to content

Commit

Permalink
format, add group_size named arg
Browse files Browse the repository at this point in the history
  • Loading branch information
mikekgfb committed Apr 9, 2024
1 parent a6c31f0 commit 7dfaba2
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,13 +498,17 @@ def create_quantized_state_dict(self):

# we use Int4 packaged in an int8 for now, packing to follow
# return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict()
return WeightOnlyInt8QuantHandler(self.mod, bitwidth=4, group_size=self.groupsize).create_quantized_state_dict()
return WeightOnlyInt8QuantHandler(
self.mod, bitwidth=4, group_size=self.groupsize
).create_quantized_state_dict()

def _convert_for_runtime(self):
# we use Int4 packaged in an int8 for now, packing to follow
# ALSO: all code must work for CPU, CUDA, MPS
# return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True)
return WeightOnlyInt4GPTQQuantHandler(self.mod, bitwidth=4, self.groupsize).convert_for_runtime()
return WeightOnlyInt4GPTQQuantHandler(
self.mod, bitwidth=4, group_size=self.groupsize
).convert_for_runtime()

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
Expand Down

0 comments on commit 7dfaba2

Please sign in to comment.