hotchpotch's picture
up
0bef579
raw
history blame contribute delete
No virus
2.15 kB
import os
import streamlit as st
from yasem import SpladeEmbedder
if os.getenv("SPACE_ID"):
USE_HF_SPACE = True
os.environ["HF_HOME"] = "/data/.huggingface"
os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface"
else:
USE_HF_SPACE = False
MODEL_NAME = os.environ.get("MODEL_NAME", "hotchpotch/japanese-splade-base-v1")
@st.cache_resource
def get_embedder(model_name: str = MODEL_NAME) -> SpladeEmbedder:
embedder = SpladeEmbedder(
model_name,
)
return embedder
def get_token_values_sorted(input_text: str) -> list[tuple[float, str]]:
embedder = get_embedder()
embeddings = embedder.encode([input_text])
token_values = embedder.get_token_values(embeddings[0])
sorted_tokens = sorted(token_values.items(), key=lambda item: item[1], reverse=True) # type: ignore
return [(value, key) for key, value in sorted_tokens]
def main():
st.set_page_config(
page_title="SPLADE 日本語 demo",
layout="centered",
initial_sidebar_state="auto",
)
st.title("SPLADE 日本語 demo")
get_embedder()
st.markdown("""
[hotchpotch/japanese-splade-base-v1](https://huggingface.co/hotchpotch/japanese-splade-base-v1)を使って、テキストからSPLADEのスパースベクトルに変換するデモです。
""")
input_text = st.text_area("テキスト", height=200)
if st.button("変換"):
if input_text.strip():
with st.spinner("変換中..."):
sorted_tokens = get_token_values_sorted(input_text)
total_tokens = len(sorted_tokens)
st.markdown(f"### 結果 (トークン数: {total_tokens})")
if sorted_tokens:
formatted_data = [
{"スコア": freq, "単語(vocab)": word}
for freq, word in sorted_tokens
]
st.table(formatted_data)
else:
st.warning("入力テキストから有効な単語が見つかりませんでした。")
else:
st.warning("テキストを入力してください。")
if __name__ == "__main__":
main()