File size: 2,230 Bytes
02f4a0c
 
fe5044d
 
02f4a0c
 
fe5044d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e88d5a6
fe5044d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os

import pandas as pd
import requests
from plotnine import aes, geom_point, ggplot, labs, theme_light
from shiny import module, reactive, render, ui


@module.ui
def query_output_ui():
    out = ui.row(
        ui.column(
            4,
            ui.input_text("prompt", "Prompt", placeholder="Enter query"),
        ),
        ui.column(4, ui.output_table("score_table")),
        ui.column(4, ui.output_plot("score_plot")),
    )

    return out


@module.server
def query_output_server(input, output, session):
    @reactive.Calc
    def response_table():
        # This is included to both show the expected API response, and populate
        # the downstream item with zeros before a prompt is entered.
        if input.prompt() == "":
            resp = [
                [
                    {"label": "neutral", "score": 0},
                    {"label": "surprise", "score": 0},
                    {"label": "fear", "score": 0},
                    {"label": "anger", "score": 0},
                    {"label": "disgust", "score": 0},
                    {"label": "sadness", "score": 0},
                    {"label": "joy", "score": 0},
                ]
            ]
        else:
            resp = query(input.prompt())

        df = pd.DataFrame(
            {
                "sentiment": [x["label"] for x in resp[0]],
                "score": [x["score"] for x in resp[0]],
            }
        )
        return df

    @output
    @render.plot
    def score_plot():
        return plot_response(response_table(), input.prompt())

    @output
    @render.table()
    def score_table():
        return response_table()


def plot_response(df, plot_title):
    out = (
        ggplot(df, aes(y="reorder(sentiment, score)", x="score"))
        + geom_point()
        + theme_light()
        + labs(title=f'Prompt: "{plot_title}"', y="Sentiment", x="Score")
    )
    return out


def query(text):
    API_URL = "https://api-inference.huggingface.co/models/j-hartmann/emotion-english-distilroberta-base"
    headers = {"Authorization": "Bearer " + os.environ["HF_API_KEY"]}
    payload = {"inputs": text}
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.json()