wetdog commited on
Commit
1b8b2e7
1 Parent(s): 5b29906

add new vocoder model and denoiser

Browse files
Files changed (2) hide show
  1. infer_onnx.py +37 -5
  2. mel_spec_22khz_v2.onnx +3 -0
infer_onnx.py CHANGED
@@ -31,7 +31,7 @@ def process_text(i: int, text: str, device: torch.device):
31
 
32
  MODEL_PATH_MATCHA_MEL="matcha_multispeaker_cat_opset_15_10_steps.onnx"
33
  MODEL_PATH_MATCHA="matcha_hifigan_multispeaker_cat.onnx"
34
- MODEL_PATH_VOCOS="mel_spec_22khz.onnx"
35
  CONFIG_PATH="config_22khz.yaml"
36
 
37
  sess_options = onnxruntime.SessionOptions()
@@ -40,7 +40,7 @@ model_vocos = onnxruntime.InferenceSession(str(MODEL_PATH_VOCOS), sess_options=s
40
  model_matcha = onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA), sess_options=sess_options, providers=["CPUExecutionProvider"])
41
 
42
 
43
- def vocos_inference(mel):
44
 
45
  with open(CONFIG_PATH, "r") as f:
46
  config = yaml.safe_load(f)
@@ -63,6 +63,37 @@ def vocos_inference(mel):
63
  spectrogram = mag * (x + 1j * y)
64
  window = torch.hann_window(win_length)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Inverse stft
67
  pad = (win_length - hop_length) // 2
68
  spectrogram = torch.tensor(spectrogram)
@@ -92,7 +123,7 @@ def vocos_inference(mel):
92
  return y
93
 
94
 
95
- def tts(text:str, spk_id:int, temperature:float, length_scale:float):
96
  sid = np.array([int(spk_id)]) if spk_id is not None else None
97
  text_matcha , text_lengths = process_text(0,text,"cpu")
98
 
@@ -111,7 +142,7 @@ def tts(text:str, spk_id:int, temperature:float, length_scale:float):
111
 
112
  vocos_t0 = perf_counter()
113
  # vocos inference
114
- wavs_vocos = vocos_inference(mel)
115
  vocos_infer_secs = perf_counter() - vocos_t0
116
  print("Vocos inference time", vocos_infer_secs)
117
 
@@ -193,7 +224,8 @@ vits2_inference = gr.Interface(
193
  step=0.01,
194
  label="Length scale",
195
  info=f"Controls speech pace, larger values for slower pace and smaller values for faster pace",
196
- )
 
197
  ],
198
  outputs=[gr.Audio(label="Matcha vocos", interactive=False, type="filepath"),
199
  gr.Audio(label="Matcha hifigan", interactive=False, type="filepath")]
 
31
 
32
  MODEL_PATH_MATCHA_MEL="matcha_multispeaker_cat_opset_15_10_steps.onnx"
33
  MODEL_PATH_MATCHA="matcha_hifigan_multispeaker_cat.onnx"
34
+ MODEL_PATH_VOCOS="mel_spec_22khz_v2.onnx"
35
  CONFIG_PATH="config_22khz.yaml"
36
 
37
  sess_options = onnxruntime.SessionOptions()
 
40
  model_matcha = onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA), sess_options=sess_options, providers=["CPUExecutionProvider"])
41
 
42
 
43
+ def vocos_inference(mel,denoise):
44
 
45
  with open(CONFIG_PATH, "r") as f:
46
  config = yaml.safe_load(f)
 
63
  spectrogram = mag * (x + 1j * y)
64
  window = torch.hann_window(win_length)
65
 
66
+ if denoise:
67
+ # Vocoder bias
68
+ mel_rand = torch.zeros_like(torch.tensor(mel))
69
+ mag_bias, x_bias, y_bias = model_vocos.run(
70
+ None,
71
+ {
72
+ "mels": mel_rand.float().numpy()
73
+ },
74
+ )
75
+
76
+ # complex spectrogram from vocos output
77
+ spectrogram_bias = mag_bias * (x_bias + 1j * y_bias)
78
+
79
+ # Denoising
80
+ spec = torch.view_as_real(torch.tensor(spectrogram))
81
+ # get magnitude of vocos spectrogram
82
+ mag_spec = torch.sqrt(spec.pow(2).sum(-1))
83
+
84
+ # get magnitude of bias spectrogram
85
+ spec_bias = torch.view_as_real(torch.tensor(spectrogram_bias))
86
+ mag_spec_bias = torch.sqrt(spec_bias.pow(2).sum(-1))
87
+
88
+ # substract
89
+ strength = 0.0005
90
+ mag_spec_denoised = mag_spec - mag_spec_bias * strength
91
+ mag_spec_denoised = torch.clamp(mag_spec_denoised, 0.0)
92
+
93
+ # return to complex spectrogram from magnitude
94
+ angle = torch.atan2(spec[..., -1], spec[..., 0] )
95
+ spectrogram = torch.complex(mag_spec_denoised * torch.cos(angle), mag_spec_denoised * torch.sin(angle))
96
+
97
  # Inverse stft
98
  pad = (win_length - hop_length) // 2
99
  spectrogram = torch.tensor(spectrogram)
 
123
  return y
124
 
125
 
126
+ def tts(text:str, spk_id:int, temperature:float, length_scale:float, denoise:bool):
127
  sid = np.array([int(spk_id)]) if spk_id is not None else None
128
  text_matcha , text_lengths = process_text(0,text,"cpu")
129
 
 
142
 
143
  vocos_t0 = perf_counter()
144
  # vocos inference
145
+ wavs_vocos = vocos_inference(mel,denoise)
146
  vocos_infer_secs = perf_counter() - vocos_t0
147
  print("Vocos inference time", vocos_infer_secs)
148
 
 
224
  step=0.01,
225
  label="Length scale",
226
  info=f"Controls speech pace, larger values for slower pace and smaller values for faster pace",
227
+ ),
228
+ gr.Checkbox(label="Denoise", info="Removes model bias from vocos"),
229
  ],
230
  outputs=[gr.Audio(label="Matcha vocos", interactive=False, type="filepath"),
231
  gr.Audio(label="Matcha hifigan", interactive=False, type="filepath")]
mel_spec_22khz_v2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b02c479881f89a8320024436e986f64b11e82b1fd48046d4b695c5fd9fb84e7
3
+ size 53883652