klasocki commited on
Commit
8868222
1 Parent(s): f42d24c

Refactor, introduce CommaFixerInterface and remove duplication

Browse files
commafixer/routers/baseline.py CHANGED
@@ -1,8 +1,8 @@
1
- from fastapi import APIRouter, HTTPException
2
  import logging
3
 
4
  from commafixer.src.baseline import BaselineCommaFixer
5
-
6
 
7
  logger = logging.Logger(__name__)
8
  logging.basicConfig(level=logging.INFO)
@@ -16,10 +16,4 @@ router.model = BaselineCommaFixer()
16
  @router.post('/fix-commas/')
17
  async def fix_commas_with_baseline(data: dict):
18
  json_field_name = 's'
19
- if json_field_name in data:
20
- logger.debug('Fixing commas.')
21
- return {json_field_name: router.model.fix_commas(data['s'])}
22
- else:
23
- msg = f"Text '{json_field_name}' missing"
24
- logger.debug(msg)
25
- raise HTTPException(status_code=400, detail=msg)
 
1
+ from fastapi import APIRouter
2
  import logging
3
 
4
  from commafixer.src.baseline import BaselineCommaFixer
5
+ from common import fix_commas_request_handler
6
 
7
  logger = logging.Logger(__name__)
8
  logging.basicConfig(level=logging.INFO)
 
16
  @router.post('/fix-commas/')
17
  async def fix_commas_with_baseline(data: dict):
18
  json_field_name = 's'
19
+ return fix_commas_request_handler(json_field_name, data, logger, router.model)
 
 
 
 
 
 
commafixer/routers/common.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException
2
+ from logging import Logger
3
+
4
+ from comma_fixer_interface import CommaFixerInterface
5
+
6
+
7
+ def fix_commas_request_handler(
8
+ json_field_name: str,
9
+ data: dict[str, str],
10
+ logger: Logger,
11
+ model: CommaFixerInterface
12
+ ) -> dict[str, str]:
13
+ if json_field_name in data:
14
+ logger.debug('Fixing commas.')
15
+ return {json_field_name: model.fix_commas(data['s'])}
16
+ else:
17
+ msg = f"Text '{json_field_name}' missing"
18
+ logger.debug(msg)
19
+ raise HTTPException(status_code=400, detail=msg)
commafixer/routers/fixer.py CHANGED
@@ -2,6 +2,7 @@ from fastapi import APIRouter, HTTPException
2
  import logging
3
 
4
  from commafixer.src.fixer import CommaFixer
 
5
 
6
 
7
  logger = logging.Logger(__name__)
@@ -16,10 +17,4 @@ router.model = CommaFixer()
16
  @router.post('/')
17
  async def fix_commas(data: dict):
18
  json_field_name = 's'
19
- if json_field_name in data:
20
- logger.debug('Fixing commas.')
21
- return {json_field_name: router.model.fix_commas(data['s'])}
22
- else:
23
- msg = f"Text '{json_field_name}' missing"
24
- logger.debug(msg)
25
- raise HTTPException(status_code=400, detail=msg)
 
2
  import logging
3
 
4
  from commafixer.src.fixer import CommaFixer
5
+ from commafixer.routers.common import fix_commas_request_handler
6
 
7
 
8
  logger = logging.Logger(__name__)
 
17
  @router.post('/')
18
  async def fix_commas(data: dict):
19
  json_field_name = 's'
20
+ return fix_commas_request_handler(json_field_name, data, logger, router.model)
 
 
 
 
 
 
commafixer/src/baseline.py CHANGED
@@ -1,8 +1,10 @@
1
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
2
  import re
3
 
 
4
 
5
- class BaselineCommaFixer:
 
6
  """
7
  A wrapper class for the oliverguhr/fullstop-punctuation-multilang-large baseline punctuation restoration model.
8
  It adapts the model to perform comma fixing instead of full punctuation restoration, that is, removes the
 
1
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
2
  import re
3
 
4
+ from commafixer.src.comma_fixer_interface import CommaFixerInterface
5
 
6
+
7
+ class BaselineCommaFixer(CommaFixerInterface):
8
  """
9
  A wrapper class for the oliverguhr/fullstop-punctuation-multilang-large baseline punctuation restoration model.
10
  It adapts the model to perform comma fixing instead of full punctuation restoration, that is, removes the
commafixer/src/comma_fixer_interface.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class CommaFixerInterface(ABC):
5
+ @abstractmethod
6
+ def fix_commas(self, s: str) -> str:
7
+ pass
commafixer/src/fixer.py CHANGED
@@ -3,8 +3,10 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipelin
3
  import nltk
4
  import re
5
 
 
6
 
7
- class CommaFixer:
 
8
  """
9
  A wrapper class for the fine-tuned comma fixer model.
10
  """
@@ -84,7 +86,7 @@ def _fix_commas_based_on_labels_and_offsets(
84
 
85
  def _should_insert_comma(label, result, current_offset) -> bool:
86
  # Only insert commas for the final token of a word, that is, if next word starts with a space.
87
- # TODO perharps for low confidence tokens, we should use the original decision of the user in the input?
88
  return label == 'B-COMMA' and result[current_offset].isspace()
89
 
90
 
 
3
  import nltk
4
  import re
5
 
6
+ from commafixer.src.comma_fixer_interface import CommaFixerInterface
7
 
8
+
9
+ class CommaFixer(CommaFixerInterface):
10
  """
11
  A wrapper class for the fine-tuned comma fixer model.
12
  """
 
86
 
87
  def _should_insert_comma(label, result, current_offset) -> bool:
88
  # Only insert commas for the final token of a word, that is, if next word starts with a space.
89
+ # TODO perhaps for low confidence tokens, we should use the original decision of the user in the input?
90
  return label == 'B-COMMA' and result[current_offset].isspace()
91
 
92