File size: 822 Bytes
803a0f9 e7ed86c 3b2043c e7ed86c 3b2043c e7ed86c 3b2043c e7ed86c 3b2043c e7ed86c 3b2043c e7ed86c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
---
license: apache-2.0
---
Number of Epochs = 5 <br>
Dataset Size = 5.5 k samples [train/validation] <br>
Number of labels used = 2 <br>
Thresholding = True<br>
Thresholding value = 0.7<br>
Below is the function to aplly thresholding to output logits.
```python
def get_prediction(text):
encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}
outputs = model(**encoding)
logits = outputs.logits
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
probs = probs.detach().numpy()
label = np.argmax(probs, axis=-1)
if label == 1:
if probs[1] > 0.7:
return 1
else:
return 0
else:
return 0
``` |