import urllib.request, urllib.error, urllib.parse import json import pandas as pd import ssl import torch import re from pprint import pprint from captum.attr import visualization REST_URL = "http://data.bioontology.org" API_KEY = "604a90bc-ef14-4c26-a347-f4928fa086ea" ssl._create_default_https_context = ssl._create_unverified_context class PyTMinMaxScalerVectorized(object): """ From https://discuss.pytorch.org/t/using-scikit-learns-scalers-for-torchvision/53455 Transforms each channel to the range [0, 1]. """ def __call__(self, tensor): scale = 1.0 / (tensor.max(dim=0, keepdim=True)[0] - tensor.min(dim=0, keepdim=True)[0]) tensor.mul_(scale).sub_(tensor.min(dim=0, keepdim=True)[0]) return tensor def get_diseases(text, pipe): results = pipe(text) diseases = [] disease_span = [] for result in results: ent = result['entity'] # start of a new entity if ent == 'B-DISEASE': disease_span = result['start'], result['end'] elif ent == 'I-DISEASE': if len(disease_span) == 0: disease_span = [] else: disease_span = disease_span[0], result['end'] else: if len(disease_span) > 1: disease = text[disease_span[0]: disease_span[1]] if len(disease) > 2: diseases.append(disease) disease_span = [] if len(disease_span) > 1: disease = text[disease_span[0]: disease_span[1]] diseases.append(disease) return diseases def find_end(text): """Find the end of the report.""" ends = [len(text)] patterns = [ re.compile(r'BY ELECTRONICALLY SIGNING THIS REPORT', re.I), re.compile(r'\n {3,}DR.', re.I), re.compile(r'[ ]{1,}RADLINE ', re.I), re.compile(r'.*electronically signed on', re.I), re.compile(r'M\[0KM\[0KM') ] for pattern in patterns: matchobj = pattern.search(text) if matchobj: ends.append(matchobj.start()) return min(ends) def pattern_repl(matchobj): """ Return a replacement string to be used for match object """ return ' '.rjust(len(matchobj.group(0))) def clean_text(text): """ Clean text """ # Replace [**Patterns**] with spaces. text = re.sub(r'\[\*\*.*?\*\*\]', pattern_repl, text) # Replace `_` with spaces. text = re.sub(r'_', ' ', text) start = 0 end = find_end(text) new_text = '' if start > 0: new_text += ' ' * start new_text = text[start:end] # make sure the new text has the same length of old text. if len(text) - end > 0: new_text += ' ' * (len(text) - end) return new_text def get_drg_link(drg_code): drg_code = str(drg_code) if len(drg_code) == 1: drg_code = '00' + drg_code elif len(drg_code) == 2: drg_code = '0' + drg_code return f'https://www.findacode.com/code.php?set=DRG&c={drg_code}' def prettify(dict_list, k): li = [di[k] for di in dict_list] result = "\n".join(l for l in li) return result def get_json(text_to_annotate): url = REST_URL + "/annotator?text=" + urllib.parse.quote(text_to_annotate) + "&ontologies=ICD9CM" +\ "&longest_only=false" + "&exclude_numbers=false" + "&whole_word_only=true" + '&exclude_synonyms=false' opener = urllib.request.build_opener() opener.addheaders = [('Authorization', 'apikey token=' + API_KEY)] try: return json.loads(opener.open(url).read()) except: return [] def parse_results(results): if len(results) == 0: return [] rlist = [] for result in results: annotations = result['annotations'] for annotation in annotations: start = annotation['from']-1 end = annotation['to'] - 1 text = annotation['text'] rlist.append({ 'start': start, 'end': end, 'text': text, 'link': result['annotatedClass']['@id'] }) return rlist def get_icd_annotations(text): response = get_json(text) annotation_list = parse_results(response) return annotation_list def subfinder(mylist, pattern): mylist = mylist.tolist() pattern = pattern.tolist() return list(filter(lambda x: x in pattern, mylist)) def tokenize_icds(tokenizer, annotations, token_ids): icd_tokens = torch.zeros(token_ids.shape) for annotation in annotations: icd = annotation['text'] icd_token_ids = tokenizer(icd, add_special_tokens=False, return_tensors='pt').input_ids[0] # find index of the beginning icd token starting_indices = (token_ids==icd_token_ids[0]).nonzero(as_tuple=False) num_icd_tokens = icd_token_ids.shape[0] # if there's more than 1 icd token for the given annotation if num_icd_tokens > 1: # if there's only one starting index if starting_indices.shape[0] == 1: starting_index = starting_indices.item() icd_tokens[starting_index: starting_index + num_icd_tokens] = 1 # if there's more than 1 starting index, determine which is the appropriate else: for starting_index in starting_indices: if token_ids[starting_index + num_icd_tokens] == icd_token_ids: icd_tokens[starting_index: starting_index + num_icd_tokens] = 1 # otherwise, set the corresponding index to a value of 1 else: icd_tokens[starting_indices] = 1 return icd_tokens def get_attribution(text, tokenizer, model_outputs, inputs, k=7): tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) padding_idx = tokens.index('[PAD]') tokens = tokens[:padding_idx][1:-1] attn = model_outputs[-1][0] agg_attn, final_text = reconstruct_text(tokenizer=tokenizer, tokens=tokens, attn=attn) return agg_attn, final_text def reconstruct_text(tokenizer, tokens, attn): """ find a word -> token_id mapping that allows you to perform an aggregation on the sub-tokens' attention values """ reconstructed_text = tokenizer.convert_tokens_to_string(tokens) num_subtokens = len([t for t in tokens if t.startswith('#')]) aggregated_attn = torch.zeros(len(tokens) - num_subtokens) token_indices = [0] token_idx = 0 reconstructed_tokens = [] for i, token in enumerate(tokens[1:], start=1): # case when a token is a subtoken if token.startswith('#'): token_indices.append(i) else: # reconstruct the tokens to make sure you're doing this correctly reconstructed_token = ''.join(tokens[i].replace('#', '') for i in token_indices) reconstructed_tokens.append(reconstructed_token) # find the corresponding attention vectors aggregated_attn[token_idx] = torch.mean(attn[token_indices]) # create new index list token_indices = [i] token_idx += 1 # reconstruct the tokens to make sure you're doing this correctly reconstructed_token = ''.join(tokens[i].replace('#', '') for i in token_indices) reconstructed_tokens.append(reconstructed_token) # find the corresponding attention vectors aggregated_attn[token_idx] = torch.mean(attn[token_indices]) # final representation of text final_text = ' '.join(reconstructed_tokens).replace(' .', '.') final_text = final_text.replace(' ,', ',') # final_text == reconstructed_text return aggregated_attn, reconstructed_tokens def load_rule(path): rule_df = pd.read_csv(path) # remove MDC 15 - neonate and couple other codes related to postcare if 'MS' in path: msk = (rule_df['MDC']!='15') & (~rule_df['MS-DRG'].isin([945, 946, 949, 950, 998, 999])) space = sorted(rule_df[msk]['DRG_CODE'].unique()) elif 'APR' in path: msk = (rule_df['MDC']!='15') & (~rule_df['APR-DRG'].isin([860, 863])) space = sorted(rule_df[msk]['DRG_CODE'].unique()) drg2idx = {} for d in space: drg2idx[d] = len(drg2idx) i2d = {v:k for k,v in drg2idx.items()} d2mdc, d2w = {}, {} for _, r in rule_df.iterrows(): drg = r['DRG_CODE'] mdc = r['MDC'] w = r['WEIGHT'] d2mdc[drg] = mdc d2w[drg] = w return rule_df, drg2idx, i2d, d2mdc, d2w def visualize_attn(model_results): class_id = model_results['class_dsc'] prob = model_results['prob'] attn = model_results['attn'] tokens = model_results['tokens'] scaler = PyTMinMaxScalerVectorized() normalized_attn = scaler(attn) viz_record = visualization.VisualizationDataRecord( word_attributions=normalized_attn, pred_prob=prob, pred_class=class_id, true_class=class_id, attr_class=0, attr_score=1, raw_input_ids=tokens, convergence_score=1 ) return visualize_text( viz_record, drg_link=model_results['drg_link'], icd_annotations=model_results['icd_results'], diseases=model_results['diseases'] ) def modify_attn_html(attn_html): attn_split = attn_html.split('' htmls.append(href_html) return "".join(htmls) def modify_code_html(html, link, icd=False): html = html.split('')[1].split('')[0] href_html = f'' if icd: href_html = href_html.replace('', '').replace('', '') return href_html def modify_drg_html(html, drg_link): return modify_code_html(html=html, link=drg_link, icd=False) def get_icd_html(icd_list): if len(icd_list) == 0: return 'N/A' final_html = '' icd_set = set() for icd_dict in icd_list: text, link = icd_dict['text'], icd_dict['link'] if text in icd_set: continue tmp_html = visualization.format_classname(classname=text) html = modify_code_html(html=tmp_html, link=link, icd=True) final_html += html icd_set.add(text) return final_html + '' def get_disease_html(diseases): if len(diseases) == 0: return 'N/A' diseases = list(set(diseases)) diseases_str = ', '.join(diseases) html = visualization.format_classname(classname=diseases_str) return html + '' # copied out of captum because we need raw html instead of a jupyter widget def visualize_text(datarecord, drg_link, icd_annotations, diseases): dom = [""] rows = [ "" "" "" "" ] pred_class_html = visualization.format_classname(datarecord.pred_class) icd_class_html = get_icd_html(icd_annotations) disease_html = get_disease_html(diseases) pred_class_html = modify_drg_html(html=pred_class_html, drg_link=drg_link) word_attn_html = visualization.format_word_importances( datarecord.raw_input_ids, datarecord.word_attributions ) rows.append( "".join( [ "", pred_class_html, word_attn_html, disease_html, icd_class_html, "", ] ) ) dom.append("".join(rows)) dom.append("
Predicted DRGWord ImportanceDiseasesICD Concepts
") html = "".join(dom) return html