---
license: apache-2.0
---
Number of Epochs = 5
Dataset Size = 5.5 k samples [train/validation]
Number of labels used = 2
Thresholding = True
Thresholding value = 0.7
Below is the function to aplly thresholding to output logits.
```python
def get_prediction(text):
encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}
outputs = model(**encoding)
logits = outputs.logits
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
probs = probs.detach().numpy()
label = np.argmax(probs, axis=-1)
if label == 1:
if probs[1] > 0.7:
return 1
else:
return 0
else:
return 0
```