juanesvelez commited on
Commit
50cc7b6
1 Parent(s): fd28976

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import solara as sol
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ model_name = "datificate/gpt2-small-spanish"
6
+ model = AutoModelForCausalLM.from_pretrained(model_name)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+
9
+ def predict_next_token(text):
10
+ inputs = tokenizer(text, return_tensors="pt")
11
+ outputs = model(**inputs)
12
+ next_token_logits = outputs.logits[:, -1, :]
13
+ next_token_probs = torch.softmax(next_token_logits, dim=-1)
14
+ top_k_probs, top_k_indices = torch.topk(next_token_probs, 10)
15
+ top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices[0])
16
+ return list(zip(top_k_tokens, top_k_probs[0].tolist()))
17
+
18
+ @sol.component
19
+ def NextTokenPredictionApp():
20
+ text = sol.use_state("")
21
+ predictions = sol.use_state([])
22
+
23
+ def on_text_change(new_text):
24
+ text.set(new_text)
25
+ preds = predict_next_token(new_text)
26
+ predictions.set(preds)
27
+
28
+ sol.InputText(value=text.value, on_change=on_text_change, placeholder="Escribe algo en español...")
29
+
30
+ if predictions.value:
31
+ sol.Markdown("### Predicciones de tokens:")
32
+ for token, prob in predictions.value:
33
+ sol.Markdown(f"- {token}: {prob:.4f}")
34
+
35
+ sol.run(NextTokenPredictionApp)