SanctiMolyDemo1 / app.py
alex6095's picture
Update app.py
5ef0c66
raw
history blame contribute delete
No virus
6.19 kB
import torch
import re
import streamlit as st
import pandas as pd
from transformers import PreTrainedTokenizerFast, DistilBertForSequenceClassification, BartForConditionalGeneration
from tokenization_kobert import KoBertTokenizer
from tokenizers import SentencePieceBPETokenizer
@st.cache(allow_output_mutation=True)
def get_topic():
model = DistilBertForSequenceClassification.from_pretrained(
'alex6095/SanctiMolyTopic', problem_type="multi_label_classification", num_labels=9)
model.eval()
tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
return model, tokenizer
@st.cache(allow_output_mutation=True)
def get_date():
model = BartForConditionalGeneration.from_pretrained('alex6095/SanctiMoly-Bart')
model.eval()
tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-summarization')
return model, tokenizer
class RegexSubstitution(object):
"""Regex substitution class for transform"""
def __init__(self, regex, sub=''):
if isinstance(regex, re.Pattern):
self.regex = regex
else:
self.regex = re.compile(regex)
self.sub = sub
def __call__(self, target):
if isinstance(target, list):
return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target]
else:
return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
default_text = '''์งˆ๋ณ‘๊ด€๋ฆฌ์ฒญ์€ 23์ผ ์ง€๋ฐฉ์ž์น˜๋‹จ์ฒด๊ฐ€ ๋ณด๊ฑด๋‹น๊ตญ๊ณผ ํ˜‘์˜ ์—†์ด ๋‹จ๋…์œผ๋กœ ์ธํ”Œ๋ฃจ์—”์ž(๋…๊ฐ) ๋ฐฑ์‹  ์ ‘์ข… ์ค‘๋‹จ์„ ๊ฒฐ์ •ํ•ด์„œ๋Š” ์•ˆ ๋œ๋‹ค๋Š” ์ž…์žฅ์„ ๋ฐํ˜”๋‹ค.
์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  ์ฐธ๊ณ ์ž๋ฃŒ๋ฅผ ๋ฐฐํฌํ•˜๊ณ  โ€œํ–ฅํ›„ ์ „์ฒด ๊ตญ๊ฐ€ ์˜ˆ๋ฐฉ์ ‘์ข…์‚ฌ์—…์ด ์ฐจ์งˆ ์—†์ด ์ง„ํ–‰๋˜๋„๋ก ์ง€์ž์ฒด๊ฐ€ ์ž์ฒด์ ์œผ๋กœ ์ ‘์ข… ์œ ๋ณด ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•˜์ง€ ์•Š๋„๋ก ์•ˆ๋‚ด๋ฅผ ํ–ˆ๋‹คโ€๊ณ  ์„ค๋ช…ํ–ˆ๋‹ค.
๋…๊ฐ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•œ ํ›„ ๊ณ ๋ น์ธต์„ ์ค‘์‹ฌ์œผ๋กœ ์ „๊ตญ์—์„œ ์‚ฌ๋ง์ž๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์„œ์šธ ์˜๋“ฑํฌ๊ตฌ๋ณด๊ฑด์†Œ๋Š” ์ „๋‚ , ๊ฒฝ๋ถ ํฌํ•ญ์‹œ๋Š” ์ด๋‚  ๊ด€๋‚ด ์˜๋ฃŒ๊ธฐ๊ด€์— ์ ‘์ข…์„ ๋ณด๋ฅ˜ํ•ด๋‹ฌ๋ผ๋Š” ๊ณต๋ฌธ์„ ๋‚ด๋ ค๋ณด๋ƒˆ๋‹ค. ์ด๋Š” ์˜ˆ๋ฐฉ์ ‘์ข…๊ณผ ์‚ฌ๋ง ๊ฐ„ ์ง์ ‘์  ์—ฐ๊ด€์„ฑ์ด ๋‚ฎ์•„ ์ ‘์ข…์„ ์ค‘๋‹จํ•  ์ƒํ™ฉ์€ ์•„๋‹ˆ๋ผ๋Š” ์งˆ๋ณ‘์ฒญ์˜ ํŒ๋‹จ๊ณผ๋Š” ๋‹ค๋ฅธ ๊ฒƒ์ด๋‹ค.
์งˆ๋ณ‘์ฒญ์€ ์ง€๋‚œ 21์ผ ์ „๋ฌธ๊ฐ€ ๋“ฑ์ด ์ฐธ์—ฌํ•œ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜โ€™์˜ ๋ถ„์„ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋…๊ฐ ์˜ˆ๋ฐฉ์ ‘์ข… ์‚ฌ์—…์„ ์ผ์ •๋Œ€๋กœ ์ง„ํ–‰ํ•˜๊ธฐ๋กœ ํ–ˆ๋‹ค. ํŠนํžˆ ๊ณ ๋ น ์–ด๋ฅด์‹ ๊ณผ ์–ด๋ฆฐ์ด, ์ž„์‹ ๋ถ€ ๋“ฑ ๋…๊ฐ ๊ณ ์œ„ํ—˜๊ตฐ์€ ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•˜์ง€ ์•Š์•˜์„ ๋•Œ ํ•ฉ๋ณ‘์ฆ ํ”ผํ•ด๊ฐ€ ํด ์ˆ˜ ์žˆ๋‹ค๋ฉด์„œ ์ ‘์ข…์„ ๋…๋ คํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ๋ฐœํ‘œ ์ดํ›„์—๋„ ์‚ฌ๋ง ๋ณด๊ณ ๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜ ํšŒ์˜โ€™์™€ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ์ „๋ฌธ์œ„์›ํšŒโ€™๋ฅผ ๊ฐœ์ตœํ•ด ๋…๊ฐ๋ฐฑ์‹ ๊ณผ ์‚ฌ๋ง ๊ฐ„ ๊ด€๋ จ์„ฑ, ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ์—ฌ๋ถ€ ๋“ฑ์— ๋Œ€ํ•ด ๋‹ค์‹œ ๊ฒฐ๋ก  ๋‚ด๋ฆฌ๊ธฐ๋กœ ํ–ˆ๋‹ค. ํšŒ์˜ ๊ฒฐ๊ณผ๋Š” ์ด๋‚  ์˜คํ›„ 7์‹œ ๋„˜์–ด ๋ฐœํ‘œ๋  ์˜ˆ์ •์ด๋‹ค.
'''
topics_raw = ['IT/๊ณผํ•™', '๊ฒฝ์ œ', '๋ฌธํ™”', '๋ฏธ์šฉ/๊ฑด๊ฐ•', '์‚ฌํšŒ', '์ƒํ™œ', '์Šคํฌ์ธ ', '์—ฐ์˜ˆ', '์ •์น˜']
topic_model, topic_tokenizer = get_topic()
date_model, date_tokenizer = get_date()
st.sidebar.header('Menu')
name = st.sidebar.selectbox('Model', ['Topic Classification', 'Date Prediction'])
if name == 'Topic Classification':
title = 'News Topic Classification'
model, tokenizer = topic_model, topic_tokenizer
elif name == 'Date Prediction':
title = 'News Date prediction'
model, tokenizer = date_model, date_tokenizer
st.title(title)
text = st.text_area("Input news :", value=default_text)
st.markdown("## Original News Data")
st.write(text)
if name == 'Topic Classification':
st.markdown("## Predict Topic")
col1, col2 = st.columns(2)
if text:
with st.spinner('processing..'):
text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
encoded_dict = tokenizer(
text=text,
add_special_tokens=True,
max_length=512,
truncation=True,
return_tensors='pt',
return_length=True
)
input_ids = encoded_dict['input_ids']
input_ids_len = encoded_dict['length'].unsqueeze(0)
attn_mask = torch.arange(input_ids.size(1))
attn_mask = attn_mask[None, :] < input_ids_len[:, None]
outputs = model(input_ids=input_ids, attention_mask=attn_mask)
_, preds = torch.max(outputs.logits, 1)
col1.write(topics_raw[preds.squeeze(0)])
softmax = torch.nn.Softmax(dim=1)
prob = softmax(outputs.logits).squeeze(0).detach()
chart_data = pd.DataFrame({
'Topic': topics_raw,
'Probability': prob
})
chart_data = chart_data.set_index('Topic')
col2.bar_chart(chart_data)
elif name == 'Date Prediction':
st.markdown("## Predict 3 possible Date")
if text:
with st.spinner('processing..'):
text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
raw_input_ids = tokenizer.encode(text)
input_ids = [tokenizer.bos_token_id] + \
raw_input_ids + [tokenizer.eos_token_id]
outputs = model.generate(torch.tensor([input_ids]),
early_stopping=True,
do_sample=True, #์ƒ˜ํ”Œ๋ง ์ „๋žต ์‚ฌ์šฉ
max_length=50, # ์ตœ๋Œ€ ๋””์ฝ”๋”ฉ ๊ธธ์ด๋Š” 50
top_k=50, # ํ™•๋ฅ  ์ˆœ์œ„๊ฐ€ 50์œ„ ๋ฐ–์ธ ํ† ํฐ์€ ์ƒ˜ํ”Œ๋ง์—์„œ ์ œ์™ธ
top_p=0.95, # ๋ˆ„์  ํ™•๋ฅ ์ด 95%์ธ ํ›„๋ณด์ง‘ํ•ฉ์—์„œ๋งŒ ์ƒ์„ฑ
num_return_sequences=3 #3๊ฐœ์˜ ๊ฒฐ๊ณผ๋ฅผ ๋””์ฝ”๋”ฉํ•ด๋‚ธ๋‹ค
)
pred_print = []
for output in outputs:
pred_print.append(tokenizer.decode(output.squeeze().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True))
st.write(", ".join(pred_print))