File size: 9,247 Bytes
2d00e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import json
import pandas as pd
import re

from dataclasses import dataclass
from nltk.tokenize import sent_tokenize
from typing import Dict, List, Sequence
from utils_report_parser import get_section_from_report

from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    pipeline,
)

@dataclass
class Report:
    patient_id: str|int
    text: str
    date: str
    summary: str|None = None

def clean(s: str) -> str:
    s = s.replace("\n", " ")  # Concatenate into one string
    s = s.replace("_", "")  # Remove long lines and underscores
    s = re.sub(r"\[.*?\]", "", s)  # Remove brackets and parentheses
    s = re.sub(r"\(.*?\)", "", s)
    s = " ".join(s.split())  # Replace multiple white spaces
    return s


def split_paragraphs(text: str) -> List[str]:
    paragraphs = text.split("\n\n")
    paragraphs = list(map(clean, paragraphs))
    paragraphs = list(filter(lambda s: len(s.split()) > 10, paragraphs))
    return paragraphs


def format_casemaker_data(
    df: pd.DataFrame, patient_id_column: str, text_column: str, date_column: str
):
    """Take in a pandas dataframe where each row corresponds to one report for a patient,
    and output a dataframe where each row corresponds to a patient, and the "records" column
    contains a list of dictionaries of all their reports sorted by date

    Args:
        df (pd.DataFrame): Input dataframe on report level
        patient_id_column (str): Patient ID
        text_column (str): Text/Report
        date_column (str): Date (will be used to sort)
    """
    df = df.rename(
        columns={
            patient_id_column: "patient_id",
            text_column: "text",
            date_column: "date",
        }
    )
    df = (
        df.sort_values(by=["patient_id", "date"])
        .groupby("patient_id")
        .apply(lambda df: df[["patient_id", "text", "date"]].to_dict("records"))
    )
    reports_by_patient = dict[str,Sequence[Report]]()
    for patient_id, report_list in zip(df.index, df):
        patient_id = str(patient_id)
        report_list = [Report(**report) for report in report_list]
        reports_by_patient[patient_id] = report_list
    return reports_by_patient


class CaseMaker:
    def __init__(self, organ_keywords_dict_path: str = "../assets/terms.json"):
        self.organ_keyword_dict = json.load(open(organ_keywords_dict_path, "r"))

        self.ner_pipe = pipeline(
            "ner",
            model=AutoModelForTokenClassification.from_pretrained(
                "d4data/biomedical-ner-all"
            ),
            tokenizer=AutoTokenizer.from_pretrained("d4data/biomedical-ner-all"),
            aggregation_strategy="simple",
            device_map="auto",
        )
        # self.summ_pipe = pipeline(
        #     "text2text-generation", model="starmpcc/Asclepius-7B", device_map="auto"
        # )

    def standardize_organ(self, organ_entity: Dict) -> Dict:
        """Given an entity, map its name to a set of recognized entities provided in
        organ_keyword_dict if it matches any of the keywords; otherwise set it as "Other"

        Args:
            organ_entity (Dict): Dictionary corresponding to entity; should contain "word" key
            which is the entity

        Returns:
            Dict: Same dictionary where the "word" key has been updated to either a set of standard
            body organs or "Other"
        """
        # If the organ matches any of the keys or their synonyms, replace the name and return
        for key in self.organ_keyword_dict:
            if (organ_entity["word"].lower() == key.lower()) or (
                organ_entity["word"].lower() in self.organ_keyword_dict[key]
            ):
                organ_entity["word"] = key
                return organ_entity
        # Otherwise, it's a bad match so set the score to 0 and return other
        organ_entity["word"] = "Other"
        organ_entity["score"] = 0.0

        return organ_entity

    def pick_organ_by_keyword(self, s: str):
        words = s.lower()
        for organ in self.organ_keyword_dict.keys():
            if any(
                [
                    keyword.lower() in words
                    for keyword in [organ] + self.organ_keyword_dict[organ]
                ]
            ):
                return organ
        return "other"

    def parse_report_by_organ(self, report: str):
        """Take in a text report and output a dictionary of body organs
        and a list of all the sentences corresponding to that organ

        Args:
            report (str): Input report
        """
        report_string_by_organ = dict[str, str]()

        # Split the report into a list of paragraphs
        paragraphs = split_paragraphs(report)
        # Collect a list of paragraphs related to each organ
        for p in paragraphs:
            # Figure out which organ is being referenced
            selected_organ = self.pick_organ_by_keyword(p)

            # Concatenate the report to its corresponding organ
            if selected_organ not in report_string_by_organ:
                report_string_by_organ[selected_organ] = p
            else:
                report_string_by_organ[selected_organ] += p

        return report_string_by_organ

    def trim_to_relevant_portion(self, report: str):
        # Cut the report to the findings
        report = get_section_from_report(report, "findings")

        # Only keep sentences with symptoms and disease descriptions
        relevant_sentences = []
        for sentence in sent_tokenize(report):
            if any(
                [
                    ent["entity_group"] in ["Sign_symptom", "Disease_disorder"]
                    for ent in self.ner_pipe(sentence)
                ]
            ):
                relevant_sentences.append(sentence)
        return "\n".join(relevant_sentences)

    def summarize_report(self, text: str) -> str:
        """Format text into prompt and summarize clinical text

        Args:
            text (str): Input report

        Returns:
            str: Output
        """

        question = (
            "Can you provide a succinct summary of the key clinical findings "
            "and treatment recommendations outlined in this discharge summary?"
        )

        prompt = """
        You are an intelligent clinical languge model.
        Below is a snippet of patient's discharge summary and a following instruction from healthcare professional.
        Write a response that appropriately completes the instruction.
        The response should provide the accurate answer to the instruction, while being concise.

        [Discharge Summary Begin]
        {note}
        [Discharge Summary End]

        [Instruction Begin]
        {question}
        [Instruction End]
        """.format(
            question=question, note=text
        )

        output = self.summ_pipe(prompt, max_new_tokens=len(text.split()) // 2)[0][
            "generated_text"
        ]
        answer = output.split("[Instruction End]")[-1]
        answer = clean(answer)
        return answer

    def parse_records(
        self,
        reports: Sequence[Report],
    ):
        """Given a list of reports (represented by dictionaries), split each of them
        by body part using parse_report_by_organ, then compile all the text for the same
        organ across different reports
        (i.e. for each body part, have a list of dicts which contain the text from various reports)

        Args:
            records (Sequence[Report]): List of reports represented by dictionaries; each dictionary
            must contain "text" and "date" keys
        """

        # For each organ, collect a list of relevant records containing the text and date
        reports_by_organ = dict[str, Sequence[Report]]()
        for report in reports:
            report_by_organ = self.parse_report_by_organ(report.text)
            for organ, report_text in report_by_organ.items():
                organ_level_record = Report(text=report_text, date=report.date, patient_id=report.patient_id)
                if organ in reports_by_organ:
                    reports_by_organ[organ].append(organ_level_record)
                else:
                    reports_by_organ[organ] = [organ_level_record]

        # For each organ, then filter only to the relevant reports and summarize them
        summarized_reports_by_organ = dict[str, Sequence[Report]]()
        for organ in reports_by_organ.keys():
            cleaned_reports = list[Report]()
            for report in reports_by_organ[organ]:
                # Trim the report
                report_text = self.trim_to_relevant_portion(report.text)
                if report_text:
                    report.summary = report_text
                    cleaned_reports.append(report)
            summarized_reports_by_organ[organ] = cleaned_reports

        return summarized_reports_by_organ

    def format_reports(self, all_reports: Dict[str, List[Dict]]):
        new_reports = {}
        for organ, organ_reports in all_reports.items():
            new_reports[organ] = "\n\n".join(
                [f"**Report {str(r.date)}**\n\n{str(r.summary)}" for r in organ_reports]
            )
        return new_reports