AkashKhamkar
commited on
Commit
•
3b2043c
1
Parent(s):
d4b5722
Updating the code
Browse files
README.md
CHANGED
@@ -13,21 +13,22 @@ Below is the function to aplly thresholding to output logits.
|
|
13 |
|
14 |
```python
|
15 |
def get_prediction(text):
|
16 |
-
encoding =
|
17 |
encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}
|
18 |
|
19 |
-
outputs =
|
20 |
|
21 |
logits = outputs.logits
|
22 |
|
23 |
sigmoid = torch.nn.Sigmoid()
|
|
|
24 |
probs = probs.detach().numpy()
|
25 |
label = np.argmax(probs, axis=-1)
|
26 |
if label == 1:
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
else:
|
32 |
-
|
33 |
```
|
|
|
13 |
|
14 |
```python
|
15 |
def get_prediction(text):
|
16 |
+
encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
|
17 |
encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}
|
18 |
|
19 |
+
outputs = model(**encoding)
|
20 |
|
21 |
logits = outputs.logits
|
22 |
|
23 |
sigmoid = torch.nn.Sigmoid()
|
24 |
+
probs = sigmoid(logits.squeeze().cpu())
|
25 |
probs = probs.detach().numpy()
|
26 |
label = np.argmax(probs, axis=-1)
|
27 |
if label == 1:
|
28 |
+
if probs[1] > 0.7:
|
29 |
+
return 1
|
30 |
+
else:
|
31 |
+
return 0
|
32 |
else:
|
33 |
+
return 0
|
34 |
```
|