hybridRAG / app.py
soojeongcrystal's picture
Update app.py
ca15903 verified
raw
history blame
No virus
12.4 kB
import gradio as gr
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx
import matplotlib.pyplot as plt
import csv
import io
import matplotlib.font_manager as fm
from datetime import datetime, timedelta
# ν•œκ΅­μ–΄ 처리λ₯Ό μœ„ν•œ KoSentence-BERT λͺ¨λΈ λ‘œλ“œ
model = SentenceTransformer('jhgan/ko-sbert-sts')
font_path = "./NanumBarunGothic.ttf"
font_prop = fm.FontProperties(fname=font_path)
plt.rcParams['font.family'] = 'NanumBarunGothic'
plt.rcParams['font.sans-serif'] = ['NanumBarunGothic']
fm.fontManager.addfont(font_path)
# μ „μ—­ λ³€μˆ˜
global_recommendations = None
global_csv_string = None
youtube_columns = None
# CSV λ¬Έμžμ—΄ 생성 ν•¨μˆ˜
def create_csv_string(recommendations):
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(["Employee ID", "Employee Name", "Recommended Programs", "Recommended YouTube Content"])
for rec in recommendations:
writer.writerow(rec)
return output.getvalue()
def create_chart(G):
plt.figure(figsize=(10, 8))
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=3000, font_size=10, font_weight='bold', edge_color='gray')
plt.title("직원과 ν”„λ‘œκ·Έλž¨ κ°„μ˜ 관계", fontsize=14, fontweight='bold')
plt.tight_layout(pad=1.0)
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
plt.close()
return buf
# μ—΄ 맀칭 ν•¨μˆ˜
def auto_match_columns(df, required_cols):
matched_cols = {}
for req_col in required_cols:
matched_col = None
for col in df.columns:
if req_col.lower() in col.lower():
matched_col = col
break
matched_cols[req_col] = matched_col
return matched_cols
# μ—΄ 검증 ν•¨μˆ˜
def validate_and_get_columns(employee_df, program_df):
required_employee_cols = ["employee_id", "employee_name", "current_skills"]
required_program_cols = ["program_name", "skills_acquired", "duration"]
employee_cols = auto_match_columns(employee_df, required_employee_cols)
program_cols = auto_match_columns(program_df, required_program_cols)
for key, value in employee_cols.items():
if value is None:
return f"직원 λ°μ΄ν„°μ—μ„œ '{key}' 열을 선택할 수 μ—†μŠ΅λ‹ˆλ‹€. μ˜¬λ°”λ₯Έ 열을 μ„ νƒν•˜μ„Έμš”.", None, None
for key, value in program_cols.items():
if value is None:
return f"ν”„λ‘œκ·Έλž¨ λ°μ΄ν„°μ—μ„œ '{key}' 열을 선택할 수 μ—†μŠ΅λ‹ˆλ‹€. μ˜¬λ°”λ₯Έ 열을 μ„ νƒν•˜μ„Έμš”.", None, None
return None, employee_cols, program_cols
# 유튜브 데이터 μ—΄ 선택 ν•¨μˆ˜
def select_youtube_columns(youtube_file):
global youtube_columns
if youtube_file is None:
return [gr.Dropdown(choices=[], value="") for _ in range(4)]
youtube_df = pd.read_csv(youtube_file.name)
required_youtube_cols = ["title", "description", "url", "upload_date"]
youtube_columns = auto_match_columns(youtube_df, required_youtube_cols)
column_options = youtube_df.columns.tolist()
return [
gr.Dropdown(choices=column_options, value=youtube_columns.get("title", "")),
gr.Dropdown(choices=column_options, value=youtube_columns.get("description", "")),
gr.Dropdown(choices=column_options, value=youtube_columns.get("url", "")),
gr.Dropdown(choices=column_options, value=youtube_columns.get("upload_date", ""))
]
# 유튜브 μ½˜ν…μΈ  데이터 λ‘œλ“œ 및 처리 ν•¨μˆ˜
def load_youtube_content(file_path, title_col, description_col, url_col, upload_date_col):
youtube_df = pd.read_csv(file_path)
selected_columns = [col for col in [title_col, description_col, url_col, upload_date_col] if col]
youtube_df = youtube_df[selected_columns]
column_mapping = {
title_col: 'title',
description_col: 'description',
url_col: 'url',
upload_date_col: 'upload_date'
}
youtube_df.rename(columns=column_mapping, inplace=True)
if 'upload_date' in youtube_df.columns:
youtube_df['upload_date'] = pd.to_datetime(youtube_df['upload_date'], errors='coerce')
return youtube_df
# 유튜브 μ½˜ν…μΈ μ™€ ꡐ윑 ν”„λ‘œκ·Έλž¨ 맀칭 ν•¨μˆ˜
def match_youtube_content(program_skills, youtube_df, model):
if 'description' not in youtube_df.columns:
return None
youtube_embeddings = model.encode(youtube_df['description'].tolist())
program_embeddings = model.encode(program_skills)
similarities = cosine_similarity(program_embeddings, youtube_embeddings)
return similarities
# 직원 데이터λ₯Ό λΆ„μ„ν•˜μ—¬ ꡐ윑 ν”„λ‘œκ·Έλž¨μ„ μΆ”μ²œν•˜κ³ , ν…Œμ΄λΈ”κ³Ό κ·Έλž˜ν”„λ₯Ό μƒμ„±ν•˜λŠ” ν•¨μˆ˜
def hybrid_rag(employee_file, program_file, youtube_file, title_col, description_col, url_col, upload_date_col):
global global_recommendations
global global_csv_string
# 직원 및 ν”„λ‘œκ·Έλž¨ 데이터 λ‘œλ“œ
employee_df = pd.read_csv(employee_file.name)
program_df = pd.read_csv(program_file.name)
error_msg, employee_cols, program_cols = validate_and_get_columns(employee_df, program_df)
if error_msg:
return error_msg, None, None, None
employee_skills = employee_df[employee_cols["current_skills"]].tolist()
program_skills = program_df[program_cols["skills_acquired"]].tolist()
employee_embeddings = model.encode(employee_skills)
program_embeddings = model.encode(program_skills)
similarities = cosine_similarity(employee_embeddings, program_embeddings)
# 유튜브 μ½˜ν…μΈ  λ‘œλ“œ 및 처리
youtube_df = load_youtube_content(youtube_file.name, title_col, description_col, url_col, upload_date_col)
# 유튜브 μ½˜ν…μΈ μ™€ ꡐ윑 ν”„λ‘œκ·Έλž¨ 맀칭
youtube_similarities = match_youtube_content(program_df[program_cols['skills_acquired']].tolist(), youtube_df, model)
recommendations = []
recommendation_rows = []
for i, employee in employee_df.iterrows():
recommended_programs = []
recommended_youtube = []
for j, program in program_df.iterrows():
if similarities[i][j] > 0.5:
recommended_programs.append(f"{program[program_cols['program_name']]} ({program[program_cols['duration']]})")
if youtube_similarities is not None:
top_youtube_indices = youtube_similarities[j].argsort()[-3:][::-1] # μƒμœ„ 3개
for idx in top_youtube_indices:
if 'title' in youtube_df.columns and 'url' in youtube_df.columns:
recommended_youtube.append(f"{youtube_df.iloc[idx]['title']} (URL: {youtube_df.iloc[idx]['url']})")
if recommended_programs:
recommendation = f"직원 {employee[employee_cols['employee_name']]}의 μΆ”μ²œ ν”„λ‘œκ·Έλž¨: {', '.join(recommended_programs)}"
youtube_recommendation = f"μΆ”μ²œ 유튜브 μ½˜ν…μΈ : {', '.join(recommended_youtube)}" if recommended_youtube else "μΆ”μ²œν•  유튜브 μ½˜ν…μΈ κ°€ μ—†μŠ΅λ‹ˆλ‹€."
recommendation_rows.append([employee[employee_cols['employee_id']], employee[employee_cols['employee_name']],
", ".join(recommended_programs), ", ".join(recommended_youtube)])
else:
recommendation = f"직원 {employee[employee_cols['employee_name']]}μ—κ²Œ μ ν•©ν•œ ν”„λ‘œκ·Έλž¨μ΄ μ—†μŠ΅λ‹ˆλ‹€."
youtube_recommendation = "μΆ”μ²œν•  유튜브 μ½˜ν…μΈ κ°€ μ—†μŠ΅λ‹ˆλ‹€."
recommendation_rows.append([employee[employee_cols['employee_id']], employee[employee_cols['employee_name']],
"μ ν•©ν•œ ν”„λ‘œκ·Έλž¨ μ—†μŒ", "μΆ”μ²œ μ½˜ν…μΈ  μ—†μŒ"])
recommendations.append(recommendation + "\n" + youtube_recommendation)
global_recommendations = recommendation_rows
G = nx.Graph()
for employee in employee_df[employee_cols['employee_name']]:
G.add_node(employee, type='employee')
for program in program_df[program_cols['program_name']]:
G.add_node(program, type='program')
for i, employee in employee_df.iterrows():
for j, program in program_df.iterrows():
if similarities[i][j] > 0.5:
G.add_edge(employee[employee_cols['employee_name']], program[program_cols['program_name']])
# 차트 생성
chart_buffer = create_chart(G)
# CSV λ¬Έμžμ—΄ 생성
global_csv_string = create_csv_string(recommendation_rows)
# κ²°κ³Ό ν…Œμ΄λΈ” λ°μ΄ν„°ν”„λ ˆμž„ 생성
result_df = pd.DataFrame(recommendation_rows, columns=["Employee ID", "Employee Name", "Recommended Programs", "Recommended YouTube Content"])
return result_df, chart_buffer, gr.File(value=global_csv_string, visible=True), gr.Button(value="CSV λ‹€μš΄λ‘œλ“œ", visible=True)
# μ±„νŒ… 응닡 ν•¨μˆ˜
def chat_response(message, history):
global global_recommendations
if global_recommendations is None:
return "λ¨Όμ € '뢄석 μ‹œμž‘' λ²„νŠΌμ„ 눌러 데이터λ₯Ό λΆ„μ„ν•΄μ£Όμ„Έμš”."
for employee in global_recommendations:
if employee[1].lower() in message.lower():
return f"{employee[1]}λ‹˜μ—κ²Œ μΆ”μ²œλœ ν”„λ‘œκ·Έλž¨μ€ λ‹€μŒκ³Ό κ°™μŠ΅λ‹ˆλ‹€: {employee[2]}\n\nμΆ”μ²œ 유튜브 μ½˜ν…μΈ : {employee[3]}"
return "μ£„μ†‘ν•©λ‹ˆλ‹€. ν•΄λ‹Ή μ§μ›μ˜ 정보λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. λ‹€λ₯Έ 직원 이름을 μž…λ ₯ν•΄μ£Όμ„Έμš”."
# CSV λ‹€μš΄λ‘œλ“œ ν•¨μˆ˜
def download_csv():
global global_csv_string
return gr.File(value=global_csv_string, visible=True)
# Gradio 블둝
with gr.Blocks(css=".gradio-button {background-color: #007bff; color: white;} .gradio-textbox {border-color: #6c757d;}") as demo:
gr.Markdown("<h1 style='text-align: center; color: #2c3e50;'>πŸ’Ό HybridRAG μ‹œμŠ€ν…œ (유튜브 μ½˜ν…μΈ  포함)</h1>")
with gr.Row():
with gr.Column(scale=1, min_width=300):
gr.Markdown("<h3 style='color: #34495e;'>1. 데이터λ₯Ό μ—…λ‘œλ“œν•˜μ„Έμš”</h3>")
employee_file = gr.File(label="직원 데이터 μ—…λ‘œλ“œ", interactive=True)
program_file = gr.File(label="ꡐ윑 ν”„λ‘œκ·Έλž¨ 데이터 μ—…λ‘œλ“œ", interactive=True)
youtube_file = gr.File(label="유튜브 μ½˜ν…μΈ  데이터 μ—…λ‘œλ“œ", interactive=True)
gr.Markdown("<h4 style='color: #34495e;'>유튜브 데이터 μ—΄ 선택</h4>")
title_col = gr.Dropdown(label="제λͺ© μ—΄")
description_col = gr.Dropdown(label="μ„€λͺ… μ—΄")
url_col = gr.Dropdown(label="URL μ—΄")
upload_date_col = gr.Dropdown(label="μ—…λ‘œλ“œ λ‚ μ§œ μ—΄")
youtube_file.change(select_youtube_columns, inputs=[youtube_file], outputs=[title_col, description_col, url_col, upload_date_col])
analyze_button = gr.Button("뢄석 μ‹œμž‘", elem_classes="gradio-button")
output_table = gr.DataFrame(label="뢄석 κ²°κ³Ό (ν…Œμ΄λΈ”)")
csv_download = gr.File(label="μΆ”μ²œ κ²°κ³Ό λ‹€μš΄λ‘œλ“œ", visible=False)
download_button = gr.Button("CSV λ‹€μš΄λ‘œλ“œ", visible=False)
with gr.Column(scale=2, min_width=500):
gr.Markdown("<h3 style='color: #34495e;'>2. 뢄석 κ²°κ³Ό 및 μ‹œκ°ν™”</h3>")
chart_output = gr.Image(label="μ‹œκ°ν™” 차트")
gr.Markdown("<h3 style='color: #34495e;'>3. 직원별 μΆ”μ²œ ν”„λ‘œκ·Έλž¨ 및 유튜브 μ½˜ν…μΈ  확인</h3>")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="직원 이름을 μž…λ ₯ν•˜μ„Έμš”")
clear = gr.Button("λŒ€ν™” λ‚΄μ—­ μ§€μš°κΈ°")
# 뢄석 λ²„νŠΌ 클릭 μ‹œ ν…Œμ΄λΈ”, 차트, 파일 λ‹€μš΄λ‘œλ“œλ₯Ό μ—…λ°μ΄νŠΈ
analyze_button.click(
hybrid_rag,
inputs=[employee_file, program_file, youtube_file, title_col, description_col, url_col, upload_date_col],
outputs=[output_table, chart_output, csv_download, download_button]
)
# CSV λ‹€μš΄λ‘œλ“œ λ²„νŠΌ
download_button.click(download_csv, inputs=[], outputs=[csv_download])
# μ±„νŒ… κΈ°λŠ₯
msg.submit(chat_response, [msg, chatbot], [chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
# Gradio μΈν„°νŽ˜μ΄μŠ€ μ‹€ν–‰
demo.launch()