hagenw commited on
Commit
f62c750
1 Parent(s): ba45a7b

Add expression model

Browse files
Files changed (1) hide show
  1. app.py +92 -34
app.py CHANGED
@@ -12,8 +12,9 @@ import audresample
12
 
13
 
14
  device = 0 if torch.cuda.is_available() else "cpu"
15
- model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender"
16
  duration = 1 # limit processing of audio
 
 
17
 
18
 
19
  class AgeGenderHead(nn.Module):
@@ -66,10 +67,55 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
66
  return hidden_states, logits_age, logits_gender
67
 
68
 
 
 
69
 
70
- # load model from hub
71
- processor = Wav2Vec2Processor.from_pretrained(model_name)
72
- model = AgeGenderModel.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  def process_func(x: np.ndarray, sampling_rate: int) -> dict:
@@ -77,28 +123,38 @@ def process_func(x: np.ndarray, sampling_rate: int) -> dict:
77
  # run through processor to normalize signal
78
  # always returns a batch, so we just get the first entry
79
  # then we put it on the device
80
- y = processor(x, sampling_rate=sampling_rate)
81
- y = y['input_values'][0]
82
- y = y.reshape(1, -1)
83
- y = torch.from_numpy(y).to(device)
84
-
85
- # run through model
86
- with torch.no_grad():
87
- y = model(y)
88
- y = torch.hstack([y[1], y[2]])
89
-
90
- # convert to numpy
91
- y = y.detach().cpu().numpy()
92
-
93
- # convert to dict
94
- y = {
95
- "age": 100 * y[0][0],
96
- "female": y[0][1],
97
- "male": y[0][2],
98
- "child": y[0][3],
99
- }
100
-
101
- return y
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  @spaces.GPU
@@ -117,17 +173,17 @@ def recognize(input_file):
117
  target_rate = 16000
118
  signal = audresample.resample(signal, sampling_rate, target_rate)
119
 
120
- age_gender = process_func(signal, target_rate)
121
- age = f"{round(age_gender['age'])} years"
122
- gender = {k: v for k, v in age_gender.items() if k != "age"}
123
- return age, gender
124
 
125
 
126
  outputs = gr.Label()
127
  title = "audEERING age and gender recognition"
128
  description = (
129
- "Recognize age and gender of a microphone recording or audio file. "
130
- f"Demo uses the checkpoint [{model_name}](https://huggingface.co/{model_name})."
 
 
 
131
  )
132
  allow_flagging = "never"
133
 
@@ -159,8 +215,8 @@ with gr.Blocks() as demo:
159
  gr.Markdown(description)
160
  with gr.Tab(label="Speech analysis"):
161
  with gr.Row():
162
- gr.Markdown("Only the first second of the audio is processed.")
163
  with gr.Column():
 
164
  input = gr.Audio(
165
  sources=["upload", "microphone"],
166
  type="filepath",
@@ -170,8 +226,10 @@ with gr.Blocks() as demo:
170
  with gr.Column():
171
  output_age = gr.Textbox(label="Age")
172
  output_gender = gr.Label(label="Gender")
 
173
 
174
- submit_btn.click(recognize, input, [output_age, output_gender])
 
175
 
176
 
177
  demo.launch(debug=True)
 
12
 
13
 
14
  device = 0 if torch.cuda.is_available() else "cpu"
 
15
  duration = 1 # limit processing of audio
16
+ age_gender_model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender"
17
+ expression_model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
18
 
19
 
20
  class AgeGenderHead(nn.Module):
 
67
  return hidden_states, logits_age, logits_gender
68
 
69
 
70
+ class ExpressionHead(nn.Module):
71
+ r"""Expression model head."""
72
 
73
+ def __init__(self, config):
74
+
75
+ super().__init__()
76
+
77
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
78
+ self.dropout = nn.Dropout(config.final_dropout)
79
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
80
+
81
+ def forward(self, features, **kwargs):
82
+
83
+ x = features
84
+ x = self.dropout(x)
85
+ x = self.dense(x)
86
+ x = torch.tanh(x)
87
+ x = self.dropout(x)
88
+ x = self.out_proj(x)
89
+
90
+ return x
91
+
92
+
93
+ class ExpressionModel(Wav2Vec2PreTrainedModel):
94
+ r"""speech expression model."""
95
+
96
+ def __init__(self, config):
97
+
98
+ super().__init__(config)
99
+
100
+ self.config = config
101
+ self.wav2vec2 = Wav2Vec2Model(config)
102
+ self.classifier = ExpressionHead(config)
103
+ self.init_weights()
104
+
105
+ def forward(self, input_values):
106
+ outputs = self.wav2vec2(input_values)
107
+ hidden_states = outputs[0]
108
+ hidden_states = torch.mean(hidden_states, dim=1)
109
+ logits = self.classifier(hidden_states)
110
+
111
+ return hidden_states, logits
112
+
113
+
114
+ # Load models from hub
115
+ age_gender_processor = Wav2Vec2Processor.from_pretrained(age_gender_model_name)
116
+ age_gender_model = AgeGenderModel.from_pretrained(age_gender_model_name)
117
+ expression_processor = Wav2Vec2Processor.from_pretrained(expression_model_name)
118
+ expression_model = ExpressionModel.from_pretrained(expression_model_name)
119
 
120
 
121
  def process_func(x: np.ndarray, sampling_rate: int) -> dict:
 
123
  # run through processor to normalize signal
124
  # always returns a batch, so we just get the first entry
125
  # then we put it on the device
126
+ results = []
127
+ for processor, model in zip(
128
+ [age_gender_processor, expression_processor],
129
+ [age_gender_model, expression_model],
130
+ ):
131
+ y = processor(x, sampling_rate=sampling_rate)
132
+ y = y['input_values'][0]
133
+ y = y.reshape(1, -1)
134
+ y = torch.from_numpy(y).to(device)
135
+
136
+ # run through model
137
+ with torch.no_grad():
138
+ y = model(y)
139
+ y = torch.hstack([y[1], y[2]])
140
+
141
+ # convert to numpy
142
+ y = y.detach().cpu().numpy()
143
+ results.append(y[0])
144
+
145
+ return (
146
+ 100 * results[0][0], # age
147
+ {
148
+ "female": results[0][1],
149
+ "male": results[0][2],
150
+ "child": results[0][3],
151
+ },
152
+ {
153
+ "arousal": results[1][0],
154
+ "dominance": results[1][1],
155
+ "valence": results[1][2],
156
+ }
157
+ )
158
 
159
 
160
  @spaces.GPU
 
173
  target_rate = 16000
174
  signal = audresample.resample(signal, sampling_rate, target_rate)
175
 
176
+ return process_func(signal, target_rate)
 
 
 
177
 
178
 
179
  outputs = gr.Label()
180
  title = "audEERING age and gender recognition"
181
  description = (
182
+ "Speech analysis of an audio file or microphone recording. \n"
183
+ f"[{age_gender_model_name}](https://huggingface.co/{age_gender_model_name}) "
184
+ "is used for age and gender recognition, "
185
+ f"[{expression_model_name}](https://huggingface.co/{expression_model_name}) "
186
+ "is used for expression recognition."
187
  )
188
  allow_flagging = "never"
189
 
 
215
  gr.Markdown(description)
216
  with gr.Tab(label="Speech analysis"):
217
  with gr.Row():
 
218
  with gr.Column():
219
+ gr.Markdown("Only the first second of the audio is processed.")
220
  input = gr.Audio(
221
  sources=["upload", "microphone"],
222
  type="filepath",
 
226
  with gr.Column():
227
  output_age = gr.Textbox(label="Age")
228
  output_gender = gr.Label(label="Gender")
229
+ output_expression = gr.Label(label="Expression")
230
 
231
+ outputs = [output_age, output_gender, output_expression]
232
+ submit_btn.click(recognize, input, outputs)
233
 
234
 
235
  demo.launch(debug=True)