"""Speech tokenizer class. Copyright PolyAI Limited. """ import logging import os import numpy as np import torch import torchaudio from speechtokenizer import SpeechTokenizer as ST from modules.tokenizer import BaseTokenizer class SpeechTokenizer(BaseTokenizer): def __init__(self, config_path: str, ckpt_path: str): self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.model = ST.load_from_checkpoint( config_path, ckpt_path).to(self.device) self.model.eval() def encode_file( self, folder_path: str, destination_folder: str, filename: str): dest_path = os.path.join( destination_folder, "semantic", os.path.splitext(filename)[0] + ".npy" ) dest_path2 = os.path.join( destination_folder, "acoustic", os.path.splitext(filename)[0] + ".npy" ) if os.path.exists(dest_path) and os.path.exists(dest_path2): pass else: self._create_subfolders(destination_folder=destination_folder) file_path = os.path.join(folder_path, filename) wav_info = torchaudio.info(file_path) wav_dur_sec = wav_info.num_frames / wav_info.sample_rate if wav_dur_sec > 60: logging.info( f"Skipping {file_path} is too long: {wav_dur_sec:.3f} sec," "can cause CUDA OOM" ) return wav, sr = torchaudio.load(file_path) if sr != self.model.sample_rate: logging.warning( "Wav sample rate %(wav_sr)s does not match the model" "sampling rate %(model_sr)s. Resampling audio", {"wav_sr": sr, "model_sr": self.model.sample_rate}, ) wav = torchaudio.functional.resample( wav, sr, self.model.sample_rate) wav = wav.unsqueeze(0) wav = wav.to(self.device) # Extract discrete codes from SpeechTokenizer with torch.no_grad(): codes = self.model.encode(wav) # codes: (n_q, B, T) semantic_tokens = codes[0, 0, :] acoustic_tokens = codes[1:, 0, :] # Save the encoding as .npy dest_path = os.path.join( destination_folder, "acoustic", os.path.splitext(filename)[0] + ".npy" ) np.save(dest_path, acoustic_tokens.cpu().numpy()) dest_path = os.path.join( destination_folder, "semantic", os.path.splitext(filename)[0] + ".npy" ) np.save(dest_path, semantic_tokens.cpu().numpy()) @staticmethod def _create_subfolders(destination_folder: str): if not os.path.exists(destination_folder + "/acoustic"): os.makedirs(destination_folder + "/acoustic") if not os.path.exists(destination_folder + "/semantic"): os.makedirs(destination_folder + "/semantic")