hanzla commited on
Commit
db2e21a
1 Parent(s): c567676
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -7,11 +7,15 @@ import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
  model_name = "ModularityAI/gemma-2b-datascience-it-raft"
 
 
 
 
10
 
11
  pipeline = transformers.pipeline(
12
  "text-generation",
13
- model=model_name,
14
- model_kwargs={"torch_dtype": torch.bfloat16},
15
  device="cuda",
16
  )
17
 
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
  model_name = "ModularityAI/gemma-2b-datascience-it-raft"
10
+ tokenizer_name = "google/gemma-2b-it"
11
+
12
+ model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,device='cuda')
13
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name,device='cuda')
14
 
15
  pipeline = transformers.pipeline(
16
  "text-generation",
17
+ model=model,
18
+ tokenizer=tokenizer,
19
  device="cuda",
20
  )
21