waleko commited on
Commit
797d112
1 Parent(s): 0aaba1c

Add limit on tokens length

Browse files
Files changed (1) hide show
  1. translate.py +3 -0
translate.py CHANGED
@@ -34,6 +34,9 @@ def translator_fn(input_text: str, k=10) -> TranslationResult:
34
  input_tokens = tokenizer.batch_decode(inputs.input_ids[0])
35
  input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens]).to(device)
36
 
 
 
 
37
  # Generate output
38
  outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True)
39
  output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
 
34
  input_tokens = tokenizer.batch_decode(inputs.input_ids[0])
35
  input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens]).to(device)
36
 
37
+ if len(input_tokens) > model.config.d_model:
38
+ raise ValueError("Input text is too long")
39
+
40
  # Generate output
41
  outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True)
42
  output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)