Skip to content

Commit

Permalink
Gguf cleanup (#230)
Browse files Browse the repository at this point in the history
* clean up gguf loading.  Move model loading to meta.

* remove cpu

* Fix CI and validation scripts (#154)

* missing device (#232)

* Use generator args to group all arguments to generator (#231)

* prompt

* chat_mode, num_samples

* Move more generator args to use dataclass (#233)

* prompt

* chat_mode, num_samples

* move more args

* more gen args

* update

* args

* undo some changes

* typos

* Minor lint fixes (#236)

* remove redundancy & remove int4 linear test from ET tests (#237)

* remove redundancy

* no int4 linear on ET

* small changes

---------

Co-authored-by: Guang Yang <[email protected]>
Co-authored-by: Michael Gschwind <[email protected]>
Co-authored-by: Mergen Nachin <[email protected]>
  • Loading branch information
4 people authored and malfet committed Jul 17, 2024
1 parent 467e387 commit 11afeff
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 175 deletions.
28 changes: 14 additions & 14 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,13 @@ def device_sync(device):
sys.path.append(str(wd))


def _load_model(builder_args):
if builder_args.gguf_path:
model = Transformer.from_gguf(builder_args.gguf_path)

# TODO: to take advantage of mmap, maybe we write converted gguf to file
# and read back in?
# TODO: should we add check that builder_args.precision is aligned with quant scheme, e.g., bfloat16
# is needed for int4
model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()
else:
return _load_model_not_gguf(builder_args)
def _load_model_gguf(builder_args):
assert builder_args.gguf_path
model = Transformer.from_gguf(builder_args.gguf_path)
return model


def _load_model_not_gguf(builder_args):
def _load_model_default(builder_args):
assert not builder_args.gguf_path

with torch.device("meta"):
Expand Down Expand Up @@ -218,9 +210,17 @@ def _load_model_not_gguf(builder_args):

model.load_state_dict(checkpoint, assign=True, strict=False)

return model


def _load_model(builder_args):
if builder_args.gguf_path:
model = _load_model_gguf(builder_args)
else:
model = _load_model_default(builder_args)

if builder_args.use_tp:
from tp import apply_tp

print("Applying tensor parallel to model ...")
apply_tp(model)

Expand Down
247 changes: 89 additions & 158 deletions build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,69 +17,22 @@
import torch
import torch.nn as nn

wd = Path(__file__).parent.resolve()
sys.path.append(str(wd))

from gguf import GGUFValueType, ReaderTensor
from quantize import (
group_dequantize_tensor_from_qparams,
pack_scales_and_zeros,
WeightOnlyInt4Linear,
)

from build.gguf_util import F16, F32, Q4_0, Q6_K

wd = Path(__file__).parent.resolve()
sys.path.append(str(wd))

from build.gguf_util import F16, F32, Q4_0, Q6_K, to_float
from model import ModelArgs, Transformer

logger: logging.Logger = logging.getLogger(__name__)


@dataclass
class AttentionArgs:
head_count: int
head_count_kv: int
layer_norm_rms_epsilon: float


@dataclass
class RopeArgs:
dimension_count: int | None = None
freq_base: float | None = None


@dataclass
class GGUFModelArgs:
arch: str
embedding_length: int
block_count: int
feed_forward_length: int
vocab_size: int
attention: AttentionArgs
rope: RopeArgs


@dataclass
class GGUFWeights:
tensors: list[ReaderTensor]


def _create_pt_model(
gguf_model_args: GGUFModelArgs,
) -> nn.Module:
llama_model_args = ModelArgs(
dim=gguf_model_args.embedding_length,
n_layers=gguf_model_args.block_count,
n_heads=gguf_model_args.attention.head_count,
n_local_heads=gguf_model_args.attention.head_count_kv,
vocab_size=gguf_model_args.vocab_size,
norm_eps=gguf_model_args.attention.layer_norm_rms_epsilon,
hidden_dim=gguf_model_args.feed_forward_length,
)
pt_model = Transformer(llama_model_args)
pt_model.eval()
return pt_model


_name_replacements = [
("blk", "layers"),
("token_embd", "tok_embeddings"),
Expand All @@ -102,29 +55,6 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str:
return result


def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs:
arch = metadata["general.architecture"]
assert (
arch == "llama"
), f"Only general.architecture=llama is supported, but got general.architecture={arch}"
return GGUFModelArgs(
arch=arch,
embedding_length=metadata[f"{arch}.embedding_length"],
block_count=metadata[f"{arch}.block_count"],
feed_forward_length=metadata[f"{arch}.feed_forward_length"],
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
attention=AttentionArgs(
head_count=metadata[f"{arch}.attention.head_count"],
head_count_kv=metadata[f"{arch}.attention.head_count_kv"],
layer_norm_rms_epsilon=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
),
rope=RopeArgs(
freq_base=metadata.get(f"{arch}.rope.freq_base", None),
dimension_count=metadata.get(f"{arch}.rope.dimension_count", None),
),
)


def _fqn_lookup(fqn: str, module: torch.nn.Module) -> Any:
if fqn == "":
return module
Expand Down Expand Up @@ -153,74 +83,6 @@ def _fqn_last(fqn: str) -> str:
return atoms[-1]


def load_weights(
pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor], inner_k_tiles=8
) -> None:
fqns = []
for fqn in pt_model.state_dict():
assert _fqn_last(fqn) == "weight"
fqns.append(_fqn_up(fqn))

state_dict = {}
for fqn in fqns:
mod = _fqn_lookup(fqn, pt_model)

t = weight_map[f"{fqn}.weight"]

if (
isinstance(mod, torch.nn.Linear)
and t.tensor_type == gguf.GGMLQuantizationType.Q4_0
):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
assert all(t.shape == (in_features, out_features))

q, s, z = Q4_0.unpack(t)
scales_and_zeros = pack_scales_and_zeros(s, z)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
q, inner_k_tiles
)

state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")

parent = _fqn_lookup(_fqn_up(fqn), pt_model)
setattr(
parent,
_fqn_last(fqn),
WeightOnlyInt4Linear(
"cpu", # TODO: should --device work for gguf load? (yes?!)
in_features,
out_features,
bias=False,
groupsize=Q4_0.groupsize,
inner_k_tiles=inner_k_tiles,
),
)
else:
# All other weights are dequantized to float
if t.tensor_type == gguf.GGMLQuantizationType.Q4_0:
as_float = group_dequantize_tensor_from_qparams(
*Q4_0.unpack(t), Q4_0.n_bit, Q4_0.groupsize
)
elif t.tensor_type == gguf.GGMLQuantizationType.Q6_K:
as_float = group_dequantize_tensor_from_qparams(
*Q6_K.unpack(t), Q6_K.n_bit, Q6_K.groupsize
)
elif t.tensor_type == gguf.GGMLQuantizationType.F16:
as_float = F16.unpack(t)
elif t.tensor_type == gguf.GGMLQuantizationType.F32:
as_float = F32.unpack(t)
else:
raise ValueError(f"Unsupported tensor type {t.tensor_type}")

state_dict[f"{fqn}.weight"] = as_float.to("cpu")

pt_model.load_state_dict(state_dict)
return pt_model


def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
metadata: dict[str, Any] = {}

Expand All @@ -244,34 +106,103 @@ def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
return metadata


def load_llama_from_gguf_file(gguf_file: str) -> torch.nn.Module:
def load_model(gguf_file: str) -> torch.nn.Module:
"""
Load a LLaMa model from a GGUF file and return a PT nn.Module.
Parses the GGUF file and returns an nn.Module on meta device.
"""
if not Path(gguf_file).is_file():
raise ValueError(f"Could not find file {gguf_file}")

logger.info("Parsing GGUF metadata.")
reader = gguf.GGUFReader(gguf_file, "r")
metadata = _get_metadata(reader)
model_args = _build_model_args(metadata)

arch = metadata["general.architecture"]
assert (
model_args.arch == "llama"
arch == "llama"
), "Only LLaMa models are supported by this converter."

logger.info("Creating initial PT model.")
pt_model = _create_pt_model(model_args)
model_args = ModelArgs(
dim=metadata[f"{arch}.embedding_length"],
n_layers=metadata[f"{arch}.block_count"],
n_heads=metadata[f"{arch}.attention.head_count"],
n_local_heads=metadata[f"{arch}.attention.head_count_kv"],
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
hidden_dim=metadata[f"{arch}.feed_forward_length"],
)

logger.info("Reading GGUF weights.")
gguf_weights = GGUFWeights(tensors=reader.tensors)
# TODO: what to do with rope args like
# metadata.get(f"{arch}.rope.freq_base", None)
# metadata.get(f"{arch}.rope.dimension_count", None)

logger.info("Building GGUF weight map.")
# map from fqn in pt_model to gguf tensor
with torch.device("meta"):
model = Transformer(model_args)
return model


def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_k_tiles = 8) -> torch.nn.Module:
"""
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
that can be loaded into it.
When load_as_quantized, the method tries to preserve the GGUF quantization when it
is natively supported by PyTorch, otherwise it converts quantized tensors to FP32.
"""

model = load_model(gguf_file)

reader = gguf.GGUFReader(gguf_file, "r")
weight_map = {
_convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor
for tensor in gguf_weights.tensors
for tensor in reader.tensors
}

logger.info("Loading weights into state_dict")
pt_model = load_weights(pt_model, weight_map, inner_k_tiles=8)
return pt_model
state_dict = {}
for fqn in weight_map:
assert _fqn_last(fqn) == "weight"
fqn = _fqn_up(fqn)

mod = _fqn_lookup(fqn, model)
t = weight_map[f"{fqn}.weight"]

if (
isinstance(mod, torch.nn.Linear)
and t.tensor_type == gguf.GGMLQuantizationType.Q4_0
and load_as_quantized
):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
assert all(t.shape == (in_features, out_features))

q, s, z = Q4_0.unpack(t)
scales_and_zeros = pack_scales_and_zeros(s, z)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
q, inner_k_tiles
)

state_dict[f"{fqn}.weight"] = weight_int4pack
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros

parent = _fqn_lookup(_fqn_up(fqn), model)
setattr(
parent,
_fqn_last(fqn),
WeightOnlyInt4Linear(
"meta",
in_features,
out_features,
bias=False,
groupsize=Q4_0.groupsize,
inner_k_tiles=inner_k_tiles,
),
)
else:
state_dict[f"{fqn}.weight"] = to_float(t)

return model, state_dict


def load_llama_from_gguf_file(gguf_file: str) -> torch.nn.Module:
model, state_dict = load_model_and_state_dict(gguf_file, load_as_quantized=True)
model.load_state_dict(state_dict, assign=True)
return model
6 changes: 3 additions & 3 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ def from_params(cls, params_path: str):

@classmethod
def from_gguf(cls, gguf_path: str):
from build.gguf_loader import load_llama_from_gguf_file

model = load_llama_from_gguf_file(gguf_path)
from build.gguf_loader import load_model_and_state_dict
model, state_dict = load_model_and_state_dict(gguf_path, load_as_quantized=True, inner_k_tiles=8)
model.load_state_dict(state_dict, assign=True)
return model


Expand Down

0 comments on commit 11afeff

Please sign in to comment.