Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 11, 2024
1 parent f29d7ad commit a7021d3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 28 deletions.
63 changes: 36 additions & 27 deletions src/brevitas_examples/llm/llm_quant/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,79 +4,88 @@
"""

from copy import deepcopy
from time import sleep

from accelerate.utils.operations import send_to_device
import torch
from tqdm import tqdm

from brevitas.graph.gptq import gptq_mode
from accelerate.utils.operations import send_to_device
from brevitas.graph.gpxq import StopFwdException
from brevitas.utils.python_utils import recurse_getattr


@torch.no_grad()
def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None):
if True:
blocks = model.model.layers #getattr(model, block_name)
def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None):
if block_name is not None:
cache_state = model.config.use_cache
model.config.use_cache = False
blocks = recurse_getattr(model, block_name)
first_block = blocks[0]
cached_args, cached_kwargs = [], []

# Intercept input to first block
def intercept_input(module, args, kwargs):
args = send_to_device(args, 'cpu')
kwargs = send_to_device(kwargs, 'cpu')
cached_args.append(args)
cached_kwargs.append(kwargs)
raise RuntimeError
raise StopFwdException

# Intercept output from block N-1 to set it as input to block N
def intercept_output(module, args, kwargs, output):
if isinstance(output, tuple):
output = output[0]
output = send_to_device(output, 'cpu')
cached_args.append((output,))
raise RuntimeError

raise StopFwdException

# Collect input to first block
hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True)
for inps in dataloader:
try:
model(**inps)
except:
except StopFwdException:
pass
hook.remove()


# Iterate through all the blocks
for index, block in enumerate(tqdm(blocks)):

with gptq_mode(block,
use_quant_activations=False,
group_of_parallel_layers=group_of_parallel_layers,
act_order=act_order,
create_weight_orig=False) as gptq:
use_quant_activations=False,
group_of_parallel_layers=group_of_parallel_layers,
act_order=act_order,
create_weight_orig=False) as gptq:
for _ in tqdm(range(gptq.num_layers)):
for args, kwargs in zip(cached_args, cached_kwargs):
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
args = send_to_device(args, 'cpu')
kwargs = send_to_device(kwargs, 'cpu')
gptq.update()
past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs)
cached_args = []

if index < len(blocks)-1:
hook = blocks[index].register_forward_hook(intercept_output, with_kwargs=True)
if index < len(blocks) - 1:
# Once the block is done, we need to update the input to the next block
past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs)
cached_args = []
hook = block.register_forward_hook(intercept_output, with_kwargs=True)
for args, kwargs in zip(past_cached_args, past_cached_kwargs):
try:
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
args = send_to_device(args, 'cpu')
kwargs = send_to_device(kwargs, 'cpu')
except Exception as e:
except StopFwdException:
pass
hook.remove()

# Restore cache state
model.config.use_cache = cache_state

else:
with gptq_mode(model,
use_quant_activations=False,
group_of_parallel_layers=group_of_parallel_layers,
act_order=act_order,
create_weight_orig=False) as gptq:
use_quant_activations=False,
group_of_parallel_layers=group_of_parallel_layers,
act_order=act_order,
create_weight_orig=False) as gptq:
gptq_model = gptq.model
for _ in tqdm(range(gptq.num_layers)):
for inps in dataloader:
Expand Down
8 changes: 7 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def validate(args):
def main(args):
validate(args)
set_seed(args.seed)

if args.export_prefix is None:
args.export_prefix = f"{args.model.replace('/', '--')}"

Expand Down Expand Up @@ -325,6 +324,13 @@ def parse_args(args):
choices=['wikitext2', 'c4'],
default='wikitext2',
help='Dataset to use for quantization (default: %(default)s)')
parser.add_argument(
'--gptq-block-name',
type=str,
default=None,
help=
'Block name for faster GPTQ optimization. It works only if FX is not needed (default: %(default)s)'
)
parser.add_argument(
'--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.')
parser.add_argument(
Expand Down

0 comments on commit a7021d3

Please sign in to comment.