Amelia-James's picture
Create app.py
7ae1348 verified
raw
history blame
No virus
3.95 kB
from transformers import MBartForConditionalGeneration, MBartTokenizer, MarianMTModel, MarianTokenizer
import streamlit as st
# Load multilingual summarization model and tokenizer
multilingual_summarization_model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-50')
multilingual_summarization_tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-50')
# Dictionary of languages and their corresponding Hugging Face model codes
LANGUAGES = {
"English": "en_XX",
"French": "fr_XX",
"Spanish": "es_XX",
"German": "de_DE",
"Chinese": "zh_CN",
"Russian": "ru_RU",
"Arabic": "ar_AR",
"Portuguese": "pt_PT",
"Hindi": "hi_IN",
"Italian": "it_IT",
"Japanese": "ja_XX",
"Korean": "ko_KR",
"Dutch": "nl_NL",
"Polish": "pl_PL",
"Turkish": "tr_TR",
"Swedish": "sv_SE",
"Greek": "el_EL",
"Finnish": "fi_FI",
"Hungarian": "hu_HU",
"Danish": "da_DK",
"Norwegian": "no_NO",
"Czech": "cs_CZ",
"Romanian": "ro_RO",
"Thai": "th_TH",
"Hebrew": "he_IL",
"Vietnamese": "vi_VN",
"Indonesian": "id_ID",
"Malay": "ms_MY",
"Bengali": "bn_BD",
"Ukrainian": "uk_UA",
"Urdu": "ur_PK",
"Swahili": "sw_KE",
"Serbian": "sr_SR",
"Croatian": "hr_HR",
"Slovak": "sk_SK",
"Lithuanian": "lt_LT",
"Latvian": "lv_LV",
"Estonian": "et_EE",
"Bulgarian": "bg_BG",
"Macedonian": "mk_MK",
"Albanian": "sq_AL",
"Georgian": "ka_GE",
"Armenian": "hy_AM",
"Kazakh": "kk_KZ",
"Uzbek": "uz_UZ",
"Tajik": "tg_TJ",
"Kyrgyz": "ky_KG",
"Turkmen": "tk_TM"
}
# Function to get the appropriate translation model and tokenizer
def get_translation_model(source_lang, target_lang):
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
return model, tokenizer
# Function to translate text
def translate_text(text, source_lang, target_lang):
model, tokenizer = get_translation_model(source_lang, target_lang)
inputs = tokenizer([text], return_tensors="pt", truncation=True)
translated_ids = model.generate(inputs['input_ids'], max_length=1024)
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
return translated_text
# Summarization function with multi-language support
def summarize_text(text, source_language="English", target_language="English"):
source_lang_code = LANGUAGES[source_language]
target_lang_code = LANGUAGES[target_language]
# If the input language is not English, translate to English
if source_lang_code != "en_XX":
text = translate_text(text, source_lang_code, "en_XX")
# Summarize the text using mBART
inputs = multilingual_summarization_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
summary_ids = multilingual_summarization_model.generate(inputs['input_ids'], num_beams=4, max_length=200, early_stopping=True)
summary = multilingual_summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Translate summary to the target language if needed
if target_lang_code != "en_XX":
summary = translate_text(summary, "en_XX", target_lang_code)
return summary
# Streamlit interface
st.title("Multi-Language Text Summarization Tool")
text = st.text_area("Input Text")
source_language = st.selectbox("Source Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English"))
target_language = st.selectbox("Target Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English"))
if st.button("Summarize"):
if text:
summary = summarize_text(text, source_language, target_language)
st.subheader("Summary")
st.write(summary)
else:
st.warning("Please enter text to summarize.")