Source code for mexca.text.transcription

"""Transcribe speech from audio to text.
"""

import argparse
import logging
import os
import re
import warnings
from dataclasses import asdict
from typing import Dict, List, Optional, Union

import numpy as np
import torch
import whisper
from intervaltree import Interval, IntervalTree
from tqdm import tqdm
from whisper.audio import SAMPLE_RATE

from mexca.data import AudioTranscription, SpeakerAnnotation, TranscriptionData
from mexca.utils import ClassInitMessage, optional_str, str2bool

# To filter out shift warnings which do not apply here
warnings.simplefilter("ignore", category=UserWarning)


[docs] class AudioTranscriber: """Transcribe speech from audio to text. Parameters ---------- whisper_model: str, optional, default='small' The name of the whisper model that is used for transcription. Available models are `['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large']`. device: str or torch.device, optional, default='cpu' The name of the device onto which the whisper model should be loaded and run. If CUDA support is available, this can be `'cuda'`, otherwise use `'cpu'` (the default). sentence_rule: str, optional A regular expression used to split segment transcripts into sentences. If `None` (default), it splits the text at all '.', '?', '!', and ':' characters that are followed by whitespace characters. It omits single or multiple words abbreviated with dots (e.g., 'Nr. ' and 'e.g. '). """ def __init__( self, whisper_model: Optional[str] = "small", device: Optional[Union[str, torch.device]] = "cpu", sentence_rule: Optional[str] = None, ): self.logger = logging.getLogger( "mexca.text.transcription.AudioTranscriber" ) self.whisper_model = whisper_model self.device = device # Lazy initialization self._transcriber = None if not sentence_rule: self.sentence_rule = ( r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!|:)\s" ) self.logger.debug( 'Using default sentence rule %s because "sentence_rule=None"', self.sentence_rule, ) else: self.sentence_rule = sentence_rule self.logger.debug(ClassInitMessage()) # Initialize pretrained models only when needed @property
[docs] def transcriber(self) -> whisper.Whisper: """The loaded whisper model for audio transcription.""" if not self._transcriber: self._transcriber = whisper.load_model( self.whisper_model, self.device ) self.logger.debug( "Initialized %s whisper model for audio transcription", self.whisper_model, ) return self._transcriber
# Delete pretrained models when not needed anymore @transcriber.deleter def transcriber(self): self._transcriber = None self.logger.debug( "Removed %s whisper model for audio transcription", self.whisper_model, ) # pylint: disable=too-many-locals
[docs] def apply( self, filepath: str, audio_annotation: SpeakerAnnotation, language: Optional[str] = None, options: Optional[whisper.DecodingOptions] = None, show_progress: bool = True, ) -> AudioTranscription: """Transcribe speech in an audio file to text. Transcribe each annotated speech segment in the audio file and split the transcription into sentences according to `sentence_rule`. Parameters ---------- filepath: str Path to the audio file. audio_annotation: SpeakerAnnotation The audio annotation object returned the `SpeakerIdentifier` component. language: str, optional, default=None The language that is transcribed. Ignored if `options.language` is not `None`. options: whisper.DecodingOptions, optional Options for transcribing the audio file. If `None`, transcription is done without timestamps, and with a number format that depends on whether CUDA is available: FP16 (half-precision floating points) if available, FP32 (single-precision floating points) otherwise. show_progress: bool, optional, default=True Whether a progress bar is displayed or not. Returns ------- AudioTranscription A data class object containing transcribed speech segments split into sentences. """ if not options: self.logger.debug( "Using default options for whisper: No native timestamps and FP16 only if CUDA is available" ) options = self.get_default_options(language=language) audio = torch.Tensor(whisper.load_audio(filepath)) transcription = AudioTranscription( filename=filepath, segments=IntervalTree() ) for i, seg in tqdm( enumerate(audio_annotation.segments), total=len(audio_annotation.segments), disable=not show_progress, ): # Get segment length segment_length = seg.end - seg.begin # Get start and end frame start = int(seg.begin * SAMPLE_RATE) end = int(seg.end * SAMPLE_RATE) # Subset audio signal audio_sub = audio[start:end] self.logger.debug( "Transcribing segment %s from %s to %s", i, seg.begin, seg.end ) try: output = self.transcriber.transcribe( audio_sub, word_timestamps=True, verbose=None, **asdict(options), ) except RuntimeError as exc: if segment_length < 0.02: self.logger.error( "Audio waveform too short to be transcribed: %s", exc ) else: self.logger.error( "The operator aten::_index_put_impl_ is not current implemented for the MPS device" ) continue self.logger.debug( "Detected language: %s", whisper.tokenizer.LANGUAGES[output["language"]].title(), ) text = output["text"].strip() # Split text into sentences sents = re.split(self.sentence_rule, text) self.logger.debug( "Segment text split into %s sentences", len(sents) ) # Concatenate word timestamps from every segment whole_word_timestamps = [] for segment in output["segments"]: whole_word_timestamps.extend(segment["words"]) if len(whole_word_timestamps) > 0: idx = 0 for j, sent in enumerate(sents): sent_len = len(sent.split(" ")) - 1 # Get timestamp of first word in sentence (BEFORE the first word is spoken - 'start') sent_start = self._get_timestamp( whole_word_timestamps, idx, timestamp_type="start" ) # Get timestamp of last word in sentence (AFTER the last word is spoken - 'end') sent_end = self._get_timestamp( whole_word_timestamps, (idx + sent_len), timestamp_type="end", ) self.logger.debug( "Processing sentence %s from %s to %s with text: %s", j, seg.begin + sent_start, seg.begin + sent_end, sent, ) if (sent_end - sent_start) > 0: # Calculate average probability of transcription accuracy for sentence conf = self._get_avg_confidence( whole_word_timestamps, idx, sent_len ) # Add transcription to output transcription.segments.add( Interval( begin=seg.begin + sent_start, end=seg.begin + sent_end, data=TranscriptionData( index=i, text=sent, speaker=seg.data.name, confidence=conf, ), ) ) else: self.logger.warning( "Sentence has duration <= 0 and was not added to transcription" ) idx += sent_len + 1 del self.transcriber return transcription
@staticmethod
[docs] def get_default_options( language: Optional[str] = None, ) -> whisper.DecodingOptions: """Set default options for transcription. Sets language as well as `without_timestamps=False` and `fp16=torch.cuda.is_available()`. Returns ------- whisper.DecodingOptions """ return whisper.DecodingOptions( language=language, without_timestamps=False, fp16=torch.cuda.is_available(), )
@staticmethod def _get_timestamp( word_timestamps: List[Dict[str, Union[str, float]]], idx: int, timestamp_type: str = "start", ) -> float: # get word-level timestamp for the word located at index idx in sequence list of words return word_timestamps[idx][timestamp_type] @staticmethod def _get_avg_confidence( word_timestamps: List[Dict[str, Union[str, float]]], idx: int, sentence_len: int, ) -> float: # Computes the average probability / accuracy of # transcription for a given sentence. Sums the # probabilities for individual words in the sentence # and divide by the sentence length if len(word_timestamps) == 0 or sentence_len == 0: return np.nan total = 0.0 for j, word in enumerate(word_timestamps): if idx <= j < sentence_len: total += word["probability"] return total / sentence_len
# Adapted from whisper.trascribe.cli # See: https://github.com/openai/whisper/blob/main/whisper/transcribe.py
[docs] def cli(): """Command line interface for audio transcription. See `transcribe -h` for details. """ parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("-f", "--filepath", type=str, required=True) parser.add_argument( "-a", "--annotation-path", type=str, required=True, dest="annotation_path", ) parser.add_argument("-o", "--outdir", type=str, required=True) parser.add_argument( "--model", default="small", choices=whisper.available_models() ) parser.add_argument("--device", default="cpu") parser.add_argument("--language", type=optional_str, default=None) parser.add_argument( "--sentence-rule", type=optional_str, default=None, dest="sentence_rule" ) parser.add_argument( "--show-progress", type=str2bool, default=True, dest="show_progress" ) args = parser.parse_args().__dict__ transcriber = AudioTranscriber( whisper_model=args["model"], device=args["device"], sentence_rule=args["sentence_rule"], ) options = whisper.DecodingOptions( language=args["language"], without_timestamps=False, fp16=torch.cuda.is_available(), ) audio_annotation = SpeakerAnnotation.from_json( args["annotation_path"], extra_filename=args["filepath"] ) output = transcriber.apply( args["filepath"], audio_annotation=audio_annotation, options=options, show_progress=args["show_progress"], ) output.write_json( os.path.join( args["outdir"], os.path.splitext(os.path.basename(args["filepath"]))[0] + f"_{output.serialization_name()}.json", ) )
if __name__ == "__main__": cli()