Skip to content

Commit

Permalink
Support finetuning with LoRA (huggingface#431)
Browse files Browse the repository at this point in the history
* Add support for LoRA models

* Workaround random bug

* Pass only trainable params to the optimizer in trainer

* Add peft dependency

* Fix edge case in peft + pipelines

* Fix dropout check

* Add notebook for Whisper LoRA

* Simplify pipelines logic, requiring adapter weights to be merged in

* Remove notebook for now
  • Loading branch information
katalinic-gc authored Sep 6, 2023
1 parent efd424f commit 8c4a1dd
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 22 deletions.
45 changes: 42 additions & 3 deletions optimum/graphcore/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import poptorch
import torch
import torch.nn.functional as F
from peft import PeftModel, PeftType, get_peft_model
from torch import nn
from transformers import PreTrainedModel

Expand Down Expand Up @@ -53,9 +54,11 @@ def wrapper(cls):


def to_pipelined(model: nn.Module, ipu_config: IPUConfig, force: bool = False):
model_cls = model.__class__
model_cls = model.get_base_model().__class__ if isinstance(model, PeftModel) else model.__class__
pipelined_cls = _PRETRAINED_TO_PIPELINED_REGISTRY.get(model_cls, None)
if pipelined_cls is not None:
if pipelined_cls is not None and isinstance(model, PeftModel):
return pipelined_cls.from_peft(model, ipu_config)
elif pipelined_cls is not None:
return pipelined_cls.from_transformers(model, ipu_config)
# If the user defined his/her own model and already subclassed from PipelineMixin. I.e., the model is already pipelined.
elif isinstance(model, PipelineMixin):
Expand Down Expand Up @@ -92,9 +95,9 @@ def from_transformers(cls, model: PreTrainedModel, ipu_config: IPUConfig):
config = copy.deepcopy(model.config)
generation_config = copy.deepcopy(model.generation_config)
pipelined_model = cls(config)
pipelined_model.generation_config = generation_config
pipelined_model.load_state_dict(model.state_dict())
pipelined_model.ipu_config = copy.deepcopy(ipu_config)
pipelined_model.generation_config = generation_config
pipelined_model.training = model.training
return pipelined_model

Expand All @@ -120,6 +123,42 @@ def from_pretrained_transformers(cls, model_name_or_path: str, ipu_config: IPUCo
pipelined_model.ipu_config = copy.deepcopy(ipu_config)
return pipelined_model

@classmethod
def from_peft(cls, model: PeftModel, ipu_config: IPUConfig):
"""
Creates a pipelined version of model from a [`~peft.PeftModel`] instance.
Currently, only `peft.PeftType.LORA` is supported.
Args:
model ([`~peft.PeftModel`]):
The model to convert to a pipelined model.
ipu_config ([`IPUConfig`]):
The `IPUConfig` instance of the pipelined model.
Returns:
An instance of `peft.PeftModel` wrapping a pipelined version of the base model.
"""
# Technically speaking, instead of returning an instance of a `PipelineMixin`, such as Pipelined<Model>For<Task>,
# we return an instance of a `peft.PeftModel` which wraps such a pipelined model and defers attribute access.
if len(model.peft_config) > 1 or model.active_adapter != "default":
raise ValueError("Currently only `PeftModel` instances with the `'default'` adapter are supported.")
if model.peft_type != PeftType.LORA:
raise ValueError(f"Currently only LoRA is supported, received {model.peft_type}.")

pretrained = model.get_base_model()
config = copy.deepcopy(pretrained.config)
generation_config = copy.deepcopy(pretrained.generation_config)
peft_config = model.active_peft_config

pipelined_model = cls(config)
pipelined_model.ipu_config = copy.deepcopy(ipu_config)
pipelined_model.generation_config = generation_config
peft_pipelined_model = get_peft_model(pipelined_model, peft_config)
peft_pipelined_model.load_state_dict(model.state_dict())
peft_pipelined_model.training = model.training
return peft_pipelined_model

@classmethod
def from_model(cls, model: nn.Module):
clone = copy.deepcopy(model)
Expand Down
7 changes: 5 additions & 2 deletions optimum/graphcore/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def forward(
if layer_head_mask is not None:
raise ValueError("layer_head_mask is not supported yet with serialized attention.")

if self.dropout or self.training:
if self.dropout and self.training:
raise ValueError("dropout is not supported yet with serialized attention.")

if attention_mask is not None:
Expand Down Expand Up @@ -594,10 +594,13 @@ def parallelize(self, for_generation=False, use_cache=False, use_cross_cache=Fal
)
logger.info(f"Decoder Embedding --> IPU {decoder_embedding_ipu}")

prev_ipu = decoder_layer_ipu[0]
for index, (layer, ipu) in enumerate(zip(self.model.decoder.layers, decoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
if ipu != prev_ipu:
self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
prev_ipu = ipu
logger.info(f"Decoder {index:<2} --> IPU {ipu}")

self.model.decoder.layer_norm = poptorch.BeginBlock(
Expand Down
6 changes: 6 additions & 0 deletions optimum/graphcore/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import poptorch
import torch
import transformers.pipelines
from peft import PeftModel
from transformers import (
AudioClassificationPipeline,
AutoFeatureExtractor,
Expand Down Expand Up @@ -375,6 +376,11 @@ def pipeline(
break
except ValueError:
continue
elif isinstance(model, PeftModel):
raise TypeError(
"Instead of providing `model` as an instance of `PeftModel`, please call `merge_and_unload()` if LoRA "
"or equivalent to obtain the original `PreTrainedModel` back with adapter weights merged in."
)
elif isinstance(model, PreTrainedModel):
if tokenizer is None and load_tokenizer:
raise ValueError("If you pass a model as a PreTrainedModel, you must pass a tokenizer as well")
Expand Down
83 changes: 66 additions & 17 deletions optimum/graphcore/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import torch
from huggingface_hub import Repository
from packaging import version
from peft import PeftModel
from poptorch import DataLoaderMode, PoplarExecutor
from poptorch.optim import LAMB, AdamW
from torch import nn, optim
Expand Down Expand Up @@ -125,6 +126,9 @@
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

# TODO: Import from transformers.utils when updating transformers version.
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"


@dataclass
class IPUTrainerState(TrainerState):
Expand Down Expand Up @@ -841,20 +845,24 @@ def create_optimizer(self):
bias_parameters = {n for n, _ in self.model.named_parameters() if "bias" in n}
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
"params": [
p for n, p in self.model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
# Disable LAMB updates for bias parameters
"params": [p for n, p in self.model.named_parameters() if n in bias_parameters],
"params": [
p for n, p in self.model.named_parameters() if (n in bias_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"max_weight_norm": 0.0,
},
{
"params": [
p
for n, p in self.model.named_parameters()
if n not in decay_parameters and n not in bias_parameters
if n not in decay_parameters and n not in bias_parameters and p.requires_grad
],
"weight_decay": 0.0,
},
Expand All @@ -868,11 +876,17 @@ def create_optimizer(self):
else:
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
"params": [
p for n, p in self.model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
"params": [
p
for n, p in self.model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
Expand Down Expand Up @@ -1326,15 +1340,25 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
if model is None:
model = self.model

if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)

if not any(
os.path.isfile(f)
for f in [
weights_file,
weights_index_file,
adapter_weights_file,
]
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

logger.info(f"Loading model from {resume_from_checkpoint}.")

if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
if os.path.isfile(config_file):
config = PretrainedConfig.from_json_file(config_file)
checkpoint_version = config.transformers_version
if checkpoint_version is not None and checkpoint_version != __version__:
logger.warning(
Expand All @@ -1343,23 +1367,46 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
"yield to errors or unwanted behavior."
)

if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
if os.path.isfile(weights_file):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
state_dict = torch.load(weights_file, map_location="cpu")
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
# release memory
del state_dict
self._issue_warnings_after_load(load_result)

# Load adapters following PR # 24096 (> 4.29.2)
elif isinstance(model, PeftModel):
# If training a model using PEFT & LoRA, assume that adapter has been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint):
model.load_adapter(resume_from_checkpoint, model.active_adapter)
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
else:
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")

def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = self.model
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
self._load_state_dict_in_model(state_dict)
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
if os.path.exists(best_model_path) or os.path.exists(best_adapter_model_path):
if isinstance(model, PeftModel):
# If training a model using PEFT & LoRA, assume that adapter has been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(best_adapter_model_path):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
self._load_state_dict_in_model(state_dict)
else:
logger.warning(
f"Could not locate the best model at {best_model_path}. If you are running a distributed training "
Expand Down Expand Up @@ -1677,8 +1724,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):

# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `transformers.PreTrainedModel`, only saving its state dict.")
if not isinstance(self.model, (PreTrainedModel, PeftModel)):
logger.info(
"Trainer.model is not a `transformers.PreTrainedModel` or `peft.PeftModel`, only saving its state dict."
)
if state_dict is None:
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"optimum==1.6.1",
"diffusers[torch]==0.12.1",
"cppimport==22.8.2",
"peft==0.3.0",
"datasets",
"tokenizers",
"typeguard",
Expand Down

0 comments on commit 8c4a1dd

Please sign in to comment.