-
Notifications
You must be signed in to change notification settings - Fork 1
/
dpo.py
88 lines (78 loc) · 2.47 KB
/
dpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import json
import argparse
import pandas as pd
from datasets import load_dataset
from transformers import TrainingArguments
from trl import DPOTrainer
from unsloth import FastLanguageModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--json", type=str)
parser.add_argument("--out", type=str)
parser.add_argument("--push", action="store_true")
return parser.parse_args()
def main():
args = get_args()
model_path = args.base
json_path = args.json
output_dir = args.out
max_seq_length = 4096
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = model_path,
max_seq_length = max_seq_length,
dtype = torch.bfloat16,
load_in_4bit = True,
load_in_8bit = False,
attn_implementation = "flash_attention_2",
)
model = FastLanguageModel.get_peft_model(
model,
r = 32,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha = 64,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = True,
random_state = 4337,
)
full_dataset = load_dataset("json", data_files="./dpo.json")
ds = full_dataset['train'].train_test_split(test_size=0.01)
train_dataset = ds['train']
test_dataset = ds['test']
training_args = TrainingArguments(
per_device_train_batch_size = 4,
per_device_eval_batch_size = 4,
num_train_epochs = 3,
remove_unused_columns = False,
gradient_accumulation_steps = 1,
learning_rate = 5e-7,
logging_first_step = True,
logging_steps = 1,
output_dir = output_dir,
optim = "rmsprop",
bf16 = True,
gradient_checkpointing = True,
eval_delay = 1000,
save_strategy = "steps",
save_steps = 500,
save_total_limit = 5,
ddp_find_unused_parameters = False,
)
dpo_trainer = DPOTrainer(
model,
args = training_args,
beta = 0.1,
train_dataset = train_dataset,
eval_dataset = test_dataset,
tokenizer = tokenizer,
max_length = max_seq_length,
max_target_length = 3092,
max_prompt_length = 1024,
generate_during_eval = False,
)
dpo_trainer.train()
model.save_pretrained_merged("dpo", tokenizer, save_method = "merged_16bit")
if __name__ == "__main__":
main()