Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
import csv | |
import datetime | |
import io | |
# Sentence-BERT ๋ชจ๋ธ ๋ก๋ | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
# ์ถ์ฒ ๊ฒฐ๊ณผ๋ฅผ CSV ํ์ผ๋ก ์ ์ฅํ๋ ํจ์ (BytesIO๋ก ์์ ) | |
def save_recommendations_to_csv(recommendations): | |
output = io.BytesIO() | |
writer = csv.writer(output) | |
writer.writerow(["Employee ID", "Employee Name", "Recommended Programs"]) | |
# ์ถ์ฒ ๊ฒฐ๊ณผ CSV ํ์ผ์ ๊ธฐ๋ก | |
for rec in recommendations: | |
writer.writerow(rec) | |
output.seek(0) | |
return output | |
# ์ง์ ๋ฐ์ดํฐ๋ฅผ ๋ถ์ํ์ฌ ๊ต์ก ํ๋ก๊ทธ๋จ์ ์ถ์ฒํ๊ณ ๊ทธ๋ํ๋ฅผ ๊ทธ๋ฆฌ๋ ํจ์ | |
def analyze_data(employee_file, program_file): | |
# ์ง์ ๋ฐ์ดํฐ์ ๊ต์ก ํ๋ก๊ทธ๋จ ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ | |
employee_df = pd.read_csv(employee_file.name) | |
program_df = pd.read_csv(program_file.name) | |
# ์ง์ ์ญ๋๊ณผ ํ๋ก๊ทธ๋จ ํ์ต ๋ชฉํ๋ฅผ ๋ฒกํฐํ | |
employee_skills = employee_df['current_skills'].tolist() | |
program_skills = program_df['skills_acquired'].tolist() | |
employee_embeddings = model.encode(employee_skills) | |
program_embeddings = model.encode(program_skills) | |
# ์ ์ฌ๋ ๊ณ์ฐ | |
similarities = cosine_similarity(employee_embeddings, program_embeddings) | |
# ์ง์๋ณ ์ถ์ฒ ํ๋ก๊ทธ๋จ ๋ฆฌ์คํธ | |
recommendations = [] | |
recommendation_rows = [] # CSV ํ์ผ์ ์ ์ฅํ ๋ฐ์ดํฐ๋ฅผ ์ํ ๋ฆฌ์คํธ | |
for i, employee in employee_df.iterrows(): | |
recommended_programs = [] | |
for j, program in program_df.iterrows(): | |
if similarities[i][j] > 0.5: # ์ ์ฌ๋ ์๊ณ๊ฐ ๊ธฐ์ค | |
recommended_programs.append(f"{program['program_name']} ({program['duration']})") | |
if recommended_programs: | |
recommendation = f"์ง์ {employee['employee_name']}์ ์ถ์ฒ ํ๋ก๊ทธ๋จ: {', '.join(recommended_programs)}" | |
recommendation_rows.append([employee['employee_id'], employee['employee_name'], ", ".join(recommended_programs)]) | |
else: | |
recommendation = f"์ง์ {employee['employee_name']}์๊ฒ ์ ํฉํ ํ๋ก๊ทธ๋จ์ด ์์ต๋๋ค." | |
recommendation_rows.append([employee['employee_id'], employee['employee_name'], "์ ํฉํ ํ๋ก๊ทธ๋จ ์์"]) | |
recommendations.append(recommendation) | |
# ๋คํธ์ํฌ ๊ทธ๋ํ ์์ฑ | |
G = nx.Graph() | |
for employee in employee_df['employee_name']: | |
G.add_node(employee, type='employee') | |
for program in program_df['program_name']: | |
G.add_node(program, type='program') | |
for i, employee in employee_df.iterrows(): | |
for j, program in program_df.iterrows(): | |
if similarities[i][j] > 0.5: # ์ ์ฌ๋ ์๊ณ๊ฐ | |
G.add_edge(employee['employee_name'], program['program_name']) | |
# ๊ทธ๋ํ ์๊ฐํ | |
plt.figure(figsize=(10, 8)) | |
pos = nx.spring_layout(G) | |
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=3000, font_size=10, font_weight='bold', edge_color='gray') | |
plt.title("์ง์๊ณผ ํ๋ก๊ทธ๋จ ๊ฐ์ ๊ด๊ณ", fontsize=14, fontweight='bold') | |
plt.tight_layout() | |
# CSV ํ์ผ๋ก ์ถ์ฒ ๊ฒฐ๊ณผ ๋ฐํ | |
csv_output = save_recommendations_to_csv(recommendation_rows) | |
return "\n".join(recommendations), plt.gcf(), csv_output | |
# Gradio ๋ธ๋ก | |
with gr.Blocks(css=".gradio-button {background-color: #6c757d; color: white;} .gradio-textbox {border-color: #6c757d;}") as demo: | |
gr.Markdown("<h1 style='text-align: center; color: #2c3e50;'>๐ผ HybridRAG ์์คํ </h1>", unsafe_allow_html=True) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("<h3 style='color: #34495e;'>1. ์ง์ ๋ฐ ํ๋ก๊ทธ๋จ ๋ฐ์ดํฐ๋ฅผ ์ ๋ก๋ํ์ธ์</h3>") | |
employee_file = gr.File(label="์ง์ ๋ฐ์ดํฐ ์ ๋ก๋", interactive=True) | |
program_file = gr.File(label="๊ต์ก ํ๋ก๊ทธ๋จ ๋ฐ์ดํฐ ์ ๋ก๋", interactive=True) | |
analyze_button = gr.Button("๋ถ์ ์์", elem_classes="gradio-button") | |
output_text = gr.Textbox(label="๋ถ์ ๊ฒฐ๊ณผ", interactive=False, elem_classes="gradio-textbox") | |
with gr.Column(scale=2): | |
gr.Markdown("<h3 style='color: #34495e;'>2. ๋ถ์ ๊ฒฐ๊ณผ</h3>") | |
chart_output = gr.Plot(label="์๊ฐํ ์ฐจํธ") | |
csv_download = gr.File(label="์ถ์ฒ ๊ฒฐ๊ณผ ๋ค์ด๋ก๋") | |
# ๋ถ์ ๋ฒํผ ํด๋ฆญ ์ ์ฐจํธ์ ํ์ผ ๋ค์ด๋ก๋๋ฅผ ์ ๋ฐ์ดํธ | |
analyze_button.click(analyze_data, inputs=[employee_file, program_file], outputs=[output_text, chart_output, csv_download]) | |
# Gradio ์ธํฐํ์ด์ค ์คํ | |
demo.launch() |