AkashKhamkar commited on
Commit
3b2043c
1 Parent(s): d4b5722

Updating the code

Browse files
Files changed (1) hide show
  1. README.md +8 -7
README.md CHANGED
@@ -13,21 +13,22 @@ Below is the function to aplly thresholding to output logits.
13
 
14
  ```python
15
  def get_prediction(text):
16
- encoding = new_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
17
  encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}
18
 
19
- outputs = new_model(**encoding)
20
 
21
  logits = outputs.logits
22
 
23
  sigmoid = torch.nn.Sigmoid()
 
24
  probs = probs.detach().numpy()
25
  label = np.argmax(probs, axis=-1)
26
  if label == 1:
27
- if probs[1] > 0.7:
28
- return 1
29
- else:
30
- return 0
31
  else:
32
- return 0
33
  ```
 
13
 
14
  ```python
15
  def get_prediction(text):
16
+ encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
17
  encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}
18
 
19
+ outputs = model(**encoding)
20
 
21
  logits = outputs.logits
22
 
23
  sigmoid = torch.nn.Sigmoid()
24
+ probs = sigmoid(logits.squeeze().cpu())
25
  probs = probs.detach().numpy()
26
  label = np.argmax(probs, axis=-1)
27
  if label == 1:
28
+ if probs[1] > 0.7:
29
+ return 1
30
+ else:
31
+ return 0
32
  else:
33
+ return 0
34
  ```