utils.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import json
  2. import os
  3. import sys
  4. import zlib
  5. from typing import Callable, TextIO
  6. system_encoding = sys.getdefaultencoding()
  7. if system_encoding != "utf-8":
  8. def make_safe(string):
  9. # replaces any character not representable using the system default encoding with an '?',
  10. # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
  11. return string.encode(system_encoding, errors="replace").decode(system_encoding)
  12. else:
  13. def make_safe(string):
  14. # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
  15. return string
  16. def exact_div(x, y):
  17. assert x % y == 0
  18. return x // y
  19. def str2bool(string):
  20. str2val = {"True": True, "False": False}
  21. if string in str2val:
  22. return str2val[string]
  23. else:
  24. raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
  25. def optional_int(string):
  26. return None if string == "None" else int(string)
  27. def optional_float(string):
  28. return None if string == "None" else float(string)
  29. def compression_ratio(text) -> float:
  30. text_bytes = text.encode("utf-8")
  31. return len(text_bytes) / len(zlib.compress(text_bytes))
  32. def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
  33. assert seconds >= 0, "non-negative timestamp expected"
  34. milliseconds = round(seconds * 1000.0)
  35. hours = milliseconds // 3_600_000
  36. milliseconds -= hours * 3_600_000
  37. minutes = milliseconds // 60_000
  38. milliseconds -= minutes * 60_000
  39. seconds = milliseconds // 1_000
  40. milliseconds -= seconds * 1_000
  41. hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
  42. return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
  43. class ResultWriter:
  44. extension: str
  45. def __init__(self, output_dir: str):
  46. self.output_dir = output_dir
  47. def __call__(self, result: dict, audio_path: str):
  48. audio_basename = os.path.basename(audio_path)
  49. output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
  50. with open(output_path, "w", encoding="utf-8") as f:
  51. self.write_result(result, file=f)
  52. def write_result(self, result: dict, file: TextIO):
  53. raise NotImplementedError
  54. class WriteTXT(ResultWriter):
  55. extension: str = "txt"
  56. def write_result(self, result: dict, file: TextIO):
  57. for segment in result["segments"]:
  58. print(segment['text'].strip(), file=file, flush=True)
  59. class WriteVTT(ResultWriter):
  60. extension: str = "vtt"
  61. def write_result(self, result: dict, file: TextIO):
  62. print("WEBVTT\n", file=file)
  63. for segment in result["segments"]:
  64. print(
  65. f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
  66. f"{segment['text'].strip().replace('-->', '->')}\n",
  67. file=file,
  68. flush=True,
  69. )
  70. class WriteSRT(ResultWriter):
  71. extension: str = "srt"
  72. def write_result(self, result: dict, file: TextIO):
  73. for i, segment in enumerate(result["segments"], start=1):
  74. # write srt lines
  75. print(
  76. f"{i}\n"
  77. f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
  78. f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
  79. f"{segment['text'].strip().replace('-->', '->')}\n",
  80. file=file,
  81. flush=True,
  82. )
  83. class WriteTSV(ResultWriter):
  84. """
  85. Write a transcript to a file in TSV (tab-separated values) format containing lines like:
  86. <start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
  87. Using integer milliseconds as start and end times means there's no chance of interference from
  88. an environment setting a language encoding that causes the decimal in a floating point number
  89. to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
  90. """
  91. extension: str = "tsv"
  92. def write_result(self, result: dict, file: TextIO):
  93. print("start", "end", "text", sep="\t", file=file)
  94. for segment in result["segments"]:
  95. print(round(1000 * segment['start']), file=file, end="\t")
  96. print(round(1000 * segment['end']), file=file, end="\t")
  97. print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
  98. class WriteJSON(ResultWriter):
  99. extension: str = "json"
  100. def write_result(self, result: dict, file: TextIO):
  101. json.dump(result, file)
  102. def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
  103. writers = {
  104. "txt": WriteTXT,
  105. "vtt": WriteVTT,
  106. "srt": WriteSRT,
  107. "tsv": WriteTSV,
  108. "json": WriteJSON,
  109. }
  110. if output_format == "all":
  111. all_writers = [writer(output_dir) for writer in writers.values()]
  112. def write_all(result: dict, file: TextIO):
  113. for writer in all_writers:
  114. writer(result, file)
  115. return write_all
  116. return writers[output_format](output_dir)