StarRing2022
commited on
Commit
•
4b184b0
1
Parent(s):
e01a67a
Update alpacatest.py
Browse files- alpacatest.py +3 -2
alpacatest.py
CHANGED
@@ -9,7 +9,8 @@ if torch.cuda.is_available():
|
|
9 |
else:
|
10 |
device = "cpu"
|
11 |
|
12 |
-
model = RwkvForCausalLM.from_pretrained("rwkv-alpaca"
|
|
|
13 |
|
14 |
tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-alpaca", add_special_tokens=True)
|
15 |
|
@@ -33,7 +34,7 @@ def evaluate(
|
|
33 |
max_new_tokens=128,
|
34 |
):
|
35 |
prompt = generate_prompt(instruction)
|
36 |
-
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
37 |
out = model.generate(input_ids=input_ids,temperature=temperature,top_p=top_p,top_k=top_k,max_new_tokens=max_new_tokens)
|
38 |
answer = tokenizer.decode(out[0])
|
39 |
return answer.split("### Response:")[1].strip()
|
|
|
9 |
else:
|
10 |
device = "cpu"
|
11 |
|
12 |
+
model = RwkvForCausalLM.from_pretrained("rwkv-alpaca")
|
13 |
+
model = model.to(device)
|
14 |
|
15 |
tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-alpaca", add_special_tokens=True)
|
16 |
|
|
|
34 |
max_new_tokens=128,
|
35 |
):
|
36 |
prompt = generate_prompt(instruction)
|
37 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
|
38 |
out = model.generate(input_ids=input_ids,temperature=temperature,top_p=top_p,top_k=top_k,max_new_tokens=max_new_tokens)
|
39 |
answer = tokenizer.decode(out[0])
|
40 |
return answer.split("### Response:")[1].strip()
|