Source code for mexca.audio.identification

"""Speech segment and speaker identification.
"""

import argparse
import logging
import os
from typing import Optional, Union

import torch
from pyannote.audio import Pipeline
from tqdm import tqdm

from mexca.data import SpeakerAnnotation
from mexca.utils import ClassInitMessage, bool_or_str, optional_int


[docs] class AuthenticationError(Exception): """Failed authentication to HuggingFace Hub. Parameters ---------- msg : str Error message. """ def __init__(self, msg: str): super().__init__(msg)
[docs] class SpeakerIdentifier: """Identify speech segments and cluster speakers using speaker diarization. Wrapper class for ``pyannote.audio.SpeakerDiarization``. Uses pretrained speaker diarization model `pyannote/speaker-diarization-3.1` from HuggingFace. Parameters ---------- num_speakers : int, optional Number of speakers to which speech segments will be assigned during the clustering (oracle speakers). If `None`, the number of speakers is estimated from the audio signal. device : torch.device, default=torch.device("cpu") The device on which the speaker diarization model is run. use_auth_token : bool or str, default=True Whether to use the HuggingFace authentication token stored on the machine (if bool) or a HuggingFace authentication token with access to the models ``pyannote/speaker-diarization`` and ``pyannote/segmentation`` (if str). Notes ----- This class requires pretrained models for speaker diarization and segmentation from HuggingFace. To download the models accept the user conditions on `<hf.co/pyannote/speaker-diarization>`_ and `<hf.co/pyannote/segmentation>`_. Then generate an authentication token on `<hf.co/settings/tokens>`_. """ def __init__( self, num_speakers: Optional[int] = None, device: torch.device = torch.device(type="cpu"), use_auth_token: Union[bool, str] = True, ): self.logger = logging.getLogger( "mexca.audio.identification.SpeakerIdentifier" ) self.num_speakers = num_speakers self.device = device self.use_auth_token = use_auth_token # Lazy initialization self._pipeline = None self.logger.debug(ClassInitMessage()) # Initialize pretrained models only when needed @property
[docs] def pipeline(self) -> Pipeline: """The pretrained speaker diarization pipeline. See `pyannote.audio.SpeakerDiarization <https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/pipelines/speaker_diarization.py#L56>`_ for details. """ if not self._pipeline: try: self._pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=self.use_auth_token, ) except EnvironmentError as exc: self.logger.exception("EnvironmentError: %s", exc) raise exc try: if self._pipeline is None: raise AuthenticationError( 'Could not download pretrained "pyannote/speaker-diarization-3.1" pipeline; please provide a valid authentication token' ) except AuthenticationError as exc: self.logger.exception("Error: %s", exc) raise exc self._pipeline.to(self.device) self.logger.debug("Initialized speaker diarization pipeline") return self._pipeline
# Delete pretrained models when not needed anymore @pipeline.deleter def pipeline(self): self._pipeline = None self.logger.debug("Removed speaker diarization pipeline")
[docs] def apply( self, filepath: str, show_progress: bool = True ) -> SpeakerAnnotation: """Identify speech segments and speakers. Parameters ---------- filepath : str Path to the audio file. show_progress: bool, default=True Enables the display of a progress bar. Returns ------- SpeakerAnnotation A data class object that contains detected speech segments and speakers. """ # Init progress bars progress_bar_embeddings = tqdm(delay=1, disable=not show_progress) progress_bar_segments = tqdm(disable=True) # Custom hook for speaker diarization pipeline to update progress bars # Requires named parameter `file` which is not used # pylint: disable=unused-argument def hook( name: str, _, file: Optional[str] = None, total: Optional[int] = None, completed: Optional[int] = None, ): if not completed or not total: return if not progress_bar_embeddings.total and name == "embeddings": self.logger.info("Calculating speaker embeddings") progress_bar_embeddings.reset(total=total) self.logger.debug("Processing batch %s", completed) elif not progress_bar_segments.total and name == "segmentation": self.logger.info("Detecting speech segments") if name == "embeddings": progress_bar_embeddings.update(1) self.logger.debug("Processing batch %s", completed + 1) annotation, embeddings = self.pipeline( filepath, num_speakers=self.num_speakers, return_embeddings=True, hook=hook, ) progress_bar_embeddings.close() progress_bar_segments.close() del self.pipeline self.logger.debug("Detected speaker chart: %s", annotation.chart()) # Update URI to point to a valid file (otherwise pydantic throws an error) annotation.uri = filepath annotation = annotation.rename_labels(generator="int") speaker_average_embeddings = { lbl: embeddings[i] for i, lbl in enumerate(annotation.labels()) } return SpeakerAnnotation.from_pyannote( annotation, speaker_average_embeddings )
[docs] def cli(): """Command line interface for identifying speech segments and speakers. See `identify-speakers -h` for details. """ parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("-f", "--filepath", type=str, required=True) parser.add_argument("-o", "--outdir", type=str, required=True) parser.add_argument( "--num-speakers", type=optional_int, default=None, dest="num_speakers" ) parser.add_argument( "--use-auth-token", type=bool_or_str, default=True, dest="use_auth_token", ) args = parser.parse_args().__dict__ identifier = SpeakerIdentifier( num_speakers=args["num_speakers"], use_auth_token=args["use_auth_token"] ) output = identifier.apply(args["filepath"]) output.write_json( os.path.join( args["outdir"], os.path.splitext(os.path.basename(args["filepath"]))[0] + f"_{output.serialization_name()}.json", ) )
if __name__ == "__main__": cli()