hagenw commited on
Commit
d1ab157
1 Parent(s): 0aad17a

Try to split age and gender output

Browse files
Files changed (1) hide show
  1. app.py +43 -18
app.py CHANGED
@@ -10,6 +10,10 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedMod
10
  import audiofile
11
 
12
 
 
 
 
 
13
  class ModelHead(nn.Module):
14
  r"""Classification head."""
15
 
@@ -63,7 +67,6 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
63
 
64
  # load model from hub
65
  device = 0 if torch.cuda.is_available() else "cpu"
66
- model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender"
67
  processor = Wav2Vec2Processor.from_pretrained(model_name)
68
  model = AgeGenderModel.from_pretrained(model_name)
69
 
@@ -98,23 +101,26 @@ def process_func(x: np.ndarray, sampling_rate: int) -> dict:
98
 
99
 
100
  @spaces.GPU
101
- def recognize(file):
102
  if file is None:
103
  raise gr.Error(
104
  "No audio file submitted! "
105
  "Please upload or record an audio file "
106
  "before submitting your request."
107
  )
108
- signal, sampling_rate = audiofile.read(file)
109
  age_gender = process_func(signal, sampling_rate)
110
- return age_gender
 
 
 
111
 
112
 
113
  outputs = gr.Label()
114
  title = "audEERING age and gender recognition"
115
  description = (
116
  "Recognize age and gender of a microphone recording or audio file. "
117
- "Demo uses the checkpoint [{model_name}](https://huggingface.co/{model_name})."
118
  )
119
  allow_flagging = "never"
120
 
@@ -127,16 +133,35 @@ allow_flagging = "never"
127
  # allow_flagging=allow_flagging,
128
  # )
129
 
130
- file = gr.Interface(
131
- fn=recognize,
132
- inputs=gr.Audio(sources="upload", type="filepath", label="Audio file"),
133
- outputs=outputs,
134
- title=title,
135
- description=description,
136
- allow_flagging=allow_flagging,
137
- )
138
-
139
- # demo = gr.TabbedInterface([microphone, file], ["Microphone", "Audio file"])
140
- # demo.queue().launch()
141
- # demo.launch()
142
- file.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import audiofile
11
 
12
 
13
+ model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender"
14
+ duration = 1 # limit processing of audio
15
+
16
+
17
  class ModelHead(nn.Module):
18
  r"""Classification head."""
19
 
 
67
 
68
  # load model from hub
69
  device = 0 if torch.cuda.is_available() else "cpu"
 
70
  processor = Wav2Vec2Processor.from_pretrained(model_name)
71
  model = AgeGenderModel.from_pretrained(model_name)
72
 
 
101
 
102
 
103
  @spaces.GPU
104
+ def recognize(file, output_selector):
105
  if file is None:
106
  raise gr.Error(
107
  "No audio file submitted! "
108
  "Please upload or record an audio file "
109
  "before submitting your request."
110
  )
111
+ signal, sampling_rate = audiofile.read(file, duration=duration)
112
  age_gender = process_func(signal, sampling_rate)
113
+ if output_selector == "age":
114
+ return age_gender["age"]
115
+ else:
116
+ return {k: v for k, v in age_gender.items() if k != "age"}
117
 
118
 
119
  outputs = gr.Label()
120
  title = "audEERING age and gender recognition"
121
  description = (
122
  "Recognize age and gender of a microphone recording or audio file. "
123
+ f"Demo uses the checkpoint [{model_name}](https://huggingface.co/{model_name})."
124
  )
125
  allow_flagging = "never"
126
 
 
133
  # allow_flagging=allow_flagging,
134
  # )
135
 
136
+ # file = gr.Interface(
137
+ # fn=recognize,
138
+ # inputs=gr.Audio(sources="upload", type="filepath", label="Audio file"),
139
+ # outputs=outputs,
140
+ # title=title,
141
+ # description=description,
142
+ # allow_flagging=allow_flagging,
143
+ # )
144
+ #
145
+ # # demo = gr.TabbedInterface([microphone, file], ["Microphone", "Audio file"])
146
+ # # demo.queue().launch()
147
+ # # demo.launch()
148
+ # file.launch()
149
+
150
+ with gr.Blocks() as demo:
151
+ gr.Markdown(description)
152
+ with gr.Tab(label="Input"):
153
+ with gr.Row():
154
+ with gr.Column():
155
+ audio = gr.Audio(sources="upload", type="filepath", label="Audio file")
156
+ output_selector = gr.Dropdown(
157
+ choices=["age", "gender"],
158
+ label="Output",
159
+ value="age",
160
+ )
161
+ submit_btn = gr.Button(value="Submit")
162
+ with gr.Column():
163
+ output_text = gr.Textbox(label="Output Text")
164
+
165
+ submit_btn.click(recognize, [audio, output_selector], [output_text])
166
+
167
+ demo.launch(debug=True)