JERNGOC commited on
Commit
d19290a
1 Parent(s): caaa2db

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from sklearn.model_selection import train_test_split
4
+ from sklearn.preprocessing import StandardScaler
5
+ from sklearn.ensemble import VotingClassifier, StackingClassifier
6
+ from sklearn.linear_model import LogisticRegression
7
+ from sklearn.tree import DecisionTreeClassifier
8
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
9
+ from sklearn.svm import SVC
10
+ from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, auc, classification_report
11
+ import matplotlib.pyplot as plt
12
+ import seaborn as sns
13
+
14
+ # 設定 Streamlit 介面標題
15
+ st.title('分類模型比較:堆疊與投票分類器')
16
+
17
+ # 讓使用者上傳資料
18
+ uploaded_file = st.file_uploader("請上傳 CSV 檔案", type=["csv"])
19
+
20
+ if uploaded_file is not None:
21
+ df = pd.read_csv(uploaded_file)
22
+
23
+ # 定義特徵與目標變數
24
+ X = df.drop(columns=['Target_goal'])
25
+ y = df['Target_goal']
26
+
27
+ # 分割數據集
28
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
29
+
30
+ # 標準化數據
31
+ scaler = StandardScaler()
32
+ X_train = scaler.fit_transform(X_train)
33
+ X_test = scaler.transform(X_test)
34
+
35
+ # 定義基礎模型
36
+ estimators = [
37
+ ('lr', LogisticRegression()),
38
+ ('dt', DecisionTreeClassifier()),
39
+ ('rf', RandomForestClassifier()),
40
+ ('gb', GradientBoostingClassifier()),
41
+ ('svc', SVC(probability=True))
42
+ ]
43
+
44
+ # 堆疊分類器
45
+ stacking_clf = StackingClassifier(
46
+ estimators=estimators,
47
+ final_estimator=LogisticRegression()
48
+ )
49
+ stacking_clf.fit(X_train, y_train)
50
+ y_pred_stack = stacking_clf.predict(X_test)
51
+ y_pred_stack_proba = stacking_clf.predict_proba(X_test)[:, 1]
52
+
53
+ # 堆疊分類器準確性
54
+ accuracy_stack = accuracy_score(y_test, y_pred_stack)
55
+ st.write(f'堆疊分類器的準確性: {accuracy_stack:.2f}')
56
+
57
+ # 堆疊分類器的分類報告
58
+ st.write("堆疊分類器的分類報告:")
59
+ st.text(classification_report(y_test, y_pred_stack))
60
+
61
+ # 投票分類器
62
+ voting_clf = VotingClassifier(
63
+ estimators=estimators,
64
+ voting='soft'
65
+ )
66
+ voting_clf.fit(X_train, y_train)
67
+ y_pred_vote = voting_clf.predict(X_test)
68
+ y_pred_vote_proba = voting_clf.predict_proba(X_test)[:, 1]
69
+
70
+ # 投票分類器準確性
71
+ accuracy_vote = accuracy_score(y_test, y_pred_vote)
72
+ st.write(f'投票分類器的準確性: {accuracy_vote:.2f}')
73
+
74
+ # 投票分類器的分類報告
75
+ st.write("投票分類器的分類報告:")
76
+ st.text(classification_report(y_test, y_pred_vote))
77
+
78
+ # 混淆矩陣可視化
79
+ st.write("堆疊分類器的混淆矩陣:")
80
+ conf_matrix_stack = confusion_matrix(y_test, y_pred_stack)
81
+ fig, ax = plt.subplots()
82
+ sns.heatmap(conf_matrix_stack, annot=True, fmt='d', cmap='Blues', ax=ax)
83
+ ax.set_title('堆疊分類器的混淆矩陣')
84
+ st.pyplot(fig)
85
+
86
+ st.write("投票分類器的混淆矩陣:")
87
+ conf_matrix_vote = confusion_matrix(y_test, y_pred_vote)
88
+ fig, ax = plt.subplots()
89
+ sns.heatmap(conf_matrix_vote, annot=True, fmt='d', cmap='Blues', ax=ax)
90
+ ax.set_title('投票分類器的混淆矩陣')
91
+ st.pyplot(fig)
92
+
93
+ # ROC 曲線
94
+ fpr_stack, tpr_stack, _ = roc_curve(y_test, y_pred_stack_proba)
95
+ roc_auc_stack = auc(fpr_stack, tpr_stack)
96
+
97
+ fpr_vote, tpr_vote, _ = roc_curve(y_test, y_pred_vote_proba)
98
+ roc_auc_vote = auc(fpr_vote, tpr_vote)
99
+
100
+ fig, ax = plt.subplots()
101
+ ax.plot(fpr_stack, tpr_stack, color='blue', lw=2, label='堆疊分類器 (AUC = %0.2f)' % roc_auc_stack)
102
+ ax.plot(fpr_vote, tpr_vote, color='red', lw=2, label='投票分類器 (AUC = %0.2f)' % roc_auc_vote)
103
+ ax.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
104
+ ax.set_xlim([0.0, 1.0])
105
+ ax.set_ylim([0.0, 1.05])
106
+ ax.set_xlabel('假陽性率(False Positive Rate)')
107
+ ax.set_ylabel('真陽性率(True Positive Rate)')
108
+ ax.set_title('ROC 曲線')
109
+ ax.legend(loc="lower right")
110
+ st.pyplot(fig)