StarRing2022 commited on
Commit
7258298
1 Parent(s): 9482798

Delete alpacatrain.py

Browse files
Files changed (1) hide show
  1. alpacatrain.py +0 -59
alpacatrain.py DELETED
@@ -1,59 +0,0 @@
1
- from datasets import load_dataset
2
- from transformers import RwkvForCausalLM, GPTNeoXTokenizerFast, Trainer, TrainingArguments,DataCollatorForLanguageModeling
3
-
4
- MICRO_BATCH_SIZE = 8
5
- BATCH_SIZE = 128
6
- GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
7
- EPOCHS = 100
8
- LEARNING_RATE = 2e-5
9
- CUTOFF_LEN = 256
10
-
11
- model = RwkvForCausalLM.from_pretrained("rwkv-430M-pile").to("cuda")
12
- tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-430M-pile", add_special_tokens=True)
13
- # model = RwkvForCausalLM.from_pretrained("rwkv-7b-pile")
14
- # tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-7b-pile", add_special_tokens=True)
15
- tokenizer.pad_token = tokenizer.eos_token
16
- tokenizer.pad_token_id = tokenizer.eos_token_id
17
-
18
- data = load_dataset("json", data_files="test.json")
19
-
20
- def generate_prompt(data_point):
21
-
22
- return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
23
-
24
- ### Instruction:
25
- {data_point["instruction"]}
26
-
27
- ### Response:
28
- {data_point["output"]}"""
29
-
30
-
31
- data = data.shuffle().map(
32
- lambda data_point: tokenizer(
33
- generate_prompt(data_point),
34
- truncation=True,
35
- max_length=CUTOFF_LEN,
36
- padding="max_length",
37
- )
38
- )
39
-
40
- trainer = Trainer(
41
- model=model,
42
- train_dataset=data["train"],
43
- args=TrainingArguments(
44
- per_device_train_batch_size=MICRO_BATCH_SIZE,
45
- gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
46
- warmup_steps=100,
47
- num_train_epochs=EPOCHS,
48
- learning_rate=LEARNING_RATE,
49
- fp16=True,
50
- logging_steps=1,
51
- output_dir="rwkv-alpaca",
52
- save_total_limit=3,
53
- ),
54
- data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
55
- )
56
- model.config.use_cache = False
57
- trainer.train(resume_from_checkpoint=False)
58
-
59
- model.save_pretrained("rwkv-alpaca")