adamo1139 commited on
Commit
91a3383
1 Parent(s): 1e20652

Upload yi-34b-dpo-unsloth-1.py

Browse files
Files changed (1) hide show
  1. yi-34b-dpo-unsloth-1.py +130 -0
yi-34b-dpo-unsloth-1.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastLanguageModel
2
+ from datasets import Dataset, load_dataset
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, Optional
5
+ import torch
6
+ max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
7
+ dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
8
+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
9
+
10
+ model, tokenizer = FastLanguageModel.from_pretrained(
11
+ model_name = "...../yi-34b-200k-llamafied", # Choose ANY! eg mistralai/Mistral-7B-Instruct-v0.2
12
+ max_seq_length = max_seq_length,
13
+ attn_implementation="flash_attention_2",
14
+ dtype = dtype,
15
+ load_in_4bit = load_in_4bit,
16
+ # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
17
+ )
18
+
19
+
20
+
21
+ #@title Alignment Handbook utils
22
+ import os
23
+ import re
24
+ from typing import List, Literal, Optional
25
+
26
+ from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
27
+ from datasets.builder import DatasetGenerationError
28
+
29
+
30
+ #DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
31
+ tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
32
+
33
+ def chatml_format(example):
34
+ # Format system
35
+ if len(example['system']) > 0:
36
+ message = {"role": "system", "content": example['system']}
37
+ system = tokenizer.apply_chat_template([message], tokenize=False)
38
+ else:
39
+ system = ""
40
+
41
+ # Format instruction
42
+ message = {"role": "user", "content": example['prompt']}
43
+ prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
44
+
45
+ # Format chosen answer
46
+ chosen = example['chosen'] + "<|im_end|>\n"
47
+
48
+ # Format rejected answer
49
+ rejected = example['rejected'] + "<|im_end|>\n"
50
+
51
+ return {
52
+ "prompt": system + prompt,
53
+ "chosen": chosen,
54
+ "rejected": rejected,
55
+ }
56
+
57
+ # Load dataset
58
+ dataset = load_dataset("adamo1139/rawrr_v1", split="train")
59
+
60
+ import pprint
61
+ pprint.pprint("""NOT a formatted dataset
62
+ """)
63
+ pprint
64
+ pprint.pprint(dataset[25])
65
+ pprint.pprint(dataset[26])
66
+ pprint.pprint(dataset[27])
67
+ pprint.pprint(dataset[28])
68
+ pprint.pprint(dataset[29])
69
+ # Save columns
70
+ original_columns = dataset.column_names
71
+
72
+ # Format dataset
73
+ dataset = dataset.map(
74
+ chatml_format,
75
+ remove_columns=original_columns
76
+ )
77
+
78
+ # Print sample
79
+ pprint.pprint("""formatted dataset""")
80
+ pprint.pprint(dataset[25])
81
+ pprint.pprint(dataset[26])
82
+ pprint.pprint(dataset[27])
83
+ pprint.pprint(dataset[28])
84
+ pprint.pprint(dataset[29])
85
+
86
+
87
+ model = FastLanguageModel.get_peft_model(
88
+ model,
89
+ r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
90
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
91
+ "gate_proj", "up_proj", "down_proj",],
92
+ lora_alpha = 32,
93
+ lora_dropout = 0, # Currently only supports dropout = 0
94
+ bias = "none", # Currently only supports bias = "none"
95
+ use_gradient_checkpointing = True,
96
+ random_state = 3407,
97
+ use_rslora = False, # We support rank stabilized LoRA
98
+ loftq_config = None, # And LoftQ
99
+ )
100
+
101
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
102
+ from trl import DPOTrainer
103
+
104
+ dpo_trainer = DPOTrainer(
105
+ model = model,
106
+ ref_model = None,
107
+ args = TrainingArguments(
108
+ per_device_train_batch_size = 1,
109
+ gradient_accumulation_steps = 16,
110
+ warmup_ratio = 0.05,
111
+ num_train_epochs = 1,
112
+ learning_rate = 5e-5,
113
+ fp16 = not torch.cuda.is_bf16_supported(),
114
+ bf16 = torch.cuda.is_bf16_supported(),
115
+ logging_steps = 1,
116
+ optim = "adamw_8bit",
117
+ weight_decay = 0.0,
118
+ lr_scheduler_type = "linear",
119
+ seed = 42,
120
+ output_dir = "outputs2",
121
+ ),
122
+ beta = 0.1,
123
+ train_dataset = dataset,
124
+ # eval_dataset = raw_datasets["test"],
125
+ tokenizer = tokenizer,
126
+ max_length = 500,
127
+ max_prompt_length = 500,
128
+ )
129
+ dpo_trainer.train()
130
+ model.save_pretrained("yi-34b-200k-rawrr_v1_unsloth_1") # Local saving