Kamtera commited on
Commit
93561d2
1 Parent(s): 0402ed0

Create train_glowtts.py

Browse files
Files changed (1) hide show
  1. train_glowtts.py +117 -0
train_glowtts.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Trainer: Where the ✨️ happens.
4
+ # TrainingArgs: Defines the set of arguments of the Trainer.
5
+ from trainer import Trainer, TrainerArgs
6
+
7
+ # GlowTTSConfig: all model related values for training, validating and testing.
8
+ from TTS.tts.configs.glow_tts_config import GlowTTSConfig
9
+
10
+ # BaseDatasetConfig: defines name, formatter and path of the dataset.
11
+ from TTS.tts.configs.shared_configs import BaseDatasetConfig , CharactersConfig
12
+ from TTS.config.shared_configs import BaseAudioConfig
13
+ from TTS.tts.datasets import load_tts_samples
14
+ from TTS.tts.models.glow_tts import GlowTTS
15
+ from TTS.tts.utils.text.tokenizer import TTSTokenizer
16
+ from TTS.utils.audio import AudioProcessor
17
+
18
+ # we use the same path as this script as our training folder.
19
+ output_path = os.path.dirname(os.path.abspath(__file__))
20
+
21
+ # DEFINE DATASET CONFIG
22
+ # Set LJSpeech as our target dataset and define its path.
23
+ # You can also use a simple Dict to define the dataset and pass it to your custom formatter.
24
+
25
+
26
+ dataset_config = BaseDatasetConfig(
27
+ formatter="mozilla", meta_file_train="metadata.csv", path="/kaggle/input/persian-tts-dataset"
28
+ )
29
+
30
+ audio_config = BaseAudioConfig(
31
+ sample_rate=22050,
32
+ do_trim_silence=True,
33
+ resample=False
34
+
35
+ )
36
+
37
+ character_config=CharactersConfig(
38
+ characters='ءابتثجحخدذرزسشصضطظعغفقلمنهويِپچژکگیآأؤإئًَُّ',
39
+ punctuations='!(),-.:;? ̠،؛؟‌<>',
40
+ phonemes='ˈˌːˑpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟaegiouwyɪʊ̩æɑɔəɚɛɝɨ̃ʉʌʍ0123456789"#$%*+/=ABCDEFGHIJKLMNOPRSTUVWXYZ[]^_{}',
41
+ pad="<PAD>",
42
+ eos="<EOS>",
43
+ bos="<BOS>",
44
+ blank="<BLNK>",
45
+ characters_class="TTS.tts.utils.text.characters.IPAPhonemes",
46
+ )
47
+ # INITIALIZE THE TRAINING CONFIGURATION
48
+ # Configure the model. Every config class inherits the BaseTTSConfig.
49
+ config = GlowTTSConfig(
50
+ batch_size=8,#batch_size=32,
51
+ eval_batch_size=4,#eval_batch_size=16,
52
+ num_loader_workers=0,
53
+ num_eval_loader_workers=0,
54
+ run_eval=True,
55
+ test_delay_epochs=-1,
56
+ epochs=1000,
57
+ save_step=1000,
58
+ text_cleaner="basic_cleaners",
59
+ use_phonemes=True,
60
+ phoneme_language="fa",
61
+ phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
62
+ characters=character_config,
63
+ print_step=25,
64
+ print_eval=False,
65
+ mixed_precision=True,
66
+ output_path=output_path,
67
+ datasets=[dataset_config],
68
+ audio=audio_config,
69
+ test_sentences=[
70
+ "سلطان محمود در زمستانی سخت به طلخک گفت که: با این جامه ی یک لا در این سرما چه می کنی ",
71
+ "مردی نزد بقالی آمد و گفت پیاز هم ده تا دهان بدان خو شبوی سازم.",
72
+ "از مال خود پاره ای گوشت بستان و زیره بایی معطّر بساز",
73
+ "یک بار هم از جهنم بگویید.",
74
+ "یکی اسبی به عاریت خواست"
75
+ ],
76
+
77
+
78
+ )
79
+
80
+ # INITIALIZE THE AUDIO PROCESSOR
81
+ # Audio processor is used for feature extraction and audio I/O.
82
+ # It mainly serves to the dataloader and the training loggers.
83
+ ap = AudioProcessor.init_from_config(config)
84
+
85
+ # INITIALIZE THE TOKENIZER
86
+ # Tokenizer is used to convert text to sequences of token IDs.
87
+ # If characters are not defined in the config, default characters are passed to the config
88
+ tokenizer, config = TTSTokenizer.init_from_config(config)
89
+
90
+ # LOAD DATA SAMPLES
91
+ # Each sample is a list of ```[text, audio_file_path, speaker_name]```
92
+ # You can define your custom sample loader returning the list of samples.
93
+ # Or define your custom formatter and pass it to the `load_tts_samples`.
94
+ # Check `TTS.tts.datasets.load_tts_samples` for more details.
95
+ train_samples, eval_samples = load_tts_samples(
96
+ dataset_config,
97
+ eval_split=True,
98
+ eval_split_max_size=config.eval_split_max_size,
99
+ eval_split_size=config.eval_split_size,
100
+ #formatter=changizer
101
+ )
102
+
103
+ # INITIALIZE THE MODEL
104
+ # Models take a config object and a speaker manager as input
105
+ # Config defines the details of the model like the number of layers, the size of the embedding, etc.
106
+ # Speaker manager is used by multi-speaker models.
107
+ model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
108
+
109
+ # INITIALIZE THE TRAINER
110
+ # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
111
+ # distributed training, etc.
112
+ trainer = Trainer(
113
+ TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
114
+ )
115
+
116
+ # AND... 3,2,1... 🚀
117
+ trainer.fit()