-
Notifications
You must be signed in to change notification settings - Fork 507
/
orpo_training.py
523 lines (494 loc) · 22.4 KB
/
orpo_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description: Train a model from base model using ORPO
"""
import os
from dataclasses import dataclass, field
from glob import glob
from typing import Dict, Optional
import torch
from datasets import load_dataset
from loguru import logger
from peft import LoraConfig, TaskType
from transformers import (
AutoConfig,
BloomForCausalLM,
AutoModelForCausalLM,
AutoModel,
LlamaForCausalLM,
BloomTokenizerFast,
AutoTokenizer,
HfArgumentParser,
BitsAndBytesConfig,
)
from transformers.deepspeed import is_deepspeed_zero3_enabled
from trl import ORPOConfig, ORPOTrainer
from template import get_conv_template
os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
MODEL_CLASSES = {
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
"llama": (AutoConfig, LlamaForCausalLM, AutoTokenizer),
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
}
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with DPO
"""
# Model arguments
model_type: str = field(
default=None,
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
)
model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The model checkpoint for weights initialization."}
)
tokenizer_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The tokenizer for weights initialization."}
)
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
load_in_4bit: bool = field(default=False, metadata={"help": "Whether to load the model in 4bit mode or not."})
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=False,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
device_map: Optional[str] = field(
default="auto",
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
)
trust_remote_code: bool = field(
default=True,
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
)
# Dataset arguments
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
template_name: Optional[str] = field(default="vicuna", metadata={"help": "The prompt template name."})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "Train batch size per device"})
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "Eval batch size per device"})
max_source_length: Optional[int] = field(default=2048, metadata={"help": "Max length of prompt input text"})
max_target_length: Optional[int] = field(default=512, metadata={"help": "Max length of output text"})
min_target_length: Optional[int] = field(default=4, metadata={"help": "Min length of output text"})
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=1,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
preprocessing_num_workers: Optional[int] = field(
default=4, metadata={"help": "The number of processes to use for the preprocessing."},
)
# Training arguments
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"})
target_modules: Optional[str] = field(default="all", metadata={"help": "The target modules for peft"})
lora_rank: Optional[int] = field(default=8)
lora_dropout: Optional[float] = field(default=0.05)
lora_alpha: Optional[float] = field(default=16.0)
peft_path: Optional[str] = field(default=None)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the validation set."})
beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for DPO loss"})
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "Learning rate"})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "The lr scheduler type"})
warmup_steps: Optional[int] = field(default=100, metadata={"help": "The number of warmup steps"})
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "The weight decay"})
optim: Optional[str] = field(default="adamw_hf", metadata={"help": "The optimizer type"})
fp16: Optional[bool] = field(default=True, metadata={"help": "Whether to use fp16"})
bf16: Optional[bool] = field(default=False, metadata={"help": "Whether to use bf16"})
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "Whether to use gradient checkpointing"}
)
gradient_accumulation_steps: Optional[int] = field(
default=4, metadata={"help": "The number of gradient accumulation steps"}
)
save_steps: Optional[int] = field(default=50, metadata={"help": "X steps to save the model"})
eval_steps: Optional[int] = field(default=50, metadata={"help": "X steps to evaluate the model"})
logging_steps: Optional[int] = field(default=1, metadata={"help": "X steps to log the model"})
output_dir: Optional[str] = field(default="outputs-dpo", metadata={"help": "The output directory"})
max_steps: Optional[int] = field(default=200, metadata={"help": "Number of steps to train"})
eval_strategy: Optional[str] = field(default="steps", metadata={"help": "Evaluation strategy"})
remove_unused_columns: Optional[bool] = field(
default=False,
metadata={"help": "Remove unused columns from the dataset if `datasets.Dataset` is used"},
)
report_to: Optional[str] = field(default="tensorboard", metadata={"help": "Report to wandb or tensorboard"})
orpo_beta: float = field(
default=0.1,
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
)
def __post_init__(self):
if self.model_type is None:
raise ValueError("You must specify a valid model_type to run training.")
if self.model_name_or_path is None:
raise ValueError("You must specify a valid model_name_or_path to run training.")
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def find_all_linear_names(peft_model, int4=False, int8=False):
"""Find all linear layer names in the model. reference from qlora paper."""
cls = torch.nn.Linear
if int4 or int8:
import bitsandbytes as bnb
if int4:
cls = bnb.nn.Linear4bit
elif int8:
cls = bnb.nn.Linear8bitLt
lora_module_names = set()
for name, module in peft_model.named_modules():
if isinstance(module, cls):
# last layer is not add to lora_module_names
if 'lm_head' in name:
continue
if 'output_layer' in name:
continue
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
return sorted(lora_module_names)
def main():
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
logger.info(f"Parse args: {args}")
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
if args.model_type == 'bloom':
args.use_fast_tokenizer = True
# Load tokenizer
tokenizer_kwargs = {
"cache_dir": args.cache_dir,
"use_fast": args.use_fast_tokenizer,
"trust_remote_code": args.trust_remote_code,
}
tokenizer_name_or_path = args.tokenizer_name_or_path
if not tokenizer_name_or_path:
tokenizer_name_or_path = args.model_name_or_path
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
prompt_template = get_conv_template(args.template_name)
if tokenizer.eos_token_id is None:
tokenizer.eos_token = prompt_template.stop_str # eos token is required
tokenizer.add_special_tokens({"eos_token": tokenizer.eos_token})
logger.info(f"Add eos_token: {tokenizer.eos_token}, eos_token_id: {tokenizer.eos_token_id}")
if tokenizer.bos_token_id is None:
tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token})
tokenizer.bos_token_id = tokenizer.eos_token_id
logger.info(f"Add bos_token: {tokenizer.bos_token}, bos_token_id: {tokenizer.bos_token_id}")
if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Add pad_token: {tokenizer.pad_token}, pad_token_id: {tokenizer.pad_token_id}")
logger.debug(f"Tokenizer: {tokenizer}")
# Get datasets
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[:{args.validation_split_percentage}%]",
cache_dir=args.cache_dir,
)
raw_datasets["train"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[{args.validation_split_percentage}%:]",
cache_dir=args.cache_dir,
)
else:
data_files = {}
if args.train_file_dir is not None and os.path.exists(args.train_file_dir):
train_data_files = glob(f'{args.train_file_dir}/**/*.json', recursive=True) + glob(
f'{args.train_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"train files: {', '.join(train_data_files)}")
data_files["train"] = train_data_files
if args.validation_file_dir is not None and os.path.exists(args.validation_file_dir):
eval_data_files = glob(f'{args.validation_file_dir}/**/*.json', recursive=True) + glob(
f'{args.validation_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"eval files: {', '.join(eval_data_files)}")
data_files["validation"] = eval_data_files
raw_datasets = load_dataset(
'json',
data_files=data_files,
cache_dir=args.cache_dir,
)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
'json',
data_files=data_files,
split=f"train[:{args.validation_split_percentage}%]",
cache_dir=args.cache_dir,
)
raw_datasets["train"] = load_dataset(
'json',
data_files=data_files,
split=f"train[{args.validation_split_percentage}%:]",
cache_dir=args.cache_dir,
)
logger.info(f"Raw datasets: {raw_datasets}")
# Preprocessing the datasets
max_source_length = args.max_source_length
max_target_length = args.max_target_length
full_max_length = max_source_length + max_target_length
def return_prompt_and_responses(examples) -> Dict[str, str]:
"""Load the paired dataset and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts are structured as follows:
system_prompt + history[[q,a], [q,a]...] + question
"""
prompts = []
for system, history, question in zip(examples["system"], examples["history"], examples["question"]):
system_prompt = system or ""
history_with_question = history + [[question, '']] if history else [[question, '']]
prompts.append(prompt_template.get_prompt(messages=history_with_question, system_prompt=system_prompt))
return {
"prompt": prompts,
"chosen": examples["response_chosen"],
"rejected": examples["response_rejected"],
}
# Preprocess the dataset
train_dataset = None
max_train_samples = 0
if args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets['train']
max_train_samples = len(train_dataset)
if args.max_train_samples is not None and args.max_train_samples > 0:
max_train_samples = min(len(train_dataset), args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
tokenized_dataset = train_dataset.shuffle().map(
return_prompt_and_responses,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=train_dataset.column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
)
train_dataset = tokenized_dataset.filter(
lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
)
logger.debug(f"Num train_samples: {len(train_dataset)}")
logger.debug("First train example:")
first_example = train_dataset[0]
logger.debug(f"prompt:\n{first_example['prompt']}")
logger.debug(f"chosen:\n{first_example['chosen']}")
logger.debug(f"rejected:\n{first_example['rejected']}")
eval_dataset = None
max_eval_samples = 0
if args.do_eval:
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"]
max_eval_samples = len(eval_dataset)
if args.max_eval_samples is not None and args.max_eval_samples > 0:
max_eval_samples = min(len(eval_dataset), args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
eval_dataset = eval_dataset.map(
return_prompt_and_responses,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=eval_dataset.column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
)
eval_dataset = eval_dataset.filter(
lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
)
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
logger.debug("First eval example:")
first_example = eval_dataset[0]
logger.debug(f"prompt:\n{first_example['prompt']}")
logger.debug(f"chosen:\n{first_example['chosen']}")
logger.debug(f"rejected:\n{first_example['rejected']}")
# Load model
torch_dtype = (
args.torch_dtype
if args.torch_dtype in ["auto", None]
else getattr(torch, args.torch_dtype)
)
world_size = int(os.environ.get("WORLD_SIZE", "1"))
ddp = world_size != 1
if ddp:
args.device_map = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info(f"Device map: {args.device_map}")
if args.qlora and is_deepspeed_zero3_enabled():
logger.warning("ZeRO3 are both currently incompatible with QLoRA.")
config = config_class.from_pretrained(
args.model_name_or_path,
trust_remote_code=args.trust_remote_code,
torch_dtype=torch_dtype,
cache_dir=args.cache_dir
)
if args.load_in_4bit or args.load_in_8bit:
logger.info(f"Quantizing model, load_in_4bit: {args.load_in_4bit}, load_in_8bit: {args.load_in_8bit}")
model = model_class.from_pretrained(
args.model_name_or_path,
config=config,
torch_dtype=torch_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
device_map=args.device_map,
trust_remote_code=args.trust_remote_code,
quantization_config=BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
) if args.qlora else None,
)
# fixed FP16 ValueError
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
# Initialize our Trainer
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
model.config.use_cache = False
else:
model.config.use_cache = True
training_args = ORPOConfig(
max_length=full_max_length,
max_prompt_length=args.max_source_length,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
max_steps=args.max_steps,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=args.gradient_checkpointing,
learning_rate=args.learning_rate,
evaluation_strategy=args.eval_strategy,
eval_steps=args.eval_steps,
output_dir=args.output_dir,
report_to=args.report_to,
lr_scheduler_type=args.lr_scheduler_type,
warmup_steps=args.warmup_steps,
optim=args.optim,
bf16=args.bf16,
fp16=args.fp16,
remove_unused_columns=args.remove_unused_columns,
run_name=f"orpo_{args.model_type}",
beta=args.orpo_beta,
)
# Initialize ORPO trainer
peft_config = None
if args.use_peft:
logger.info("Fine-tuning method: LoRA(PEFT)")
target_modules = args.target_modules.split(',') if args.target_modules else None
if target_modules and 'all' in target_modules:
target_modules = find_all_linear_names(model, int4=args.load_in_4bit, int8=args.load_in_8bit)
logger.info(f"Peft target_modules: {target_modules}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
else:
logger.info("Fine-tuning method: Full parameters training")
trainer = ORPOTrainer(
model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=peft_config if args.use_peft else None,
)
print_trainable_parameters(trainer.model)
# Training
if args.do_train:
logger.info("*** Train ***")
train_result = trainer.train()
metrics = train_result.metrics
metrics["train_samples"] = max_train_samples
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if trainer.is_world_process_zero():
logger.debug(f"Training metrics: {metrics}")
logger.info(f"Saving model checkpoint to {args.output_dir}")
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
trainer.model.save_pretrained(args.output_dir)
# Evaluation
if args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = max_eval_samples
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if trainer.is_world_process_zero():
logger.debug(f"Eval metrics: {metrics}")
if __name__ == "__main__":
main()