hybridRAG / app.py
soojeongcrystal's picture
Update app.py
3db0045 verified
raw
history blame
No virus
3.86 kB
import gradio as gr
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx
import matplotlib.pyplot as plt
# Sentence-BERT λͺ¨λΈ λ‘œλ“œ
model = SentenceTransformer('all-MiniLM-L6-v2')
# 직원 데이터λ₯Ό λΆ„μ„ν•˜μ—¬ ꡐ윑 ν”„λ‘œκ·Έλž¨μ„ μΆ”μ²œν•˜κ³  κ·Έλž˜ν”„λ₯Ό κ·Έλ¦¬λŠ” ν•¨μˆ˜
def analyze_data(employee_file, program_file):
# 직원 데이터와 ꡐ윑 ν”„λ‘œκ·Έλž¨ 데이터 뢈러였기
employee_df = pd.read_csv(employee_file.name)
program_df = pd.read_csv(program_file.name)
# 직원 μ—­λŸ‰κ³Ό ν”„λ‘œκ·Έλž¨ ν•™μŠ΅ λͺ©ν‘œλ₯Ό 벑터화
employee_skills = employee_df['current_skills'].tolist()
program_skills = program_df['skills_acquired'].tolist()
employee_embeddings = model.encode(employee_skills)
program_embeddings = model.encode(program_skills)
# μœ μ‚¬λ„ 계산
similarities = cosine_similarity(employee_embeddings, program_embeddings)
# 직원별 μΆ”μ²œ ν”„λ‘œκ·Έλž¨ 리슀트
recommendations = []
for i, employee in employee_df.iterrows():
recommended_programs = []
for j, program in program_df.iterrows():
if similarities[i][j] > 0.5: # μœ μ‚¬λ„ μž„κ³„κ°’ κΈ°μ€€
recommended_programs.append(f"{program['program_name']} ({program['duration']})")
if recommended_programs:
recommendation = f"직원 {employee['employee_name']}의 μΆ”μ²œ ν”„λ‘œκ·Έλž¨: {', '.join(recommended_programs)}"
else:
recommendation = f"직원 {employee['employee_name']}μ—κ²Œ μ ν•©ν•œ ν”„λ‘œκ·Έλž¨μ΄ μ—†μŠ΅λ‹ˆλ‹€."
recommendations.append(recommendation)
# κ²°κ³Ό ν…μŠ€νŠΈ
result_text = "\n".join(recommendations)
# λ„€νŠΈμ›Œν¬ κ·Έλž˜ν”„ 생성
G = nx.Graph()
for employee in employee_df['employee_name']:
G.add_node(employee, type='employee')
for program in program_df['program_name']:
G.add_node(program, type='program')
for i, employee in employee_df.iterrows():
for j, program in program_df.iterrows():
if similarities[i][j] > 0.5: # μœ μ‚¬λ„ μž„κ³„κ°’
G.add_edge(employee['employee_name'], program['program_name'])
# κ·Έλž˜ν”„ μ‹œκ°ν™”
plt.figure(figsize=(10, 8))
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=10, font_weight='bold')
plt.title("직원과 ν”„λ‘œκ·Έλž¨ κ°„μ˜ 관계")
plt.tight_layout()
return result_text, plt.gcf()
# Gradio μΈν„°νŽ˜μ΄μŠ€ μ •μ˜
def main(employee_file, program_file):
return analyze_data(employee_file, program_file)
# Gradio 블둝
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("# HybridRAG μ‹œμŠ€ν…œ")
gr.Markdown("두 개의 CSV νŒŒμΌμ„ μ—…λ‘œλ“œν•˜μ—¬ 뢄석을 μ§„ν–‰ν•˜μ„Έμš”.")
employee_file = gr.File(label="직원 데이터 μ—…λ‘œλ“œ")
program_file = gr.File(label="ꡐ윑 ν”„λ‘œκ·Έλž¨ 데이터 μ—…λ‘œλ“œ")
analyze_button = gr.Button("뢄석 μ‹œμž‘")
output_text = gr.Textbox(label="뢄석 κ²°κ³Ό")
analyze_button.click(main, inputs=[employee_file, program_file], outputs=[output_text])
with gr.Column(scale=2):
gr.Markdown("### 정보 νŒ¨λ„")
gr.Markdown("μ—…λ‘œλ“œλœ 데이터에 λŒ€ν•œ 뢄석 및 κ²°κ³Όλ₯Ό 여기에 ν‘œμ‹œν•©λ‹ˆλ‹€.")
# μ‹œκ°ν™” 차트 좜λ ₯
chart_output = gr.Plot(label="μ‹œκ°ν™” 차트")
# 뢄석 λ²„νŠΌ 클릭 μ‹œ 차트 μ—…λ°μ΄νŠΈ
analyze_button.click(main, inputs=[employee_file, program_file], outputs=[output_text, chart_output])
# Gradio μΈν„°νŽ˜μ΄μŠ€ μ‹€ν–‰
demo.launch()