StarRing2022 commited on
Commit
4b184b0
1 Parent(s): e01a67a

Update alpacatest.py

Browse files
Files changed (1) hide show
  1. alpacatest.py +3 -2
alpacatest.py CHANGED
@@ -9,7 +9,8 @@ if torch.cuda.is_available():
9
  else:
10
  device = "cpu"
11
 
12
- model = RwkvForCausalLM.from_pretrained("rwkv-alpaca",device_map='auto') #仅500MB,自训练,使用alpaca
 
13
 
14
  tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-alpaca", add_special_tokens=True)
15
 
@@ -33,7 +34,7 @@ def evaluate(
33
  max_new_tokens=128,
34
  ):
35
  prompt = generate_prompt(instruction)
36
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
37
  out = model.generate(input_ids=input_ids,temperature=temperature,top_p=top_p,top_k=top_k,max_new_tokens=max_new_tokens)
38
  answer = tokenizer.decode(out[0])
39
  return answer.split("### Response:")[1].strip()
 
9
  else:
10
  device = "cpu"
11
 
12
+ model = RwkvForCausalLM.from_pretrained("rwkv-alpaca")
13
+ model = model.to(device)
14
 
15
  tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-alpaca", add_special_tokens=True)
16
 
 
34
  max_new_tokens=128,
35
  ):
36
  prompt = generate_prompt(instruction)
37
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
38
  out = model.generate(input_ids=input_ids,temperature=temperature,top_p=top_p,top_k=top_k,max_new_tokens=max_new_tokens)
39
  answer = tokenizer.decode(out[0])
40
  return answer.split("### Response:")[1].strip()