import json import os import re import sys import zlib from typing import Callable, Optional, TextIO LANGUAGES = { "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", "yue": "cantonese", } # language code lookup by name, with a few language aliases TO_LANGUAGE_CODE = { **{language: code for code, language in LANGUAGES.items()}, "burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb", "pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si", "castilian": "es", } LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] system_encoding = sys.getdefaultencoding() if system_encoding != "utf-8": def make_safe(string): # replaces any character not representable using the system default encoding with an '?', # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). return string.encode(system_encoding, errors="replace").decode(system_encoding) else: def make_safe(string): # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding return string def exact_div(x, y): assert x % y == 0 return x // y def str2bool(string): str2val = {"True": True, "False": False} if string in str2val: return str2val[string] else: raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") def optional_int(string): return None if string == "None" else int(string) def optional_float(string): return None if string == "None" else float(string) def compression_ratio(text) -> float: text_bytes = text.encode("utf-8") return len(text_bytes) / len(zlib.compress(text_bytes)) def format_timestamp( seconds: float, always_include_hours: bool = False, decimal_marker: str = "." ): assert seconds >= 0, "non-negative timestamp expected" milliseconds = round(seconds * 1000.0) hours = milliseconds // 3_600_000 milliseconds -= hours * 3_600_000 minutes = milliseconds // 60_000 milliseconds -= minutes * 60_000 seconds = milliseconds // 1_000 milliseconds -= seconds * 1_000 hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" return ( f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" ) class ResultWriter: extension: str def __init__(self, output_dir: str): self.output_dir = output_dir def __call__(self, result: dict, audio_path: str, options: dict): audio_basename = os.path.basename(audio_path) audio_basename = os.path.splitext(audio_basename)[0] output_path = os.path.join( self.output_dir, audio_basename + "." + self.extension ) with open(output_path, "w", encoding="utf-8") as f: self.write_result(result, file=f, options=options) def write_result(self, result: dict, file: TextIO, options: dict): raise NotImplementedError class WriteTXT(ResultWriter): extension: str = "txt" def write_result(self, result: dict, file: TextIO, options: dict): for segment in result["segments"]: print(segment["text"].strip(), file=file, flush=True) class SubtitlesWriter(ResultWriter): always_include_hours: bool decimal_marker: str def iterate_result(self, result: dict, options: dict): raw_max_line_width: Optional[int] = options["max_line_width"] max_line_count: Optional[int] = options["max_line_count"] highlight_words: bool = options["highlight_words"] max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width preserve_segments = max_line_count is None or raw_max_line_width is None if len(result["segments"]) == 0: return def iterate_subtitles(): line_len = 0 line_count = 1 # the next subtitle to yield (a list of word timings with whitespace) subtitle: list[dict] = [] times = [] last = result["segments"][0]["start"] for segment in result["segments"]: for i, original_timing in enumerate(segment["words"]): timing = original_timing.copy() long_pause = not preserve_segments if "start" in timing: long_pause = long_pause and timing["start"] - last > 3.0 else: long_pause = False has_room = line_len + len(timing["word"]) <= max_line_width seg_break = i == 0 and len(subtitle) > 0 and preserve_segments if line_len > 0 and has_room and not long_pause and not seg_break: # line continuation line_len += len(timing["word"]) else: # new line timing["word"] = timing["word"].strip() if ( len(subtitle) > 0 and max_line_count is not None and (long_pause or line_count >= max_line_count) or seg_break ): # subtitle break yield subtitle, times subtitle = [] times = [] line_count = 1 elif line_len > 0: # line break line_count += 1 timing["word"] = "\n" + timing["word"] line_len = len(timing["word"].strip()) subtitle.append(timing) times.append((segment["start"], segment["end"], segment.get("speaker"))) if "start" in timing: last = timing["start"] if len(subtitle) > 0: yield subtitle, times if "words" in result["segments"][0]: for subtitle, _ in iterate_subtitles(): sstart, ssend, speaker = _[0] subtitle_start = self.format_timestamp(sstart) subtitle_end = self.format_timestamp(ssend) if result["language"] in LANGUAGES_WITHOUT_SPACES: subtitle_text = "".join([word["word"] for word in subtitle]) else: subtitle_text = " ".join([word["word"] for word in subtitle]) has_timing = any(["start" in word for word in subtitle]) # add [$SPEAKER_ID]: to each subtitle if speaker is available prefix = "" if speaker is not None: prefix = f"[{speaker}]: " if highlight_words and has_timing: last = subtitle_start all_words = [timing["word"] for timing in subtitle] for i, this_word in enumerate(subtitle): if "start" in this_word: start = self.format_timestamp(this_word["start"]) end = self.format_timestamp(this_word["end"]) if last != start: yield last, start, prefix + subtitle_text yield start, end, prefix + " ".join( [ re.sub(r"^(\s*)(.*)$", r"\1\2", word) if j == i else word for j, word in enumerate(all_words) ] ) last = end else: yield subtitle_start, subtitle_end, prefix + subtitle_text else: for segment in result["segments"]: segment_start = self.format_timestamp(segment["start"]) segment_end = self.format_timestamp(segment["end"]) segment_text = segment["text"].strip().replace("-->", "->") if "speaker" in segment: segment_text = f"[{segment['speaker']}]: {segment_text}" yield segment_start, segment_end, segment_text def format_timestamp(self, seconds: float): return format_timestamp( seconds=seconds, always_include_hours=self.always_include_hours, decimal_marker=self.decimal_marker, ) class WriteVTT(SubtitlesWriter): extension: str = "vtt" always_include_hours: bool = False decimal_marker: str = "." def write_result(self, result: dict, file: TextIO, options: dict): print("WEBVTT\n", file=file) for start, end, text in self.iterate_result(result, options): print(f"{start} --> {end}\n{text}\n", file=file, flush=True) class WriteSRT(SubtitlesWriter): extension: str = "srt" always_include_hours: bool = True decimal_marker: str = "," def write_result(self, result: dict, file: TextIO, options: dict): for i, (start, end, text) in enumerate( self.iterate_result(result, options), start=1 ): print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) class WriteTSV(ResultWriter): """ Write a transcript to a file in TSV (tab-separated values) format containing lines like: \t\t Using integer milliseconds as start and end times means there's no chance of interference from an environment setting a language encoding that causes the decimal in a floating point number to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. """ extension: str = "tsv" def write_result(self, result: dict, file: TextIO, options: dict): print("start", "end", "text", sep="\t", file=file) for segment in result["segments"]: print(round(1000 * segment["start"]), file=file, end="\t") print(round(1000 * segment["end"]), file=file, end="\t") print(segment["text"].strip().replace("\t", " "), file=file, flush=True) class WriteAudacity(ResultWriter): """ Write a transcript to a text file that audacity can import as labels. The extension used is "aud" to distinguish it from the txt file produced by WriteTXT. Yet this is not an audacity project but only a label file! Please note : Audacity uses seconds in timestamps not ms! Also there is no header expected. If speaker is provided it is prepended to the text between double square brackets [[]]. """ extension: str = "aud" def write_result(self, result: dict, file: TextIO, options: dict): ARROW = " " for segment in result["segments"]: print(segment["start"], file=file, end=ARROW) print(segment["end"], file=file, end=ARROW) print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True) class WriteJSON(ResultWriter): extension: str = "json" def write_result(self, result: dict, file: TextIO, options: dict): json.dump(result, file, ensure_ascii=False) def get_writer( output_format: str, output_dir: str ) -> Callable[[dict, TextIO, dict], None]: writers = { "txt": WriteTXT, "vtt": WriteVTT, "srt": WriteSRT, "tsv": WriteTSV, "json": WriteJSON, } optional_writers = { "aud": WriteAudacity, } if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] def write_all(result: dict, file: TextIO, options: dict): for writer in all_writers: writer(result, file, options) return write_all if output_format in optional_writers: return optional_writers[output_format](output_dir) return writers[output_format](output_dir) def interpolate_nans(x, method='nearest'): if x.notnull().sum() > 1: return x.interpolate(method=method).ffill().bfill() else: return x.ffill().bfill()