StarRing2022
commited on
Commit
•
7258298
1
Parent(s):
9482798
Delete alpacatrain.py
Browse files- alpacatrain.py +0 -59
alpacatrain.py
DELETED
@@ -1,59 +0,0 @@
|
|
1 |
-
from datasets import load_dataset
|
2 |
-
from transformers import RwkvForCausalLM, GPTNeoXTokenizerFast, Trainer, TrainingArguments,DataCollatorForLanguageModeling
|
3 |
-
|
4 |
-
MICRO_BATCH_SIZE = 8
|
5 |
-
BATCH_SIZE = 128
|
6 |
-
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
|
7 |
-
EPOCHS = 100
|
8 |
-
LEARNING_RATE = 2e-5
|
9 |
-
CUTOFF_LEN = 256
|
10 |
-
|
11 |
-
model = RwkvForCausalLM.from_pretrained("rwkv-430M-pile").to("cuda")
|
12 |
-
tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-430M-pile", add_special_tokens=True)
|
13 |
-
# model = RwkvForCausalLM.from_pretrained("rwkv-7b-pile")
|
14 |
-
# tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-7b-pile", add_special_tokens=True)
|
15 |
-
tokenizer.pad_token = tokenizer.eos_token
|
16 |
-
tokenizer.pad_token_id = tokenizer.eos_token_id
|
17 |
-
|
18 |
-
data = load_dataset("json", data_files="test.json")
|
19 |
-
|
20 |
-
def generate_prompt(data_point):
|
21 |
-
|
22 |
-
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
23 |
-
|
24 |
-
### Instruction:
|
25 |
-
{data_point["instruction"]}
|
26 |
-
|
27 |
-
### Response:
|
28 |
-
{data_point["output"]}"""
|
29 |
-
|
30 |
-
|
31 |
-
data = data.shuffle().map(
|
32 |
-
lambda data_point: tokenizer(
|
33 |
-
generate_prompt(data_point),
|
34 |
-
truncation=True,
|
35 |
-
max_length=CUTOFF_LEN,
|
36 |
-
padding="max_length",
|
37 |
-
)
|
38 |
-
)
|
39 |
-
|
40 |
-
trainer = Trainer(
|
41 |
-
model=model,
|
42 |
-
train_dataset=data["train"],
|
43 |
-
args=TrainingArguments(
|
44 |
-
per_device_train_batch_size=MICRO_BATCH_SIZE,
|
45 |
-
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
|
46 |
-
warmup_steps=100,
|
47 |
-
num_train_epochs=EPOCHS,
|
48 |
-
learning_rate=LEARNING_RATE,
|
49 |
-
fp16=True,
|
50 |
-
logging_steps=1,
|
51 |
-
output_dir="rwkv-alpaca",
|
52 |
-
save_total_limit=3,
|
53 |
-
),
|
54 |
-
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
55 |
-
)
|
56 |
-
model.config.use_cache = False
|
57 |
-
trainer.train(resume_from_checkpoint=False)
|
58 |
-
|
59 |
-
model.save_pretrained("rwkv-alpaca")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|