Spaces:
Runtime error
Runtime error
adds table
Browse files- app.py +24 -8
- setup.py +1 -1
- story_gen.py +54 -30
app.py
CHANGED
@@ -3,6 +3,7 @@ from story_gen import StoryGenerator
|
|
3 |
import plotly.figure_factory as ff
|
4 |
import plotly.express as px
|
5 |
import random
|
|
|
6 |
gen = StoryGenerator()
|
7 |
|
8 |
st.set_page_config(page_title='Storytelling ' +
|
@@ -14,12 +15,13 @@ if 'count' not in st.session_state or st.session_state.count == 6:
|
|
14 |
else:
|
15 |
st.session_state.count += 1
|
16 |
container_mode = st.sidebar.container()
|
|
|
17 |
container_param = st.sidebar.container()
|
18 |
container_button = st.sidebar.container()
|
19 |
mode = container_mode.radio(
|
20 |
"Select your mode",
|
21 |
('Create Statistics', 'Play Storytelling'), index=0)
|
22 |
-
story_till_now =
|
23 |
label='First Sentence',
|
24 |
value=random.choice([
|
25 |
'Hello, I\'m a language model,',
|
@@ -32,7 +34,7 @@ story_till_now = container_param.text_input(
|
|
32 |
]))
|
33 |
|
34 |
num_generation = container_param.slider(
|
35 |
-
label='Number of generation', min_value=1, max_value=100, value=
|
36 |
length = container_param.slider(label='Length of the generated sentence',
|
37 |
min_value=1, max_value=100, value=10, step=1)
|
38 |
if mode == 'Create Statistics':
|
@@ -49,19 +51,31 @@ if mode == 'Create Statistics':
|
|
49 |
if container_button.button('Analyse'):
|
50 |
gen.get_stats(story_till_now=story_till_now,
|
51 |
num_generation=num_generation, length=length, reaction_weight=reaction_weight, num_tests=num_tests)
|
52 |
-
if len(gen.stories) > 0:
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
st.markdown(f'### Story no. {si}:', unsafe_allow_html=False)
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
data=gen.stats_df[gen.stats_df.sentence_no==3]
|
57 |
fig = px.violin(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
|
58 |
st.plotly_chart(fig, use_container_width=True)
|
59 |
fig2 = px.box(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
|
60 |
st.plotly_chart(fig2, use_container_width=True)
|
61 |
else:
|
62 |
-
|
63 |
elif mode == 'Play Storytelling':
|
64 |
-
container_mode.write('Let\'s play storytelling.')
|
65 |
|
66 |
# # , placeholder="Start writing your story...")
|
67 |
# story_till_now = st.text_input(
|
@@ -78,4 +92,6 @@ elif mode == 'Play Storytelling':
|
|
78 |
st.text(story_till_now)
|
79 |
st.markdown(f'The last sentence has the "{emotion["label"]}" **Emotion** with a confidence score of {emotion["score"]}.')
|
80 |
else:
|
81 |
-
|
|
|
|
|
|
3 |
import plotly.figure_factory as ff
|
4 |
import plotly.express as px
|
5 |
import random
|
6 |
+
import numpy as np
|
7 |
gen = StoryGenerator()
|
8 |
|
9 |
st.set_page_config(page_title='Storytelling ' +
|
|
|
15 |
else:
|
16 |
st.session_state.count += 1
|
17 |
container_mode = st.sidebar.container()
|
18 |
+
container_guide = st.sidebar.container()
|
19 |
container_param = st.sidebar.container()
|
20 |
container_button = st.sidebar.container()
|
21 |
mode = container_mode.radio(
|
22 |
"Select your mode",
|
23 |
('Create Statistics', 'Play Storytelling'), index=0)
|
24 |
+
story_till_now = st.text_input(
|
25 |
label='First Sentence',
|
26 |
value=random.choice([
|
27 |
'Hello, I\'m a language model,',
|
|
|
34 |
]))
|
35 |
|
36 |
num_generation = container_param.slider(
|
37 |
+
label='Number of generation', min_value=1, max_value=100, value=5, step=1)
|
38 |
length = container_param.slider(label='Length of the generated sentence',
|
39 |
min_value=1, max_value=100, value=10, step=1)
|
40 |
if mode == 'Create Statistics':
|
|
|
51 |
if container_button.button('Analyse'):
|
52 |
gen.get_stats(story_till_now=story_till_now,
|
53 |
num_generation=num_generation, length=length, reaction_weight=reaction_weight, num_tests=num_tests)
|
54 |
+
# if len(gen.stories) > 0:
|
55 |
+
# for si, story in enumerate(gen.stories):
|
56 |
+
# st.markdown(f'### Story no. {si}:', unsafe_allow_html=False)
|
57 |
+
# st.markdown(story, unsafe_allow_html=False)
|
58 |
+
# data=gen.stats_df[gen.stats_df.sentence_no==3]
|
59 |
+
# fig = px.violin(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
|
60 |
+
# st.plotly_chart(fig, use_container_width=True)
|
61 |
+
# fig2 = px.box(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
|
62 |
+
# st.plotly_chart(fig2, use_container_width=True)
|
63 |
+
if len(gen.data) > 0:
|
64 |
+
for si, story in enumerate(gen.data):
|
65 |
st.markdown(f'### Story no. {si}:', unsafe_allow_html=False)
|
66 |
+
for i, sentence in enumerate(story):
|
67 |
+
col_sentence, col_emo = st.columns([3,1])
|
68 |
+
col_sentence.markdown(sentence['sentence'], unsafe_allow_html=False)
|
69 |
+
col_emo.markdown(f'{sentence["emotion"]} {np.round(sentence["confidence_score"], 3)}', unsafe_allow_html=False)
|
70 |
+
st.table(data=gen.stats_df, )
|
71 |
data=gen.stats_df[gen.stats_df.sentence_no==3]
|
72 |
fig = px.violin(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
|
73 |
st.plotly_chart(fig, use_container_width=True)
|
74 |
fig2 = px.box(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
|
75 |
st.plotly_chart(fig2, use_container_width=True)
|
76 |
else:
|
77 |
+
container_guide.markdown('### You selected statistics. Now set your parameters and click the `Analyse` button.')
|
78 |
elif mode == 'Play Storytelling':
|
|
|
79 |
|
80 |
# # , placeholder="Start writing your story...")
|
81 |
# story_till_now = st.text_input(
|
|
|
92 |
st.text(story_till_now)
|
93 |
st.markdown(f'The last sentence has the "{emotion["label"]}" **Emotion** with a confidence score of {emotion["score"]}.')
|
94 |
else:
|
95 |
+
container_guide.markdown('### Write the first sentence and then hit the `Run` button')
|
96 |
+
# elif mode == 'Analyse Emotions':
|
97 |
+
# container_mode.write('Let\'s play storytelling.')
|
setup.py
CHANGED
@@ -4,7 +4,7 @@ with open("README.md", "r") as fh:
|
|
4 |
long_description = fh.read()
|
5 |
|
6 |
setuptools.setup(
|
7 |
-
name="
|
8 |
version="0.0.1",
|
9 |
author="Jitesh Gosar",
|
10 |
author_email="gosar95@gmail.com",
|
|
|
4 |
long_description = fh.read()
|
5 |
|
6 |
setuptools.setup(
|
7 |
+
name="storytelling",
|
8 |
version="0.0.1",
|
9 |
author="Jitesh Gosar",
|
10 |
author_email="gosar95@gmail.com",
|
story_gen.py
CHANGED
@@ -9,14 +9,14 @@ import pandas as pd
|
|
9 |
# import nltk
|
10 |
import re
|
11 |
|
|
|
12 |
class StoryGenerator:
|
13 |
def __init__(self):
|
14 |
self.initialise_models()
|
15 |
self.stats_df = pd.DataFrame(data=[], columns=[])
|
16 |
self.stories = []
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
def initialise_models(self):
|
21 |
start = time.time()
|
22 |
self.generator = pipeline('text-generation', model='gpt2')
|
@@ -32,11 +32,17 @@ class StoryGenerator:
|
|
32 |
self.clear_stats()
|
33 |
|
34 |
def clear_stories(self):
|
|
|
35 |
self.stories = []
|
36 |
|
37 |
def clear_stats(self):
|
38 |
self.stats_df = pd.DataFrame(data=[], columns=[])
|
39 |
|
|
|
|
|
|
|
|
|
|
|
40 |
@staticmethod
|
41 |
def get_num_token(text):
|
42 |
# return len(nltk.word_tokenize(text))
|
@@ -60,8 +66,7 @@ class StoryGenerator:
|
|
60 |
length, num_return_sequences=1)
|
61 |
story_till_now = genreate_robot_sentence[0]['generated_text']
|
62 |
new_sentence = story_till_now[last_length:]
|
63 |
-
|
64 |
-
emotion = max(emotions[0], key=lambda x: x['score'])
|
65 |
# printj.yellow(f'Sentence {i}:')
|
66 |
# story_to_print = f'{printj.ColorText.cyan(story_till_now[:last_length])}{printj.ColorText.green(story_till_now[last_length:])}\n'
|
67 |
# print(story_to_print)
|
@@ -76,6 +81,13 @@ class StoryGenerator:
|
|
76 |
stats_dict = dict()
|
77 |
num_reactions = 0
|
78 |
reaction_frequency = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
for i in range(num_generation):
|
80 |
# Text generation for User
|
81 |
last_length = len(story_till_now)
|
@@ -85,21 +97,26 @@ class StoryGenerator:
|
|
85 |
genreate_user_sentence = self.generator(story_till_now, max_length=self.get_num_token(
|
86 |
story_till_now)+length, num_return_sequences=1)
|
87 |
story_till_now = genreate_user_sentence[0]['generated_text']
|
88 |
-
|
89 |
|
90 |
printj.red.bold_on_white(f'loop: {i}; check emotion')
|
91 |
# Emotion self.classifier for User
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
show_emotion = False
|
96 |
else:
|
97 |
reaction_frequency = num_reactions/(i+1)
|
98 |
-
|
99 |
-
confidence_score=
|
100 |
-
if
|
101 |
num_reactions += 1
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
# Text generation for Robot
|
104 |
last_length = len(story_till_now)
|
105 |
printj.cyan(story_till_now)
|
@@ -109,46 +126,53 @@ class StoryGenerator:
|
|
109 |
story_till_now)+length, num_return_sequences=1)
|
110 |
story_till_now = genreate_robot_sentence[0]['generated_text']
|
111 |
new_sentence = story_till_now[last_length:]
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
116 |
stats_dict['sentence_no'] = i
|
117 |
-
stats_dict['
|
118 |
-
stats_dict['
|
119 |
-
stats_dict['
|
|
|
120 |
stats_dict['num_reactions'] = num_reactions
|
121 |
stats_dict['reaction_frequency'] = reaction_frequency
|
122 |
stats_dict['reaction_weight'] = reaction_weight
|
123 |
stats_df = pd.concat(
|
124 |
[stats_df, pd.DataFrame(stats_dict, index=[f'idx_{i}'])])
|
125 |
-
return stats_df, story_till_now
|
126 |
|
127 |
def get_stats(self,
|
128 |
story_till_now="Hello, I'm a language model,",
|
129 |
num_generation=4,
|
130 |
length=20, reaction_weight=-1, num_tests=2):
|
131 |
use_random_w = reaction_weight == -1
|
132 |
-
self.stories = []
|
133 |
try:
|
134 |
-
num_rows = max(self.stats_df.
|
135 |
except Exception:
|
136 |
-
num_rows=0
|
137 |
-
for
|
138 |
if use_random_w:
|
139 |
# reaction_weight = np.random.random_sample()
|
140 |
reaction_weight = np.round(np.random.random_sample(), 1)
|
141 |
-
stats_df0, _story_till_now = self.auto_ist(
|
142 |
story_till_now=story_till_now,
|
143 |
num_generation=4,
|
144 |
length=20, reaction_weight=reaction_weight)
|
145 |
-
stats_df0.insert(loc=0, column='
|
146 |
|
147 |
-
# stats_df0['
|
148 |
self.stats_df = pd.concat([self.stats_df, stats_df0])
|
149 |
-
printj.yellow(f'
|
150 |
printj.green(stats_df0)
|
151 |
self.stories.append(_story_till_now)
|
|
|
152 |
self.stats_df = self.stats_df.reset_index(drop=True)
|
153 |
print(self.stats_df)
|
154 |
|
@@ -159,4 +183,4 @@ class StoryGenerator:
|
|
159 |
self.stats_df.to_excel(writer, sheet_name='IST')
|
160 |
|
161 |
# Close the Pandas Excel writer and output the Excel file.
|
162 |
-
writer.save()
|
|
|
9 |
# import nltk
|
10 |
import re
|
11 |
|
12 |
+
|
13 |
class StoryGenerator:
|
14 |
def __init__(self):
|
15 |
self.initialise_models()
|
16 |
self.stats_df = pd.DataFrame(data=[], columns=[])
|
17 |
self.stories = []
|
18 |
+
self.data = []
|
19 |
+
|
|
|
20 |
def initialise_models(self):
|
21 |
start = time.time()
|
22 |
self.generator = pipeline('text-generation', model='gpt2')
|
|
|
32 |
self.clear_stats()
|
33 |
|
34 |
def clear_stories(self):
|
35 |
+
self.data = []
|
36 |
self.stories = []
|
37 |
|
38 |
def clear_stats(self):
|
39 |
self.stats_df = pd.DataFrame(data=[], columns=[])
|
40 |
|
41 |
+
def get_emotion(self, text):
|
42 |
+
emotions = self.classifier(text)
|
43 |
+
emotion = max(emotions[0], key=lambda x: x['score'])
|
44 |
+
return emotion
|
45 |
+
|
46 |
@staticmethod
|
47 |
def get_num_token(text):
|
48 |
# return len(nltk.word_tokenize(text))
|
|
|
66 |
length, num_return_sequences=1)
|
67 |
story_till_now = genreate_robot_sentence[0]['generated_text']
|
68 |
new_sentence = story_till_now[last_length:]
|
69 |
+
emotion = self.get_emotion(new_sentence)
|
|
|
70 |
# printj.yellow(f'Sentence {i}:')
|
71 |
# story_to_print = f'{printj.ColorText.cyan(story_till_now[:last_length])}{printj.ColorText.green(story_till_now[last_length:])}\n'
|
72 |
# print(story_to_print)
|
|
|
81 |
stats_dict = dict()
|
82 |
num_reactions = 0
|
83 |
reaction_frequency = 0
|
84 |
+
emotion = self.get_emotion(story_till_now) # first line emotion
|
85 |
+
story_data = [{
|
86 |
+
'sentence': story_till_now,
|
87 |
+
'turn': 'first',
|
88 |
+
'emotion': emotion['label'],
|
89 |
+
'confidence_score': emotion['score'],
|
90 |
+
}]
|
91 |
for i in range(num_generation):
|
92 |
# Text generation for User
|
93 |
last_length = len(story_till_now)
|
|
|
97 |
genreate_user_sentence = self.generator(story_till_now, max_length=self.get_num_token(
|
98 |
story_till_now)+length, num_return_sequences=1)
|
99 |
story_till_now = genreate_user_sentence[0]['generated_text']
|
100 |
+
new_sentence_user = story_till_now[last_length:]
|
101 |
|
102 |
printj.red.bold_on_white(f'loop: {i}; check emotion')
|
103 |
# Emotion self.classifier for User
|
104 |
+
emotion_user = self.get_emotion(new_sentence_user)
|
105 |
+
if emotion_user['label'] == 'neutral':
|
106 |
+
show_emotion_user = False
|
|
|
107 |
else:
|
108 |
reaction_frequency = num_reactions/(i+1)
|
109 |
+
show_emotion_user = self.check_show_emotion(
|
110 |
+
confidence_score=emotion_user['score'], frequency=reaction_frequency, w=reaction_weight)
|
111 |
+
if show_emotion_user:
|
112 |
num_reactions += 1
|
113 |
|
114 |
+
story_data.append({
|
115 |
+
'sentence': new_sentence_user,
|
116 |
+
'turn': 'user',
|
117 |
+
'emotion': emotion_user['label'],
|
118 |
+
'confidence_score': emotion_user['score'],
|
119 |
+
})
|
120 |
# Text generation for Robot
|
121 |
last_length = len(story_till_now)
|
122 |
printj.cyan(story_till_now)
|
|
|
126 |
story_till_now)+length, num_return_sequences=1)
|
127 |
story_till_now = genreate_robot_sentence[0]['generated_text']
|
128 |
new_sentence = story_till_now[last_length:]
|
129 |
+
emotion = self.get_emotion(new_sentence)
|
130 |
+
|
131 |
+
story_data.append({
|
132 |
+
'sentence': new_sentence,
|
133 |
+
'turn': 'robot',
|
134 |
+
'emotion': emotion['label'],
|
135 |
+
'confidence_score': emotion['score'],
|
136 |
+
})
|
137 |
+
|
138 |
stats_dict['sentence_no'] = i
|
139 |
+
stats_dict['sentence'] = new_sentence_user
|
140 |
+
stats_dict['show_emotion'] = show_emotion_user
|
141 |
+
stats_dict['emotion_label'] = emotion_user['label']
|
142 |
+
stats_dict['emotion_score'] = emotion_user['score']
|
143 |
stats_dict['num_reactions'] = num_reactions
|
144 |
stats_dict['reaction_frequency'] = reaction_frequency
|
145 |
stats_dict['reaction_weight'] = reaction_weight
|
146 |
stats_df = pd.concat(
|
147 |
[stats_df, pd.DataFrame(stats_dict, index=[f'idx_{i}'])])
|
148 |
+
return stats_df, story_till_now, story_data
|
149 |
|
150 |
def get_stats(self,
|
151 |
story_till_now="Hello, I'm a language model,",
|
152 |
num_generation=4,
|
153 |
length=20, reaction_weight=-1, num_tests=2):
|
154 |
use_random_w = reaction_weight == -1
|
155 |
+
# self.stories = []
|
156 |
try:
|
157 |
+
num_rows = max(self.stats_df.story_id)+1
|
158 |
except Exception:
|
159 |
+
num_rows = 0
|
160 |
+
for story_id in range(num_tests):
|
161 |
if use_random_w:
|
162 |
# reaction_weight = np.random.random_sample()
|
163 |
reaction_weight = np.round(np.random.random_sample(), 1)
|
164 |
+
stats_df0, _story_till_now, story_data = self.auto_ist(
|
165 |
story_till_now=story_till_now,
|
166 |
num_generation=4,
|
167 |
length=20, reaction_weight=reaction_weight)
|
168 |
+
stats_df0.insert(loc=0, column='story_id', value=story_id+num_rows)
|
169 |
|
170 |
+
# stats_df0['story_id'] = story_id
|
171 |
self.stats_df = pd.concat([self.stats_df, stats_df0])
|
172 |
+
printj.yellow(f'story_id: {story_id}')
|
173 |
printj.green(stats_df0)
|
174 |
self.stories.append(_story_till_now)
|
175 |
+
self.data.append(story_data)
|
176 |
self.stats_df = self.stats_df.reset_index(drop=True)
|
177 |
print(self.stats_df)
|
178 |
|
|
|
183 |
self.stats_df.to_excel(writer, sheet_name='IST')
|
184 |
|
185 |
# Close the Pandas Excel writer and output the Excel file.
|
186 |
+
writer.save()
|