|
from transformers import MBartForConditionalGeneration, MBartTokenizer, MarianMTModel, MarianTokenizer |
|
import streamlit as st |
|
|
|
|
|
multilingual_summarization_model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-50') |
|
multilingual_summarization_tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-50') |
|
|
|
|
|
LANGUAGES = { |
|
"English": "en_XX", |
|
"French": "fr_XX", |
|
"Spanish": "es_XX", |
|
"German": "de_DE", |
|
"Chinese": "zh_CN", |
|
"Russian": "ru_RU", |
|
"Arabic": "ar_AR", |
|
"Portuguese": "pt_PT", |
|
"Hindi": "hi_IN", |
|
"Italian": "it_IT", |
|
"Japanese": "ja_XX", |
|
"Korean": "ko_KR", |
|
"Dutch": "nl_NL", |
|
"Polish": "pl_PL", |
|
"Turkish": "tr_TR", |
|
"Swedish": "sv_SE", |
|
"Greek": "el_EL", |
|
"Finnish": "fi_FI", |
|
"Hungarian": "hu_HU", |
|
"Danish": "da_DK", |
|
"Norwegian": "no_NO", |
|
"Czech": "cs_CZ", |
|
"Romanian": "ro_RO", |
|
"Thai": "th_TH", |
|
"Hebrew": "he_IL", |
|
"Vietnamese": "vi_VN", |
|
"Indonesian": "id_ID", |
|
"Malay": "ms_MY", |
|
"Bengali": "bn_BD", |
|
"Ukrainian": "uk_UA", |
|
"Urdu": "ur_PK", |
|
"Swahili": "sw_KE", |
|
"Serbian": "sr_SR", |
|
"Croatian": "hr_HR", |
|
"Slovak": "sk_SK", |
|
"Lithuanian": "lt_LT", |
|
"Latvian": "lv_LV", |
|
"Estonian": "et_EE", |
|
"Bulgarian": "bg_BG", |
|
"Macedonian": "mk_MK", |
|
"Albanian": "sq_AL", |
|
"Georgian": "ka_GE", |
|
"Armenian": "hy_AM", |
|
"Kazakh": "kk_KZ", |
|
"Uzbek": "uz_UZ", |
|
"Tajik": "tg_TJ", |
|
"Kyrgyz": "ky_KG", |
|
"Turkmen": "tk_TM" |
|
} |
|
|
|
|
|
def get_translation_model(source_lang, target_lang): |
|
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}" |
|
model = MarianMTModel.from_pretrained(model_name) |
|
tokenizer = MarianTokenizer.from_pretrained(model_name) |
|
return model, tokenizer |
|
|
|
|
|
def translate_text(text, source_lang, target_lang): |
|
model, tokenizer = get_translation_model(source_lang, target_lang) |
|
inputs = tokenizer([text], return_tensors="pt", truncation=True) |
|
translated_ids = model.generate(inputs['input_ids'], max_length=1024) |
|
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True) |
|
return translated_text |
|
|
|
|
|
def summarize_text(text, source_language="English", target_language="English"): |
|
source_lang_code = LANGUAGES[source_language] |
|
target_lang_code = LANGUAGES[target_language] |
|
|
|
|
|
if source_lang_code != "en_XX": |
|
text = translate_text(text, source_lang_code, "en_XX") |
|
|
|
|
|
inputs = multilingual_summarization_tokenizer(text, return_tensors='pt', padding=True, truncation=True) |
|
summary_ids = multilingual_summarization_model.generate(inputs['input_ids'], num_beams=4, max_length=200, early_stopping=True) |
|
summary = multilingual_summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if target_lang_code != "en_XX": |
|
summary = translate_text(summary, "en_XX", target_lang_code) |
|
|
|
return summary |
|
|
|
|
|
st.title("Multi-Language Text Summarization Tool") |
|
|
|
text = st.text_area("Input Text") |
|
source_language = st.selectbox("Source Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English")) |
|
target_language = st.selectbox("Target Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English")) |
|
|
|
if st.button("Summarize"): |
|
if text: |
|
summary = summarize_text(text, source_language, target_language) |
|
st.subheader("Summary") |
|
st.write(summary) |
|
else: |
|
st.warning("Please enter text to summarize.") |
|
|