hybridRAG / app.py
soojeongcrystal's picture
Update app.py
5ad04e8 verified
raw
history blame
No virus
4.81 kB
import gradio as gr
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx
import matplotlib.pyplot as plt
import csv
import datetime
import io
# Sentence-BERT ๋ชจ๋ธ ๋กœ๋“œ
model = SentenceTransformer('all-MiniLM-L6-v2')
# ์ถ”์ฒœ ๊ฒฐ๊ณผ๋ฅผ CSV ํŒŒ์ผ๋กœ ์ €์žฅํ•˜๋Š” ํ•จ์ˆ˜ (BytesIO๋กœ ์ˆ˜์ •)
def save_recommendations_to_csv(recommendations):
output = io.BytesIO()
writer = csv.writer(output)
writer.writerow(["Employee ID", "Employee Name", "Recommended Programs"])
# ์ถ”์ฒœ ๊ฒฐ๊ณผ CSV ํŒŒ์ผ์— ๊ธฐ๋ก
for rec in recommendations:
writer.writerow(rec)
output.seek(0)
return output
# ์ง์› ๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„์„ํ•˜์—ฌ ๊ต์œก ํ”„๋กœ๊ทธ๋žจ์„ ์ถ”์ฒœํ•˜๊ณ  ๊ทธ๋ž˜ํ”„๋ฅผ ๊ทธ๋ฆฌ๋Š” ํ•จ์ˆ˜
def analyze_data(employee_file, program_file):
# ์ง์› ๋ฐ์ดํ„ฐ์™€ ๊ต์œก ํ”„๋กœ๊ทธ๋žจ ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
employee_df = pd.read_csv(employee_file.name)
program_df = pd.read_csv(program_file.name)
# ์ง์› ์—ญ๋Ÿ‰๊ณผ ํ”„๋กœ๊ทธ๋žจ ํ•™์Šต ๋ชฉํ‘œ๋ฅผ ๋ฒกํ„ฐํ™”
employee_skills = employee_df['current_skills'].tolist()
program_skills = program_df['skills_acquired'].tolist()
employee_embeddings = model.encode(employee_skills)
program_embeddings = model.encode(program_skills)
# ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
similarities = cosine_similarity(employee_embeddings, program_embeddings)
# ์ง์›๋ณ„ ์ถ”์ฒœ ํ”„๋กœ๊ทธ๋žจ ๋ฆฌ์ŠคํŠธ
recommendations = []
recommendation_rows = [] # CSV ํŒŒ์ผ์— ์ €์žฅํ•  ๋ฐ์ดํ„ฐ๋ฅผ ์œ„ํ•œ ๋ฆฌ์ŠคํŠธ
for i, employee in employee_df.iterrows():
recommended_programs = []
for j, program in program_df.iterrows():
if similarities[i][j] > 0.5: # ์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’ ๊ธฐ์ค€
recommended_programs.append(f"{program['program_name']} ({program['duration']})")
if recommended_programs:
recommendation = f"์ง์› {employee['employee_name']}์˜ ์ถ”์ฒœ ํ”„๋กœ๊ทธ๋žจ: {', '.join(recommended_programs)}"
recommendation_rows.append([employee['employee_id'], employee['employee_name'], ", ".join(recommended_programs)])
else:
recommendation = f"์ง์› {employee['employee_name']}์—๊ฒŒ ์ ํ•ฉํ•œ ํ”„๋กœ๊ทธ๋žจ์ด ์—†์Šต๋‹ˆ๋‹ค."
recommendation_rows.append([employee['employee_id'], employee['employee_name'], "์ ํ•ฉํ•œ ํ”„๋กœ๊ทธ๋žจ ์—†์Œ"])
recommendations.append(recommendation)
# ๋„คํŠธ์›Œํฌ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
G = nx.Graph()
for employee in employee_df['employee_name']:
G.add_node(employee, type='employee')
for program in program_df['program_name']:
G.add_node(program, type='program')
for i, employee in employee_df.iterrows():
for j, program in program_df.iterrows():
if similarities[i][j] > 0.5: # ์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’
G.add_edge(employee['employee_name'], program['program_name'])
# ๊ทธ๋ž˜ํ”„ ์‹œ๊ฐํ™”
plt.figure(figsize=(10, 8))
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=3000, font_size=10, font_weight='bold', edge_color='gray')
plt.title("์ง์›๊ณผ ํ”„๋กœ๊ทธ๋žจ ๊ฐ„์˜ ๊ด€๊ณ„", fontsize=14, fontweight='bold')
plt.tight_layout()
# CSV ํŒŒ์ผ๋กœ ์ถ”์ฒœ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
csv_output = save_recommendations_to_csv(recommendation_rows)
return "\n".join(recommendations), plt.gcf(), csv_output
# Gradio ๋ธ”๋ก
with gr.Blocks(css=".gradio-button {background-color: #6c757d; color: white;} .gradio-textbox {border-color: #6c757d;}") as demo:
gr.Markdown("<h1 style='text-align: center; color: #2c3e50;'>๐Ÿ’ผ HybridRAG ์‹œ์Šคํ…œ</h1>", unsafe_allow_html=True)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("<h3 style='color: #34495e;'>1. ์ง์› ๋ฐ ํ”„๋กœ๊ทธ๋žจ ๋ฐ์ดํ„ฐ๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”</h3>")
employee_file = gr.File(label="์ง์› ๋ฐ์ดํ„ฐ ์—…๋กœ๋“œ", interactive=True)
program_file = gr.File(label="๊ต์œก ํ”„๋กœ๊ทธ๋žจ ๋ฐ์ดํ„ฐ ์—…๋กœ๋“œ", interactive=True)
analyze_button = gr.Button("๋ถ„์„ ์‹œ์ž‘", elem_classes="gradio-button")
output_text = gr.Textbox(label="๋ถ„์„ ๊ฒฐ๊ณผ", interactive=False, elem_classes="gradio-textbox")
with gr.Column(scale=2):
gr.Markdown("<h3 style='color: #34495e;'>2. ๋ถ„์„ ๊ฒฐ๊ณผ</h3>")
chart_output = gr.Plot(label="์‹œ๊ฐํ™” ์ฐจํŠธ")
csv_download = gr.File(label="์ถ”์ฒœ ๊ฒฐ๊ณผ ๋‹ค์šด๋กœ๋“œ")
# ๋ถ„์„ ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ฐจํŠธ์™€ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ๋ฅผ ์—…๋ฐ์ดํŠธ
analyze_button.click(analyze_data, inputs=[employee_file, program_file], outputs=[output_text, chart_output, csv_download])
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
demo.launch()