klasocki commited on
Commit
72fb4d7
1 Parent(s): 1114dde

Add option to choose device for the pipeline

Browse files
Files changed (1) hide show
  1. commafixer/src/baseline.py +4 -4
commafixer/src/baseline.py CHANGED
@@ -2,8 +2,8 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipelin
2
 
3
 
4
  class BaselineCommaFixer:
5
- def __init__(self):
6
- self._ner = _create_baseline_pipeline()
7
 
8
  def fix_commas(self, s: str) -> str:
9
  return _fix_commas_based_on_pipeline_output(
@@ -12,10 +12,10 @@ class BaselineCommaFixer:
12
  )
13
 
14
 
15
- def _create_baseline_pipeline(model_name="oliverguhr/fullstop-punctuation-multilang-large") -> NerPipeline:
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForTokenClassification.from_pretrained(model_name)
18
- return pipeline('ner', model=model, tokenizer=tokenizer)
19
 
20
 
21
  def _remove_punctuation(s: str) -> str:
 
2
 
3
 
4
  class BaselineCommaFixer:
5
+ def __init__(self, device=-1):
6
+ self._ner = _create_baseline_pipeline(device=device)
7
 
8
  def fix_commas(self, s: str) -> str:
9
  return _fix_commas_based_on_pipeline_output(
 
12
  )
13
 
14
 
15
+ def _create_baseline_pipeline(model_name="oliverguhr/fullstop-punctuation-multilang-large", device=-1) -> NerPipeline:
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForTokenClassification.from_pretrained(model_name)
18
+ return pipeline('ner', model=model, tokenizer=tokenizer, device=device)
19
 
20
 
21
  def _remove_punctuation(s: str) -> str: