StarRing2022
commited on
Commit
•
a612470
1
Parent(s):
4b184b0
Update alpacatrain.py
Browse files- alpacatrain.py +1 -1
alpacatrain.py
CHANGED
@@ -8,7 +8,7 @@ EPOCHS = 100
|
|
8 |
LEARNING_RATE = 2e-5
|
9 |
CUTOFF_LEN = 256
|
10 |
|
11 |
-
model = RwkvForCausalLM.from_pretrained("rwkv-430M-pile")
|
12 |
tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-430M-pile", add_special_tokens=True)
|
13 |
# model = RwkvForCausalLM.from_pretrained("rwkv-7b-pile")
|
14 |
# tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-7b-pile", add_special_tokens=True)
|
|
|
8 |
LEARNING_RATE = 2e-5
|
9 |
CUTOFF_LEN = 256
|
10 |
|
11 |
+
model = RwkvForCausalLM.from_pretrained("rwkv-430M-pile").to("cuda")
|
12 |
tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-430M-pile", add_special_tokens=True)
|
13 |
# model = RwkvForCausalLM.from_pretrained("rwkv-7b-pile")
|
14 |
# tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-7b-pile", add_special_tokens=True)
|