erikjm's picture
Upload app.py
154b9c1 verified
raw
history blame contribute delete
No virus
7.19 kB
import gradio as gr
import os
from interface_utils import *
maxim = 'quantity'
submaxims = ["The response provides a sufficient amount of information.",
"The response does not contain unnecessary details."]
checkbox_choices = [
["Yes", "No", "NA"],
["Yes", "No", "NA"]
]
conversation_data_sliced = load_from_jsonl('./data/conversations_unlabeled_sliced.jsonl')
max_conversation_length = max([len(conversation['transcript']) for conversation in conversation_data_sliced])
conversation = get_conversation(conversation_data_sliced)
def save_labels(conv_id, slice_idx, skipped, submaxim_0=None, submaxim_1=None):
data = {
'conv_id': conv_id,
'slice_idx': int(slice_idx),
'maxim': maxim,
'skipped': skipped,
'submaxim_0': submaxim_0,
'submaxim_1': submaxim_1,
}
os.makedirs("../labels", exist_ok=True)
with open(f"../labels/{maxim}_human_labels_{conv_id}_{slice_idx}.json", 'w') as f:
json.dump(data, f, indent=4)
def update_interface(new_conversation):
new_conv_id = new_conversation['conv_id']
new_slice_idx = new_conversation['slice_idx']
new_transcript = new_conversation['transcript']
markdown_blocks = [None] * max_conversation_length
for i in range(max_conversation_length):
if i < len(new_transcript) and new_transcript[i]['speaker'] != '':
if i < len(transcript) - 1:
markdown_blocks[i] = gr.Markdown(f"""&nbsp;&nbsp;**{new_transcript[i]['speaker']}**: &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;{new_transcript[i]['response']}""", visible=True)
if i == len(transcript) - 1:
markdown_blocks[i] = gr.Markdown(f"""&nbsp;&nbsp;**{transcript[i]['speaker']}**: &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<mark style="background-color: lightyellow">{transcript[i]['response']}</mark>""", visible=True)
else:
markdown_blocks[i] = gr.Markdown("", visible=False)
# new_last_response = gr.Text(value=get_last_response(new_transcript),
# label="",
# lines=1,
# container=False,
# interactive=False,
# autoscroll=True,
# visible=True)
new_radio_0_base = gr.Radio(label=submaxims[0],
choices=checkbox_choices[0],
value=None,
visible=True)
new_radio_1_base = gr.Radio(label=submaxims[1],
choices=checkbox_choices[1],
value=None,
visible=True)
conv_len = gr.Number(value=len(new_transcript), visible=False)
return [new_conv_id] + [new_slice_idx] + list(markdown_blocks) + [new_radio_0_base] + [new_radio_1_base] + [conv_len]
def submit(*args):
conv_id = args[0]
slice_idx = args[1]
submaxim_0 = args[-3]
submaxim_1 = args[-2]
save_labels(conv_id, slice_idx, skipped=False, submaxim_0=submaxim_0, submaxim_1=submaxim_1)
new_conversation = get_conversation(conversation_data_sliced)
return update_interface(new_conversation)
def skip(*args):
conv_id = args[0]
slice_idx = args[1]
save_labels(conv_id, slice_idx, skipped=True)
new_conversation = get_conversation(conversation_data_sliced)
return update_interface(new_conversation, slice_idx)
with gr.Blocks(theme=gr.themes.Default()) as interface:
conv_id = conversation['conv_id']
slice_idx = conversation['slice_idx']
transcript = conversation['transcript']
conv_len = gr.Number(value=len(transcript), visible=False)
markdown_blocks = [None] * max_conversation_length
with gr.Column(scale=1, min_width=600):
with gr.Group():
gr.Markdown("""<span style='font-size: 16px;'>&nbsp;&nbsp;&nbsp;&nbsp;**Conversation** </span>""",
visible=True)
for i in range(max_conversation_length):
if i < len(transcript):
if i < len(transcript) - 1:
markdown_blocks[i] = gr.Markdown(f"""&nbsp;&nbsp;**{transcript[i]['speaker']}**: &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;{transcript[i]['response']}""")
if i == len(transcript) - 1:
markdown_blocks[i] = gr.Markdown(f"""&nbsp;&nbsp;**{transcript[i]['speaker']}**: &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<mark style="background-color: lightyellow">{transcript[i]['response']}</mark>""")
else:
markdown_blocks[i] = gr.Markdown("")
if i >= conv_len.value:
markdown_blocks[i].visible = False
with gr.Row():
with gr.Group(elem_classes="bottom-aligned-group"):
speaker_adapted = gr.Markdown(
f"""<span style='font-size: 16px;'>&nbsp;&nbsp;&nbsp;&nbsp;**Labels** </span>""",
visible=True)
# last_response = gr.Textbox(value=get_last_response(transcript),
# label="",
# lines=1,
# container=False,
# interactive=False,
# autoscroll=True,
# visible=True)
radio_submaxim_0_base = gr.Radio(label=submaxims[0],
choices=checkbox_choices[0],
value=None,
visible=True)
radio_submaxim_1_base = gr.Radio(label=submaxims[1],
choices=checkbox_choices[1],
value=None,
visible=True)
submit_button = gr.Button("Submit")
skip_button = gr.Button("Skip")
conv_id_element = gr.Text(value=conv_id, visible=False)
slice_idx_element = gr.Text(value=slice_idx, visible=False)
input_list = [conv_id_element] + \
[slice_idx_element] + \
markdown_blocks + \
[radio_submaxim_0_base] + \
[radio_submaxim_1_base] + \
[conv_len]
submit_button.click(
fn=submit,
inputs=input_list,
outputs=[conv_id_element,
slice_idx_element,
*markdown_blocks,
radio_submaxim_0_base,
radio_submaxim_1_base,
conv_len]
)
skip_button.click(
fn=skip,
inputs=input_list,
outputs=[conv_id_element,
slice_idx_element,
*markdown_blocks,
radio_submaxim_0_base,
radio_submaxim_1_base,
conv_len]
)
css = """
#textbox_id textarea {
background-color: white;
}
.bottom-aligned-group {
display: flex;
flex-direction: column;
justify-content: flex-end;
height: 100%;
}
"""
interface.css = css
interface.launch()