StarRing2022 commited on
Commit
2e504fe
1 Parent(s): 7258298

Delete alpacatest.py

Browse files
Files changed (1) hide show
  1. alpacatest.py +0 -64
alpacatest.py DELETED
@@ -1,64 +0,0 @@
1
- from datasets import load_dataset
2
- from transformers import RwkvForCausalLM, GPTNeoXTokenizerFast,GPT2Config,pipeline,GenerationConfig
3
- import torch
4
- import numpy as np
5
- import gradio as gr
6
-
7
- if torch.cuda.is_available():
8
- device = "cuda"
9
- else:
10
- device = "cpu"
11
-
12
- model = RwkvForCausalLM.from_pretrained("rwkv-alpaca")
13
- model = model.to(device)
14
-
15
- tokenizer = GPTNeoXTokenizerFast.from_pretrained("rwkv-alpaca", add_special_tokens=True)
16
-
17
-
18
-
19
- #rwkv with alpaca
20
- def generate_prompt(instruction, input=None):
21
-
22
- return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
23
-
24
- ### Instruction:
25
- {instruction}
26
-
27
- ### Response:"""
28
-
29
- def evaluate(
30
- instruction,
31
- temperature=0.1,
32
- top_p=0.75,
33
- top_k=40,
34
- max_new_tokens=128,
35
- ):
36
- prompt = generate_prompt(instruction)
37
- input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
38
- out = model.generate(input_ids=input_ids,temperature=temperature,top_p=top_p,top_k=top_k,max_new_tokens=max_new_tokens)
39
- answer = tokenizer.decode(out[0])
40
- return answer.split("### Response:")[1].strip()
41
-
42
-
43
- gr.Interface(
44
- fn=evaluate,#接口函数
45
- inputs=[
46
- gr.components.Textbox(
47
- lines=2, label="Instruction", placeholder="Tell me about alpacas."
48
- ),
49
- gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
50
- gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
51
- gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
52
- gr.components.Slider(
53
- minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
54
- ),
55
- ],
56
- outputs=[
57
- gr.inputs.Textbox(
58
- lines=5,
59
- label="Output",
60
- )
61
- ],
62
- title="RWKV-Alpaca",
63
- description="RWKV,easy in HF.",
64
- ).launch()