Mahiruoshi commited on
Commit
4de73fc
1 Parent(s): 22980ff

Upload 73 files

Browse files
.gitattributes CHANGED
@@ -30,3 +30,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
  image/梁芷柔.png filter=lfs diff=lfs merge=lfs -text
 
 
 
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
  image/梁芷柔.png filter=lfs diff=lfs merge=lfs -text
33
+ cleaners/JapaneseCleaner.dll filter=lfs diff=lfs merge=lfs -text
34
+ cleaners/sys.dic filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,124 +1,162 @@
1
- import re
2
- import gradio as gr
 
 
 
 
 
3
  import torch
4
- import unicodedata
5
  import commons
6
  import utils
7
- from models import SynthesizerTrn
8
- from text import text_to_sequence
9
-
10
- config_json = "checkpoints/paimeng/config.json"
11
- pth_path = "checkpoints/paimeng/model.pth"
12
-
13
-
14
- def get_text(text, hps, cleaned=False):
15
- if cleaned:
16
- text_norm = text_to_sequence(text, hps.symbols, [])
17
- else:
18
- text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
19
- if hps.data.add_blank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  text_norm = commons.intersperse(text_norm, 0)
21
  text_norm = torch.LongTensor(text_norm)
22
  return text_norm
23
 
24
-
25
- def get_label(text, label):
26
- if f'[{label}]' in text:
27
- return True, text.replace(f'[{label}]', '')
28
- else:
29
- return False, text
30
-
31
-
32
- def clean_text(text):
33
- print(text)
34
- jap = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]') # 匹配日文
35
- text = unicodedata.normalize('NFKC', text)
36
- text = f"[JA]{text}[JA]" if jap.search(text) else f"[ZH]{text}[ZH]"
37
- return text
38
-
39
-
40
- def load_model(config_json, pth_path):
41
- dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42
- hps_ms = utils.get_hparams_from_file(f"{config_json}")
43
- n_speakers = hps_ms.data.n_speakers if 'n_speakers' in hps_ms.data.keys() else 0
44
- n_symbols = len(hps_ms.symbols) if 'symbols' in hps_ms.keys() else 0
45
- net_g_ms = SynthesizerTrn(
46
- n_symbols,
47
- hps_ms.data.filter_length // 2 + 1,
48
- hps_ms.train.segment_size // hps_ms.data.hop_length,
49
- n_speakers=n_speakers,
50
- **hps_ms.model).to(dev)
51
- _ = net_g_ms.eval()
52
- _ = utils.load_checkpoint(pth_path, net_g_ms)
53
- return net_g_ms
54
-
55
- net_g_ms = load_model(config_json, pth_path)
56
-
57
- def selection(speaker):
58
- if speaker == "南小鸟":
59
- spk = 0
60
- return spk
61
-
62
- elif speaker == "园田海未":
63
- spk = 1
64
- return spk
65
-
66
- elif speaker == "小泉花阳":
67
- spk = 2
68
- return spk
69
-
70
- elif speaker == "星空凛":
71
- spk = 3
72
- return spk
73
-
74
- elif speaker == "东条希":
75
- spk = 4
76
- return spk
77
-
78
- elif speaker == "矢泽妮可":
79
- spk = 5
80
- return spk
81
-
82
- elif speaker == "绚濑绘里":
83
- spk = 6
84
- return spk
85
-
86
- elif speaker == "西木野真姬":
87
- spk = 7
88
- return spk
89
-
90
- elif speaker == "高坂穗乃果":
91
- spk = 8
92
- return spk
93
-
94
- def infer(text,speaker_id, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
95
- text = clean_text(text)
96
- speaker_id = int(selection(speaker_id))
97
- dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
98
- hps_ms = utils.get_hparams_from_file(f"{config_json}")
99
- with torch.no_grad():
100
- stn_tst = get_text(text, hps_ms, cleaned=False)
101
- x_tst = stn_tst.unsqueeze(0).to(dev)
102
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
103
- sid = torch.LongTensor([speaker_id]).to(dev)
104
- audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=n_scale, noise_scale_w=n_scale_w, length_scale=l_scale)[0][
105
- 0, 0].data.cpu().float().numpy()
106
- return (hps_ms.data.sampling_rate, audio)
107
-
108
- idols = ["南小鸟","园田海未","小泉花阳","星空凛","东条希","矢泽妮可","绚濑绘里","西木野真姬","高坂穗乃果"]
109
- app = gr.Blocks()
110
- with app:
111
- with gr.Tabs():
112
-
113
- with gr.TabItem("Basic"):
114
-
115
- tts_input1 = gr.TextArea(label="请输入纯中文或纯日文", value="大家好")
116
- para_input1 = gr.Slider(minimum= 0.01,maximum=1.0,label="更改噪声比例", value=0.667)
117
- para_input2 = gr.Slider(minimum= 0.01,maximum=1.0,label="更改噪声偏差", value=0.8)
118
- para_input3 = gr.Slider(minimum= 0.1,maximum=10,label="更改时间比例", value=1)
119
- tts_submit = gr.Button("Generate", variant="primary")
120
- speaker1 = gr.Dropdown(label="选择说话人",choices=idols, value="高坂穗乃果", interactive=True)
121
- tts_output2 = gr.Audio(label="Output")
122
-
123
- tts_submit.click(infer, [tts_input1,speaker1,para_input1,para_input2,para_input3], [tts_output2])
124
  app.launch()
 
1
+ import logging
2
+ logging.getLogger('numba').setLevel(logging.WARNING)
3
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
4
+ logging.getLogger('urllib3').setLevel(logging.WARNING)
5
+ from text import text_to_sequence
6
+ import numpy as np
7
+ from scipy.io import wavfile
8
  import torch
9
+ import json
10
  import commons
11
  import utils
12
+ import sys
13
+ import pathlib
14
+ import onnxruntime as ort
15
+ import gradio as gr
16
+ import argparse
17
+ import time
18
+ import os
19
+ from scipy.io.wavfile import write
20
+
21
+ def is_japanese(string):
22
+ for ch in string:
23
+ if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
24
+ return True
25
+ return False
26
+
27
+ def is_english(string):
28
+ import re
29
+ pattern = re.compile('^[A-Za-z0-9.,:;!?()_*"\' ]+$')
30
+ if pattern.fullmatch(string):
31
+ return True
32
+ else:
33
+ return False
34
+
35
+ def to_numpy(tensor: torch.Tensor):
36
+ return tensor.detach().cpu().numpy() if tensor.requires_grad \
37
+ else tensor.detach().numpy()
38
+
39
+ def get_symbols_from_json(path):
40
+ assert os.path.isfile(path)
41
+ with open(path, 'r') as f:
42
+ data = json.load(f)
43
+ return data['symbols']
44
+
45
+ def sle(language,text):
46
+ text = text.replace('\n','。').replace(' ',',')
47
+ if language == "中文":
48
+ tts_input1 = "[ZH]" + text + "[ZH]"
49
+ return tts_input1
50
+ elif language == "自动":
51
+ tts_input1 = f"[JA]{text}[JA]" if is_japanese(text) else f"[ZH]{text}[ZH]"
52
+ return tts_input1
53
+ elif language == "日文":
54
+ tts_input1 = "[JA]" + text + "[JA]"
55
+ return tts_input1
56
+ elif language == "英文":
57
+ tts_input1 = "[EN]" + text + "[EN]"
58
+ return tts_input1
59
+ elif language == "手动":
60
+ return text
61
+
62
+ def get_text(text,hps_ms):
63
+ text_norm = text_to_sequence(text,hps_ms.data.text_cleaners)
64
+ if hps_ms.data.add_blank:
65
  text_norm = commons.intersperse(text_norm, 0)
66
  text_norm = torch.LongTensor(text_norm)
67
  return text_norm
68
 
69
+ def create_tts_fn(ort_sess, speaker_id):
70
+ def tts_fn(text , language, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
71
+ text =sle(language,text)
72
+ seq = text_to_sequence(text, cleaner_names=hps.data.text_cleaners)
73
+ if hps.data.add_blank:
74
+ seq = commons.intersperse(seq, 0)
75
+ with torch.no_grad():
76
+ x = np.array([seq], dtype=np.int64)
77
+ x_len = np.array([x.shape[1]], dtype=np.int64)
78
+ sid = np.array([speaker_id], dtype=np.int64)
79
+ scales = np.array([n_scale, n_scale_w, l_scale], dtype=np.float32)
80
+ scales.resize(1, 3)
81
+ ort_inputs = {
82
+ 'input': x,
83
+ 'input_lengths': x_len,
84
+ 'scales': scales,
85
+ 'sid': sid
86
+ }
87
+ t1 = time.time()
88
+ audio = np.squeeze(ort_sess.run(None, ort_inputs))
89
+ audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
90
+ audio = np.clip(audio, -32767.0, 32767.0)
91
+ t2 = time.time()
92
+ spending_time = "推理时间:"+str(t2-t1)+"s"
93
+ print(spending_time)
94
+ return (hps.data.sampling_rate, audio)
95
+ return tts_fn
96
+
97
+
98
+ if __name__ == '__main__':
99
+ symbols = get_symbols_from_json('checkpoints/Nijigasaki/config.json')
100
+ hps = utils.get_hparams_from_file('checkpoints/Nijigasaki/config.json')
101
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
102
+ models = []
103
+ schools = ["ShojoKageki-Nijigasaki","ShojoKageki","Nijigasaki"]
104
+ lan = ["中文","日文","自动","手动"]
105
+ with open("checkpoints/info.json", "r", encoding="utf-8") as f:
106
+ models_info = json.load(f)
107
+ for i in models_info:
108
+ school = models_info[i]
109
+ speakers = school["speakers"]
110
+ checkpoint = school["checkpoint"]
111
+ phone_dict = {
112
+ symbol: i for i, symbol in enumerate(symbols)
113
+ }
114
+ ort_sess = ort.InferenceSession(checkpoint)
115
+ content = []
116
+ for j in speakers:
117
+ sid = int(speakers[j]['sid'])
118
+ title = school
119
+ example = speakers[j]['speech']
120
+ name = speakers[j]["name"]
121
+ content.append((sid, name, title, example, create_tts_fn(ort_sess, sid)))
122
+ models.append(content)
123
+
124
+ with gr.Blocks() as app:
125
+ gr.Markdown(
126
+ "# <center> vits-models\n"
127
+ )
128
+ with gr.Tabs():
129
+ for i in schools:
130
+ with gr.TabItem(i):
131
+ for (sid, name, title, example, tts_fn) in models[schools.index(i)]:
132
+ with gr.TabItem(name):
133
+ '''
134
+ with gr.Row():
135
+ gr.Markdown(
136
+ '<div align="center">'
137
+ f'<a><strong>{name}</strong></a>'
138
+ f'<img style="width:auto;height:300px;" src="file/{sid}.png">'
139
+ '</div>'
140
+ )
141
+ '''
142
+ with gr.Row():
143
+ with gr.Column():
144
+ with gr.Row():
145
+ with gr.Column():
146
+ gr.Markdown(
147
+ '<div align="center">'
148
+ f'<a><strong>{name}</strong></a>'
149
+ f'<img style="width:auto;height:400px;" src="file/image/{name}.png">'
150
+ '</div>'
151
+ )
152
+ input2 = gr.Dropdown(label="Language", choices=lan, value="自动", interactive=True)
153
+ with gr.Column():
154
+ input1 = gr.TextArea(label="Text", value=example)
155
+ input4 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声比例(noise scale),以控制情感", value=0.6)
156
+ input5 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声偏差(noise scale w),以控制音素长短", value=0.668)
157
+ input6 = gr.Slider(minimum=0.1, maximum=10, label="duration", value=1)
158
+ btnVC = gr.Button("Submit")
159
+ output1 = gr.Audio(label="采样率22050")
160
+
161
+ btnVC.click(tts_fn, inputs=[input1, input2, input4, input5, input6], outputs=[output1])
 
 
 
 
 
 
 
162
  app.launch()
attentions.py CHANGED
@@ -1,300 +1,392 @@
1
  import math
 
2
  import torch
3
  from torch import nn
4
  from torch.nn import functional as F
5
 
6
  import commons
7
  from modules import LayerNorm
8
-
9
 
10
  class Encoder(nn.Module):
11
- def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
12
- super().__init__()
13
- self.hidden_channels = hidden_channels
14
- self.filter_channels = filter_channels
15
- self.n_heads = n_heads
16
- self.n_layers = n_layers
17
- self.kernel_size = kernel_size
18
- self.p_dropout = p_dropout
19
- self.window_size = window_size
20
-
21
- self.drop = nn.Dropout(p_dropout)
22
- self.attn_layers = nn.ModuleList()
23
- self.norm_layers_1 = nn.ModuleList()
24
- self.ffn_layers = nn.ModuleList()
25
- self.norm_layers_2 = nn.ModuleList()
26
- for i in range(self.n_layers):
27
- self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
28
- self.norm_layers_1.append(LayerNorm(hidden_channels))
29
- self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
30
- self.norm_layers_2.append(LayerNorm(hidden_channels))
31
-
32
- def forward(self, x, x_mask):
33
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
34
- x = x * x_mask
35
- for i in range(self.n_layers):
36
- y = self.attn_layers[i](x, x, attn_mask)
37
- y = self.drop(y)
38
- x = self.norm_layers_1[i](x + y)
39
-
40
- y = self.ffn_layers[i](x, x_mask)
41
- y = self.drop(y)
42
- x = self.norm_layers_2[i](x + y)
43
- x = x * x_mask
44
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  class Decoder(nn.Module):
48
- def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
49
- super().__init__()
50
- self.hidden_channels = hidden_channels
51
- self.filter_channels = filter_channels
52
- self.n_heads = n_heads
53
- self.n_layers = n_layers
54
- self.kernel_size = kernel_size
55
- self.p_dropout = p_dropout
56
- self.proximal_bias = proximal_bias
57
- self.proximal_init = proximal_init
58
-
59
- self.drop = nn.Dropout(p_dropout)
60
- self.self_attn_layers = nn.ModuleList()
61
- self.norm_layers_0 = nn.ModuleList()
62
- self.encdec_attn_layers = nn.ModuleList()
63
- self.norm_layers_1 = nn.ModuleList()
64
- self.ffn_layers = nn.ModuleList()
65
- self.norm_layers_2 = nn.ModuleList()
66
- for i in range(self.n_layers):
67
- self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
68
- self.norm_layers_0.append(LayerNorm(hidden_channels))
69
- self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
70
- self.norm_layers_1.append(LayerNorm(hidden_channels))
71
- self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
72
- self.norm_layers_2.append(LayerNorm(hidden_channels))
73
-
74
- def forward(self, x, x_mask, h, h_mask):
75
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  x: decoder input
77
  h: encoder output
78
  """
79
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
80
- encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
81
- x = x * x_mask
82
- for i in range(self.n_layers):
83
- y = self.self_attn_layers[i](x, x, self_attn_mask)
84
- y = self.drop(y)
85
- x = self.norm_layers_0[i](x + y)
86
-
87
- y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
88
- y = self.drop(y)
89
- x = self.norm_layers_1[i](x + y)
90
-
91
- y = self.ffn_layers[i](x, x_mask)
92
- y = self.drop(y)
93
- x = self.norm_layers_2[i](x + y)
94
- x = x * x_mask
95
- return x
 
96
 
97
 
98
  class MultiHeadAttention(nn.Module):
99
- def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
100
- super().__init__()
101
- assert channels % n_heads == 0
102
-
103
- self.channels = channels
104
- self.out_channels = out_channels
105
- self.n_heads = n_heads
106
- self.p_dropout = p_dropout
107
- self.window_size = window_size
108
- self.heads_share = heads_share
109
- self.block_length = block_length
110
- self.proximal_bias = proximal_bias
111
- self.proximal_init = proximal_init
112
- self.attn = None
113
-
114
- self.k_channels = channels // n_heads
115
- self.conv_q = nn.Conv1d(channels, channels, 1)
116
- self.conv_k = nn.Conv1d(channels, channels, 1)
117
- self.conv_v = nn.Conv1d(channels, channels, 1)
118
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
119
- self.drop = nn.Dropout(p_dropout)
120
-
121
- if window_size is not None:
122
- n_heads_rel = 1 if heads_share else n_heads
123
- rel_stddev = self.k_channels**-0.5
124
- self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
125
- self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
126
-
127
- nn.init.xavier_uniform_(self.conv_q.weight)
128
- nn.init.xavier_uniform_(self.conv_k.weight)
129
- nn.init.xavier_uniform_(self.conv_v.weight)
130
- if proximal_init:
131
- with torch.no_grad():
132
- self.conv_k.weight.copy_(self.conv_q.weight)
133
- self.conv_k.bias.copy_(self.conv_q.bias)
134
-
135
- def forward(self, x, c, attn_mask=None):
136
- q = self.conv_q(x)
137
- k = self.conv_k(c)
138
- v = self.conv_v(c)
139
-
140
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
141
-
142
- x = self.conv_o(x)
143
- return x
144
-
145
- def attention(self, query, key, value, mask=None):
146
- # reshape [b, d, t] -> [b, n_h, t, d_k]
147
- b, d, t_s, t_t = (*key.size(), query.size(2))
148
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
149
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
150
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
151
-
152
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
153
- if self.window_size is not None:
154
- assert t_s == t_t, "Relative attention is only available for self-attention."
155
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
156
- rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
157
- scores_local = self._relative_position_to_absolute_position(rel_logits)
158
- scores = scores + scores_local
159
- if self.proximal_bias:
160
- assert t_s == t_t, "Proximal bias is only available for self-attention."
161
- scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
162
- if mask is not None:
163
- scores = scores.masked_fill(mask == 0, -1e4)
164
- if self.block_length is not None:
165
- assert t_s == t_t, "Local attention is only available for self-attention."
166
- block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
167
- scores = scores.masked_fill(block_mask == 0, -1e4)
168
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
169
- p_attn = self.drop(p_attn)
170
- output = torch.matmul(p_attn, value)
171
- if self.window_size is not None:
172
- relative_weights = self._absolute_position_to_relative_position(p_attn)
173
- value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
174
- output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
175
- output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
176
- return output, p_attn
177
-
178
- def _matmul_with_relative_values(self, x, y):
179
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  x: [b, h, l, m]
181
  y: [h or 1, m, d]
182
  ret: [b, h, l, d]
183
  """
184
- ret = torch.matmul(x, y.unsqueeze(0))
185
- return ret
186
 
187
- def _matmul_with_relative_keys(self, x, y):
188
- """
189
  x: [b, h, l, d]
190
  y: [h or 1, m, d]
191
  ret: [b, h, l, m]
192
  """
193
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
194
- return ret
195
-
196
- def _get_relative_embeddings(self, relative_embeddings, length):
197
- max_relative_position = 2 * self.window_size + 1
198
- # Pad first before slice to avoid using cond ops.
199
- pad_length = max(length - (self.window_size + 1), 0)
200
- slice_start_position = max((self.window_size + 1) - length, 0)
201
- slice_end_position = slice_start_position + 2 * length - 1
202
- if pad_length > 0:
203
- padded_relative_embeddings = F.pad(
204
- relative_embeddings,
205
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
206
- else:
207
- padded_relative_embeddings = relative_embeddings
208
- used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
209
- return used_relative_embeddings
210
-
211
- def _relative_position_to_absolute_position(self, x):
212
- """
 
 
 
213
  x: [b, h, l, 2*l-1]
214
  ret: [b, h, l, l]
215
  """
216
- batch, heads, length, _ = x.size()
217
- # Concat columns of pad to shift from relative to absolute indexing.
218
- x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
219
-
220
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
221
- x_flat = x.view([batch, heads, length * 2 * length])
222
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
223
-
224
- # Reshape and slice out the padded elements.
225
- x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
226
- return x_final
227
-
228
- def _absolute_position_to_relative_position(self, x):
229
- """
 
 
 
 
230
  x: [b, h, l, l]
231
  ret: [b, h, l, 2*l-1]
232
  """
233
- batch, heads, length, _ = x.size()
234
- # padd along column
235
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
236
- x_flat = x.view([batch, heads, length**2 + length*(length -1)])
237
- # add 0's in the beginning that will skew the elements after reshape
238
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
239
- x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
240
- return x_final
241
-
242
- def _attention_bias_proximal(self, length):
243
- """Bias for self-attention to encourage attention to close positions.
 
 
 
 
244
  Args:
245
  length: an integer scalar.
246
  Returns:
247
  a Tensor with shape [1, 1, length, length]
248
  """
249
- r = torch.arange(length, dtype=torch.float32)
250
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
251
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
 
252
 
253
 
254
  class FFN(nn.Module):
255
- def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
256
- super().__init__()
257
- self.in_channels = in_channels
258
- self.out_channels = out_channels
259
- self.filter_channels = filter_channels
260
- self.kernel_size = kernel_size
261
- self.p_dropout = p_dropout
262
- self.activation = activation
263
- self.causal = causal
264
-
265
- if causal:
266
- self.padding = self._causal_padding
267
- else:
268
- self.padding = self._same_padding
269
-
270
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
271
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
272
- self.drop = nn.Dropout(p_dropout)
273
-
274
- def forward(self, x, x_mask):
275
- x = self.conv_1(self.padding(x * x_mask))
276
- if self.activation == "gelu":
277
- x = x * torch.sigmoid(1.702 * x)
278
- else:
279
- x = torch.relu(x)
280
- x = self.drop(x)
281
- x = self.conv_2(self.padding(x * x_mask))
282
- return x * x_mask
283
-
284
- def _causal_padding(self, x):
285
- if self.kernel_size == 1:
286
- return x
287
- pad_l = self.kernel_size - 1
288
- pad_r = 0
289
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
290
- x = F.pad(x, commons.convert_pad_shape(padding))
291
- return x
292
-
293
- def _same_padding(self, x):
294
- if self.kernel_size == 1:
295
- return x
296
- pad_l = (self.kernel_size - 1) // 2
297
- pad_r = self.kernel_size // 2
298
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
299
- x = F.pad(x, commons.convert_pad_shape(padding))
300
- return x
 
 
 
 
 
 
 
 
1
  import math
2
+
3
  import torch
4
  from torch import nn
5
  from torch.nn import functional as F
6
 
7
  import commons
8
  from modules import LayerNorm
9
+
10
 
11
  class Encoder(nn.Module):
12
+ def __init__(self,
13
+ hidden_channels,
14
+ filter_channels,
15
+ n_heads,
16
+ n_layers,
17
+ kernel_size=1,
18
+ p_dropout=0.,
19
+ window_size=4,
20
+ **kwargs):
21
+ super().__init__()
22
+ self.hidden_channels = hidden_channels
23
+ self.filter_channels = filter_channels
24
+ self.n_heads = n_heads
25
+ self.n_layers = n_layers
26
+ self.kernel_size = kernel_size
27
+ self.p_dropout = p_dropout
28
+ self.window_size = window_size
29
+
30
+ self.drop = nn.Dropout(p_dropout)
31
+ self.attn_layers = nn.ModuleList()
32
+ self.norm_layers_1 = nn.ModuleList()
33
+ self.ffn_layers = nn.ModuleList()
34
+ self.norm_layers_2 = nn.ModuleList()
35
+ for i in range(self.n_layers):
36
+ self.attn_layers.append(
37
+ MultiHeadAttention(hidden_channels,
38
+ hidden_channels,
39
+ n_heads,
40
+ p_dropout=p_dropout,
41
+ window_size=window_size))
42
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
43
+ self.ffn_layers.append(
44
+ FFN(hidden_channels,
45
+ hidden_channels,
46
+ filter_channels,
47
+ kernel_size,
48
+ p_dropout=p_dropout))
49
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
50
+
51
+ def forward(self, x, x_mask):
52
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
53
+ x = x * x_mask
54
+ for i in range(self.n_layers):
55
+ y = self.attn_layers[i](x, x, attn_mask)
56
+ y = self.drop(y)
57
+ x = self.norm_layers_1[i](x + y)
58
+
59
+ y = self.ffn_layers[i](x, x_mask)
60
+ y = self.drop(y)
61
+ x = self.norm_layers_2[i](x + y)
62
+ x = x * x_mask
63
+ return x
64
 
65
 
66
  class Decoder(nn.Module):
67
+ def __init__(self,
68
+ hidden_channels,
69
+ filter_channels,
70
+ n_heads,
71
+ n_layers,
72
+ kernel_size=1,
73
+ p_dropout=0.,
74
+ proximal_bias=False,
75
+ proximal_init=True,
76
+ **kwargs):
77
+ super().__init__()
78
+ self.hidden_channels = hidden_channels
79
+ self.filter_channels = filter_channels
80
+ self.n_heads = n_heads
81
+ self.n_layers = n_layers
82
+ self.kernel_size = kernel_size
83
+ self.p_dropout = p_dropout
84
+ self.proximal_bias = proximal_bias
85
+ self.proximal_init = proximal_init
86
+
87
+ self.drop = nn.Dropout(p_dropout)
88
+ self.self_attn_layers = nn.ModuleList()
89
+ self.norm_layers_0 = nn.ModuleList()
90
+ self.encdec_attn_layers = nn.ModuleList()
91
+ self.norm_layers_1 = nn.ModuleList()
92
+ self.ffn_layers = nn.ModuleList()
93
+ self.norm_layers_2 = nn.ModuleList()
94
+ for i in range(self.n_layers):
95
+ self.self_attn_layers.append(
96
+ MultiHeadAttention(hidden_channels,
97
+ hidden_channels,
98
+ n_heads,
99
+ p_dropout=p_dropout,
100
+ proximal_bias=proximal_bias,
101
+ proximal_init=proximal_init))
102
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
103
+ self.encdec_attn_layers.append(
104
+ MultiHeadAttention(hidden_channels,
105
+ hidden_channels,
106
+ n_heads,
107
+ p_dropout=p_dropout))
108
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
109
+ self.ffn_layers.append(
110
+ FFN(hidden_channels,
111
+ hidden_channels,
112
+ filter_channels,
113
+ kernel_size,
114
+ p_dropout=p_dropout,
115
+ causal=True))
116
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
117
+
118
+ def forward(self, x, x_mask, h, h_mask):
119
+ """
120
  x: decoder input
121
  h: encoder output
122
  """
123
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
124
+ device=x.device, dtype=x.dtype)
125
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
126
+ x = x * x_mask
127
+ for i in range(self.n_layers):
128
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
129
+ y = self.drop(y)
130
+ x = self.norm_layers_0[i](x + y)
131
+
132
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
133
+ y = self.drop(y)
134
+ x = self.norm_layers_1[i](x + y)
135
+
136
+ y = self.ffn_layers[i](x, x_mask)
137
+ y = self.drop(y)
138
+ x = self.norm_layers_2[i](x + y)
139
+ x = x * x_mask
140
+ return x
141
 
142
 
143
  class MultiHeadAttention(nn.Module):
144
+ def __init__(self,
145
+ channels,
146
+ out_channels,
147
+ n_heads,
148
+ p_dropout=0.,
149
+ window_size=None,
150
+ heads_share=True,
151
+ block_length=None,
152
+ proximal_bias=False,
153
+ proximal_init=False):
154
+ super().__init__()
155
+ assert channels % n_heads == 0
156
+
157
+ self.channels = channels
158
+ self.out_channels = out_channels
159
+ self.n_heads = n_heads
160
+ self.p_dropout = p_dropout
161
+ self.window_size = window_size
162
+ self.heads_share = heads_share
163
+ self.block_length = block_length
164
+ self.proximal_bias = proximal_bias
165
+ self.proximal_init = proximal_init
166
+ self.attn = None
167
+
168
+ self.k_channels = channels // n_heads
169
+ self.conv_q = nn.Conv1d(channels, channels, 1)
170
+ self.conv_k = nn.Conv1d(channels, channels, 1)
171
+ self.conv_v = nn.Conv1d(channels, channels, 1)
172
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
173
+ self.drop = nn.Dropout(p_dropout)
174
+
175
+ if window_size is not None:
176
+ n_heads_rel = 1 if heads_share else n_heads
177
+ rel_stddev = self.k_channels**-0.5
178
+ self.emb_rel_k = nn.Parameter(
179
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
180
+ * rel_stddev)
181
+ self.emb_rel_v = nn.Parameter(
182
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
183
+ * rel_stddev)
184
+
185
+ nn.init.xavier_uniform_(self.conv_q.weight)
186
+ nn.init.xavier_uniform_(self.conv_k.weight)
187
+ nn.init.xavier_uniform_(self.conv_v.weight)
188
+ if proximal_init:
189
+ with torch.no_grad():
190
+ self.conv_k.weight.copy_(self.conv_q.weight)
191
+ self.conv_k.bias.copy_(self.conv_q.bias)
192
+
193
+ def forward(self, x, c, attn_mask=None):
194
+ q = self.conv_q(x)
195
+ k = self.conv_k(c)
196
+ v = self.conv_v(c)
197
+
198
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
199
+
200
+ x = self.conv_o(x)
201
+ return x
202
+
203
+ def attention(self, query, key, value, mask=None):
204
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
205
+ b, d, t_s, t_t = (*key.size(), query.size(2))
206
+ query = query.view(b, self.n_heads, self.k_channels,
207
+ t_t).transpose(2, 3)
208
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
209
+ value = value.view(b, self.n_heads, self.k_channels,
210
+ t_s).transpose(2, 3)
211
+
212
+ scores = torch.matmul(query / math.sqrt(self.k_channels),
213
+ key.transpose(-2, -1))
214
+ if self.window_size is not None:
215
+ msg = "Relative attention is only available for self-attention."
216
+ assert t_s == t_t, msg
217
+ key_relative_embeddings = self._get_relative_embeddings(
218
+ self.emb_rel_k, t_s)
219
+ rel_logits = self._matmul_with_relative_keys(
220
+ query / math.sqrt(self.k_channels), key_relative_embeddings)
221
+ scores_local = self._relative_position_to_absolute_position(
222
+ rel_logits)
223
+ scores = scores + scores_local
224
+ if self.proximal_bias:
225
+ msg = "Proximal bias is only available for self-attention."
226
+ assert t_s == t_t, msg
227
+ scores = scores + self._attention_bias_proximal(t_s).to(
228
+ device=scores.device, dtype=scores.dtype)
229
+ if mask is not None:
230
+ scores = scores.masked_fill(mask == 0, -1e4)
231
+ if self.block_length is not None:
232
+ msg = "Local attention is only available for self-attention."
233
+ assert t_s == t_t, msg
234
+ block_mask = torch.ones_like(scores).triu(
235
+ -self.block_length).tril(self.block_length)
236
+ scores = scores.masked_fill(block_mask == 0, -1e4)
237
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
238
+ p_attn = self.drop(p_attn)
239
+ output = torch.matmul(p_attn, value)
240
+ if self.window_size is not None:
241
+ relative_weights = self._absolute_position_to_relative_position(
242
+ p_attn)
243
+ value_relative_embeddings = self._get_relative_embeddings(
244
+ self.emb_rel_v, t_s)
245
+ output = output + self._matmul_with_relative_values(
246
+ relative_weights, value_relative_embeddings)
247
+ output = output.transpose(2, 3).contiguous().view(
248
+ b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
249
+ return output, p_attn
250
+
251
+ def _matmul_with_relative_values(self, x, y):
252
+ """
253
  x: [b, h, l, m]
254
  y: [h or 1, m, d]
255
  ret: [b, h, l, d]
256
  """
257
+ ret = torch.matmul(x, y.unsqueeze(0))
258
+ return ret
259
 
260
+ def _matmul_with_relative_keys(self, x, y):
261
+ """
262
  x: [b, h, l, d]
263
  y: [h or 1, m, d]
264
  ret: [b, h, l, m]
265
  """
266
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
267
+ return ret
268
+
269
+ def _get_relative_embeddings(self, relative_embeddings, length):
270
+ max_relative_position = 2 * self.window_size + 1
271
+ # Pad first before slice to avoid using cond ops.
272
+ pad_length = max(length - (self.window_size + 1), 0)
273
+ slice_start_position = max((self.window_size + 1) - length, 0)
274
+ slice_end_position = slice_start_position + 2 * length - 1
275
+ if pad_length > 0:
276
+ padded_relative_embeddings = F.pad(
277
+ relative_embeddings,
278
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length],
279
+ [0, 0]]))
280
+ else:
281
+ padded_relative_embeddings = relative_embeddings
282
+ used_relative_embeddings = padded_relative_embeddings[:,
283
+ slice_start_position:
284
+ slice_end_position]
285
+ return used_relative_embeddings
286
+
287
+ def _relative_position_to_absolute_position(self, x):
288
+ """
289
  x: [b, h, l, 2*l-1]
290
  ret: [b, h, l, l]
291
  """
292
+ batch, heads, length, _ = x.size()
293
+ # Concat columns of pad to shift from relative to absolute indexing.
294
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0,
295
+ 1]]))
296
+
297
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
298
+ x_flat = x.view([batch, heads, length * 2 * length])
299
+ x_flat = F.pad(
300
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0,
301
+ length - 1]]))
302
+
303
+ # Reshape and slice out the padded elements.
304
+ x_final = x_flat.view([batch, heads, length + 1,
305
+ 2 * length - 1])[:, :, :length, length - 1:]
306
+ return x_final
307
+
308
+ def _absolute_position_to_relative_position(self, x):
309
+ """
310
  x: [b, h, l, l]
311
  ret: [b, h, l, 2*l-1]
312
  """
313
+ batch, heads, length, _ = x.size()
314
+ # padd along column
315
+ x = F.pad(
316
+ x,
317
+ commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0,
318
+ length - 1]]))
319
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
320
+ # add 0's in the beginning that will skew the elements after reshape
321
+ x_flat = F.pad(
322
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
323
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
324
+ return x_final
325
+
326
+ def _attention_bias_proximal(self, length):
327
+ """Bias for self-attention to encourage attention to close positions.
328
  Args:
329
  length: an integer scalar.
330
  Returns:
331
  a Tensor with shape [1, 1, length, length]
332
  """
333
+ r = torch.arange(length, dtype=torch.float32)
334
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
335
+ return torch.unsqueeze(
336
+ torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
337
 
338
 
339
  class FFN(nn.Module):
340
+ def __init__(self,
341
+ in_channels,
342
+ out_channels,
343
+ filter_channels,
344
+ kernel_size,
345
+ p_dropout=0.,
346
+ activation=None,
347
+ causal=False):
348
+ super().__init__()
349
+ self.in_channels = in_channels
350
+ self.out_channels = out_channels
351
+ self.filter_channels = filter_channels
352
+ self.kernel_size = kernel_size
353
+ self.p_dropout = p_dropout
354
+ self.activation = activation
355
+ self.causal = causal
356
+
357
+ if causal:
358
+ self.padding = self._causal_padding
359
+ else:
360
+ self.padding = self._same_padding
361
+
362
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
363
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
364
+ self.drop = nn.Dropout(p_dropout)
365
+
366
+ def forward(self, x, x_mask):
367
+ x = self.conv_1(self.padding(x * x_mask))
368
+ if self.activation == "gelu":
369
+ x = x * torch.sigmoid(1.702 * x)
370
+ else:
371
+ x = torch.relu(x)
372
+ x = self.drop(x)
373
+ x = self.conv_2(self.padding(x * x_mask))
374
+ return x * x_mask
375
+
376
+ def _causal_padding(self, x):
377
+ if self.kernel_size == 1:
378
+ return x
379
+ pad_l = self.kernel_size - 1
380
+ pad_r = 0
381
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
382
+ x = F.pad(x, commons.convert_pad_shape(padding))
383
+ return x
384
+
385
+ def _same_padding(self, x):
386
+ if self.kernel_size == 1:
387
+ return x
388
+ pad_l = (self.kernel_size - 1) // 2
389
+ pad_r = self.kernel_size // 2
390
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
391
+ x = F.pad(x, commons.convert_pad_shape(padding))
392
+ return x
checkpoints/Default/config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 2e-4,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 32,
11
+ "fp16_run": true,
12
+ "lr_decay": 0.999875,
13
+ "segment_size": 8192,
14
+ "init_lr_ratio": 1,
15
+ "warmup_epochs": 0,
16
+ "c_mel": 45,
17
+ "c_kl": 1.0
18
+ },
19
+ "data": {
20
+ "training_files":"/www/training/dataset/train_with_paimeng2.txt",
21
+ "validation_files":"/www/training/dataset/val_filelist.txt",
22
+ "text_cleaners":["cjke_cleaners"],
23
+ "max_wav_value": 32768.0,
24
+ "sampling_rate": 22050,
25
+ "filter_length": 1024,
26
+ "hop_length": 256,
27
+ "win_length": 1024,
28
+ "n_mel_channels": 80,
29
+ "mel_fmin": 0.0,
30
+ "mel_fmax": null,
31
+ "add_blank": true,
32
+ "n_speakers": 50,
33
+ "cleaned_text": true
34
+ },
35
+ "model": {
36
+ "inter_channels": 192,
37
+ "hidden_channels": 192,
38
+ "filter_channels": 768,
39
+ "n_heads": 2,
40
+ "n_layers": 6,
41
+ "kernel_size": 3,
42
+ "p_dropout": 0.1,
43
+ "resblock": "1",
44
+ "resblock_kernel_sizes": [3,7,11],
45
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46
+ "upsample_rates": [8,8,2,2],
47
+ "upsample_initial_channel": 512,
48
+ "upsample_kernel_sizes": [16,16,4,4],
49
+ "n_layers_q": 3,
50
+ "use_spectral_norm": false,
51
+ "gin_channels": 256
52
+ },
53
+ "symbols": ["_", ",", ".", "!", "?", "-", "~", "\u2026", "A", "E", "I", "N", "O", "Q", "U", "a", "b", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "r", "s", "t", "u", "v", "w", "y", "z", "\u0283", "\u02a7", "\u02a6", "\u026f", "\u0279", "\u0259", "\u0265", "\u207c", "\u02b0", "`", "\u2192", "\u2193", "\u2191", " "]
54
+ }
checkpoints/Default/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6008ba1612a7e6fbefbdd633d07d6e8db07bebf6bcf1a4bb803e1dff636c5fcb
3
+ size 120734883
checkpoints/NIjigasaki/config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 2e-4,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 32,
11
+ "fp16_run": true,
12
+ "lr_decay": 0.999875,
13
+ "segment_size": 8192,
14
+ "init_lr_ratio": 1,
15
+ "warmup_epochs": 0,
16
+ "c_mel": 45,
17
+ "c_kl": 1.0
18
+ },
19
+ "data": {
20
+ "training_files":"/www/training/dataset/train_with_paimeng2.txt",
21
+ "validation_files":"/www/training/dataset/val_filelist.txt",
22
+ "text_cleaners":["cjke_cleaners"],
23
+ "max_wav_value": 32768.0,
24
+ "sampling_rate": 22050,
25
+ "filter_length": 1024,
26
+ "hop_length": 256,
27
+ "win_length": 1024,
28
+ "n_mel_channels": 80,
29
+ "mel_fmin": 0.0,
30
+ "mel_fmax": null,
31
+ "add_blank": true,
32
+ "n_speakers": 50,
33
+ "cleaned_text": true
34
+ },
35
+ "model": {
36
+ "inter_channels": 192,
37
+ "hidden_channels": 192,
38
+ "filter_channels": 768,
39
+ "n_heads": 2,
40
+ "n_layers": 6,
41
+ "kernel_size": 3,
42
+ "p_dropout": 0.1,
43
+ "resblock": "1",
44
+ "resblock_kernel_sizes": [3,7,11],
45
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46
+ "upsample_rates": [8,8,2,2],
47
+ "upsample_initial_channel": 512,
48
+ "upsample_kernel_sizes": [16,16,4,4],
49
+ "n_layers_q": 3,
50
+ "use_spectral_norm": false,
51
+ "gin_channels": 256
52
+ },
53
+ "symbols": ["_", ",", ".", "!", "?", "-", "~", "\u2026", "A", "E", "I", "N", "O", "Q", "U", "a", "b", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "r", "s", "t", "u", "v", "w", "y", "z", "\u0283", "\u02a7", "\u02a6", "\u026f", "\u0279", "\u0259", "\u0265", "\u207c", "\u02b0", "`", "\u2192", "\u2193", "\u2191", " "]
54
+ }
checkpoints/NIjigasaki/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcdfabd68e081f0b9b0b2ac7600fd6d6124102607718680fcd8611cee9d5a2da
3
+ size 120734883
checkpoints/ShojoKageki/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6008ba1612a7e6fbefbdd633d07d6e8db07bebf6bcf1a4bb803e1dff636c5fcb
3
+ size 120734883
checkpoints/Starlight/config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 2e-4,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 32,
11
+ "fp16_run": true,
12
+ "lr_decay": 0.999875,
13
+ "segment_size": 8192,
14
+ "init_lr_ratio": 1,
15
+ "warmup_epochs": 0,
16
+ "c_mel": 45,
17
+ "c_kl": 1.0
18
+ },
19
+ "data": {
20
+ "training_files":"/www/training/dataset/train_with_paimeng2.txt",
21
+ "validation_files":"/www/training/dataset/val_filelist.txt",
22
+ "text_cleaners":["cjke_cleaners"],
23
+ "max_wav_value": 32768.0,
24
+ "sampling_rate": 22050,
25
+ "filter_length": 1024,
26
+ "hop_length": 256,
27
+ "win_length": 1024,
28
+ "n_mel_channels": 80,
29
+ "mel_fmin": 0.0,
30
+ "mel_fmax": null,
31
+ "add_blank": true,
32
+ "n_speakers": 50,
33
+ "cleaned_text": true
34
+ },
35
+ "model": {
36
+ "inter_channels": 192,
37
+ "hidden_channels": 192,
38
+ "filter_channels": 768,
39
+ "n_heads": 2,
40
+ "n_layers": 6,
41
+ "kernel_size": 3,
42
+ "p_dropout": 0.1,
43
+ "resblock": "1",
44
+ "resblock_kernel_sizes": [3,7,11],
45
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46
+ "upsample_rates": [8,8,2,2],
47
+ "upsample_initial_channel": 512,
48
+ "upsample_kernel_sizes": [16,16,4,4],
49
+ "n_layers_q": 3,
50
+ "use_spectral_norm": false,
51
+ "gin_channels": 256
52
+ },
53
+ "symbols": ["_", ",", ".", "!", "?", "-", "~", "\u2026", "A", "E", "I", "N", "O", "Q", "U", "a", "b", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "r", "s", "t", "u", "v", "w", "y", "z", "\u0283", "\u02a7", "\u02a6", "\u026f", "\u0279", "\u0259", "\u0265", "\u207c", "\u02b0", "`", "\u2192", "\u2193", "\u2191", " "]
54
+ }
checkpoints/Starlight/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3657d9607364481af521ce67ca1d6a3d3e710b28dca7d2c2bbe19f44b3b67e4a
3
+ size 120734883
checkpoints/info.json CHANGED
@@ -167,10 +167,10 @@
167
  "name": "高咲侑"
168
  }
169
  },
170
- "checkpoint": "checkpoints/tmp/model.pth"
171
 
172
  },
173
- "Seisho-betterchinese":{
174
  "speakers":{
175
  "華恋":{
176
  "sid": 21,
@@ -288,7 +288,7 @@
288
  "name": "墨小菊"
289
  }
290
  },
291
- "checkpoint": "checkpoints/ShojoKageki/model.pth"
292
  },
293
  "Nijigasaki":{
294
  "speakers":{
@@ -353,72 +353,6 @@
353
  "name": "高咲侑"
354
  }
355
  },
356
- "checkpoint": "checkpoints/paimeng/model.pth"
357
- },
358
- "Nijigasaki-biaobei":{
359
- "speakers":{
360
- "歩夢":{
361
- "sid": 1,
362
- "speech": "みなさん、はじめまして。上原歩夢です。",
363
- "name": "歩夢"
364
- },
365
- "かすみ":{
366
- "sid": 2,
367
- "speech": "みんなのアイドルかすみんだよー。",
368
- "name": "かすみ"
369
- },
370
- "しずく":{
371
- "sid": 3,
372
- "speech": "みなさん、こんにちは。しずくです。",
373
- "name": "しずく"
374
- },
375
- "果林":{
376
- "sid": 4,
377
- "speech": "ハーイ。 朝香果林よ。よろしくね",
378
- "name": "果林"
379
- },
380
- "愛":{
381
- "sid": 5,
382
- "speech": "ちっすー。アタシは愛。",
383
- "name": "愛"
384
- },
385
- "せつ菜":{
386
- "sid": 7,
387
- "speech": "絶えぬ命は,常世に在らず。終わらぬ芝居も,夢幻のごとく。儚く燃えゆく,さだめであれば。舞台に刻まん,刹那の瞬き。",
388
- "name": "せつ菜"
389
- },
390
- "エマ":{
391
- "sid": 8,
392
- "speech": "こんにちは、エマです。自然溢れるスイスからやってきましたっ。",
393
- "name": "エマ"
394
- },
395
- "璃奈":{
396
- "sid": 9,
397
- "speech": "私、天王寺璃奈。とってもきゅーとな女の子。ホントだよ?",
398
- "name": "璃奈"
399
- },
400
- "栞子":{
401
- "sid": 10,
402
- "speech": "みなさん、初めまして。三船栞子と申します。",
403
- "name": "栞子"
404
- },
405
- "ランジュ":{
406
- "sid": 11,
407
- "speech": "你好啊,我是钟岚珠。",
408
- "name": "ランジュ"
409
- },
410
- "ミア":{
411
- "sid": 12,
412
- "speech": "ボクはミア・テイラー。",
413
- "name": "ミア"
414
- },
415
- "高咲侑":{
416
- "sid": 0,
417
- "speech": "只选一个做不到啊",
418
- "name": "高咲侑"
419
- }
420
- },
421
- "checkpoint": "checkpoints/biaobei/model.pth"
422
  }
423
-
424
  }
 
167
  "name": "高咲侑"
168
  }
169
  },
170
+ "checkpoint": "checkpoints/Default/model.onnx"
171
 
172
  },
173
+ "ShojoKageki":{
174
  "speakers":{
175
  "華恋":{
176
  "sid": 21,
 
288
  "name": "墨小菊"
289
  }
290
  },
291
+ "checkpoint": "checkpoints/ShojoKageki/model.onnx"
292
  },
293
  "Nijigasaki":{
294
  "speakers":{
 
353
  "name": "高咲侑"
354
  }
355
  },
356
+ "checkpoint": "checkpoints/NIjigasaki/model.onnx"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  }
 
358
  }
cleaners/JapaneseCleaner.dll ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a659eb68d12d4a88ef7dfde6086b9974cd4d43634f7e4bfe710d5537cdd61a75
3
+ size 3097600
cleaners/char.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:888ee94c5a8a7a26d24ab3f1b7155441351954fd51ea06b4a2f78bd742492b2f
3
+ size 262496
cleaners/matrix.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62fd16b4f64c851d5dc352ef0d5740c5fc83ddc7c203b2b0b1fc5271969a14ce
3
+ size 3792262
cleaners/sys.dic ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca57d9029691a70a5dfb99afc2844180256161d7130da65b1a867510e129b9a6
3
+ size 103073776
cleaners/unk.dic ADDED
Binary file (5.69 kB). View file
 
commons.py CHANGED
@@ -1,97 +1,161 @@
1
  import math
 
2
  import torch
3
  from torch.nn import functional as F
4
- import torch.jit
5
 
6
 
7
- def script_method(fn, _rcb=None):
8
- return fn
 
 
9
 
10
 
11
- def script(obj, optimize=True, _frames_up=0, _rcb=None):
12
- return obj
13
 
14
 
15
- torch.jit.script_method = script_method
16
- torch.jit.script = script
 
17
 
18
 
19
- def init_weights(m, mean=0.0, std=0.01):
20
- classname = m.__class__.__name__
21
- if classname.find("Conv") != -1:
22
- m.weight.data.normal_(mean, std)
23
 
24
 
25
- def get_padding(kernel_size, dilation=1):
26
- return int((kernel_size*dilation - dilation)/2)
 
 
 
 
27
 
28
 
29
- def intersperse(lst, item):
30
- result = [item] * (len(lst) * 2 + 1)
31
- result[1::2] = lst
32
- return result
 
 
 
 
 
33
 
34
 
35
  def slice_segments(x, ids_str, segment_size=4):
36
- ret = torch.zeros_like(x[:, :, :segment_size])
37
- for i in range(x.size(0)):
38
- idx_str = ids_str[i]
39
- idx_end = idx_str + segment_size
40
- ret[i] = x[i, :, idx_str:idx_end]
41
- return ret
42
 
43
 
44
  def rand_slice_segments(x, x_lengths=None, segment_size=4):
45
- b, d, t = x.size()
46
- if x_lengths is None:
47
- x_lengths = t
48
- ids_str_max = x_lengths - segment_size + 1
49
- ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
50
- ret = slice_segments(x, ids_str, segment_size)
51
- return ret, ids_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  def subsequent_mask(length):
55
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
56
- return mask
57
 
58
 
59
  @torch.jit.script
60
  def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
61
- n_channels_int = n_channels[0]
62
- in_act = input_a + input_b
63
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
64
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
65
- acts = t_act * s_act
66
- return acts
67
 
68
 
69
- def convert_pad_shape(pad_shape):
70
- l = pad_shape[::-1]
71
- pad_shape = [item for sublist in l for item in sublist]
72
- return pad_shape
73
 
74
 
75
  def sequence_mask(length, max_length=None):
76
- if max_length is None:
77
- max_length = length.max()
78
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
79
- return x.unsqueeze(0) < length.unsqueeze(1)
80
 
81
 
82
  def generate_path(duration, mask):
83
- """
84
  duration: [b, 1, t_x]
85
  mask: [b, 1, t_y, t_x]
86
  """
87
- device = duration.device
88
-
89
- b, _, t_y, t_x = mask.shape
90
- cum_duration = torch.cumsum(duration, -1)
91
-
92
- cum_duration_flat = cum_duration.view(b * t_x)
93
- path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
94
- path = path.view(b, t_x, t_y)
95
- path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
96
- path = path.unsqueeze(1).transpose(2,3) * mask
97
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
+
3
  import torch
4
  from torch.nn import functional as F
 
5
 
6
 
7
+ def init_weights(m, mean=0.0, std=0.01):
8
+ classname = m.__class__.__name__
9
+ if classname.find("Conv") != -1:
10
+ m.weight.data.normal_(mean, std)
11
 
12
 
13
+ def get_padding(kernel_size, dilation=1):
14
+ return int((kernel_size * dilation - dilation) / 2)
15
 
16
 
17
+ def convert_pad_shape(pad_shape):
18
+ pad_shape = [item for sublist in reversed(pad_shape) for item in sublist]
19
+ return pad_shape
20
 
21
 
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
 
27
 
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += 0.5 * (torch.exp(2. * logs_p) +
32
+ ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
33
+ return kl
34
 
35
 
36
+ def rand_gumbel(shape):
37
+ """Sample from the Gumbel distribution, protect from overflows."""
38
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
39
+ return -torch.log(-torch.log(uniform_samples))
40
+
41
+
42
+ def rand_gumbel_like(x):
43
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
44
+ return g
45
 
46
 
47
  def slice_segments(x, ids_str, segment_size=4):
48
+ ret = torch.zeros_like(x[:, :, :segment_size])
49
+ for i in range(x.size(0)):
50
+ idx_str = ids_str[i]
51
+ idx_end = idx_str + segment_size
52
+ ret[i] = x[i, :, idx_str:idx_end]
53
+ return ret
54
 
55
 
56
  def rand_slice_segments(x, x_lengths=None, segment_size=4):
57
+ b, d, t = x.size()
58
+ if x_lengths is None:
59
+ x_lengths = t
60
+ ids_str_max = x_lengths - segment_size + 1
61
+ ids_str = (torch.rand([b]).to(device=x.device) *
62
+ ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length,
68
+ channels,
69
+ min_timescale=1.0,
70
+ max_timescale=1.0e4):
71
+ position = torch.arange(length, dtype=torch.float)
72
+ num_timescales = channels // 2
73
+ log_timescale_increment = (
74
+ math.log(float(max_timescale) / float(min_timescale)) /
75
+ (num_timescales - 1))
76
+ inv_timescales = min_timescale * torch.exp(
77
+ torch.arange(num_timescales, dtype=torch.float) *
78
+ -log_timescale_increment)
79
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
80
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
81
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
82
+ signal = signal.view(1, channels, length)
83
+ return signal
84
+
85
+
86
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
87
+ b, channels, length = x.size()
88
+ signal = get_timing_signal_1d(length, channels, min_timescale,
89
+ max_timescale)
90
+ return x + signal.to(dtype=x.dtype, device=x.device)
91
+
92
+
93
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
94
+ b, channels, length = x.size()
95
+ signal = get_timing_signal_1d(length, channels, min_timescale,
96
+ max_timescale)
97
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
98
 
99
 
100
  def subsequent_mask(length):
101
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
102
+ return mask
103
 
104
 
105
  @torch.jit.script
106
  def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
107
+ n_channels_int = n_channels[0]
108
+ in_act = input_a + input_b
109
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
110
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
111
+ acts = t_act * s_act
112
+ return acts
113
 
114
 
115
+ def shift_1d(x):
116
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
117
+ return x
 
118
 
119
 
120
  def sequence_mask(length, max_length=None):
121
+ if max_length is None:
122
+ max_length = length.max()
123
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
124
+ return x.unsqueeze(0) < length.unsqueeze(1)
125
 
126
 
127
  def generate_path(duration, mask):
128
+ """
129
  duration: [b, 1, t_x]
130
  mask: [b, 1, t_y, t_x]
131
  """
132
+ device = duration.device
133
+
134
+ b, _, t_y, t_x = mask.shape
135
+ cum_duration = torch.cumsum(duration, -1)
136
+
137
+ cum_duration_flat = cum_duration.view(b * t_x)
138
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
139
+ path = path.view(b, t_x, t_y)
140
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]
141
+ ]))[:, :-1]
142
+ path = path.unsqueeze(1).transpose(2, 3) * mask
143
+ return path
144
+
145
+
146
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
147
+ if isinstance(parameters, torch.Tensor):
148
+ parameters = [parameters]
149
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
150
+ norm_type = float(norm_type)
151
+ if clip_value is not None:
152
+ clip_value = float(clip_value)
153
+
154
+ total_norm = 0
155
+ for p in parameters:
156
+ param_norm = p.grad.data.norm(norm_type)
157
+ total_norm += param_norm.item()**norm_type
158
+ if clip_value is not None:
159
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
160
+ total_norm = total_norm**(1. / norm_type)
161
+ return total_norm
data_utils.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import torch
5
+ import torchaudio
6
+ import torch.utils.data
7
+
8
+ import commons
9
+ from mel_processing import spectrogram_torch
10
+ from utils import load_filepaths_and_text
11
+
12
+
13
+ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
14
+ """
15
+ 1) loads audio, speaker_id, text pairs
16
+ 2) normalizes text and converts them to sequences of integers
17
+ 3) computes spectrograms from audio files.
18
+ """
19
+ def __init__(self, audiopaths_sid_text, hparams):
20
+ self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
21
+ # self.text_cleaners = hparams.text_cleaners
22
+ self.max_wav_value = hparams.max_wav_value
23
+ self.sampling_rate = hparams.sampling_rate
24
+ self.filter_length = hparams.filter_length
25
+ self.hop_length = hparams.hop_length
26
+ self.win_length = hparams.win_length
27
+ self.sampling_rate = hparams.sampling_rate
28
+ self.src_sampling_rate = getattr(hparams, "src_sampling_rate",
29
+ self.sampling_rate)
30
+
31
+ self.cleaned_text = getattr(hparams, "cleaned_text", False)
32
+
33
+ self.add_blank = hparams.add_blank
34
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
35
+ self.max_text_len = getattr(hparams, "max_text_len", 190)
36
+
37
+ phone_file = getattr(hparams, "phone_table", None)
38
+ self.phone_dict = None
39
+ if phone_file is not None:
40
+ self.phone_dict = {}
41
+ with open(phone_file) as fin:
42
+ for line in fin:
43
+ arr = line.strip().split()
44
+ self.phone_dict[arr[0]] = int(arr[1])
45
+
46
+ speaker_file = getattr(hparams, "speaker_table", None)
47
+ self.speaker_dict = None
48
+ if speaker_file is not None:
49
+ self.speaker_dict = {}
50
+ with open(speaker_file) as fin:
51
+ for line in fin:
52
+ arr = line.strip().split()
53
+ self.speaker_dict[arr[0]] = int(arr[1])
54
+
55
+ random.seed(1234)
56
+ random.shuffle(self.audiopaths_sid_text)
57
+ self._filter()
58
+
59
+ def _filter(self):
60
+ """
61
+ Filter text & store spec lengths
62
+ """
63
+ # Store spectrogram lengths for Bucketing
64
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
65
+ # spec_length = wav_length // hop_length
66
+
67
+ audiopaths_sid_text_new = []
68
+ lengths = []
69
+ for item in self.audiopaths_sid_text:
70
+ audiopath = item[0]
71
+ # filename|text or filename|speaker|text
72
+ text = item[1] if len(item) == 2 else item[2]
73
+ if self.min_text_len <= len(text) and len(
74
+ text) <= self.max_text_len:
75
+ audiopaths_sid_text_new.append(item)
76
+ lengths.append(
77
+ int(
78
+ os.path.getsize(audiopath) * self.sampling_rate /
79
+ self.src_sampling_rate) // (2 * self.hop_length))
80
+ self.audiopaths_sid_text = audiopaths_sid_text_new
81
+ self.lengths = lengths
82
+
83
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
84
+ audiopath = audiopath_sid_text[0]
85
+ if len(audiopath_sid_text) == 2: # filename|text
86
+ sid = 0
87
+ text = audiopath_sid_text[1]
88
+ else: # filename|speaker|text
89
+ sid = self.speaker_dict[audiopath_sid_text[1]]
90
+ text = audiopath_sid_text[2]
91
+ text = self.get_text(text)
92
+ spec, wav = self.get_audio(audiopath)
93
+ sid = self.get_sid(sid)
94
+ return (text, spec, wav, sid)
95
+
96
+ def get_audio(self, filename):
97
+ audio, sampling_rate = torchaudio.load(filename, normalize=False)
98
+ if sampling_rate != self.sampling_rate:
99
+ audio = audio.to(torch.float)
100
+ audio = torchaudio.transforms.Resample(sampling_rate,
101
+ self.sampling_rate)(audio)
102
+ audio = audio.to(torch.int16)
103
+ audio = audio[0] # Get the first channel
104
+ audio_norm = audio / self.max_wav_value
105
+ audio_norm = audio_norm.unsqueeze(0)
106
+ spec = spectrogram_torch(audio_norm,
107
+ self.filter_length,
108
+ self.sampling_rate,
109
+ self.hop_length,
110
+ self.win_length,
111
+ center=False)
112
+ spec = torch.squeeze(spec, 0)
113
+ return spec, audio_norm
114
+
115
+ def get_text(self, text):
116
+ text_norm = [self.phone_dict[phone] for phone in text.split()]
117
+ if self.add_blank:
118
+ text_norm = commons.intersperse(text_norm, 0)
119
+ text_norm = torch.LongTensor(text_norm)
120
+ return text_norm
121
+
122
+ def get_sid(self, sid):
123
+ sid = torch.LongTensor([int(sid)])
124
+ return sid
125
+
126
+ def __getitem__(self, index):
127
+ return self.get_audio_text_speaker_pair(
128
+ self.audiopaths_sid_text[index])
129
+
130
+ def __len__(self):
131
+ return len(self.audiopaths_sid_text)
132
+
133
+
134
+ class TextAudioSpeakerCollate():
135
+ """ Zero-pads model inputs and targets
136
+ """
137
+ def __init__(self, return_ids=False):
138
+ self.return_ids = return_ids
139
+
140
+ def __call__(self, batch):
141
+ """Collate's training batch from normalized text, audio and speaker identities
142
+ PARAMS
143
+ ------
144
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
145
+ """
146
+ # Right zero-pad all one-hot text sequences to max input length
147
+ _, ids_sorted_decreasing = torch.sort(torch.LongTensor(
148
+ [x[1].size(1) for x in batch]),
149
+ dim=0,
150
+ descending=True)
151
+
152
+ max_text_len = max([len(x[0]) for x in batch])
153
+ max_spec_len = max([x[1].size(1) for x in batch])
154
+ max_wav_len = max([x[2].size(1) for x in batch])
155
+
156
+ text_lengths = torch.LongTensor(len(batch))
157
+ spec_lengths = torch.LongTensor(len(batch))
158
+ wav_lengths = torch.LongTensor(len(batch))
159
+ sid = torch.LongTensor(len(batch))
160
+
161
+ text_padded = torch.LongTensor(len(batch), max_text_len)
162
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0),
163
+ max_spec_len)
164
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
165
+ text_padded.zero_()
166
+ spec_padded.zero_()
167
+ wav_padded.zero_()
168
+ for i in range(len(ids_sorted_decreasing)):
169
+ row = batch[ids_sorted_decreasing[i]]
170
+
171
+ text = row[0]
172
+ text_padded[i, :text.size(0)] = text
173
+ text_lengths[i] = text.size(0)
174
+
175
+ spec = row[1]
176
+ spec_padded[i, :, :spec.size(1)] = spec
177
+ spec_lengths[i] = spec.size(1)
178
+
179
+ wav = row[2]
180
+ wav_padded[i, :, :wav.size(1)] = wav
181
+ wav_lengths[i] = wav.size(1)
182
+
183
+ sid[i] = row[3]
184
+
185
+ if self.return_ids:
186
+ return (text_padded, text_lengths, spec_padded, spec_lengths,
187
+ wav_padded, wav_lengths, sid, ids_sorted_decreasing)
188
+ return (text_padded, text_lengths, spec_padded, spec_lengths,
189
+ wav_padded, wav_lengths, sid)
190
+
191
+
192
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler
193
+ ):
194
+ """
195
+ Maintain similar input lengths in a batch.
196
+ Length groups are specified by boundaries.
197
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either
198
+ {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
199
+
200
+ It removes samples which are not included in the boundaries.
201
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1
202
+ or length(x) > b3 are discarded.
203
+ """
204
+ def __init__(self,
205
+ dataset,
206
+ batch_size,
207
+ boundaries,
208
+ num_replicas=None,
209
+ rank=None,
210
+ shuffle=True):
211
+ super().__init__(dataset,
212
+ num_replicas=num_replicas,
213
+ rank=rank,
214
+ shuffle=shuffle)
215
+ self.lengths = dataset.lengths
216
+ self.batch_size = batch_size
217
+ self.boundaries = boundaries
218
+
219
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
220
+ self.total_size = sum(self.num_samples_per_bucket)
221
+ self.num_samples = self.total_size // self.num_replicas
222
+
223
+ def _create_buckets(self):
224
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
225
+ for i in range(len(self.lengths)):
226
+ length = self.lengths[i]
227
+ idx_bucket = self._bisect(length)
228
+ if idx_bucket != -1:
229
+ buckets[idx_bucket].append(i)
230
+
231
+ for i in range(len(buckets) - 1, 0, -1):
232
+ if len(buckets[i]) == 0:
233
+ buckets.pop(i)
234
+ self.boundaries.pop(i + 1)
235
+
236
+ num_samples_per_bucket = []
237
+ for i in range(len(buckets)):
238
+ len_bucket = len(buckets[i])
239
+ total_batch_size = self.num_replicas * self.batch_size
240
+ rem = (total_batch_size -
241
+ (len_bucket % total_batch_size)) % total_batch_size
242
+ num_samples_per_bucket.append(len_bucket + rem)
243
+ return buckets, num_samples_per_bucket
244
+
245
+ def __iter__(self):
246
+ # deterministically shuffle based on epoch
247
+ g = torch.Generator()
248
+ g.manual_seed(self.epoch)
249
+
250
+ indices = []
251
+ if self.shuffle:
252
+ for bucket in self.buckets:
253
+ indices.append(
254
+ torch.randperm(len(bucket), generator=g).tolist())
255
+ else:
256
+ for bucket in self.buckets:
257
+ indices.append(list(range(len(bucket))))
258
+
259
+ batches = []
260
+ for i in range(len(self.buckets)):
261
+ bucket = self.buckets[i]
262
+ len_bucket = len(bucket)
263
+ ids_bucket = indices[i]
264
+ num_samples_bucket = self.num_samples_per_bucket[i]
265
+
266
+ # add extra samples to make it evenly divisible
267
+ rem = num_samples_bucket - len_bucket
268
+ ids_bucket = ids_bucket + ids_bucket * (
269
+ rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
270
+
271
+ # subsample
272
+ ids_bucket = ids_bucket[self.rank::self.num_replicas]
273
+
274
+ # batching
275
+ for j in range(len(ids_bucket) // self.batch_size):
276
+ batch = [
277
+ bucket[idx]
278
+ for idx in ids_bucket[j * self.batch_size:(j + 1) *
279
+ self.batch_size]
280
+ ]
281
+ batches.append(batch)
282
+
283
+ if self.shuffle:
284
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
285
+ batches = [batches[i] for i in batch_ids]
286
+ self.batches = batches
287
+
288
+ assert len(self.batches) * self.batch_size == self.num_samples
289
+ return iter(self.batches)
290
+
291
+ def _bisect(self, x, lo=0, hi=None):
292
+ if hi is None:
293
+ hi = len(self.boundaries) - 1
294
+
295
+ if hi > lo:
296
+ mid = (hi + lo) // 2
297
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
298
+ return mid
299
+ elif x <= self.boundaries[mid]:
300
+ return self._bisect(x, lo, mid)
301
+ else:
302
+ return self._bisect(x, mid + 1, hi)
303
+ else:
304
+ return -1
305
+
306
+ def __len__(self):
307
+ return self.num_samples // self.batch_size
export_onnx.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Yongqiang Li (yongqiangli@alumni.hust.edu.cn)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ import sys
19
+
20
+ import torch
21
+
22
+ from models import SynthesizerTrn
23
+ import utils
24
+
25
+ try:
26
+ import onnxruntime as ort
27
+ except ImportError:
28
+ print('Please install onnxruntime!')
29
+ sys.exit(1)
30
+
31
+
32
+ def to_numpy(tensor):
33
+ return tensor.detach().cpu().numpy() if tensor.requires_grad \
34
+ else tensor.detach().numpy()
35
+
36
+
37
+ def get_args():
38
+ parser = argparse.ArgumentParser(description='export onnx model')
39
+ parser.add_argument('--checkpoint', required=True, help='checkpoint')
40
+ parser.add_argument('--cfg', required=True, help='config file')
41
+ parser.add_argument('--onnx_model', required=True, help='onnx model name')
42
+ # parser.add_argument('--phone_table',
43
+ # required=True,
44
+ # help='input phone dict')
45
+ # parser.add_argument('--speaker_table', default=None, help='speaker table')
46
+ # parser.add_argument("--speaker_num", required=True,
47
+ # type=int, help="speaker num")
48
+ parser.add_argument(
49
+ '--providers',
50
+ required=False,
51
+ default='CPUExecutionProvider',
52
+ choices=['CUDAExecutionProvider', 'CPUExecutionProvider'],
53
+ help='the model to send request to')
54
+ args = parser.parse_args()
55
+ return args
56
+
57
+
58
+ def get_data_from_cfg(cfg_path: str):
59
+ assert os.path.isfile(cfg_path)
60
+ with open(cfg_path, 'r') as f:
61
+ data = json.load(f)
62
+ symbols = data["symbols"]
63
+ speaker_num = data["data"]["n_speakers"]
64
+ return len(symbols), speaker_num
65
+
66
+
67
+ def main():
68
+ args = get_args()
69
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
70
+
71
+ hps = utils.get_hparams_from_file(args.cfg)
72
+ # with open(args.phone_table) as p_f:
73
+ # phone_num = len(p_f.readlines()) + 1
74
+ # num_speakers = 1
75
+ # if args.speaker_table is not None:
76
+ # num_speakers = len(open(args.speaker_table).readlines()) + 1
77
+ phone_num, num_speakers = get_data_from_cfg(args.cfg)
78
+ net_g = SynthesizerTrn(phone_num,
79
+ hps.data.filter_length // 2 + 1,
80
+ hps.train.segment_size // hps.data.hop_length,
81
+ n_speakers=num_speakers,
82
+ **hps.model)
83
+ utils.load_checkpoint(args.checkpoint, net_g, None)
84
+ net_g.forward = net_g.export_forward
85
+ net_g.eval()
86
+
87
+ seq = torch.randint(low=0, high=phone_num, size=(1, 10), dtype=torch.long)
88
+ seq_len = torch.IntTensor([seq.size(1)]).long()
89
+
90
+ # noise(可用于控制感情等变化程度) lenth(可用于控制整体语速) noisew(控制音素发音长度变化程度)
91
+ # 参考 https://github.com/gbxh/genshinTTS
92
+ scales = torch.FloatTensor([0.667, 1.0, 0.8])
93
+ # make triton dynamic shape happy
94
+ scales = scales.unsqueeze(0)
95
+ sid = torch.IntTensor([0]).long()
96
+
97
+ dummy_input = (seq, seq_len, scales, sid)
98
+ torch.onnx.export(model=net_g,
99
+ args=dummy_input,
100
+ f=args.onnx_model,
101
+ input_names=['input', 'input_lengths', 'scales', 'sid'],
102
+ output_names=['output'],
103
+ dynamic_axes={
104
+ 'input': {
105
+ 0: 'batch',
106
+ 1: 'phonemes'
107
+ },
108
+ 'input_lengths': {
109
+ 0: 'batch'
110
+ },
111
+ 'scales': {
112
+ 0: 'batch'
113
+ },
114
+ 'sid': {
115
+ 0: 'batch'
116
+ },
117
+ 'output': {
118
+ 0: 'batch',
119
+ 1: 'audio',
120
+ 2: 'audio_length'
121
+ }
122
+ },
123
+ opset_version=13,
124
+ verbose=False)
125
+
126
+ # Verify onnx precision
127
+ torch_output = net_g(seq, seq_len, scales, sid)
128
+ providers = [args.providers]
129
+ ort_sess = ort.InferenceSession(args.onnx_model, providers=providers)
130
+ ort_inputs = {
131
+ 'input': to_numpy(seq),
132
+ 'input_lengths': to_numpy(seq_len),
133
+ 'scales': to_numpy(scales),
134
+ 'sid': to_numpy(sid),
135
+ }
136
+ onnx_output = ort_sess.run(None, ort_inputs)
137
+
138
+
139
+ if __name__ == '__main__':
140
+ main()
inference.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Yongqiang Li (yongqiangli@alumni.hust.edu.cn)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+
17
+ import numpy as np
18
+ from scipy.io import wavfile
19
+ import torch
20
+
21
+ import commons
22
+ from models import SynthesizerTrn
23
+ import utils
24
+
25
+
26
+ def get_args():
27
+ parser = argparse.ArgumentParser(description='inference')
28
+ parser.add_argument('--checkpoint', required=True, help='checkpoint')
29
+ parser.add_argument('--cfg', required=True, help='config file')
30
+ parser.add_argument('--outdir', required=True, help='ouput directory')
31
+ parser.add_argument('--phone_table',
32
+ required=True,
33
+ help='input phone dict')
34
+ parser.add_argument('--speaker_table', default=None, help='speaker table')
35
+ parser.add_argument('--test_file', required=True, help='test file')
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def main():
41
+ args = get_args()
42
+ print(args)
43
+ phone_dict = {}
44
+ with open(args.phone_table) as p_f:
45
+ for line in p_f:
46
+ phone_id = line.strip().split()
47
+ phone_dict[phone_id[0]] = int(phone_id[1])
48
+ speaker_dict = {}
49
+ if args.speaker_table is not None:
50
+ with open(args.speaker_table) as p_f:
51
+ for line in p_f:
52
+ arr = line.strip().split()
53
+ assert len(arr) == 2
54
+ speaker_dict[arr[0]] = int(arr[1])
55
+ hps = utils.get_hparams_from_file(args.cfg)
56
+
57
+ net_g = SynthesizerTrn(
58
+ len(phone_dict) + 1,
59
+ hps.data.filter_length // 2 + 1,
60
+ hps.train.segment_size // hps.data.hop_length,
61
+ n_speakers=len(speaker_dict) + 1, # 0 is kept for unknown speaker
62
+ **hps.model).cuda()
63
+ net_g.eval()
64
+ utils.load_checkpoint(args.checkpoint, net_g, None)
65
+
66
+ with open(args.test_file) as fin:
67
+ for line in fin:
68
+ arr = line.strip().split("|")
69
+ audio_path = arr[0]
70
+ if len(arr) == 2:
71
+ sid = 0
72
+ text = arr[1]
73
+ else:
74
+ sid = speaker_dict[arr[1]]
75
+ text = arr[2]
76
+ seq = [phone_dict[symbol] for symbol in text.split()]
77
+ if hps.data.add_blank:
78
+ seq = commons.intersperse(seq, 0)
79
+ seq = torch.LongTensor(seq)
80
+ with torch.no_grad():
81
+ x = seq.cuda().unsqueeze(0)
82
+ x_length = torch.LongTensor([seq.size(0)]).cuda()
83
+ sid = torch.LongTensor([sid]).cuda()
84
+ audio = net_g.infer(
85
+ x,
86
+ x_length,
87
+ sid=sid,
88
+ noise_scale=.667,
89
+ noise_scale_w=0.8,
90
+ length_scale=1)[0][0, 0].data.cpu().float().numpy()
91
+ audio *= 32767 / max(0.01, np.max(np.abs(audio))) * 0.6
92
+ audio = np.clip(audio, -32767.0, 32767.0)
93
+ wavfile.write(args.outdir + "/" + audio_path.split("/")[-1],
94
+ hps.data.sampling_rate, audio.astype(np.int16))
95
+
96
+
97
+ if __name__ == '__main__':
98
+ main()
inference_onnx.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Yongqiang Li (yongqiangli@alumni.hust.edu.cn)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ from text import text_to_sequence
17
+ import numpy as np
18
+ from scipy.io import wavfile
19
+ import torch
20
+ import json
21
+ import commons
22
+ import utils
23
+ import sys
24
+ import pathlib
25
+
26
+ try:
27
+ import onnxruntime as ort
28
+ except ImportError:
29
+ print('Please install onnxruntime!')
30
+ sys.exit(1)
31
+
32
+
33
+ def to_numpy(tensor: torch.Tensor):
34
+ return tensor.detach().cpu().numpy() if tensor.requires_grad \
35
+ else tensor.detach().numpy()
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser(description='inference')
40
+ parser.add_argument('--onnx_model', required=True, help='onnx model')
41
+ parser.add_argument('--cfg', required=True, help='config file')
42
+ parser.add_argument('--outdir', default="onnx_output",
43
+ help='ouput directory')
44
+ # parser.add_argument('--phone_table',
45
+ # required=True,
46
+ # help='input phone dict')
47
+ # parser.add_argument('--speaker_table', default=None, help='speaker table')
48
+ parser.add_argument('--test_file', required=True, help='test file')
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def get_symbols_from_json(path):
54
+ import os
55
+ assert os.path.isfile(path)
56
+ with open(path, 'r') as f:
57
+ data = json.load(f)
58
+ return data['symbols']
59
+
60
+
61
+ def main():
62
+ args = get_args()
63
+ print(args)
64
+ if not pathlib.Path(args.outdir).exists():
65
+ pathlib.Path(args.outdir).mkdir(exist_ok=True, parents=True)
66
+ # phones =
67
+ symbols = get_symbols_from_json(args.cfg)
68
+ phone_dict = {
69
+ symbol: i for i, symbol in enumerate(symbols)
70
+ }
71
+
72
+ # speaker_dict = {}
73
+ # if args.speaker_table is not None:
74
+ # with open(args.speaker_table) as p_f:
75
+ # for line in p_f:
76
+ # arr = line.strip().split()
77
+ # assert len(arr) == 2
78
+ # speaker_dict[arr[0]] = int(arr[1])
79
+ hps = utils.get_hparams_from_file(args.cfg)
80
+
81
+ ort_sess = ort.InferenceSession(args.onnx_model)
82
+
83
+ with open(args.test_file) as fin:
84
+ for line in fin:
85
+ arr = line.strip().split("|")
86
+ audio_path = arr[0]
87
+
88
+ # TODO: 控制说话人编号
89
+ sid = 3
90
+ text = '[ZH]你好,重庆市位于四川省东边[ZH]'
91
+ # else:
92
+ # sid = speaker_dict[arr[1]]
93
+ # text = arr[2]
94
+ seq = text_to_sequence(text, cleaner_names=hps.data.text_cleaners
95
+ )
96
+ if hps.data.add_blank:
97
+ seq = commons.intersperse(seq, 0)
98
+
99
+ # if hps.data.add_blank:
100
+ # seq = commons.intersperse(seq, 0)
101
+ with torch.no_grad():
102
+ # x = torch.LongTensor([seq])
103
+ # x_len = torch.IntTensor([x.size(1)]).long()
104
+ # sid = torch.LongTensor([sid]).long()
105
+ # scales = torch.FloatTensor([0.667, 1.0, 1])
106
+ # # make triton dynamic shape happy
107
+ # scales = scales.unsqueeze(0)
108
+
109
+ # use numpy to replace torch
110
+ x = np.array([seq], dtype=np.int64)
111
+ x_len = np.array([x.shape[1]], dtype=np.int64)
112
+ sid = np.array([sid], dtype=np.int64)
113
+ # noise(可用于控制感情等变化程度) lenth(可用于控制整体语速) noisew(控制音素发音长度变化程度)
114
+ # 参考 https://github.com/gbxh/genshinTTS
115
+ scales = np.array([0.667, 0.8, 1], dtype=np.float32)
116
+ # scales = scales[np.newaxis, :]
117
+ # scales.reshape(1, -1)
118
+ scales.resize(1, 3)
119
+
120
+ ort_inputs = {
121
+ 'input': x,
122
+ 'input_lengths': x_len,
123
+ 'scales': scales,
124
+ 'sid': sid
125
+ }
126
+
127
+ # ort_inputs = {
128
+ # 'input': to_numpy(x),
129
+ # 'input_lengths': to_numpy(x_len),
130
+ # 'scales': to_numpy(scales),
131
+ # 'sid': to_numpy(sid)
132
+ # }
133
+ import time
134
+ # start_time = time.time()
135
+ start_time = time.perf_counter()
136
+ audio = np.squeeze(ort_sess.run(None, ort_inputs))
137
+ audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
138
+ audio = np.clip(audio, -32767.0, 32767.0)
139
+ end_time = time.perf_counter()
140
+ # end_time = time.time()
141
+ print("infer time cost: ", end_time - start_time, "s")
142
+
143
+ wavfile.write(args.outdir + "/" + audio_path.split("/")[-1],
144
+ hps.data.sampling_rate, audio.astype(np.int16))
145
+
146
+
147
+ if __name__ == '__main__':
148
+ main()
local_run.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from text import text_to_sequence
3
+ import numpy as np
4
+ from scipy.io import wavfile
5
+ import torch
6
+ import json
7
+ import commons
8
+ import utils
9
+ import sys
10
+ import pathlib
11
+ from flask import Flask, request
12
+ import threading
13
+ import onnxruntime as ort
14
+ import time
15
+ from pydub import AudioSegment
16
+ import io
17
+ import os
18
+ from transformers import AutoTokenizer, AutoModel
19
+ import tkinter as tk
20
+ from tkinter import scrolledtext
21
+ from scipy.io.wavfile import write
22
+ def get_args():
23
+ parser = argparse.ArgumentParser(description='inference')
24
+ parser.add_argument('--onnx_model', default = './moe/model.onnx')
25
+ parser.add_argument('--cfg', default="./moe/config_v.json")
26
+ parser.add_argument('--outdir', default="./moe",
27
+ help='ouput folder')
28
+ parser.add_argument('--audio',
29
+ type=str,
30
+ help='你要替换的音频文件的,假设这些音频文件为temp1、temp2、temp3......',
31
+ default = 'D:/app_develop/live2d_whole/2010002/sounds/temp.wav')
32
+ parser.add_argument('--ChatGLM',default = "./moe",
33
+ help='https://github.com/THUDM/ChatGLM-6B')
34
+ args = parser.parse_args()
35
+ return args
36
+
37
+ def to_numpy(tensor: torch.Tensor):
38
+ return tensor.detach().cpu().numpy() if tensor.requires_grad \
39
+ else tensor.detach().numpy()
40
+
41
+ def get_symbols_from_json(path):
42
+ import os
43
+ assert os.path.isfile(path)
44
+ with open(path, 'r') as f:
45
+ data = json.load(f)
46
+ return data['symbols']
47
+
48
+ args = get_args()
49
+ symbols = get_symbols_from_json(args.cfg)
50
+ phone_dict = {
51
+ symbol: i for i, symbol in enumerate(symbols)
52
+ }
53
+ hps = utils.get_hparams_from_file(args.cfg)
54
+ ort_sess = ort.InferenceSession(args.onnx_model)
55
+
56
+ def is_japanese(string):
57
+ for ch in string:
58
+ if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
59
+ return True
60
+ return False
61
+
62
+ def infer(text):
63
+ #选择你想要的角色
64
+ sid = 7
65
+ text = f"[JA]{text}[JA]" if is_japanese(text) else f"[ZH]{text}[ZH]"
66
+ #seq = text_to_sequence(text, symbols=hps.symbols, cleaner_names=hps.data.text_cleaners)
67
+ seq = text_to_sequence(text, cleaner_names=hps.data.text_cleaners)
68
+ if hps.data.add_blank:
69
+ seq = commons.intersperse(seq, 0)
70
+ with torch.no_grad():
71
+ x = np.array([seq], dtype=np.int64)
72
+ x_len = np.array([x.shape[1]], dtype=np.int64)
73
+ sid = np.array([sid], dtype=np.int64)
74
+ scales = np.array([0.667, 0.7, 1], dtype=np.float32)
75
+ scales.resize(1, 3)
76
+ ort_inputs = {
77
+ 'input': x,
78
+ 'input_lengths': x_len,
79
+ 'scales': scales,
80
+ 'sid': sid
81
+ }
82
+ t1 = time.time()
83
+ audio = np.squeeze(ort_sess.run(None, ort_inputs))
84
+ audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
85
+ audio = np.clip(audio, -32767.0, 32767.0)
86
+ bytes_wav = bytes()
87
+ byte_io = io.BytesIO(bytes_wav)
88
+ wavfile.write(args.audio + '.wav',hps.data.sampling_rate, audio.astype(np.int16))
89
+ i = 0
90
+ while i < 19:
91
+ i +=1
92
+ cmd = 'ffmpeg -y -i ' + args.audio + '.wav' + ' -ar 44100 '+ args.audio.replace('temp','temp'+str(i))
93
+ os.system(cmd)
94
+ t2 = time.time()
95
+ print("推理耗时:",(t2 - t1),"s")
96
+ return text
97
+ tokenizer = AutoTokenizer.from_pretrained(args.ChatGLM, trust_remote_code=True)
98
+ #8G GPU
99
+ model = AutoModel.from_pretrained(args.ChatGLM, trust_remote_code=True).half().quantize(4).cuda()
100
+ history = []
101
+ def send_message():
102
+ global history
103
+ message = input_box.get("1.0", "end-1c") # 获取用户输入的文本
104
+ t1 = time.time()
105
+ if message == 'clear':
106
+ history = []
107
+ else:
108
+ response, new_history = model.chat(tokenizer, message, history)
109
+ response = response.replace(" ",'').replace("\n",'.')
110
+ text = infer(response)
111
+ text = text.replace('[JA]','').replace('[ZH]','')
112
+ chat_box.configure(state='normal') # 配置聊天框为可写状态
113
+ chat_box.insert(tk.END, "You: " + message + "\n") # 在聊天框中显示用户输入的文本
114
+ chat_box.insert(tk.END, "Tamao: " + text + "\n") # 在聊天框中显示 chatbot 的回复
115
+ chat_box.configure(state='disabled') # 配置聊天框为只读状态
116
+ input_box.delete("1.0", tk.END) # 清空输入框
117
+ t2 = time.time()
118
+ print("总共耗时:",(t2 - t1),"s")
119
+
120
+ root = tk.Tk()
121
+ root.title("Tamao")
122
+
123
+ # 创建聊天框
124
+ chat_box = scrolledtext.ScrolledText(root, width=50, height=10)
125
+ chat_box.configure(state='disabled') # 聊天框一开始是只读状态
126
+ chat_box.pack(side=tk.TOP, fill=tk.BOTH, padx=10, pady=10, expand=True)
127
+
128
+ # 创建输入框和发送按钮
129
+ input_frame = tk.Frame(root)
130
+ input_frame.pack(side=tk.BOTTOM, fill=tk.X, padx=10, pady=10)
131
+ input_box = tk.Text(input_frame, height=3, width=50) # 设置输入框宽度为50
132
+ input_box.pack(side=tk.LEFT, fill=tk.X, padx=10, expand=True)
133
+ send_button = tk.Button(input_frame, text="Send", command=send_message)
134
+ send_button.pack(side=tk.RIGHT, padx=10)
135
+
136
+ # 运行主程序
137
+ root.mainloop()
losses.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def feature_loss(fmap_r, fmap_g):
5
+ loss = 0
6
+ for dr, dg in zip(fmap_r, fmap_g):
7
+ for rl, gl in zip(dr, dg):
8
+ rl = rl.float().detach()
9
+ gl = gl.float()
10
+ loss += torch.mean(torch.abs(rl - gl))
11
+
12
+ return loss * 2
13
+
14
+
15
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16
+ loss = 0
17
+ r_losses = []
18
+ g_losses = []
19
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20
+ dr = dr.float()
21
+ dg = dg.float()
22
+ r_loss = torch.mean((1 - dr)**2)
23
+ g_loss = torch.mean(dg**2)
24
+ loss += (r_loss + g_loss)
25
+ r_losses.append(r_loss.item())
26
+ g_losses.append(g_loss.item())
27
+
28
+ return loss, r_losses, g_losses
29
+
30
+
31
+ def generator_loss(disc_outputs):
32
+ loss = 0
33
+ gen_losses = []
34
+ for dg in disc_outputs:
35
+ dg = dg.float()
36
+ l = torch.mean((1 - dg)**2)
37
+ gen_losses.append(l)
38
+ loss += l
39
+
40
+ return loss, gen_losses
41
+
42
+
43
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44
+ """
45
+ z_p, logs_q: [b, h, t_t]
46
+ m_p, logs_p: [b, h, t_t]
47
+ """
48
+ z_p = z_p.float()
49
+ logs_q = logs_q.float()
50
+ m_p = m_p.float()
51
+ logs_p = logs_p.float()
52
+ z_mask = z_mask.float()
53
+
54
+ kl = logs_p - logs_q - 0.5
55
+ kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
56
+ kl = torch.sum(kl * z_mask)
57
+ l = kl / torch.sum(z_mask)
58
+ return l
main.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logging.getLogger('numba').setLevel(logging.WARNING)
3
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
4
+ logging.getLogger('urllib3').setLevel(logging.WARNING)
5
+ from text import text_to_sequence
6
+ import numpy as np
7
+ from scipy.io import wavfile
8
+ import torch
9
+ import json
10
+ import commons
11
+ import utils
12
+ import sys
13
+ import pathlib
14
+ import onnxruntime as ort
15
+ import gradio as gr
16
+ import argparse
17
+ import time
18
+ import os
19
+ import io
20
+ from scipy.io.wavfile import write
21
+ from flask import Flask, request
22
+ from threading import Thread
23
+ import openai
24
+ import requests
25
+ class VitsGradio:
26
+ def __init__(self):
27
+ self.lan = ["中文","日文","自动"]
28
+ self.chatapi = ["gpt-3.5-turbo","gpt3"]
29
+ self.modelPaths = []
30
+ for root,dirs,files in os.walk("checkpoints"):
31
+ for dir in dirs:
32
+ self.modelPaths.append(dir)
33
+ with gr.Blocks() as self.Vits:
34
+ with gr.Tab("调试用"):
35
+ with gr.Row():
36
+ with gr.Column():
37
+ with gr.Row():
38
+ with gr.Column():
39
+ self.text = gr.TextArea(label="Text", value="你好")
40
+ with gr.Accordion(label="测试api", open=False):
41
+ self.local_chat1 = gr.Checkbox(value=False, label="使用网址+文本进行模拟")
42
+ self.url_input = gr.TextArea(label="键入测试", value="http://127.0.0.1:8080/chat?Text=")
43
+ butto = gr.Button("模拟前端抓取语音文件")
44
+ btnVC = gr.Button("测试tts+对话程序")
45
+ with gr.Column():
46
+ output2 = gr.TextArea(label="回复")
47
+ output1 = gr.Audio(label="采样率22050")
48
+ output3 = gr.outputs.File(label="44100hz: output.wav")
49
+ butto.click(self.Simul, inputs=[self.text, self.url_input], outputs=[output2,output3])
50
+ btnVC.click(self.tts_fn, inputs=[self.text], outputs=[output1,output2])
51
+ with gr.Tab("控制面板"):
52
+ with gr.Row():
53
+ with gr.Column():
54
+ with gr.Row():
55
+ with gr.Column():
56
+ self.api_input1 = gr.TextArea(label="输入api-key或本地存储说话模型的路径", value="https://platform.openai.com/account/api-keys")
57
+ with gr.Accordion(label="chatbot选择", open=False):
58
+ self.api_input2 = gr.Checkbox(value=True, label="采用gpt3.5")
59
+ self.local_chat1 = gr.Checkbox(value=False, label="启动本地chatbot")
60
+ self.local_chat2 = gr.Checkbox(value=True, label="是否量化")
61
+ res = gr.TextArea()
62
+ Botselection = gr.Button("完成chatbot设定")
63
+ Botselection.click(self.check_bot, inputs=[self.api_input1,self.api_input2,self.local_chat1,self.local_chat2], outputs = [res])
64
+ self.input1 = gr.Dropdown(label = "模型", choices = self.modelPaths, value = self.modelPaths[0], type = "value")
65
+ self.input2 = gr.Dropdown(label="Language", choices=self.lan, value="自动", interactive=True)
66
+ with gr.Column():
67
+ btnVC = gr.Button("完成vits TTS端设定")
68
+ self.input3 = gr.Dropdown(label="Speaker", choices=list(range(101)), value=0, interactive=True)
69
+ self.input4 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声比例(noise scale),以控制情感", value=0.267)
70
+ self.input5 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声偏差(noise scale w),以控制音素长短", value=0.7)
71
+ self.input6 = gr.Slider(minimum=0.1, maximum=10, label="duration", value=1)
72
+ statusa = gr.TextArea()
73
+ btnVC.click(self.create_tts_fn, inputs=[self.input1, self.input2, self.input3, self.input4, self.input5, self.input6], outputs = [statusa])
74
+
75
+ def Simul(self,text,url_input):
76
+ web = url_input + text
77
+ res = requests.get(web)
78
+ music = res.content
79
+ with open('output.wav', 'wb') as code:
80
+ code.write(music)
81
+ file_path = "output.wav"
82
+ return web,file_path
83
+
84
+
85
+ def chatgpt(self,text):
86
+ self.messages.append({"role": "user", "content": text},)
87
+ chat = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages= self.messages)
88
+ reply = chat.choices[0].message.content
89
+ return reply
90
+
91
+ def ChATGLM(self,text):
92
+ if text == 'clear':
93
+ self.history = []
94
+ response, new_history = self.model.chat(self.tokenizer, text, self.history)
95
+ response = response.replace(" ",'').replace("\n",'.')
96
+ self.history = new_history
97
+ return response
98
+
99
+ def gpt3_chat(self,text):
100
+ call_name = "Waifu"
101
+ openai.api_key = args.key
102
+ identity = ""
103
+ start_sequence = '\n'+str(call_name)+':'
104
+ restart_sequence = "\nYou: "
105
+ if 1 == 1:
106
+ prompt0 = text #当期prompt
107
+ if text == 'quit':
108
+ return prompt0
109
+ prompt = identity + prompt0 + start_sequence
110
+ response = openai.Completion.create(
111
+ model="text-davinci-003",
112
+ prompt=prompt,
113
+ temperature=0.5,
114
+ max_tokens=1000,
115
+ top_p=1.0,
116
+ frequency_penalty=0.5,
117
+ presence_penalty=0.0,
118
+ stop=["\nYou:"]
119
+ )
120
+ return response['choices'][0]['text'].strip()
121
+
122
+ def check_bot(self,api_input1,api_input2,local_chat1,local_chat2):
123
+ if local_chat1:
124
+ from transformers import AutoTokenizer, AutoModel
125
+ self.tokenizer = AutoTokenizer.from_pretrained(api_input1, trust_remote_code=True)
126
+ if local_chat2:
127
+ self.model = AutoModel.from_pretrained(api_input1, trust_remote_code=True).half().cuda()
128
+ else:
129
+ self.model = AutoModel.from_pretrained(api_input1, trust_remote_code=True)
130
+ self.history = []
131
+ else:
132
+ self.messages = []
133
+ openai.api_key = api_input1
134
+ return "Finished"
135
+
136
+ def is_japanese(self,string):
137
+ for ch in string:
138
+ if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
139
+ return True
140
+ return False
141
+
142
+ def is_english(self,string):
143
+ import re
144
+ pattern = re.compile('^[A-Za-z0-9.,:;!?()_*"\' ]+$')
145
+ if pattern.fullmatch(string):
146
+ return True
147
+ else:
148
+ return False
149
+
150
+ def get_symbols_from_json(self,path):
151
+ assert os.path.isfile(path)
152
+ with open(path, 'r') as f:
153
+ data = json.load(f)
154
+ return data['symbols']
155
+
156
+ def sle(self,language,text):
157
+ text = text.replace('\n','。').replace(' ',',')
158
+ if language == "中文":
159
+ tts_input1 = "[ZH]" + text + "[ZH]"
160
+ return tts_input1
161
+ elif language == "自动":
162
+ tts_input1 = f"[JA]{text}[JA]" if self.is_japanese(text) else f"[ZH]{text}[ZH]"
163
+ return tts_input1
164
+ elif language == "日文":
165
+ tts_input1 = "[JA]" + text + "[JA]"
166
+ return tts_input1
167
+
168
+ def get_text(self,text,hps_ms):
169
+ text_norm = text_to_sequence(text,hps_ms.data.text_cleaners)
170
+ if hps_ms.data.add_blank:
171
+ text_norm = commons.intersperse(text_norm, 0)
172
+ text_norm = torch.LongTensor(text_norm)
173
+ return text_norm
174
+
175
+ def create_tts_fn(self,path, input2, input3, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
176
+ self.symbols = self.get_symbols_from_json(f"checkpoints/{path}/config.json")
177
+ self.hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
178
+ phone_dict = {
179
+ symbol: i for i, symbol in enumerate(self.symbols)
180
+ }
181
+ self.ort_sess = ort.InferenceSession(f"checkpoints/{path}/model.onnx")
182
+ self.language = input2
183
+ self.speaker_id = input3
184
+ self.n_scale = n_scale
185
+ self.n_scale_w = n_scale_w
186
+ self.l_scale = l_scale
187
+ print(self.language,self.speaker_id,self.n_scale)
188
+ return 'success'
189
+
190
+ def tts_fn(self,text):
191
+ if self.local_chat1:
192
+ text = self.chatgpt(text)
193
+ elif self.api_input2:
194
+ text = self.ChATGLM(text)
195
+ else:
196
+ text = self.gpt3_chat(text)
197
+ print(text)
198
+ text =self.sle(self.language,text)
199
+ seq = text_to_sequence(text, cleaner_names=self.hps.data.text_cleaners)
200
+ if self.hps.data.add_blank:
201
+ seq = commons.intersperse(seq, 0)
202
+ with torch.no_grad():
203
+ x = np.array([seq], dtype=np.int64)
204
+ x_len = np.array([x.shape[1]], dtype=np.int64)
205
+ sid = np.array([self.speaker_id], dtype=np.int64)
206
+ scales = np.array([self.n_scale, self.n_scale_w, self.l_scale], dtype=np.float32)
207
+ scales.resize(1, 3)
208
+ ort_inputs = {
209
+ 'input': x,
210
+ 'input_lengths': x_len,
211
+ 'scales': scales,
212
+ 'sid': sid
213
+ }
214
+ t1 = time.time()
215
+ audio = np.squeeze(self.ort_sess.run(None, ort_inputs))
216
+ audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
217
+ audio = np.clip(audio, -32767.0, 32767.0)
218
+ t2 = time.time()
219
+ spending_time = "推理时间:"+str(t2-t1)+"s"
220
+ print(spending_time)
221
+ bytes_wav = bytes()
222
+ byte_io = io.BytesIO(bytes_wav)
223
+ wavfile.write('moe/temp1.wav',self.hps.data.sampling_rate, audio.astype(np.int16))
224
+ cmd = 'ffmpeg -y -i ' + 'moe/temp1.wav' + ' -ar 44100 ' + 'moe/temp2.wav'
225
+ os.system(cmd)
226
+ return (self.hps.data.sampling_rate, audio),text.replace('[JA]','').replace('[ZH]','')
227
+
228
+ app = Flask(__name__)
229
+ print("开始部署")
230
+ grVits = VitsGradio()
231
+
232
+ @app.route('/chat')
233
+ def text_api():
234
+ message = request.args.get('Text','')
235
+ audio,text = grVits.tts_fn(message)
236
+ text = text.replace('[JA]','').replace('[ZH]','')
237
+ with open('moe/temp2.wav','rb') as bit:
238
+ wav_bytes = bit.read()
239
+ headers = {
240
+ 'Content-Type': 'audio/wav',
241
+ 'Text': text.encode('utf-8')}
242
+ return wav_bytes, 200, headers
243
+
244
+ def gradio_interface():
245
+ return grVits.Vits.launch()
246
+
247
+ if __name__ == '__main__':
248
+ api_thread = Thread(target=app.run, args=("0.0.0.0", 8080))
249
+ gradio_thread = Thread(target=gradio_interface)
250
+ api_thread.start()
251
+ gradio_thread.start()
mel_processing.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+
6
+ MAX_WAV_VALUE = 32768.0
7
+
8
+
9
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
10
+ """
11
+ PARAMS
12
+ ------
13
+ C: compression factor
14
+ """
15
+ return torch.log(torch.clamp(x, min=clip_val) * C)
16
+
17
+
18
+ def dynamic_range_decompression_torch(x, C=1):
19
+ """
20
+ PARAMS
21
+ ------
22
+ C: compression factor used to compress
23
+ """
24
+ return torch.exp(x) / C
25
+
26
+
27
+ def spectral_normalize_torch(magnitudes):
28
+ output = dynamic_range_compression_torch(magnitudes)
29
+ return output
30
+
31
+
32
+ def spectral_de_normalize_torch(magnitudes):
33
+ output = dynamic_range_decompression_torch(magnitudes)
34
+ return output
35
+
36
+
37
+ mel_basis = {}
38
+ hann_window = {}
39
+
40
+
41
+ def spectrogram_torch(y,
42
+ n_fft,
43
+ sampling_rate,
44
+ hop_size,
45
+ win_size,
46
+ center=False):
47
+ if torch.min(y) < -1.:
48
+ print('min value is ', torch.min(y))
49
+ if torch.max(y) > 1.:
50
+ print('max value is ', torch.max(y))
51
+
52
+ global hann_window
53
+ dtype_device = str(y.dtype) + '_' + str(y.device)
54
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
55
+ if wnsize_dtype_device not in hann_window:
56
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
57
+ dtype=y.dtype, device=y.device)
58
+
59
+ y = F.pad(y.unsqueeze(1),
60
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
61
+ mode='reflect')
62
+ y = y.squeeze(1)
63
+
64
+ spec = torch.stft(y,
65
+ n_fft,
66
+ hop_length=hop_size,
67
+ win_length=win_size,
68
+ window=hann_window[wnsize_dtype_device],
69
+ center=center,
70
+ pad_mode='reflect',
71
+ normalized=False,
72
+ onesided=True)
73
+
74
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
75
+ return spec
76
+
77
+
78
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
79
+ global mel_basis
80
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
81
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
82
+ if fmax_dtype_device not in mel_basis:
83
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
84
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
85
+ dtype=spec.dtype, device=spec.device)
86
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
87
+ spec = spectral_normalize_torch(spec)
88
+ return spec
89
+
90
+
91
+ def mel_spectrogram_torch(y,
92
+ n_fft,
93
+ num_mels,
94
+ sampling_rate,
95
+ hop_size,
96
+ win_size,
97
+ fmin,
98
+ fmax,
99
+ center=False):
100
+ if torch.min(y) < -1.:
101
+ print('min value is ', torch.min(y))
102
+ if torch.max(y) > 1.:
103
+ print('max value is ', torch.max(y))
104
+
105
+ global mel_basis, hann_window
106
+ dtype_device = str(y.dtype) + '_' + str(y.device)
107
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
108
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
109
+ if fmax_dtype_device not in mel_basis:
110
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
111
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
112
+ dtype=y.dtype, device=y.device)
113
+ if wnsize_dtype_device not in hann_window:
114
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
115
+ dtype=y.dtype, device=y.device)
116
+
117
+ y = F.pad(y.unsqueeze(1),
118
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
119
+ mode='reflect')
120
+ y = y.squeeze(1)
121
+
122
+ spec = torch.stft(y,
123
+ n_fft,
124
+ hop_length=hop_size,
125
+ win_length=win_size,
126
+ window=hann_window[wnsize_dtype_device],
127
+ center=center,
128
+ pad_mode='reflect',
129
+ normalized=False,
130
+ onesided=True)
131
+
132
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
133
+
134
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
135
+ spec = spectral_normalize_torch(spec)
136
+
137
+ return spec
models.py CHANGED
@@ -2,18 +2,25 @@ import math
2
 
3
  import torch
4
  from torch import nn
5
- from torch.nn import Conv1d, ConvTranspose1d, Conv2d
6
  from torch.nn import functional as F
 
7
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
 
8
 
9
- import attentions
10
  import commons
11
  import modules
 
12
  from commons import init_weights, get_padding
13
 
14
 
15
  class StochasticDurationPredictor(nn.Module):
16
- def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
 
 
 
 
 
 
17
  super().__init__()
18
  filter_channels = in_channels # it needs to be removed from future version.
19
  self.in_channels = in_channels
@@ -27,25 +34,39 @@ class StochasticDurationPredictor(nn.Module):
27
  self.flows = nn.ModuleList()
28
  self.flows.append(modules.ElementwiseAffine(2))
29
  for i in range(n_flows):
30
- self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
 
31
  self.flows.append(modules.Flip())
32
 
33
  self.post_pre = nn.Conv1d(1, filter_channels, 1)
34
  self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
35
- self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
 
 
 
36
  self.post_flows = nn.ModuleList()
37
  self.post_flows.append(modules.ElementwiseAffine(2))
38
  for i in range(4):
39
- self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
 
40
  self.post_flows.append(modules.Flip())
41
 
42
  self.pre = nn.Conv1d(in_channels, filter_channels, 1)
43
  self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
44
- self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
 
 
 
45
  if gin_channels != 0:
46
  self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
47
 
48
- def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
 
 
 
 
 
 
49
  x = torch.detach(x)
50
  x = self.pre(x)
51
  if g is not None:
@@ -62,7 +83,8 @@ class StochasticDurationPredictor(nn.Module):
62
  h_w = self.post_pre(w)
63
  h_w = self.post_convs(h_w, x_mask)
64
  h_w = self.post_proj(h_w) * x_mask
65
- e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
 
66
  z_q = e_q
67
  for flow in self.post_flows:
68
  z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
@@ -70,8 +92,11 @@ class StochasticDurationPredictor(nn.Module):
70
  z_u, z1 = torch.split(z_q, [1, 1], 1)
71
  u = torch.sigmoid(z_u) * x_mask
72
  z0 = (w - u) * x_mask
73
- logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
74
- logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q
 
 
 
75
 
76
  logdet_tot = 0
77
  z0, logdet = self.log_flow(z0, x_mask)
@@ -80,12 +105,14 @@ class StochasticDurationPredictor(nn.Module):
80
  for flow in flows:
81
  z, logdet = flow(z, x_mask, g=x, reverse=reverse)
82
  logdet_tot = logdet_tot + logdet
83
- nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
 
84
  return nll + logq # [b]
85
  else:
86
  flows = list(reversed(self.flows))
87
  flows = flows[:-2] + [flows[-1]] # remove a useless vflow
88
- z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
 
89
  for flow in flows:
90
  z = flow(z, x_mask, g=x, reverse=reverse)
91
  z0, z1 = torch.split(z, [1, 1], 1)
@@ -94,7 +121,12 @@ class StochasticDurationPredictor(nn.Module):
94
 
95
 
96
  class DurationPredictor(nn.Module):
97
- def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
 
 
 
 
 
98
  super().__init__()
99
 
100
  self.in_channels = in_channels
@@ -104,9 +136,15 @@ class DurationPredictor(nn.Module):
104
  self.gin_channels = gin_channels
105
 
106
  self.drop = nn.Dropout(p_dropout)
107
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
 
 
 
108
  self.norm_1 = modules.LayerNorm(filter_channels)
109
- self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
 
 
 
110
  self.norm_2 = modules.LayerNorm(filter_channels)
111
  self.proj = nn.Conv1d(filter_channels, 1, 1)
112
 
@@ -131,15 +169,8 @@ class DurationPredictor(nn.Module):
131
 
132
 
133
  class TextEncoder(nn.Module):
134
- def __init__(self,
135
- n_vocab,
136
- out_channels,
137
- hidden_channels,
138
- filter_channels,
139
- n_heads,
140
- n_layers,
141
- kernel_size,
142
- p_dropout):
143
  super().__init__()
144
  self.n_vocab = n_vocab
145
  self.out_channels = out_channels
@@ -150,24 +181,19 @@ class TextEncoder(nn.Module):
150
  self.kernel_size = kernel_size
151
  self.p_dropout = p_dropout
152
 
153
- if self.n_vocab != 0:
154
- self.emb = nn.Embedding(n_vocab, hidden_channels)
155
- nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
156
-
157
- self.encoder = attentions.Encoder(
158
- hidden_channels,
159
- filter_channels,
160
- n_heads,
161
- n_layers,
162
- kernel_size,
163
- p_dropout)
164
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
165
 
166
  def forward(self, x, x_lengths):
167
- if self.n_vocab != 0:
168
- x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
169
  x = torch.transpose(x, 1, -1) # [b, h, t]
170
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
 
171
 
172
  x = self.encoder(x * x_mask, x_mask)
173
  stats = self.proj(x) * x_mask
@@ -197,8 +223,13 @@ class ResidualCouplingBlock(nn.Module):
197
  self.flows = nn.ModuleList()
198
  for i in range(n_flows):
199
  self.flows.append(
200
- modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
201
- gin_channels=gin_channels, mean_only=True))
 
 
 
 
 
202
  self.flows.append(modules.Flip())
203
 
204
  def forward(self, x, x_mask, g=None, reverse=False):
@@ -230,11 +261,16 @@ class PosteriorEncoder(nn.Module):
230
  self.gin_channels = gin_channels
231
 
232
  self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
233
- self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
 
 
 
 
234
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
235
 
236
  def forward(self, x, x_lengths, g=None):
237
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
 
238
  x = self.pre(x) * x_mask
239
  x = self.enc(x, x_mask, g=g)
240
  stats = self.proj(x) * x_mask
@@ -244,24 +280,40 @@ class PosteriorEncoder(nn.Module):
244
 
245
 
246
  class Generator(torch.nn.Module):
247
- def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
248
- upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
 
 
 
 
 
 
 
249
  super(Generator, self).__init__()
250
  self.num_kernels = len(resblock_kernel_sizes)
251
  self.num_upsamples = len(upsample_rates)
252
- self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
 
 
 
 
253
  resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
254
 
255
  self.ups = nn.ModuleList()
256
  for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
257
- self.ups.append(weight_norm(
258
- ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)),
259
- k, u, padding=(k - u) // 2)))
 
 
 
 
260
 
261
  self.resblocks = nn.ModuleList()
262
  for i in range(len(self.ups)):
263
- ch = upsample_initial_channel // (2 ** (i + 1))
264
- for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
 
265
  self.resblocks.append(resblock(ch, k, d))
266
 
267
  self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
@@ -300,17 +352,37 @@ class Generator(torch.nn.Module):
300
 
301
 
302
  class DiscriminatorP(torch.nn.Module):
303
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
 
 
 
 
304
  super(DiscriminatorP, self).__init__()
305
  self.period = period
306
  self.use_spectral_norm = use_spectral_norm
307
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
308
  self.convs = nn.ModuleList([
309
- norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
310
- norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
311
- norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
312
- norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
313
- norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  ])
315
  self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
316
 
@@ -339,7 +411,7 @@ class DiscriminatorP(torch.nn.Module):
339
  class DiscriminatorS(torch.nn.Module):
340
  def __init__(self, use_spectral_norm=False):
341
  super(DiscriminatorS, self).__init__()
342
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
343
  self.convs = nn.ModuleList([
344
  norm_f(Conv1d(1, 16, 15, 1, padding=7)),
345
  norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
@@ -370,7 +442,10 @@ class MultiPeriodDiscriminator(torch.nn.Module):
370
  periods = [2, 3, 5, 7, 11]
371
 
372
  discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
373
- discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
 
 
 
374
  self.discriminators = nn.ModuleList(discs)
375
 
376
  def forward(self, y, y_hat):
@@ -391,9 +466,8 @@ class MultiPeriodDiscriminator(torch.nn.Module):
391
 
392
  class SynthesizerTrn(nn.Module):
393
  """
394
- Synthesizer for Training
395
- """
396
-
397
  def __init__(self,
398
  n_vocab,
399
  spec_channels,
@@ -435,32 +509,116 @@ class SynthesizerTrn(nn.Module):
435
  self.segment_size = segment_size
436
  self.n_speakers = n_speakers
437
  self.gin_channels = gin_channels
 
 
 
438
 
439
  self.use_sdp = use_sdp
440
 
441
- self.enc_p = TextEncoder(n_vocab,
442
- inter_channels,
443
- hidden_channels,
444
- filter_channels,
445
- n_heads,
446
- n_layers,
447
- kernel_size,
448
- p_dropout)
449
- self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
450
- upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
451
- self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
 
 
 
 
 
 
452
  gin_channels=gin_channels)
453
- self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
 
 
 
 
 
454
 
455
  if use_sdp:
456
- self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
 
 
 
 
 
457
  else:
458
- self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
 
 
 
 
459
 
460
  if n_speakers > 1:
461
  self.emb_g = nn.Embedding(n_speakers, gin_channels)
462
 
463
- def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
465
  if self.n_speakers > 0:
466
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
@@ -468,25 +626,41 @@ class SynthesizerTrn(nn.Module):
468
  g = None
469
 
470
  if self.use_sdp:
471
- logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
 
 
 
 
472
  else:
473
  logw = self.dp(x, x_mask, g=g)
474
  w = torch.exp(logw) * x_mask * length_scale
475
  w_ceil = torch.ceil(w)
476
  y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
477
- y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
 
478
  attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
479
  attn = commons.generate_path(w_ceil, attn_mask)
480
 
481
- m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
482
- logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1,
483
- 2) # [b, t', t], [b, t, d] -> [b, d, t']
 
484
 
485
  z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
486
  z = self.flow(z_p, y_mask, g=g, reverse=True)
487
  o = self.dec((z * y_mask)[:, :, :max_len], g=g)
488
  return o, attn, y_mask, (z, z_p, m_p, logs_p)
489
 
 
 
 
 
 
 
 
 
 
 
490
  def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
491
  assert self.n_speakers > 0, "n_speakers have to be larger than 0."
492
  g_src = self.emb_g(sid_src).unsqueeze(-1)
 
2
 
3
  import torch
4
  from torch import nn
 
5
  from torch.nn import functional as F
6
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
7
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
8
+ import monotonic_align
9
 
 
10
  import commons
11
  import modules
12
+ import attentions
13
  from commons import init_weights, get_padding
14
 
15
 
16
  class StochasticDurationPredictor(nn.Module):
17
+ def __init__(self,
18
+ in_channels,
19
+ filter_channels,
20
+ kernel_size,
21
+ p_dropout,
22
+ n_flows=4,
23
+ gin_channels=0):
24
  super().__init__()
25
  filter_channels = in_channels # it needs to be removed from future version.
26
  self.in_channels = in_channels
 
34
  self.flows = nn.ModuleList()
35
  self.flows.append(modules.ElementwiseAffine(2))
36
  for i in range(n_flows):
37
+ self.flows.append(
38
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
39
  self.flows.append(modules.Flip())
40
 
41
  self.post_pre = nn.Conv1d(1, filter_channels, 1)
42
  self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
43
+ self.post_convs = modules.DDSConv(filter_channels,
44
+ kernel_size,
45
+ n_layers=3,
46
+ p_dropout=p_dropout)
47
  self.post_flows = nn.ModuleList()
48
  self.post_flows.append(modules.ElementwiseAffine(2))
49
  for i in range(4):
50
+ self.post_flows.append(
51
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
52
  self.post_flows.append(modules.Flip())
53
 
54
  self.pre = nn.Conv1d(in_channels, filter_channels, 1)
55
  self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
56
+ self.convs = modules.DDSConv(filter_channels,
57
+ kernel_size,
58
+ n_layers=3,
59
+ p_dropout=p_dropout)
60
  if gin_channels != 0:
61
  self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
62
 
63
+ def forward(self,
64
+ x,
65
+ x_mask,
66
+ w=None,
67
+ g=None,
68
+ reverse=False,
69
+ noise_scale=1.0):
70
  x = torch.detach(x)
71
  x = self.pre(x)
72
  if g is not None:
 
83
  h_w = self.post_pre(w)
84
  h_w = self.post_convs(h_w, x_mask)
85
  h_w = self.post_proj(h_w) * x_mask
86
+ e_q = torch.randn(w.size(0), 2, w.size(2)).to(
87
+ device=x.device, dtype=x.dtype) * x_mask
88
  z_q = e_q
89
  for flow in self.post_flows:
90
  z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
 
92
  z_u, z1 = torch.split(z_q, [1, 1], 1)
93
  u = torch.sigmoid(z_u) * x_mask
94
  z0 = (w - u) * x_mask
95
+ logdet_tot_q += torch.sum(
96
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
97
+ logq = torch.sum(
98
+ -0.5 * (math.log(2 * math.pi) +
99
+ (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
100
 
101
  logdet_tot = 0
102
  z0, logdet = self.log_flow(z0, x_mask)
 
105
  for flow in flows:
106
  z, logdet = flow(z, x_mask, g=x, reverse=reverse)
107
  logdet_tot = logdet_tot + logdet
108
+ nll = torch.sum(0.5 * (math.log(2 * math.pi) +
109
+ (z**2)) * x_mask, [1, 2]) - logdet_tot
110
  return nll + logq # [b]
111
  else:
112
  flows = list(reversed(self.flows))
113
  flows = flows[:-2] + [flows[-1]] # remove a useless vflow
114
+ z = torch.randn(x.size(0), 2, x.size(2)).to(
115
+ device=x.device, dtype=x.dtype) * noise_scale
116
  for flow in flows:
117
  z = flow(z, x_mask, g=x, reverse=reverse)
118
  z0, z1 = torch.split(z, [1, 1], 1)
 
121
 
122
 
123
  class DurationPredictor(nn.Module):
124
+ def __init__(self,
125
+ in_channels,
126
+ filter_channels,
127
+ kernel_size,
128
+ p_dropout,
129
+ gin_channels=0):
130
  super().__init__()
131
 
132
  self.in_channels = in_channels
 
136
  self.gin_channels = gin_channels
137
 
138
  self.drop = nn.Dropout(p_dropout)
139
+ self.conv_1 = nn.Conv1d(in_channels,
140
+ filter_channels,
141
+ kernel_size,
142
+ padding=kernel_size // 2)
143
  self.norm_1 = modules.LayerNorm(filter_channels)
144
+ self.conv_2 = nn.Conv1d(filter_channels,
145
+ filter_channels,
146
+ kernel_size,
147
+ padding=kernel_size // 2)
148
  self.norm_2 = modules.LayerNorm(filter_channels)
149
  self.proj = nn.Conv1d(filter_channels, 1, 1)
150
 
 
169
 
170
 
171
  class TextEncoder(nn.Module):
172
+ def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels,
173
+ n_heads, n_layers, kernel_size, p_dropout):
 
 
 
 
 
 
 
174
  super().__init__()
175
  self.n_vocab = n_vocab
176
  self.out_channels = out_channels
 
181
  self.kernel_size = kernel_size
182
  self.p_dropout = p_dropout
183
 
184
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
185
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
186
+
187
+ self.encoder = attentions.Encoder(hidden_channels, filter_channels,
188
+ n_heads, n_layers, kernel_size,
189
+ p_dropout)
 
 
 
 
 
190
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
191
 
192
  def forward(self, x, x_lengths):
193
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
 
194
  x = torch.transpose(x, 1, -1) # [b, h, t]
195
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)),
196
+ 1).to(x.dtype)
197
 
198
  x = self.encoder(x * x_mask, x_mask)
199
  stats = self.proj(x) * x_mask
 
223
  self.flows = nn.ModuleList()
224
  for i in range(n_flows):
225
  self.flows.append(
226
+ modules.ResidualCouplingLayer(channels,
227
+ hidden_channels,
228
+ kernel_size,
229
+ dilation_rate,
230
+ n_layers,
231
+ gin_channels=gin_channels,
232
+ mean_only=True))
233
  self.flows.append(modules.Flip())
234
 
235
  def forward(self, x, x_mask, g=None, reverse=False):
 
261
  self.gin_channels = gin_channels
262
 
263
  self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
264
+ self.enc = modules.WN(hidden_channels,
265
+ kernel_size,
266
+ dilation_rate,
267
+ n_layers,
268
+ gin_channels=gin_channels)
269
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
270
 
271
  def forward(self, x, x_lengths, g=None):
272
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)),
273
+ 1).to(x.dtype)
274
  x = self.pre(x) * x_mask
275
  x = self.enc(x, x_mask, g=g)
276
  stats = self.proj(x) * x_mask
 
280
 
281
 
282
  class Generator(torch.nn.Module):
283
+ def __init__(self,
284
+ initial_channel,
285
+ resblock,
286
+ resblock_kernel_sizes,
287
+ resblock_dilation_sizes,
288
+ upsample_rates,
289
+ upsample_initial_channel,
290
+ upsample_kernel_sizes,
291
+ gin_channels=0):
292
  super(Generator, self).__init__()
293
  self.num_kernels = len(resblock_kernel_sizes)
294
  self.num_upsamples = len(upsample_rates)
295
+ self.conv_pre = Conv1d(initial_channel,
296
+ upsample_initial_channel,
297
+ 7,
298
+ 1,
299
+ padding=3)
300
  resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
301
 
302
  self.ups = nn.ModuleList()
303
  for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
304
+ self.ups.append(
305
+ weight_norm(
306
+ ConvTranspose1d(upsample_initial_channel // (2**i),
307
+ upsample_initial_channel // (2**(i + 1)),
308
+ k,
309
+ u,
310
+ padding=(k - u) // 2)))
311
 
312
  self.resblocks = nn.ModuleList()
313
  for i in range(len(self.ups)):
314
+ ch = upsample_initial_channel // (2**(i + 1))
315
+ for j, (k, d) in enumerate(
316
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)):
317
  self.resblocks.append(resblock(ch, k, d))
318
 
319
  self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
 
352
 
353
 
354
  class DiscriminatorP(torch.nn.Module):
355
+ def __init__(self,
356
+ period,
357
+ kernel_size=5,
358
+ stride=3,
359
+ use_spectral_norm=False):
360
  super(DiscriminatorP, self).__init__()
361
  self.period = period
362
  self.use_spectral_norm = use_spectral_norm
363
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
364
  self.convs = nn.ModuleList([
365
+ norm_f(
366
+ Conv2d(1,
367
+ 32, (kernel_size, 1), (stride, 1),
368
+ padding=(get_padding(kernel_size, 1), 0))),
369
+ norm_f(
370
+ Conv2d(32,
371
+ 128, (kernel_size, 1), (stride, 1),
372
+ padding=(get_padding(kernel_size, 1), 0))),
373
+ norm_f(
374
+ Conv2d(128,
375
+ 512, (kernel_size, 1), (stride, 1),
376
+ padding=(get_padding(kernel_size, 1), 0))),
377
+ norm_f(
378
+ Conv2d(512,
379
+ 1024, (kernel_size, 1), (stride, 1),
380
+ padding=(get_padding(kernel_size, 1), 0))),
381
+ norm_f(
382
+ Conv2d(1024,
383
+ 1024, (kernel_size, 1),
384
+ 1,
385
+ padding=(get_padding(kernel_size, 1), 0))),
386
  ])
387
  self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
388
 
 
411
  class DiscriminatorS(torch.nn.Module):
412
  def __init__(self, use_spectral_norm=False):
413
  super(DiscriminatorS, self).__init__()
414
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
415
  self.convs = nn.ModuleList([
416
  norm_f(Conv1d(1, 16, 15, 1, padding=7)),
417
  norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
 
442
  periods = [2, 3, 5, 7, 11]
443
 
444
  discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
445
+ discs = discs + [
446
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm)
447
+ for i in periods
448
+ ]
449
  self.discriminators = nn.ModuleList(discs)
450
 
451
  def forward(self, y, y_hat):
 
466
 
467
  class SynthesizerTrn(nn.Module):
468
  """
469
+ Synthesizer for Training
470
+ """
 
471
  def __init__(self,
472
  n_vocab,
473
  spec_channels,
 
509
  self.segment_size = segment_size
510
  self.n_speakers = n_speakers
511
  self.gin_channels = gin_channels
512
+ if self.n_speakers != 0:
513
+ message = "gin_channels must be none zero for multiple speakers"
514
+ assert gin_channels != 0, message
515
 
516
  self.use_sdp = use_sdp
517
 
518
+ self.enc_p = TextEncoder(n_vocab, inter_channels, hidden_channels,
519
+ filter_channels, n_heads, n_layers,
520
+ kernel_size, p_dropout)
521
+ self.dec = Generator(inter_channels,
522
+ resblock,
523
+ resblock_kernel_sizes,
524
+ resblock_dilation_sizes,
525
+ upsample_rates,
526
+ upsample_initial_channel,
527
+ upsample_kernel_sizes,
528
+ gin_channels=gin_channels)
529
+ self.enc_q = PosteriorEncoder(spec_channels,
530
+ inter_channels,
531
+ hidden_channels,
532
+ 5,
533
+ 1,
534
+ 16,
535
  gin_channels=gin_channels)
536
+ self.flow = ResidualCouplingBlock(inter_channels,
537
+ hidden_channels,
538
+ 5,
539
+ 1,
540
+ 4,
541
+ gin_channels=gin_channels)
542
 
543
  if use_sdp:
544
+ self.dp = StochasticDurationPredictor(hidden_channels,
545
+ 192,
546
+ 3,
547
+ 0.5,
548
+ 4,
549
+ gin_channels=gin_channels)
550
  else:
551
+ self.dp = DurationPredictor(hidden_channels,
552
+ 256,
553
+ 3,
554
+ 0.5,
555
+ gin_channels=gin_channels)
556
 
557
  if n_speakers > 1:
558
  self.emb_g = nn.Embedding(n_speakers, gin_channels)
559
 
560
+ def forward(self, x, x_lengths, y, y_lengths, sid=None):
561
+
562
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
563
+ if self.n_speakers > 0:
564
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
565
+ else:
566
+ g = None
567
+
568
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
569
+ z_p = self.flow(z, y_mask, g=g)
570
+
571
+ with torch.no_grad():
572
+ # negative cross-entropy
573
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
574
+ neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1],
575
+ keepdim=True) # [b, 1, t_s]
576
+ neg_cent2 = torch.matmul(
577
+ -0.5 * (z_p**2).transpose(1, 2),
578
+ s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
579
+ neg_cent3 = torch.matmul(
580
+ z_p.transpose(1, 2),
581
+ (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
582
+ neg_cent4 = torch.sum(-0.5 * (m_p**2) * s_p_sq_r, [1],
583
+ keepdim=True) # [b, 1, t_s]
584
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
585
+
586
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(
587
+ y_mask, -1)
588
+ attn = monotonic_align.maximum_path(
589
+ neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
590
+
591
+ w = attn.sum(2)
592
+ if self.use_sdp:
593
+ l_length = self.dp(x, x_mask, w, g=g)
594
+ l_length = l_length / torch.sum(x_mask)
595
+ else:
596
+ logw_ = torch.log(w + 1e-6) * x_mask
597
+ logw = self.dp(x, x_mask, g=g)
598
+ l_length = torch.sum(
599
+ (logw - logw_)**2, [1, 2]) / torch.sum(x_mask) # for averaging
600
+
601
+ # expand prior
602
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1,
603
+ 2)).transpose(1, 2)
604
+ logs_p = torch.matmul(attn.squeeze(1),
605
+ logs_p.transpose(1, 2)).transpose(1, 2)
606
+
607
+ z_slice, ids_slice = commons.rand_slice_segments(
608
+ z, y_lengths, self.segment_size)
609
+ o = self.dec(z_slice, g=g)
610
+ return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p,
611
+ logs_p, m_q,
612
+ logs_q)
613
+
614
+ def infer(self,
615
+ x,
616
+ x_lengths,
617
+ sid=None,
618
+ noise_scale=1,
619
+ length_scale=1,
620
+ noise_scale_w=1.,
621
+ max_len=None):
622
  x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
623
  if self.n_speakers > 0:
624
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
 
626
  g = None
627
 
628
  if self.use_sdp:
629
+ logw = self.dp(x,
630
+ x_mask,
631
+ g=g,
632
+ reverse=True,
633
+ noise_scale=noise_scale_w)
634
  else:
635
  logw = self.dp(x, x_mask, g=g)
636
  w = torch.exp(logw) * x_mask * length_scale
637
  w_ceil = torch.ceil(w)
638
  y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
639
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None),
640
+ 1).to(x_mask.dtype)
641
  attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
642
  attn = commons.generate_path(w_ceil, attn_mask)
643
 
644
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
645
+ 1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
646
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(
647
+ 1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
648
 
649
  z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
650
  z = self.flow(z_p, y_mask, g=g, reverse=True)
651
  o = self.dec((z * y_mask)[:, :, :max_len], g=g)
652
  return o, attn, y_mask, (z, z_p, m_p, logs_p)
653
 
654
+ def export_forward(self, x, x_lengths, scales, sid):
655
+ # shape of scales: Bx3, make triton happy
656
+ audio, *_ = self.infer(x,
657
+ x_lengths,
658
+ sid,
659
+ noise_scale=scales[0][0],
660
+ length_scale=scales[0][1],
661
+ noise_scale_w=scales[0][2])
662
+ return audio
663
+
664
  def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
665
  assert self.n_speakers > 0, "n_speakers have to be larger than 0."
666
  g_src = self.emb_g(sid_src).unsqueeze(-1)
modules.py CHANGED
@@ -1,8 +1,8 @@
1
  import math
 
2
  import torch
3
  from torch import nn
4
  from torch.nn import functional as F
5
-
6
  from torch.nn import Conv1d
7
  from torch.nn.utils import weight_norm, remove_weight_norm
8
 
@@ -10,197 +10,249 @@ import commons
10
  from commons import init_weights, get_padding
11
  from transforms import piecewise_rational_quadratic_transform
12
 
13
-
14
  LRELU_SLOPE = 0.1
15
 
16
 
17
  class LayerNorm(nn.Module):
18
- def __init__(self, channels, eps=1e-5):
19
- super().__init__()
20
- self.channels = channels
21
- self.eps = eps
22
 
23
- self.gamma = nn.Parameter(torch.ones(channels))
24
- self.beta = nn.Parameter(torch.zeros(channels))
 
 
 
 
 
25
 
26
- def forward(self, x):
27
- x = x.transpose(1, -1)
28
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
- return x.transpose(1, -1)
30
 
31
-
32
  class ConvReluNorm(nn.Module):
33
- def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
34
- super().__init__()
35
- self.in_channels = in_channels
36
- self.hidden_channels = hidden_channels
37
- self.out_channels = out_channels
38
- self.kernel_size = kernel_size
39
- self.n_layers = n_layers
40
- self.p_dropout = p_dropout
41
- assert n_layers > 1, "Number of layers should be larger than 0."
42
-
43
- self.conv_layers = nn.ModuleList()
44
- self.norm_layers = nn.ModuleList()
45
- self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
46
- self.norm_layers.append(LayerNorm(hidden_channels))
47
- self.relu_drop = nn.Sequential(
48
- nn.ReLU(),
49
- nn.Dropout(p_dropout))
50
- for _ in range(n_layers-1):
51
- self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
52
- self.norm_layers.append(LayerNorm(hidden_channels))
53
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
54
- self.proj.weight.data.zero_()
55
- self.proj.bias.data.zero_()
56
-
57
- def forward(self, x, x_mask):
58
- x_org = x
59
- for i in range(self.n_layers):
60
- x = self.conv_layers[i](x * x_mask)
61
- x = self.norm_layers[i](x)
62
- x = self.relu_drop(x)
63
- x = x_org + self.proj(x)
64
- return x * x_mask
 
 
 
 
 
 
 
65
 
66
 
67
  class DDSConv(nn.Module):
68
- """
69
  Dialted and Depth-Separable Convolution
70
  """
71
- def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
72
- super().__init__()
73
- self.channels = channels
74
- self.kernel_size = kernel_size
75
- self.n_layers = n_layers
76
- self.p_dropout = p_dropout
77
-
78
- self.drop = nn.Dropout(p_dropout)
79
- self.convs_sep = nn.ModuleList()
80
- self.convs_1x1 = nn.ModuleList()
81
- self.norms_1 = nn.ModuleList()
82
- self.norms_2 = nn.ModuleList()
83
- for i in range(n_layers):
84
- dilation = kernel_size ** i
85
- padding = (kernel_size * dilation - dilation) // 2
86
- self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
87
- groups=channels, dilation=dilation, padding=padding
88
- ))
89
- self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
90
- self.norms_1.append(LayerNorm(channels))
91
- self.norms_2.append(LayerNorm(channels))
92
-
93
- def forward(self, x, x_mask, g=None):
94
- if g is not None:
95
- x = x + g
96
- for i in range(self.n_layers):
97
- y = self.convs_sep[i](x * x_mask)
98
- y = self.norms_1[i](y)
99
- y = F.gelu(y)
100
- y = self.convs_1x1[i](y)
101
- y = self.norms_2[i](y)
102
- y = F.gelu(y)
103
- y = self.drop(y)
104
- x = x + y
105
- return x * x_mask
 
 
 
 
106
 
107
 
108
  class WN(torch.nn.Module):
109
- def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
110
- super(WN, self).__init__()
111
- assert(kernel_size % 2 == 1)
112
- self.hidden_channels =hidden_channels
113
- self.kernel_size = kernel_size,
114
- self.dilation_rate = dilation_rate
115
- self.n_layers = n_layers
116
- self.gin_channels = gin_channels
117
- self.p_dropout = p_dropout
118
-
119
- self.in_layers = torch.nn.ModuleList()
120
- self.res_skip_layers = torch.nn.ModuleList()
121
- self.drop = nn.Dropout(p_dropout)
122
-
123
- if gin_channels != 0:
124
- cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
125
- self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
126
-
127
- for i in range(n_layers):
128
- dilation = dilation_rate ** i
129
- padding = int((kernel_size * dilation - dilation) / 2)
130
- in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
131
- dilation=dilation, padding=padding)
132
- in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
133
- self.in_layers.append(in_layer)
134
-
135
- # last one is not necessary
136
- if i < n_layers - 1:
137
- res_skip_channels = 2 * hidden_channels
138
- else:
139
- res_skip_channels = hidden_channels
140
-
141
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
142
- res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
143
- self.res_skip_layers.append(res_skip_layer)
144
-
145
- def forward(self, x, x_mask, g=None, **kwargs):
146
- output = torch.zeros_like(x)
147
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
148
-
149
- if g is not None:
150
- g = self.cond_layer(g)
151
-
152
- for i in range(self.n_layers):
153
- x_in = self.in_layers[i](x)
154
- if g is not None:
155
- cond_offset = i * 2 * self.hidden_channels
156
- g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
157
- else:
158
- g_l = torch.zeros_like(x_in)
159
-
160
- acts = commons.fused_add_tanh_sigmoid_multiply(
161
- x_in,
162
- g_l,
163
- n_channels_tensor)
164
- acts = self.drop(acts)
165
-
166
- res_skip_acts = self.res_skip_layers[i](acts)
167
- if i < self.n_layers - 1:
168
- res_acts = res_skip_acts[:,:self.hidden_channels,:]
169
- x = (x + res_acts) * x_mask
170
- output = output + res_skip_acts[:,self.hidden_channels:,:]
171
- else:
172
- output = output + res_skip_acts
173
- return output * x_mask
174
-
175
- def remove_weight_norm(self):
176
- if self.gin_channels != 0:
177
- torch.nn.utils.remove_weight_norm(self.cond_layer)
178
- for l in self.in_layers:
179
- torch.nn.utils.remove_weight_norm(l)
180
- for l in self.res_skip_layers:
181
- torch.nn.utils.remove_weight_norm(l)
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
 
184
  class ResBlock1(torch.nn.Module):
185
  def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
186
  super(ResBlock1, self).__init__()
187
  self.convs1 = nn.ModuleList([
188
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
189
- padding=get_padding(kernel_size, dilation[0]))),
190
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
191
- padding=get_padding(kernel_size, dilation[1]))),
192
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
193
- padding=get_padding(kernel_size, dilation[2])))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  ])
195
  self.convs1.apply(init_weights)
196
 
197
  self.convs2 = nn.ModuleList([
198
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
199
- padding=get_padding(kernel_size, 1))),
200
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
201
- padding=get_padding(kernel_size, 1))),
202
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
203
- padding=get_padding(kernel_size, 1)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  ])
205
  self.convs2.apply(init_weights)
206
 
@@ -230,10 +282,20 @@ class ResBlock2(torch.nn.Module):
230
  def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
231
  super(ResBlock2, self).__init__()
232
  self.convs = nn.ModuleList([
233
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
234
- padding=get_padding(kernel_size, dilation[0]))),
235
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
236
- padding=get_padding(kernel_size, dilation[1])))
 
 
 
 
 
 
 
 
 
 
237
  ])
238
  self.convs.apply(init_weights)
239
 
@@ -254,134 +316,154 @@ class ResBlock2(torch.nn.Module):
254
 
255
 
256
  class Log(nn.Module):
257
- def forward(self, x, x_mask, reverse=False, **kwargs):
258
- if not reverse:
259
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
260
- logdet = torch.sum(-y, [1, 2])
261
- return y, logdet
262
- else:
263
- x = torch.exp(x) * x_mask
264
- return x
265
-
266
 
267
  class Flip(nn.Module):
268
- def forward(self, x, *args, reverse=False, **kwargs):
269
- x = torch.flip(x, [1])
270
- if not reverse:
271
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
272
- return x, logdet
273
- else:
274
- return x
275
 
276
 
277
  class ElementwiseAffine(nn.Module):
278
- def __init__(self, channels):
279
- super().__init__()
280
- self.channels = channels
281
- self.m = nn.Parameter(torch.zeros(channels,1))
282
- self.logs = nn.Parameter(torch.zeros(channels,1))
283
-
284
- def forward(self, x, x_mask, reverse=False, **kwargs):
285
- if not reverse:
286
- y = self.m + torch.exp(self.logs) * x
287
- y = y * x_mask
288
- logdet = torch.sum(self.logs * x_mask, [1,2])
289
- return y, logdet
290
- else:
291
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
292
- return x
293
 
294
 
295
  class ResidualCouplingLayer(nn.Module):
296
- def __init__(self,
297
- channels,
298
- hidden_channels,
299
- kernel_size,
300
- dilation_rate,
301
- n_layers,
302
- p_dropout=0,
303
- gin_channels=0,
304
- mean_only=False):
305
- assert channels % 2 == 0, "channels should be divisible by 2"
306
- super().__init__()
307
- self.channels = channels
308
- self.hidden_channels = hidden_channels
309
- self.kernel_size = kernel_size
310
- self.dilation_rate = dilation_rate
311
- self.n_layers = n_layers
312
- self.half_channels = channels // 2
313
- self.mean_only = mean_only
314
-
315
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
316
- self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
317
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
318
- self.post.weight.data.zero_()
319
- self.post.bias.data.zero_()
320
-
321
- def forward(self, x, x_mask, g=None, reverse=False):
322
- x0, x1 = torch.split(x, [self.half_channels]*2, 1)
323
- h = self.pre(x0) * x_mask
324
- h = self.enc(h, x_mask, g=g)
325
- stats = self.post(h) * x_mask
326
- if not self.mean_only:
327
- m, logs = torch.split(stats, [self.half_channels]*2, 1)
328
- else:
329
- m = stats
330
- logs = torch.zeros_like(m)
331
-
332
- if not reverse:
333
- x1 = m + x1 * torch.exp(logs) * x_mask
334
- x = torch.cat([x0, x1], 1)
335
- logdet = torch.sum(logs, [1,2])
336
- return x, logdet
337
- else:
338
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
339
- x = torch.cat([x0, x1], 1)
340
- return x
 
 
 
 
 
 
341
 
342
 
343
  class ConvFlow(nn.Module):
344
- def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
345
- super().__init__()
346
- self.in_channels = in_channels
347
- self.filter_channels = filter_channels
348
- self.kernel_size = kernel_size
349
- self.n_layers = n_layers
350
- self.num_bins = num_bins
351
- self.tail_bound = tail_bound
352
- self.half_channels = in_channels // 2
353
-
354
- self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
355
- self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
356
- self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
357
- self.proj.weight.data.zero_()
358
- self.proj.bias.data.zero_()
359
-
360
- def forward(self, x, x_mask, g=None, reverse=False):
361
- x0, x1 = torch.split(x, [self.half_channels]*2, 1)
362
- h = self.pre(x0)
363
- h = self.convs(h, x_mask, g=g)
364
- h = self.proj(h) * x_mask
365
-
366
- b, c, t = x0.shape
367
- h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
368
-
369
- unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
370
- unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
371
- unnormalized_derivatives = h[..., 2 * self.num_bins:]
372
-
373
- x1, logabsdet = piecewise_rational_quadratic_transform(x1,
374
- unnormalized_widths,
375
- unnormalized_heights,
376
- unnormalized_derivatives,
377
- inverse=reverse,
378
- tails='linear',
379
- tail_bound=self.tail_bound
380
- )
381
-
382
- x = torch.cat([x0, x1], 1) * x_mask
383
- logdet = torch.sum(logabsdet * x_mask, [1,2])
384
- if not reverse:
385
- return x, logdet
386
- else:
387
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
+
3
  import torch
4
  from torch import nn
5
  from torch.nn import functional as F
 
6
  from torch.nn import Conv1d
7
  from torch.nn.utils import weight_norm, remove_weight_norm
8
 
 
10
  from commons import init_weights, get_padding
11
  from transforms import piecewise_rational_quadratic_transform
12
 
 
13
  LRELU_SLOPE = 0.1
14
 
15
 
16
  class LayerNorm(nn.Module):
17
+ def __init__(self, channels, eps=1e-5):
18
+ super().__init__()
19
+ self.channels = channels
20
+ self.eps = eps
21
 
22
+ self.gamma = nn.Parameter(torch.ones(channels))
23
+ self.beta = nn.Parameter(torch.zeros(channels))
24
+
25
+ def forward(self, x):
26
+ x = x.transpose(1, -1)
27
+ x = F.layer_norm(x, (self.channels, ), self.gamma, self.beta, self.eps)
28
+ return x.transpose(1, -1)
29
 
 
 
 
 
30
 
 
31
  class ConvReluNorm(nn.Module):
32
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
33
+ n_layers, p_dropout):
34
+ super().__init__()
35
+ self.in_channels = in_channels
36
+ self.hidden_channels = hidden_channels
37
+ self.out_channels = out_channels
38
+ self.kernel_size = kernel_size
39
+ self.n_layers = n_layers
40
+ self.p_dropout = p_dropout
41
+ assert n_layers > 1, "Number of layers should be larger than 0."
42
+
43
+ self.conv_layers = nn.ModuleList()
44
+ self.norm_layers = nn.ModuleList()
45
+ self.conv_layers.append(
46
+ nn.Conv1d(in_channels,
47
+ hidden_channels,
48
+ kernel_size,
49
+ padding=kernel_size // 2))
50
+ self.norm_layers.append(LayerNorm(hidden_channels))
51
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
52
+ for _ in range(n_layers - 1):
53
+ self.conv_layers.append(
54
+ nn.Conv1d(hidden_channels,
55
+ hidden_channels,
56
+ kernel_size,
57
+ padding=kernel_size // 2))
58
+ self.norm_layers.append(LayerNorm(hidden_channels))
59
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
60
+ self.proj.weight.data.zero_()
61
+ self.proj.bias.data.zero_()
62
+
63
+ def forward(self, x, x_mask):
64
+ x_org = x
65
+ for i in range(self.n_layers):
66
+ x = self.conv_layers[i](x * x_mask)
67
+ x = self.norm_layers[i](x)
68
+ x = self.relu_drop(x)
69
+ x = x_org + self.proj(x)
70
+ return x * x_mask
71
 
72
 
73
  class DDSConv(nn.Module):
74
+ """
75
  Dialted and Depth-Separable Convolution
76
  """
77
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
78
+ super().__init__()
79
+ self.channels = channels
80
+ self.kernel_size = kernel_size
81
+ self.n_layers = n_layers
82
+ self.p_dropout = p_dropout
83
+
84
+ self.drop = nn.Dropout(p_dropout)
85
+ self.convs_sep = nn.ModuleList()
86
+ self.convs_1x1 = nn.ModuleList()
87
+ self.norms_1 = nn.ModuleList()
88
+ self.norms_2 = nn.ModuleList()
89
+ for i in range(n_layers):
90
+ dilation = kernel_size**i
91
+ padding = (kernel_size * dilation - dilation) // 2
92
+ self.convs_sep.append(
93
+ nn.Conv1d(channels,
94
+ channels,
95
+ kernel_size,
96
+ groups=channels,
97
+ dilation=dilation,
98
+ padding=padding))
99
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
100
+ self.norms_1.append(LayerNorm(channels))
101
+ self.norms_2.append(LayerNorm(channels))
102
+
103
+ def forward(self, x, x_mask, g=None):
104
+ if g is not None:
105
+ x = x + g
106
+ for i in range(self.n_layers):
107
+ y = self.convs_sep[i](x * x_mask)
108
+ y = self.norms_1[i](y)
109
+ y = F.gelu(y)
110
+ y = self.convs_1x1[i](y)
111
+ y = self.norms_2[i](y)
112
+ y = F.gelu(y)
113
+ y = self.drop(y)
114
+ x = x + y
115
+ return x * x_mask
116
 
117
 
118
  class WN(torch.nn.Module):
119
+ def __init__(self,
120
+ hidden_channels,
121
+ kernel_size,
122
+ dilation_rate,
123
+ n_layers,
124
+ gin_channels=0,
125
+ p_dropout=0):
126
+ super(WN, self).__init__()
127
+ assert (kernel_size % 2 == 1)
128
+ self.hidden_channels = hidden_channels
129
+ self.kernel_size = kernel_size,
130
+ self.dilation_rate = dilation_rate
131
+ self.n_layers = n_layers
132
+ self.gin_channels = gin_channels
133
+ self.p_dropout = p_dropout
134
+
135
+ self.in_layers = torch.nn.ModuleList()
136
+ self.res_skip_layers = torch.nn.ModuleList()
137
+ self.drop = nn.Dropout(p_dropout)
138
+
139
+ if gin_channels != 0:
140
+ cond_layer = torch.nn.Conv1d(gin_channels,
141
+ 2 * hidden_channels * n_layers, 1)
142
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
143
+ name='weight')
144
+
145
+ for i in range(n_layers):
146
+ dilation = dilation_rate**i
147
+ padding = int((kernel_size * dilation - dilation) / 2)
148
+ in_layer = torch.nn.Conv1d(hidden_channels,
149
+ 2 * hidden_channels,
150
+ kernel_size,
151
+ dilation=dilation,
152
+ padding=padding)
153
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
154
+ self.in_layers.append(in_layer)
155
+
156
+ # last one is not necessary
157
+ if i < n_layers - 1:
158
+ res_skip_channels = 2 * hidden_channels
159
+ else:
160
+ res_skip_channels = hidden_channels
161
+
162
+ res_skip_layer = torch.nn.Conv1d(hidden_channels,
163
+ res_skip_channels, 1)
164
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer,
165
+ name='weight')
166
+ self.res_skip_layers.append(res_skip_layer)
167
+
168
+ def forward(self, x, x_mask, g=None, **kwargs):
169
+ output = torch.zeros_like(x)
170
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
171
+
172
+ if g is not None:
173
+ g = self.cond_layer(g)
174
+
175
+ for i in range(self.n_layers):
176
+ x_in = self.in_layers[i](x)
177
+ if g is not None:
178
+ cond_offset = i * 2 * self.hidden_channels
179
+ g_l = g[:,
180
+ cond_offset:cond_offset + 2 * self.hidden_channels, :]
181
+ else:
182
+ g_l = torch.zeros_like(x_in)
183
+
184
+ acts = commons.fused_add_tanh_sigmoid_multiply(
185
+ x_in, g_l, n_channels_tensor)
186
+ acts = self.drop(acts)
187
+
188
+ res_skip_acts = self.res_skip_layers[i](acts)
189
+ if i < self.n_layers - 1:
190
+ res_acts = res_skip_acts[:, :self.hidden_channels, :]
191
+ x = (x + res_acts) * x_mask
192
+ output = output + res_skip_acts[:, self.hidden_channels:, :]
193
+ else:
194
+ output = output + res_skip_acts
195
+ return output * x_mask
196
+
197
+ def remove_weight_norm(self):
198
+ if self.gin_channels != 0:
199
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
200
+ for l in self.in_layers:
201
+ torch.nn.utils.remove_weight_norm(l)
202
+ for l in self.res_skip_layers:
203
+ torch.nn.utils.remove_weight_norm(l)
204
 
205
 
206
  class ResBlock1(torch.nn.Module):
207
  def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
208
  super(ResBlock1, self).__init__()
209
  self.convs1 = nn.ModuleList([
210
+ weight_norm(
211
+ Conv1d(channels,
212
+ channels,
213
+ kernel_size,
214
+ 1,
215
+ dilation=dilation[0],
216
+ padding=get_padding(kernel_size, dilation[0]))),
217
+ weight_norm(
218
+ Conv1d(channels,
219
+ channels,
220
+ kernel_size,
221
+ 1,
222
+ dilation=dilation[1],
223
+ padding=get_padding(kernel_size, dilation[1]))),
224
+ weight_norm(
225
+ Conv1d(channels,
226
+ channels,
227
+ kernel_size,
228
+ 1,
229
+ dilation=dilation[2],
230
+ padding=get_padding(kernel_size, dilation[2])))
231
  ])
232
  self.convs1.apply(init_weights)
233
 
234
  self.convs2 = nn.ModuleList([
235
+ weight_norm(
236
+ Conv1d(channels,
237
+ channels,
238
+ kernel_size,
239
+ 1,
240
+ dilation=1,
241
+ padding=get_padding(kernel_size, 1))),
242
+ weight_norm(
243
+ Conv1d(channels,
244
+ channels,
245
+ kernel_size,
246
+ 1,
247
+ dilation=1,
248
+ padding=get_padding(kernel_size, 1))),
249
+ weight_norm(
250
+ Conv1d(channels,
251
+ channels,
252
+ kernel_size,
253
+ 1,
254
+ dilation=1,
255
+ padding=get_padding(kernel_size, 1)))
256
  ])
257
  self.convs2.apply(init_weights)
258
 
 
282
  def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
283
  super(ResBlock2, self).__init__()
284
  self.convs = nn.ModuleList([
285
+ weight_norm(
286
+ Conv1d(channels,
287
+ channels,
288
+ kernel_size,
289
+ 1,
290
+ dilation=dilation[0],
291
+ padding=get_padding(kernel_size, dilation[0]))),
292
+ weight_norm(
293
+ Conv1d(channels,
294
+ channels,
295
+ kernel_size,
296
+ 1,
297
+ dilation=dilation[1],
298
+ padding=get_padding(kernel_size, dilation[1])))
299
  ])
300
  self.convs.apply(init_weights)
301
 
 
316
 
317
 
318
  class Log(nn.Module):
319
+ def forward(self, x, x_mask, reverse=False, **kwargs):
320
+ if not reverse:
321
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
322
+ logdet = torch.sum(-y, [1, 2])
323
+ return y, logdet
324
+ else:
325
+ x = torch.exp(x) * x_mask
326
+ return x
327
+
328
 
329
  class Flip(nn.Module):
330
+ def forward(self, x, *args, reverse=False, **kwargs):
331
+ x = torch.flip(x, [1])
332
+ if not reverse:
333
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
334
+ return x, logdet
335
+ else:
336
+ return x
337
 
338
 
339
  class ElementwiseAffine(nn.Module):
340
+ def __init__(self, channels):
341
+ super().__init__()
342
+ self.channels = channels
343
+ self.m = nn.Parameter(torch.zeros(channels, 1))
344
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
345
+
346
+ def forward(self, x, x_mask, reverse=False, **kwargs):
347
+ if not reverse:
348
+ y = self.m + torch.exp(self.logs) * x
349
+ y = y * x_mask
350
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
351
+ return y, logdet
352
+ else:
353
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
354
+ return x
355
 
356
 
357
  class ResidualCouplingLayer(nn.Module):
358
+ def __init__(self,
359
+ channels,
360
+ hidden_channels,
361
+ kernel_size,
362
+ dilation_rate,
363
+ n_layers,
364
+ p_dropout=0,
365
+ gin_channels=0,
366
+ mean_only=False):
367
+ assert channels % 2 == 0, "channels should be divisible by 2"
368
+ super().__init__()
369
+ self.channels = channels
370
+ self.hidden_channels = hidden_channels
371
+ self.kernel_size = kernel_size
372
+ self.dilation_rate = dilation_rate
373
+ self.n_layers = n_layers
374
+ self.half_channels = channels // 2
375
+ self.mean_only = mean_only
376
+
377
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
378
+ self.enc = WN(hidden_channels,
379
+ kernel_size,
380
+ dilation_rate,
381
+ n_layers,
382
+ p_dropout=p_dropout,
383
+ gin_channels=gin_channels)
384
+ self.post = nn.Conv1d(hidden_channels,
385
+ self.half_channels * (2 - mean_only), 1)
386
+ self.post.weight.data.zero_()
387
+ self.post.bias.data.zero_()
388
+
389
+ def forward(self, x, x_mask, g=None, reverse=False):
390
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
391
+ h = self.pre(x0) * x_mask
392
+ h = self.enc(h, x_mask, g=g)
393
+ stats = self.post(h) * x_mask
394
+ if not self.mean_only:
395
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
396
+ else:
397
+ m = stats
398
+ logs = torch.zeros_like(m)
399
+
400
+ if not reverse:
401
+ x1 = m + x1 * torch.exp(logs) * x_mask
402
+ x = torch.cat([x0, x1], 1)
403
+ logdet = torch.sum(logs, [1, 2])
404
+ return x, logdet
405
+ else:
406
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
407
+ x = torch.cat([x0, x1], 1)
408
+ return x
409
 
410
 
411
  class ConvFlow(nn.Module):
412
+ def __init__(self,
413
+ in_channels,
414
+ filter_channels,
415
+ kernel_size,
416
+ n_layers,
417
+ num_bins=10,
418
+ tail_bound=5.0):
419
+ super().__init__()
420
+ self.in_channels = in_channels
421
+ self.filter_channels = filter_channels
422
+ self.kernel_size = kernel_size
423
+ self.n_layers = n_layers
424
+ self.num_bins = num_bins
425
+ self.tail_bound = tail_bound
426
+ self.half_channels = in_channels // 2
427
+
428
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
429
+ self.convs = DDSConv(filter_channels,
430
+ kernel_size,
431
+ n_layers,
432
+ p_dropout=0.)
433
+ self.proj = nn.Conv1d(filter_channels,
434
+ self.half_channels * (num_bins * 3 - 1), 1)
435
+ self.proj.weight.data.zero_()
436
+ self.proj.bias.data.zero_()
437
+
438
+ def forward(self, x, x_mask, g=None, reverse=False):
439
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
440
+ h = self.pre(x0)
441
+ h = self.convs(h, x_mask, g=g)
442
+ h = self.proj(h) * x_mask
443
+
444
+ b, c, t = x0.shape
445
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3,
446
+ 2) # [b, cx?, t] -> [b, c, t, ?]
447
+
448
+ unnormalized_widths = h[..., :self.num_bins] / math.sqrt(
449
+ self.filter_channels)
450
+ unnormalized_heights = h[...,
451
+ self.num_bins:2 * self.num_bins] / math.sqrt(
452
+ self.filter_channels)
453
+ unnormalized_derivatives = h[..., 2 * self.num_bins:]
454
+
455
+ x1, logabsdet = piecewise_rational_quadratic_transform(
456
+ x1,
457
+ unnormalized_widths,
458
+ unnormalized_heights,
459
+ unnormalized_derivatives,
460
+ inverse=reverse,
461
+ tails='linear',
462
+ tail_bound=self.tail_bound)
463
+
464
+ x = torch.cat([x0, x1], 1) * x_mask
465
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
466
+ if not reverse:
467
+ return x, logdet
468
+ else:
469
+ return x
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  Cython==0.29.21
2
- romajitable
3
  librosa==0.8.0
4
  matplotlib==3.3.1
5
  numpy==1.21.6
@@ -9,7 +8,6 @@ tensorboard==2.3.0
9
  torch
10
  torchvision
11
  Unidecode==1.1.1
12
- pyopenjtalk==0.2.0
13
  jamo==0.4.1
14
  pypinyin==0.44.0
15
  jieba==0.42.1
@@ -17,4 +15,8 @@ cn2an==0.5.17
17
  jieba==0.42.1
18
  ipython==7.34.0
19
  gradio==3.4.1
20
- openai
 
 
 
 
 
1
  Cython==0.29.21
 
2
  librosa==0.8.0
3
  matplotlib==3.3.1
4
  numpy==1.21.6
 
8
  torch
9
  torchvision
10
  Unidecode==1.1.1
 
11
  jamo==0.4.1
12
  pypinyin==0.44.0
13
  jieba==0.42.1
 
15
  jieba==0.42.1
16
  ipython==7.34.0
17
  gradio==3.4.1
18
+ openai
19
+ pydub
20
+ inflect
21
+ eng_to_ipa
22
+ onnxruntime
text/__init__.py CHANGED
@@ -1,8 +1,14 @@
1
  """ from https://github.com/keithito/tacotron """
2
  from text import cleaners
 
3
 
4
 
5
- def text_to_sequence(text, symbols, cleaner_names):
 
 
 
 
 
6
  '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
7
  Args:
8
  text: string to convert to a sequence
@@ -10,8 +16,6 @@ def text_to_sequence(text, symbols, cleaner_names):
10
  Returns:
11
  List of integers corresponding to the symbols in the text
12
  '''
13
- _symbol_to_id = {s: i for i, s in enumerate(symbols)}
14
-
15
  sequence = []
16
 
17
  clean_text = _clean_text(text, cleaner_names)
@@ -23,6 +27,26 @@ def text_to_sequence(text, symbols, cleaner_names):
23
  return sequence
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def _clean_text(text, cleaner_names):
27
  for name in cleaner_names:
28
  cleaner = getattr(cleaners, name)
 
1
  """ from https://github.com/keithito/tacotron """
2
  from text import cleaners
3
+ from text.symbols import symbols
4
 
5
 
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+
11
+ def text_to_sequence(text, cleaner_names):
12
  '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13
  Args:
14
  text: string to convert to a sequence
 
16
  Returns:
17
  List of integers corresponding to the symbols in the text
18
  '''
 
 
19
  sequence = []
20
 
21
  clean_text = _clean_text(text, cleaner_names)
 
27
  return sequence
28
 
29
 
30
+ def cleaned_text_to_sequence(cleaned_text):
31
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
32
+ Args:
33
+ text: string to convert to a sequence
34
+ Returns:
35
+ List of integers corresponding to the symbols in the text
36
+ '''
37
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
38
+ return sequence
39
+
40
+
41
+ def sequence_to_text(sequence):
42
+ '''Converts a sequence of IDs back to a string'''
43
+ result = ''
44
+ for symbol_id in sequence:
45
+ s = _id_to_symbol[symbol_id]
46
+ result += s
47
+ return result
48
+
49
+
50
  def _clean_text(text, cleaner_names):
51
  for name in cleaner_names:
52
  cleaner = getattr(cleaners, name)
text/cleaners.py CHANGED
@@ -1,33 +1,21 @@
1
  import re
2
- from text.japanese import japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_ipa3
 
3
  from text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
4
 
5
- def japanese_cleaners(text):
6
- from text.japanese import japanese_to_romaji_with_accent
7
- text = japanese_to_romaji_with_accent(text)
8
- if re.match('[A-Za-z]', text[-1]):
9
- text += '.'
10
  return text
11
 
 
 
 
 
12
 
13
  def japanese_cleaners2(text):
14
  return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
15
 
16
-
17
- def korean_cleaners(text):
18
- '''Pipeline for Korean text'''
19
- from text.korean import latin_to_hangul, number_to_hangul, divide_hangul
20
- text = latin_to_hangul(text)
21
- text = number_to_hangul(text)
22
- text = divide_hangul(text)
23
- if re.match('[\u3131-\u3163]', text[-1]):
24
- text += '.'
25
- return text
26
-
27
-
28
  def chinese_cleaners(text):
29
  '''Pipeline for Chinese text'''
30
- from text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo
31
  text = number_to_chinese(text)
32
  text = chinese_to_bopomofo(text)
33
  text = latin_to_bopomofo(text)
@@ -35,10 +23,7 @@ def chinese_cleaners(text):
35
  text += '。'
36
  return text
37
 
38
-
39
  def zh_ja_mixture_cleaners(text):
40
- from text.mandarin import chinese_to_romaji
41
- from text.japanese import japanese_to_romaji_with_accent
42
  chinese_texts = re.findall(r'\[ZH\].*?\[ZH\]', text)
43
  japanese_texts = re.findall(r'\[JA\].*?\[JA\]', text)
44
  for chinese_text in chinese_texts:
@@ -53,53 +38,25 @@ def zh_ja_mixture_cleaners(text):
53
  text += '.'
54
  return text
55
 
56
-
57
- def sanskrit_cleaners(text):
58
- text = text.replace('॥', '।').replace('ॐ', 'ओम्')
59
- if text[-1] != '।':
60
- text += ' ।'
61
- return text
62
-
63
-
64
- def cjks_cleaners(text):
65
- from text.mandarin import chinese_to_lazy_ipa
66
- from text.japanese import japanese_to_ipa
67
- from text.korean import korean_to_lazy_ipa
68
- from text.sanskrit import devanagari_to_ipa
69
- chinese_texts = re.findall(r'\[ZH\].*?\[ZH\]', text)
70
- japanese_texts = re.findall(r'\[JA\].*?\[JA\]', text)
71
- korean_texts = re.findall(r'\[KO\].*?\[KO\]', text)
72
- sanskrit_texts = re.findall(r'\[SA\].*?\[SA\]', text)
73
- for chinese_text in chinese_texts:
74
- cleaned_text = chinese_to_lazy_ipa(chinese_text[4:-4])
75
- text = text.replace(chinese_text, cleaned_text+' ', 1)
76
- for japanese_text in japanese_texts:
77
- cleaned_text = japanese_to_ipa(japanese_text[4:-4])
78
- text = text.replace(japanese_text, cleaned_text+' ', 1)
79
- for korean_text in korean_texts:
80
- cleaned_text = korean_to_lazy_ipa(korean_text[4:-4])
81
- text = text.replace(korean_text, cleaned_text+' ', 1)
82
- for sanskrit_text in sanskrit_texts:
83
- cleaned_text = devanagari_to_ipa(sanskrit_text[4:-4])
84
- text = text.replace(sanskrit_text, cleaned_text+' ', 1)
85
- text = text[:-1]
86
- if re.match(r'[^\.,!\?\-…~]', text[-1]):
87
- text += '.'
88
- return text
89
-
90
  def cjke_cleaners(text):
91
  chinese_texts = re.findall(r'\[ZH\].*?\[ZH\]', text)
92
  japanese_texts = re.findall(r'\[JA\].*?\[JA\]', text)
 
93
  for chinese_text in chinese_texts:
94
  cleaned_text = chinese_to_lazy_ipa(chinese_text[4:-4])
95
  cleaned_text = cleaned_text.replace(
96
  'ʧ', 'tʃ').replace('ʦ', 'ts').replace('ɥan', 'ɥæn')
97
  text = text.replace(chinese_text, cleaned_text+' ', 1)
98
  for japanese_text in japanese_texts:
99
- cleaned_text = japanese_to_ipa(japanese_text[4:-4])
100
  cleaned_text = cleaned_text.replace('ʧ', 'tʃ').replace(
101
  'ʦ', 'ts').replace('ɥan', 'ɥæn').replace('ʥ', 'dz')
102
  text = text.replace(japanese_text, cleaned_text+' ', 1)
 
 
 
 
 
103
  text = text[:-1]
104
  if re.match(r'[^\.,!\?\-…~]', text[-1]):
105
  text += '.'
 
1
  import re
2
+ from text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2
3
+ from text.japanese import clean_japanese, japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_ipa3
4
  from text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
5
 
6
+ def none_cleaner(text):
 
 
 
 
7
  return text
8
 
9
+ def japanese_cleaners(text):
10
+ text = clean_japanese(text)
11
+ text = re.sub(r'([A-Za-z])$', r'\1.', text)
12
+ return text
13
 
14
  def japanese_cleaners2(text):
15
  return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def chinese_cleaners(text):
18
  '''Pipeline for Chinese text'''
 
19
  text = number_to_chinese(text)
20
  text = chinese_to_bopomofo(text)
21
  text = latin_to_bopomofo(text)
 
23
  text += '。'
24
  return text
25
 
 
26
  def zh_ja_mixture_cleaners(text):
 
 
27
  chinese_texts = re.findall(r'\[ZH\].*?\[ZH\]', text)
28
  japanese_texts = re.findall(r'\[JA\].*?\[JA\]', text)
29
  for chinese_text in chinese_texts:
 
38
  text += '.'
39
  return text
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def cjke_cleaners(text):
42
  chinese_texts = re.findall(r'\[ZH\].*?\[ZH\]', text)
43
  japanese_texts = re.findall(r'\[JA\].*?\[JA\]', text)
44
+ english_texts = re.findall(r'\[EN\].*?\[EN\]', text)
45
  for chinese_text in chinese_texts:
46
  cleaned_text = chinese_to_lazy_ipa(chinese_text[4:-4])
47
  cleaned_text = cleaned_text.replace(
48
  'ʧ', 'tʃ').replace('ʦ', 'ts').replace('ɥan', 'ɥæn')
49
  text = text.replace(chinese_text, cleaned_text+' ', 1)
50
  for japanese_text in japanese_texts:
51
+ cleaned_text = clean_japanese(japanese_text[4:-4])
52
  cleaned_text = cleaned_text.replace('ʧ', 'tʃ').replace(
53
  'ʦ', 'ts').replace('ɥan', 'ɥæn').replace('ʥ', 'dz')
54
  text = text.replace(japanese_text, cleaned_text+' ', 1)
55
+ for english_text in english_texts:
56
+ cleaned_text = english_to_ipa2(english_text[4:-4])
57
+ cleaned_text = cleaned_text.replace('ɑ', 'a').replace(
58
+ 'ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u')
59
+ text = text.replace(english_text, cleaned_text+' ', 1)
60
  text = text[:-1]
61
  if re.match(r'[^\.,!\?\-…~]', text[-1]):
62
  text += '.'
text/english.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ '''
4
+ Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ 1. "english_cleaners" for English text
9
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ the symbols in symbols.py to match your data).
13
+ '''
14
+
15
+
16
+ # Regular expression matching whitespace:
17
+
18
+
19
+ import re
20
+ import inflect
21
+ from unidecode import unidecode
22
+ import eng_to_ipa as ipa
23
+ _inflect = inflect.engine()
24
+ _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
25
+ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
26
+ _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
27
+ _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
28
+ _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
29
+ _number_re = re.compile(r'[0-9]+')
30
+
31
+ # List of (regular expression, replacement) pairs for abbreviations:
32
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
33
+ ('mrs', 'misess'),
34
+ ('mr', 'mister'),
35
+ ('dr', 'doctor'),
36
+ ('st', 'saint'),
37
+ ('co', 'company'),
38
+ ('jr', 'junior'),
39
+ ('maj', 'major'),
40
+ ('gen', 'general'),
41
+ ('drs', 'doctors'),
42
+ ('rev', 'reverend'),
43
+ ('lt', 'lieutenant'),
44
+ ('hon', 'honorable'),
45
+ ('sgt', 'sergeant'),
46
+ ('capt', 'captain'),
47
+ ('esq', 'esquire'),
48
+ ('ltd', 'limited'),
49
+ ('col', 'colonel'),
50
+ ('ft', 'fort'),
51
+ ]]
52
+
53
+
54
+ # List of (ipa, lazy ipa) pairs:
55
+ _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
56
+ ('r', 'ɹ'),
57
+ ('æ', 'e'),
58
+ ('ɑ', 'a'),
59
+ ('ɔ', 'o'),
60
+ ('ð', 'z'),
61
+ ('θ', 's'),
62
+ ('ɛ', 'e'),
63
+ ('ɪ', 'i'),
64
+ ('ʊ', 'u'),
65
+ ('ʒ', 'ʥ'),
66
+ ('ʤ', 'ʥ'),
67
+ ('ˈ', '↓'),
68
+ ]]
69
+
70
+ # List of (ipa, lazy ipa2) pairs:
71
+ _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
72
+ ('r', 'ɹ'),
73
+ ('ð', 'z'),
74
+ ('θ', 's'),
75
+ ('ʒ', 'ʑ'),
76
+ ('ʤ', 'dʑ'),
77
+ ('ˈ', '↓'),
78
+ ]]
79
+
80
+ # List of (ipa, ipa2) pairs
81
+ _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
82
+ ('r', 'ɹ'),
83
+ ('ʤ', 'dʒ'),
84
+ ('ʧ', 'tʃ')
85
+ ]]
86
+
87
+
88
+ def expand_abbreviations(text):
89
+ for regex, replacement in _abbreviations:
90
+ text = re.sub(regex, replacement, text)
91
+ return text
92
+
93
+
94
+ def collapse_whitespace(text):
95
+ return re.sub(r'\s+', ' ', text)
96
+
97
+
98
+ def _remove_commas(m):
99
+ return m.group(1).replace(',', '')
100
+
101
+
102
+ def _expand_decimal_point(m):
103
+ return m.group(1).replace('.', ' point ')
104
+
105
+
106
+ def _expand_dollars(m):
107
+ match = m.group(1)
108
+ parts = match.split('.')
109
+ if len(parts) > 2:
110
+ return match + ' dollars' # Unexpected format
111
+ dollars = int(parts[0]) if parts[0] else 0
112
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
113
+ if dollars and cents:
114
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
115
+ cent_unit = 'cent' if cents == 1 else 'cents'
116
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
117
+ elif dollars:
118
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
119
+ return '%s %s' % (dollars, dollar_unit)
120
+ elif cents:
121
+ cent_unit = 'cent' if cents == 1 else 'cents'
122
+ return '%s %s' % (cents, cent_unit)
123
+ else:
124
+ return 'zero dollars'
125
+
126
+
127
+ def _expand_ordinal(m):
128
+ return _inflect.number_to_words(m.group(0))
129
+
130
+
131
+ def _expand_number(m):
132
+ num = int(m.group(0))
133
+ if num > 1000 and num < 3000:
134
+ if num == 2000:
135
+ return 'two thousand'
136
+ elif num > 2000 and num < 2010:
137
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
138
+ elif num % 100 == 0:
139
+ return _inflect.number_to_words(num // 100) + ' hundred'
140
+ else:
141
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
142
+ else:
143
+ return _inflect.number_to_words(num, andword='')
144
+
145
+
146
+ def normalize_numbers(text):
147
+ text = re.sub(_comma_number_re, _remove_commas, text)
148
+ text = re.sub(_pounds_re, r'\1 pounds', text)
149
+ text = re.sub(_dollars_re, _expand_dollars, text)
150
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
151
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
152
+ text = re.sub(_number_re, _expand_number, text)
153
+ return text
154
+
155
+
156
+ def mark_dark_l(text):
157
+ return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
158
+
159
+
160
+ def english_to_ipa(text):
161
+ text = unidecode(text).lower()
162
+ text = expand_abbreviations(text)
163
+ text = normalize_numbers(text)
164
+ phonemes = ipa.convert(text)
165
+ phonemes = collapse_whitespace(phonemes)
166
+ return phonemes
167
+
168
+
169
+ def english_to_lazy_ipa(text):
170
+ text = english_to_ipa(text)
171
+ for regex, replacement in _lazy_ipa:
172
+ text = re.sub(regex, replacement, text)
173
+ return text
174
+
175
+
176
+ def english_to_ipa2(text):
177
+ text = english_to_ipa(text)
178
+ text = mark_dark_l(text)
179
+ for regex, replacement in _ipa_to_ipa2:
180
+ text = re.sub(regex, replacement, text)
181
+ return text.replace('...', '…')
182
+
183
+
184
+ def english_to_lazy_ipa2(text):
185
+ text = english_to_ipa(text)
186
+ for regex, replacement in _lazy_ipa2:
187
+ text = re.sub(regex, replacement, text)
188
+ return text
text/japanese.py CHANGED
@@ -1,6 +1,18 @@
1
  import re
2
  from unidecode import unidecode
3
- import pyopenjtalk
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  # Regular expression matching Japanese without punctuation marks:
 
1
  import re
2
  from unidecode import unidecode
3
+ from unidecode import unidecode
4
+ import ctypes
5
+
6
+ dll = ctypes.cdll.LoadLibrary('cleaners/JapaneseCleaner.dll')
7
+ dll.CreateOjt.restype = ctypes.c_uint64
8
+ dll.PluginMain.restype = ctypes.c_uint64
9
+ floder = ctypes.create_unicode_buffer("cleaners")
10
+ dll.CreateOjt(floder)
11
+
12
+ def clean_japanese(text):
13
+ input_wchar_pointer = ctypes.create_unicode_buffer(text)
14
+ result = ctypes.wstring_at(dll.PluginMain(input_wchar_pointer))
15
+ return result
16
 
17
 
18
  # Regular expression matching Japanese without punctuation marks:
text/symbols.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Defines the set of symbols used in text input to the model.
3
+ '''
4
+ _pad = '_'
5
+ _punctuation = ',.!?-~…'
6
+ _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
7
+ '''
8
+ # japanese_cleaners2
9
+ _pad = '_'
10
+ _punctuation = ',.!?-~…'
11
+ _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
12
+ '''
13
+
14
+ '''# korean_cleaners
15
+ _pad = '_'
16
+ _punctuation = ',.!?…~'
17
+ _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
18
+ '''
19
+
20
+ '''# chinese_cleaners
21
+ _pad = '_'
22
+ _punctuation = ',。!?—…'
23
+ _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
24
+ '''
25
+
26
+
27
+ '''# sanskrit_cleaners
28
+ _pad = '_'
29
+ _punctuation = '।'
30
+ _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
31
+ '''
32
+
33
+ '''# cjks_cleaners
34
+ _pad = '_'
35
+ _punctuation = ',.!?-~…'
36
+ _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
37
+ '''
38
+
39
+ '''# thai_cleaners
40
+ _pad = '_'
41
+ _punctuation = '.!? '
42
+ _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
43
+ '''
44
+
45
+ '''# cjke_cleaners2
46
+ _pad = '_'
47
+ _punctuation = ',.!?-~…'
48
+ _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
49
+ '''
50
+
51
+ '''# shanghainese_cleaners
52
+ _pad = '_'
53
+ _punctuation = ',.!?…'
54
+ _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
55
+ '''
56
+
57
+ '''# chinese_dialect_cleaners
58
+ _pad = '_'
59
+ _punctuation = ',.!?~…─'
60
+ _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚αᴀᴇ↑↓∅ⱼ '
61
+ '''
62
+
63
+ # Export all symbols:
64
+ symbols = [_pad] + list(_punctuation) + list(_letters)
65
+
66
+ # Special symbol ids
67
+ SPACE_ID = symbols.index(" ")
train.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+ from torch.utils.data import DataLoader
6
+ from torch.utils.tensorboard import SummaryWriter
7
+ import torch.multiprocessing as mp
8
+ import torch.distributed as dist
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ from torch.cuda.amp import autocast, GradScaler
11
+
12
+ import commons
13
+ import utils
14
+ from data_utils import (TextAudioSpeakerLoader, TextAudioSpeakerCollate,
15
+ DistributedBucketSampler)
16
+ from models import (
17
+ SynthesizerTrn,
18
+ MultiPeriodDiscriminator,
19
+ )
20
+ from losses import (generator_loss, discriminator_loss, feature_loss, kl_loss)
21
+ from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
22
+
23
+ torch.backends.cudnn.benchmark = True
24
+ global_step = 0
25
+
26
+
27
+ def main():
28
+ """Assume Single Node Multi GPUs Training Only"""
29
+ assert torch.cuda.is_available(), "CPU training is not allowed."
30
+
31
+ n_gpus = torch.cuda.device_count()
32
+ hps = utils.get_hparams()
33
+ mp.spawn(run, nprocs=n_gpus, args=(
34
+ n_gpus,
35
+ hps,
36
+ ))
37
+
38
+
39
+ def run(rank, n_gpus, hps):
40
+ global global_step
41
+ if rank == 0:
42
+ logger = utils.get_logger(hps.model_dir)
43
+ logger.info(hps)
44
+ utils.check_git_hash(hps.model_dir)
45
+ writer = SummaryWriter(log_dir=hps.model_dir)
46
+ writer_eval = SummaryWriter(
47
+ log_dir=os.path.join(hps.model_dir, "eval"))
48
+
49
+ dist.init_process_group(backend='nccl',
50
+ init_method='env://',
51
+ world_size=n_gpus,
52
+ rank=rank)
53
+ torch.manual_seed(hps.train.seed)
54
+ torch.cuda.set_device(rank)
55
+ train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
56
+ train_sampler = DistributedBucketSampler(
57
+ train_dataset,
58
+ hps.train.batch_size, [32, 300, 400, 500, 600, 700, 800, 900, 1000],
59
+ num_replicas=n_gpus,
60
+ rank=rank,
61
+ shuffle=True)
62
+ collate_fn = TextAudioSpeakerCollate()
63
+ train_loader = DataLoader(train_dataset,
64
+ num_workers=8,
65
+ shuffle=False,
66
+ pin_memory=True,
67
+ collate_fn=collate_fn,
68
+ batch_sampler=train_sampler)
69
+ if rank == 0:
70
+ eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files,
71
+ hps.data)
72
+ eval_loader = DataLoader(eval_dataset,
73
+ num_workers=8,
74
+ shuffle=False,
75
+ batch_size=hps.train.batch_size,
76
+ pin_memory=True,
77
+ drop_last=False,
78
+ collate_fn=collate_fn)
79
+
80
+ net_g = SynthesizerTrn(hps.data.num_phones,
81
+ hps.data.filter_length // 2 + 1,
82
+ hps.train.segment_size // hps.data.hop_length,
83
+ n_speakers=hps.data.n_speakers,
84
+ **hps.model).cuda(rank)
85
+ net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
86
+ optim_g = torch.optim.AdamW(net_g.parameters(),
87
+ hps.train.learning_rate,
88
+ betas=hps.train.betas,
89
+ eps=hps.train.eps)
90
+ optim_d = torch.optim.AdamW(net_d.parameters(),
91
+ hps.train.learning_rate,
92
+ betas=hps.train.betas,
93
+ eps=hps.train.eps)
94
+ net_g = DDP(net_g, device_ids=[rank])
95
+ net_d = DDP(net_d, device_ids=[rank])
96
+
97
+ try:
98
+ _, _, _, epoch_str = utils.load_checkpoint(
99
+ utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
100
+ optim_g)
101
+ _, _, _, epoch_str = utils.load_checkpoint(
102
+ utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
103
+ optim_d)
104
+ global_step = (epoch_str - 1) * len(train_loader)
105
+ except Exception as e:
106
+ epoch_str = 1
107
+ global_step = 0
108
+
109
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
110
+ optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
111
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
112
+ optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
113
+
114
+ scaler = GradScaler(enabled=hps.train.fp16_run)
115
+
116
+ for epoch in range(epoch_str, hps.train.epochs + 1):
117
+ if rank == 0:
118
+ train_and_evaluate(rank, epoch, hps, [net_g, net_d],
119
+ [optim_g, optim_d], [scheduler_g, scheduler_d],
120
+ scaler, [train_loader, eval_loader], logger,
121
+ [writer, writer_eval])
122
+ else:
123
+ train_and_evaluate(rank, epoch, hps, [net_g, net_d],
124
+ [optim_g, optim_d], [scheduler_g, scheduler_d],
125
+ scaler, [train_loader, None], None, None)
126
+ scheduler_g.step()
127
+ scheduler_d.step()
128
+
129
+
130
+ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler,
131
+ loaders, logger, writers):
132
+ net_g, net_d = nets
133
+ optim_g, optim_d = optims
134
+ scheduler_g, scheduler_d = schedulers
135
+ train_loader, eval_loader = loaders
136
+ if writers is not None:
137
+ writer, writer_eval = writers
138
+
139
+ train_loader.batch_sampler.set_epoch(epoch)
140
+ global global_step
141
+
142
+ net_g.train()
143
+ net_d.train()
144
+ for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths,
145
+ speakers) in enumerate(train_loader):
146
+ x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
147
+ rank, non_blocking=True)
148
+ spec, spec_lengths = spec.cuda(
149
+ rank, non_blocking=True), spec_lengths.cuda(rank,
150
+ non_blocking=True)
151
+ y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
152
+ rank, non_blocking=True)
153
+ speakers = speakers.cuda(rank, non_blocking=True)
154
+
155
+ with autocast(enabled=hps.train.fp16_run):
156
+ y_hat, l_length, attn, ids_slice, x_mask, z_mask, (
157
+ z, z_p, m_p, logs_p, m_q,
158
+ logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers)
159
+
160
+ mel = spec_to_mel_torch(spec, hps.data.filter_length,
161
+ hps.data.n_mel_channels,
162
+ hps.data.sampling_rate, hps.data.mel_fmin,
163
+ hps.data.mel_fmax)
164
+ y_mel = commons.slice_segments(
165
+ mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
166
+ y_hat_mel = mel_spectrogram_torch(
167
+ y_hat.squeeze(1), hps.data.filter_length,
168
+ hps.data.n_mel_channels, hps.data.sampling_rate,
169
+ hps.data.hop_length, hps.data.win_length, hps.data.mel_fmin,
170
+ hps.data.mel_fmax)
171
+
172
+ y = commons.slice_segments(y, ids_slice * hps.data.hop_length,
173
+ hps.train.segment_size) # slice
174
+
175
+ # Discriminator
176
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
177
+ with autocast(enabled=False):
178
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
179
+ y_d_hat_r, y_d_hat_g)
180
+ loss_disc_all = loss_disc
181
+ optim_d.zero_grad()
182
+ scaler.scale(loss_disc_all).backward()
183
+ scaler.unscale_(optim_d)
184
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
185
+ scaler.step(optim_d)
186
+
187
+ with autocast(enabled=hps.train.fp16_run):
188
+ # Generator
189
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
190
+ with autocast(enabled=False):
191
+ loss_dur = torch.sum(l_length.float())
192
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
193
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p,
194
+ z_mask) * hps.train.c_kl
195
+
196
+ loss_fm = feature_loss(fmap_r, fmap_g)
197
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
198
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
199
+ optim_g.zero_grad()
200
+ scaler.scale(loss_gen_all).backward()
201
+ scaler.unscale_(optim_g)
202
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
203
+ scaler.step(optim_g)
204
+ scaler.update()
205
+
206
+ if rank == 0:
207
+ if global_step % hps.train.log_interval == 0:
208
+ lr = optim_g.param_groups[0]['lr']
209
+ losses = [
210
+ loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl
211
+ ]
212
+ logger.info('Train Epoch: {} [{:.0f}%]'.format(
213
+ epoch, 100. * batch_idx / len(train_loader)))
214
+ logger.info([x.item() for x in losses] + [global_step, lr])
215
+
216
+ scalar_dict = {
217
+ "loss/g/total": loss_gen_all,
218
+ "loss/d/total": loss_disc_all,
219
+ "learning_rate": lr,
220
+ "grad_norm_d": grad_norm_d,
221
+ "grad_norm_g": grad_norm_g
222
+ }
223
+ scalar_dict.update({
224
+ "loss/g/fm": loss_fm,
225
+ "loss/g/mel": loss_mel,
226
+ "loss/g/dur": loss_dur,
227
+ "loss/g/kl": loss_kl
228
+ })
229
+
230
+ scalar_dict.update({
231
+ "loss/g/{}".format(i): v
232
+ for i, v in enumerate(losses_gen)
233
+ })
234
+ scalar_dict.update({
235
+ "loss/d_r/{}".format(i): v
236
+ for i, v in enumerate(losses_disc_r)
237
+ })
238
+ scalar_dict.update({
239
+ "loss/d_g/{}".format(i): v
240
+ for i, v in enumerate(losses_disc_g)
241
+ })
242
+ image_dict = {
243
+ "slice/mel_org":
244
+ utils.plot_spectrogram_to_numpy(
245
+ y_mel[0].data.cpu().numpy()),
246
+ "slice/mel_gen":
247
+ utils.plot_spectrogram_to_numpy(
248
+ y_hat_mel[0].data.cpu().numpy()),
249
+ "all/mel":
250
+ utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
251
+ "all/attn":
252
+ utils.plot_alignment_to_numpy(attn[0,
253
+ 0].data.cpu().numpy())
254
+ }
255
+ utils.summarize(writer=writer,
256
+ global_step=global_step,
257
+ images=image_dict,
258
+ scalars=scalar_dict)
259
+
260
+ if global_step % hps.train.eval_interval == 0:
261
+ evaluate(hps, net_g, eval_loader, writer_eval)
262
+ utils.save_checkpoint(
263
+ net_g, optim_g, hps.train.learning_rate, epoch,
264
+ os.path.join(hps.model_dir,
265
+ "G_{}.pth".format(global_step)))
266
+ utils.save_checkpoint(
267
+ net_d, optim_d, hps.train.learning_rate, epoch,
268
+ os.path.join(hps.model_dir,
269
+ "D_{}.pth".format(global_step)))
270
+ global_step += 1
271
+
272
+ if rank == 0:
273
+ logger.info('====> Epoch: {}'.format(epoch))
274
+
275
+
276
+ def evaluate(hps, generator, eval_loader, writer_eval):
277
+ generator.eval()
278
+ with torch.no_grad():
279
+ for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths,
280
+ speakers) in enumerate(eval_loader):
281
+ x, x_lengths = x.cuda(0), x_lengths.cuda(0)
282
+ spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0)
283
+ y, y_lengths = y.cuda(0), y_lengths.cuda(0)
284
+ speakers = speakers.cuda(0)
285
+
286
+ # remove else
287
+ x = x[:1]
288
+ x_lengths = x_lengths[:1]
289
+ spec = spec[:1]
290
+ spec_lengths = spec_lengths[:1]
291
+ y = y[:1]
292
+ y_lengths = y_lengths[:1]
293
+ speakers = speakers[:1]
294
+ break
295
+ y_hat, attn, mask, *_ = generator.module.infer(x,
296
+ x_lengths,
297
+ speakers,
298
+ max_len=1000)
299
+ y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
300
+
301
+ mel = spec_to_mel_torch(spec, hps.data.filter_length,
302
+ hps.data.n_mel_channels,
303
+ hps.data.sampling_rate, hps.data.mel_fmin,
304
+ hps.data.mel_fmax)
305
+ y_hat_mel = mel_spectrogram_torch(
306
+ y_hat.squeeze(1).float(), hps.data.filter_length,
307
+ hps.data.n_mel_channels, hps.data.sampling_rate,
308
+ hps.data.hop_length, hps.data.win_length, hps.data.mel_fmin,
309
+ hps.data.mel_fmax)
310
+ image_dict = {
311
+ "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
312
+ }
313
+ audio_dict = {"gen/audio": y_hat[0, :, :y_hat_lengths[0]]}
314
+ if global_step == 0:
315
+ image_dict.update(
316
+ {"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
317
+ audio_dict.update({"gt/audio": y[0, :, :y_lengths[0]]})
318
+
319
+ utils.summarize(writer=writer_eval,
320
+ global_step=global_step,
321
+ images=image_dict,
322
+ audios=audio_dict,
323
+ audio_sampling_rate=hps.data.sampling_rate)
324
+ generator.train()
325
+
326
+
327
+ if __name__ == "__main__":
328
+ main()
transforms.py CHANGED
@@ -1,67 +1,60 @@
 
1
  import torch
2
  from torch.nn import functional as F
3
 
4
- import numpy as np
5
-
6
-
7
  DEFAULT_MIN_BIN_WIDTH = 1e-3
8
  DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
  DEFAULT_MIN_DERIVATIVE = 1e-3
10
 
11
 
12
- def piecewise_rational_quadratic_transform(inputs,
13
- unnormalized_widths,
14
- unnormalized_heights,
15
- unnormalized_derivatives,
16
- inverse=False,
17
- tails=None,
18
- tail_bound=1.,
19
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
20
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
21
- min_derivative=DEFAULT_MIN_DERIVATIVE):
 
22
 
23
  if tails is None:
24
  spline_fn = rational_quadratic_spline
25
  spline_kwargs = {}
26
  else:
27
  spline_fn = unconstrained_rational_quadratic_spline
28
- spline_kwargs = {
29
- 'tails': tails,
30
- 'tail_bound': tail_bound
31
- }
32
 
33
  outputs, logabsdet = spline_fn(
34
- inputs=inputs,
35
- unnormalized_widths=unnormalized_widths,
36
- unnormalized_heights=unnormalized_heights,
37
- unnormalized_derivatives=unnormalized_derivatives,
38
- inverse=inverse,
39
- min_bin_width=min_bin_width,
40
- min_bin_height=min_bin_height,
41
- min_derivative=min_derivative,
42
- **spline_kwargs
43
- )
44
  return outputs, logabsdet
45
 
46
 
47
  def searchsorted(bin_locations, inputs, eps=1e-6):
48
- bin_locations[..., -1] += eps
49
- return torch.sum(
50
- inputs[..., None] >= bin_locations,
51
- dim=-1
52
- ) - 1
53
-
54
-
55
- def unconstrained_rational_quadratic_spline(inputs,
56
- unnormalized_widths,
57
- unnormalized_heights,
58
- unnormalized_derivatives,
59
- inverse=False,
60
- tails='linear',
61
- tail_bound=1.,
62
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
63
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
64
- min_derivative=DEFAULT_MIN_DERIVATIVE):
65
  inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66
  outside_interval_mask = ~inside_interval_mask
67
 
@@ -72,33 +65,41 @@ def unconstrained_rational_quadratic_spline(inputs,
72
  unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73
  constant = np.log(np.exp(1 - min_derivative) - 1)
74
  unnormalized_derivatives[..., 0] = constant
75
- unnormalized_derivatives[..., -1] = constant
76
 
77
  outputs[outside_interval_mask] = inputs[outside_interval_mask]
78
  logabsdet[outside_interval_mask] = 0
79
  else:
80
  raise RuntimeError('{} tails are not implemented.'.format(tails))
81
 
82
- outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83
- inputs=inputs[inside_interval_mask],
84
- unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
- unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
- unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
- inverse=inverse,
88
- left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
89
- min_bin_width=min_bin_width,
90
- min_bin_height=min_bin_height,
91
- min_derivative=min_derivative
92
- )
 
 
 
 
93
 
94
  return outputs, logabsdet
95
 
 
96
  def rational_quadratic_spline(inputs,
97
  unnormalized_widths,
98
  unnormalized_heights,
99
  unnormalized_derivatives,
100
  inverse=False,
101
- left=0., right=1., bottom=0., top=1.,
 
 
 
102
  min_bin_width=DEFAULT_MIN_BIN_WIDTH,
103
  min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
104
  min_derivative=DEFAULT_MIN_DERIVATIVE):
@@ -118,7 +119,7 @@ def rational_quadratic_spline(inputs,
118
  cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
119
  cumwidths = (right - left) * cumwidths + left
120
  cumwidths[..., 0] = left
121
- cumwidths[..., -1] = right
122
  widths = cumwidths[..., 1:] - cumwidths[..., :-1]
123
 
124
  derivatives = min_derivative + F.softplus(unnormalized_derivatives)
@@ -129,7 +130,7 @@ def rational_quadratic_spline(inputs,
129
  cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
130
  cumheights = (top - bottom) * cumheights + bottom
131
  cumheights[..., 0] = bottom
132
- cumheights[..., -1] = top
133
  heights = cumheights[..., 1:] - cumheights[..., :-1]
134
 
135
  if inverse:
@@ -145,20 +146,20 @@ def rational_quadratic_spline(inputs,
145
  input_delta = delta.gather(-1, bin_idx)[..., 0]
146
 
147
  input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
148
- input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
 
149
 
150
  input_heights = heights.gather(-1, bin_idx)[..., 0]
151
 
152
  if inverse:
153
- a = (((inputs - input_cumheights) * (input_derivatives
154
- + input_derivatives_plus_one
155
- - 2 * input_delta)
156
- + input_heights * (input_delta - input_derivatives)))
157
- b = (input_heights * input_derivatives
158
- - (inputs - input_cumheights) * (input_derivatives
159
- + input_derivatives_plus_one
160
- - 2 * input_delta))
161
- c = - input_delta * (inputs - input_cumheights)
162
 
163
  discriminant = b.pow(2) - 4 * a * c
164
  assert (discriminant >= 0).all()
@@ -167,27 +168,33 @@ def rational_quadratic_spline(inputs,
167
  outputs = root * input_bin_widths + input_cumwidths
168
 
169
  theta_one_minus_theta = root * (1 - root)
170
- denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
171
- * theta_one_minus_theta)
172
- derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
173
- + 2 * input_delta * theta_one_minus_theta
174
- + input_derivatives * (1 - root).pow(2))
175
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
 
 
 
176
 
177
  return outputs, -logabsdet
178
  else:
179
  theta = (inputs - input_cumwidths) / input_bin_widths
180
  theta_one_minus_theta = theta * (1 - theta)
181
 
182
- numerator = input_heights * (input_delta * theta.pow(2)
183
- + input_derivatives * theta_one_minus_theta)
184
- denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185
- * theta_one_minus_theta)
 
186
  outputs = input_cumheights + numerator / denominator
187
 
188
- derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
189
- + 2 * input_delta * theta_one_minus_theta
190
- + input_derivatives * (1 - theta).pow(2))
191
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
 
 
192
 
193
  return outputs, logabsdet
 
1
+ import numpy as np
2
  import torch
3
  from torch.nn import functional as F
4
 
 
 
 
5
  DEFAULT_MIN_BIN_WIDTH = 1e-3
6
  DEFAULT_MIN_BIN_HEIGHT = 1e-3
7
  DEFAULT_MIN_DERIVATIVE = 1e-3
8
 
9
 
10
+ def piecewise_rational_quadratic_transform(
11
+ inputs,
12
+ unnormalized_widths,
13
+ unnormalized_heights,
14
+ unnormalized_derivatives,
15
+ inverse=False,
16
+ tails=None,
17
+ tail_bound=1.,
18
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
19
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
20
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
21
 
22
  if tails is None:
23
  spline_fn = rational_quadratic_spline
24
  spline_kwargs = {}
25
  else:
26
  spline_fn = unconstrained_rational_quadratic_spline
27
+ spline_kwargs = {'tails': tails, 'tail_bound': tail_bound}
 
 
 
28
 
29
  outputs, logabsdet = spline_fn(
30
+ inputs=inputs,
31
+ unnormalized_widths=unnormalized_widths,
32
+ unnormalized_heights=unnormalized_heights,
33
+ unnormalized_derivatives=unnormalized_derivatives,
34
+ inverse=inverse,
35
+ min_bin_width=min_bin_width,
36
+ min_bin_height=min_bin_height,
37
+ min_derivative=min_derivative,
38
+ **spline_kwargs)
 
39
  return outputs, logabsdet
40
 
41
 
42
  def searchsorted(bin_locations, inputs, eps=1e-6):
43
+ bin_locations[..., bin_locations.size(-1) - 1] += eps
44
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
45
+
46
+
47
+ def unconstrained_rational_quadratic_spline(
48
+ inputs,
49
+ unnormalized_widths,
50
+ unnormalized_heights,
51
+ unnormalized_derivatives,
52
+ inverse=False,
53
+ tails='linear',
54
+ tail_bound=1.,
55
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
56
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
57
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
 
 
58
  inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
59
  outside_interval_mask = ~inside_interval_mask
60
 
 
65
  unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
66
  constant = np.log(np.exp(1 - min_derivative) - 1)
67
  unnormalized_derivatives[..., 0] = constant
68
+ unnormalized_derivatives[..., unnormalized_derivatives.size(-1) - 1] = constant
69
 
70
  outputs[outside_interval_mask] = inputs[outside_interval_mask]
71
  logabsdet[outside_interval_mask] = 0
72
  else:
73
  raise RuntimeError('{} tails are not implemented.'.format(tails))
74
 
75
+ outputs[inside_interval_mask], logabsdet[
76
+ inside_interval_mask] = rational_quadratic_spline(
77
+ inputs=inputs[inside_interval_mask],
78
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
79
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
80
+ unnormalized_derivatives=unnormalized_derivatives[
81
+ inside_interval_mask, :],
82
+ inverse=inverse,
83
+ left=-tail_bound,
84
+ right=tail_bound,
85
+ bottom=-tail_bound,
86
+ top=tail_bound,
87
+ min_bin_width=min_bin_width,
88
+ min_bin_height=min_bin_height,
89
+ min_derivative=min_derivative)
90
 
91
  return outputs, logabsdet
92
 
93
+
94
  def rational_quadratic_spline(inputs,
95
  unnormalized_widths,
96
  unnormalized_heights,
97
  unnormalized_derivatives,
98
  inverse=False,
99
+ left=0.,
100
+ right=1.,
101
+ bottom=0.,
102
+ top=1.,
103
  min_bin_width=DEFAULT_MIN_BIN_WIDTH,
104
  min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
105
  min_derivative=DEFAULT_MIN_DERIVATIVE):
 
119
  cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
120
  cumwidths = (right - left) * cumwidths + left
121
  cumwidths[..., 0] = left
122
+ cumwidths[..., cumwidths.size(-1) - 1] = right
123
  widths = cumwidths[..., 1:] - cumwidths[..., :-1]
124
 
125
  derivatives = min_derivative + F.softplus(unnormalized_derivatives)
 
130
  cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
131
  cumheights = (top - bottom) * cumheights + bottom
132
  cumheights[..., 0] = bottom
133
+ cumheights[..., cumheights.size(-1) - 1] = top
134
  heights = cumheights[..., 1:] - cumheights[..., :-1]
135
 
136
  if inverse:
 
146
  input_delta = delta.gather(-1, bin_idx)[..., 0]
147
 
148
  input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
149
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[...,
150
+ 0]
151
 
152
  input_heights = heights.gather(-1, bin_idx)[..., 0]
153
 
154
  if inverse:
155
+ a = (
156
+ ((inputs - input_cumheights) *
157
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
158
+ + input_heights * (input_delta - input_derivatives)))
159
+ b = (
160
+ input_heights * input_derivatives - (inputs - input_cumheights) *
161
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta))
162
+ c = -input_delta * (inputs - input_cumheights)
 
163
 
164
  discriminant = b.pow(2) - 4 * a * c
165
  assert (discriminant >= 0).all()
 
168
  outputs = root * input_bin_widths + input_cumwidths
169
 
170
  theta_one_minus_theta = root * (1 - root)
171
+ denominator = input_delta + (
172
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
173
+ * theta_one_minus_theta)
174
+ derivative_numerator = input_delta.pow(2) * (
175
+ input_derivatives_plus_one * root.pow(2) +
176
+ 2 * input_delta * theta_one_minus_theta + input_derivatives *
177
+ (1 - root).pow(2))
178
+ logabsdet = torch.log(
179
+ derivative_numerator) - 2 * torch.log(denominator)
180
 
181
  return outputs, -logabsdet
182
  else:
183
  theta = (inputs - input_cumwidths) / input_bin_widths
184
  theta_one_minus_theta = theta * (1 - theta)
185
 
186
+ numerator = input_heights * (input_delta * theta.pow(2) +
187
+ input_derivatives * theta_one_minus_theta)
188
+ denominator = input_delta + (
189
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
190
+ * theta_one_minus_theta)
191
  outputs = input_cumheights + numerator / denominator
192
 
193
+ derivative_numerator = input_delta.pow(2) * (
194
+ input_derivatives_plus_one * theta.pow(2) +
195
+ 2 * input_delta * theta_one_minus_theta + input_derivatives *
196
+ (1 - theta).pow(2))
197
+ logabsdet = torch.log(
198
+ derivative_numerator) - 2 * torch.log(denominator)
199
 
200
  return outputs, logabsdet
utils.py CHANGED
@@ -1,8 +1,278 @@
 
 
 
1
  import logging
2
- from json import loads
3
- from torch import load, FloatTensor
4
- from numpy import float32
5
- import librosa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class HParams():
@@ -35,42 +305,3 @@ class HParams():
35
 
36
  def __repr__(self):
37
  return self.__dict__.__repr__()
38
-
39
-
40
- def load_checkpoint(checkpoint_path, model):
41
- checkpoint_dict = load(checkpoint_path, map_location='cpu')
42
- iteration = checkpoint_dict['iteration']
43
- saved_state_dict = checkpoint_dict['model']
44
- if hasattr(model, 'module'):
45
- state_dict = model.module.state_dict()
46
- else:
47
- state_dict = model.state_dict()
48
- new_state_dict = {}
49
- for k, v in state_dict.items():
50
- try:
51
- new_state_dict[k] = saved_state_dict[k]
52
- except:
53
- logging.info("%s is not in the checkpoint" % k)
54
- new_state_dict[k] = v
55
- pass
56
- if hasattr(model, 'module'):
57
- model.module.load_state_dict(new_state_dict)
58
- else:
59
- model.load_state_dict(new_state_dict)
60
- logging.info("Loaded checkpoint '{}' (iteration {})".format(
61
- checkpoint_path, iteration))
62
- return
63
-
64
-
65
- def get_hparams_from_file(config_path):
66
- with open(config_path, "r") as f:
67
- data = f.read()
68
- config = loads(data)
69
-
70
- hparams = HParams(**config)
71
- return hparams
72
-
73
-
74
- def load_audio_to_torch(full_path, target_sampling_rate):
75
- audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True)
76
- return FloatTensor(audio.astype(float32))
 
1
+ import argparse
2
+ import glob
3
+ import json
4
  import logging
5
+ import os
6
+ import subprocess
7
+ import sys
8
+
9
+ import numpy as np
10
+ from scipy.io.wavfile import read
11
+ import torch
12
+
13
+ MATPLOTLIB_FLAG = False
14
+
15
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
16
+ logger = logging
17
+
18
+
19
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
20
+ assert os.path.isfile(checkpoint_path)
21
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
22
+ iteration = checkpoint_dict['iteration']
23
+ learning_rate = checkpoint_dict['learning_rate']
24
+ if optimizer is not None:
25
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
26
+ saved_state_dict = checkpoint_dict['model']
27
+ if hasattr(model, 'module'):
28
+ state_dict = model.module.state_dict()
29
+ else:
30
+ state_dict = model.state_dict()
31
+ new_state_dict = {}
32
+ for k, v in state_dict.items():
33
+ try:
34
+ new_state_dict[k] = saved_state_dict[k]
35
+ except Exception as e:
36
+ logger.info("%s is not in the checkpoint" % k)
37
+ new_state_dict[k] = v
38
+ if hasattr(model, 'module'):
39
+ model.module.load_state_dict(new_state_dict)
40
+ else:
41
+ model.load_state_dict(new_state_dict)
42
+ logger.info("Loaded checkpoint '{}' (iteration {})".format(
43
+ checkpoint_path, iteration))
44
+ return model, optimizer, learning_rate, iteration
45
+
46
+
47
+ def save_checkpoint(model, optimizer, learning_rate, iteration,
48
+ checkpoint_path):
49
+ logger.info(
50
+ "Saving model and optimizer state at iteration {} to {}".format(
51
+ iteration, checkpoint_path))
52
+ if hasattr(model, 'module'):
53
+ state_dict = model.module.state_dict()
54
+ else:
55
+ state_dict = model.state_dict()
56
+ torch.save(
57
+ {
58
+ 'model': state_dict,
59
+ 'iteration': iteration,
60
+ 'optimizer': optimizer.state_dict(),
61
+ 'learning_rate': learning_rate
62
+ }, checkpoint_path)
63
+
64
+
65
+ def summarize(
66
+ writer,
67
+ global_step,
68
+ scalars={}, # noqa
69
+ histograms={}, # noqa
70
+ images={}, # noqa
71
+ audios={}, # noqa
72
+ audio_sampling_rate=22050):
73
+ for k, v in scalars.items():
74
+ writer.add_scalar(k, v, global_step)
75
+ for k, v in histograms.items():
76
+ writer.add_histogram(k, v, global_step)
77
+ for k, v in images.items():
78
+ writer.add_image(k, v, global_step, dataformats='HWC')
79
+ for k, v in audios.items():
80
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
81
+
82
+
83
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
84
+ f_list = glob.glob(os.path.join(dir_path, regex))
85
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
86
+ x = f_list[-1]
87
+ print(x)
88
+ return x
89
+
90
+
91
+ def plot_spectrogram_to_numpy(spectrogram):
92
+ global MATPLOTLIB_FLAG
93
+ if not MATPLOTLIB_FLAG:
94
+ import matplotlib
95
+ matplotlib.use("Agg")
96
+ MATPLOTLIB_FLAG = True
97
+ mpl_logger = logging.getLogger('matplotlib')
98
+ mpl_logger.setLevel(logging.WARNING)
99
+ import matplotlib.pylab as plt
100
+ import numpy as np
101
+
102
+ fig, ax = plt.subplots(figsize=(10, 2))
103
+ im = ax.imshow(spectrogram,
104
+ aspect="auto",
105
+ origin="lower",
106
+ interpolation='none')
107
+ plt.colorbar(im, ax=ax)
108
+ plt.xlabel("Frames")
109
+ plt.ylabel("Channels")
110
+ plt.tight_layout()
111
+
112
+ fig.canvas.draw()
113
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
114
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
115
+ plt.close()
116
+ return data
117
+
118
+
119
+ def plot_alignment_to_numpy(alignment, info=None):
120
+ global MATPLOTLIB_FLAG
121
+ if not MATPLOTLIB_FLAG:
122
+ import matplotlib
123
+ matplotlib.use("Agg")
124
+ MATPLOTLIB_FLAG = True
125
+ mpl_logger = logging.getLogger('matplotlib')
126
+ mpl_logger.setLevel(logging.WARNING)
127
+ import matplotlib.pylab as plt
128
+ import numpy as np
129
+
130
+ fig, ax = plt.subplots(figsize=(6, 4))
131
+ im = ax.imshow(alignment.transpose(),
132
+ aspect='auto',
133
+ origin='lower',
134
+ interpolation='none')
135
+ fig.colorbar(im, ax=ax)
136
+ xlabel = 'Decoder timestep'
137
+ if info is not None:
138
+ xlabel += '\n\n' + info
139
+ plt.xlabel(xlabel)
140
+ plt.ylabel('Encoder timestep')
141
+ plt.tight_layout()
142
+
143
+ fig.canvas.draw()
144
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
145
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
146
+ plt.close()
147
+ return data
148
+
149
+
150
+ def load_wav_to_torch(full_path):
151
+ sampling_rate, data = read(full_path)
152
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
153
+
154
+
155
+ def load_filepaths_and_text(filename, split="|"):
156
+ with open(filename, encoding='utf-8') as f:
157
+ filepaths_and_text = [line.strip().split(split) for line in f]
158
+ return filepaths_and_text
159
+
160
+
161
+ def get_hparams(init=True):
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument('-c',
164
+ '--config',
165
+ type=str,
166
+ default="./configs/base.json",
167
+ help='JSON file for configuration')
168
+ parser.add_argument('-m',
169
+ '--model',
170
+ type=str,
171
+ required=True,
172
+ help='Model name')
173
+ parser.add_argument('--train_data',
174
+ type=str,
175
+ required=True,
176
+ help='train data')
177
+ parser.add_argument('--val_data', type=str, required=True, help='val data')
178
+ parser.add_argument('--phone_table',
179
+ type=str,
180
+ required=True,
181
+ help='phone table')
182
+ parser.add_argument('--speaker_table',
183
+ type=str,
184
+ default=None,
185
+ help='speaker table, required for multiple speakers')
186
+
187
+ args = parser.parse_args()
188
+ model_dir = args.model
189
+
190
+ if not os.path.exists(model_dir):
191
+ os.makedirs(model_dir)
192
+
193
+ config_path = args.config
194
+ config_save_path = os.path.join(model_dir, "config.json")
195
+ if init:
196
+ with open(config_path, "r", encoding='utf8') as f:
197
+ data = f.read()
198
+ with open(config_save_path, "w", encoding='utf8') as f:
199
+ f.write(data)
200
+ else:
201
+ with open(config_save_path, "r", encoding='utf8') as f:
202
+ data = f.read()
203
+ config = json.loads(data)
204
+ config['data']['training_files'] = args.train_data
205
+ config['data']['validation_files'] = args.val_data
206
+ config['data']['phone_table'] = args.phone_table
207
+ # 0 is kept for blank
208
+ config['data']['num_phones'] = len(open(args.phone_table).readlines()) + 1
209
+ if args.speaker_table is not None:
210
+ config['data']['speaker_table'] = args.speaker_table
211
+ # 0 is kept for unknown speaker
212
+ config['data']['n_speakers'] = len(
213
+ open(args.speaker_table).readlines()) + 1
214
+ else:
215
+ config['data']['n_speakers'] = 0
216
+
217
+ hparams = HParams(**config)
218
+ hparams.model_dir = model_dir
219
+ return hparams
220
+
221
+
222
+ def get_hparams_from_dir(model_dir):
223
+ config_save_path = os.path.join(model_dir, "config.json")
224
+ with open(config_save_path, "r") as f:
225
+ data = f.read()
226
+ config = json.loads(data)
227
+
228
+ hparams = HParams(**config)
229
+ hparams.model_dir = model_dir
230
+ return hparams
231
+
232
+
233
+ def get_hparams_from_file(config_path):
234
+ with open(config_path, "r") as f:
235
+ data = f.read()
236
+ config = json.loads(data)
237
+
238
+ hparams = HParams(**config)
239
+ return hparams
240
+
241
+
242
+ def check_git_hash(model_dir):
243
+ source_dir = os.path.dirname(os.path.realpath(__file__))
244
+ if not os.path.exists(os.path.join(source_dir, ".git")):
245
+ logger.warn('''{} is not a git repository, therefore hash value
246
+ comparison will be ignored.'''.format(source_dir))
247
+ return
248
+
249
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
250
+
251
+ path = os.path.join(model_dir, "githash")
252
+ if os.path.exists(path):
253
+ saved_hash = open(path).read()
254
+ if saved_hash != cur_hash:
255
+ logger.warn(
256
+ "git hash values are different. {}(saved) != {}(current)".
257
+ format(saved_hash[:8], cur_hash[:8]))
258
+ else:
259
+ open(path, "w").write(cur_hash)
260
+
261
+
262
+ def get_logger(model_dir, filename="train.log"):
263
+ global logger
264
+ logger = logging.getLogger(os.path.basename(model_dir))
265
+ logger.setLevel(logging.INFO)
266
+
267
+ formatter = logging.Formatter(
268
+ "%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
269
+ if not os.path.exists(model_dir):
270
+ os.makedirs(model_dir)
271
+ h = logging.FileHandler(os.path.join(model_dir, filename))
272
+ h.setLevel(logging.INFO)
273
+ h.setFormatter(formatter)
274
+ logger.addHandler(h)
275
+ return logger
276
 
277
 
278
  class HParams():
 
305
 
306
  def __repr__(self):
307
  return self.__dict__.__repr__()