StarRing2022 commited on
Commit
a612470
1 Parent(s): 4b184b0

Update alpacatrain.py

Browse files
Files changed (1) hide show
  1. alpacatrain.py +1 -1
alpacatrain.py CHANGED
@@ -8,7 +8,7 @@ EPOCHS = 100
8
  LEARNING_RATE = 2e-5
9
  CUTOFF_LEN = 256
10
 
11
- model = RwkvForCausalLM.from_pretrained("rwkv-430M-pile")
12
  tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-430M-pile", add_special_tokens=True)
13
  # model = RwkvForCausalLM.from_pretrained("rwkv-7b-pile")
14
  # tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-7b-pile", add_special_tokens=True)
 
8
  LEARNING_RATE = 2e-5
9
  CUTOFF_LEN = 256
10
 
11
+ model = RwkvForCausalLM.from_pretrained("rwkv-430M-pile").to("cuda")
12
  tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-430M-pile", add_special_tokens=True)
13
  # model = RwkvForCausalLM.from_pretrained("rwkv-7b-pile")
14
  # tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-7b-pile", add_special_tokens=True)