File size: 5,720 Bytes
96ee597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""Collators for T2S and S2A.

Copyright PolyAI Limited.
"""
from pathlib import Path
from typing import List, Tuple, Union

import numpy as np
import torch

from utils.symbol_table import SymbolTable


class GlobalCollater:
    def __init__(self, n_codes, n_semantic_codes):
        self.n_codes = n_codes
        self.sem_mask_id = n_semantic_codes

    def collate(self, batch):
        output = {
            'speaker': [],
            'tts_quantize_input': [],
            'tts_quantize_output': [],
            'quantize_mask': [],
            'f_names': [],
            'semantic_tokens': [],
            'quantization_lengths': [],
        }
        # Get the max length of everything
        max_len_q = 0
        for _, q_s, q_e, _, _ in batch:
            if len(q_s) > max_len_q:
                max_len_q = len(q_s)

            output['quantization_lengths'].append(len(q_s))

        # Pad each element, create mask
        for spkr, qs, qe, itm_name, s_tokens in batch:          
            # Deal with quantizations
            q_mask = np.array(
                [False] * len(qs) + [True] * (max_len_q - len(qs)))
            qs = np.pad(
                qs, 
                [[0, max_len_q-len(qs)], [0, 0]], 
                constant_values=self.n_codes
            )
            qe = np.pad(
                qe, 
                [[0, max_len_q-len(qe)], [0, 0]], 
                constant_values=self.n_codes
            )

            # Deal with semantics
            s_tokens = s_tokens.flatten()
            s_tokens = np.pad(
                s_tokens, 
                (0, max_len_q-len(s_tokens)), 
                constant_values=self.sem_mask_id
            )

            # Speaker padding
            spkr = np.concatenate(
                (spkr, np.zeros((max_len_q - len(spkr), 512))))  

            # Aggregate
            output['speaker'].append(spkr)
            output['tts_quantize_input'].append(qs)
            output['tts_quantize_output'].append(qe)
            output['quantize_mask'].append(q_mask)
            output['f_names'].append(itm_name)
            output["semantic_tokens"].append(s_tokens)

        for k in output.keys():
            if k == 'f_names':
                continue
            output[k] = np.array(output[k])
            if 'mask' in k:
                output[k] = torch.BoolTensor(output[k])
            elif k in [
                'tts_quantize_input', 'tts_quantize_output',
                'semantic_tokens', 'quantization_lengths'
            ]:
                output[k] = torch.LongTensor(output[k])
            else:
                output[k] = torch.FloatTensor(output[k])
        return output


class TextTokenCollater:
    def __init__(
        self,
        text_tokens: List[str],
        add_eos: bool = True,
        add_bos: bool = True,
        pad_symbol: str = "<pad>",
        bos_symbol: str = "<bos>",
        eos_symbol: str = "<eos>",
        spkr_1_symbol: str = "spkr_1",
        spkr_2_symbol: str = "spkr_2",
    ):
        self.pad_symbol = pad_symbol

        self.add_eos = add_eos
        self.add_bos = add_bos

        self.bos_symbol = bos_symbol
        self.eos_symbol = eos_symbol
        self.spkr_1_symbol = spkr_1_symbol
        self.spkr_2_symbol = spkr_2_symbol

        unique_tokens = (
            [pad_symbol]
            + ([bos_symbol] if add_bos else [])
            + ([eos_symbol] if add_eos else [])
            + ([spkr_1_symbol])
            + ([spkr_2_symbol])
            + sorted(text_tokens)
        )

        self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
        self.idx2token = [token for token in unique_tokens]

    def __call__(
        self, texts: List[str], texts_2: Union[None, List[str]] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        tokens_seqs = [[p for p in text] for text in texts]

        if texts_2 is None:
            seqs = [
                ([self.bos_symbol] if self.add_bos else [])
                + [self.spkr_1_symbol]
                + list(seq)
                + ([self.eos_symbol] if self.add_eos else [])
                for seq in tokens_seqs
            ]
        else:
            tokens_seqs_2 = [[p for p in text] for text in texts_2]
            seqs = [
                ([self.bos_symbol] if self.add_bos else [])
                + [self.spkr_1_symbol]
                + list(seq)
                + ([self.spkr_2_symbol])
                + list(seq_2)
                + ([self.eos_symbol] if self.add_eos else [])
                for seq, seq_2 in zip(tokens_seqs, tokens_seqs_2)
            ]

        tokens_batch = torch.from_numpy(
            np.array(
                [[self.token2idx[token] for token in seq] for seq in seqs],
                dtype=np.int64,
            )
        )

        return tokens_batch


def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
    text_tokens_path = Path(text_tokens_file)
    unique_tokens = SymbolTable.from_file(text_tokens_path)
    collater = TextTokenCollater(
        unique_tokens.symbols, add_bos=True, add_eos=True
    )
    return collater


def get_text_semantic_token_collater(
        text_tokens_file: str, n_semantic_tokens=1024) -> TextTokenCollater:
    text_tokens_path = Path(text_tokens_file)
    unique_tokens = SymbolTable.from_file(text_tokens_path)
    for semantic_idx in range(n_semantic_tokens):
        unique_tokens.add(str(semantic_idx))

    collater = TextTokenCollater(
        unique_tokens.symbols, add_bos=True, add_eos=True
    )
    return collater


if __name__ == '__main__':
    text_tokens_file = 'ckpt/unique_text_tokens.k2symbols'
    collater = get_text_semantic_token_collater(text_tokens_file)