Source code for expert.data.annotation.speech_to_text

from __future__ import annotations

import re
from os import PathLike
from typing import Dict, List, Tuple, Union, Optional

import torch
import whisper

import expert.data.annotation.transcribe as transcribe


[docs]def transcribe_video( video_path: Union[str, PathLike], lang: Optional[str] = "en", model: Optional[str] = "server", device: Optional[Union[torch.device, None]] = None, ) -> Dict: """Speech recognition module from video. Args: video_path (Union[str, PathLike]): Path to the local video file. lang (Optional[str]): Language for speech recognition ['ru', 'en']. Defaults to 'en'. model (Optional[str]): Model configuration for speech recognition ['server', 'local']. Defaults to 'server'. device (Optional[Union[torch.device, None]]): Device type on local machine (GPU recommended). Defaults to None. Raises: NotImplementedError: If 'lang' is not equal to 'en' or 'ru'. NotImplementedError: If 'model' is not equal to 'server' or 'local'. """ if lang not in ["en", "ru"]: raise NotImplementedError("'lang' must be 'en' or 'ru'.") if model not in ["server", "local"]: raise NotImplementedError("'model' must be 'server' or 'local'.") _device = torch.device("cpu") if device is not None: _device = device if model == "server": model = whisper.load_model("medium", device=_device) elif model == "local": model = whisper.load_model("base", device=_device) transcribation = transcribe.transcribe_timestamped( model=model, audio=video_path, language=lang ) return transcribation
[docs]def get_all_words(transcribation: Dict) -> Tuple[List, str]: """Get all stamps with words from the transcribed text. Args: transcribation (Dict): Speech recognition module results. """ full_text = transcribation["text"] all_words = [] for segment in transcribation["segments"]: for word in segment["words"]: all_words.append(word) return all_words, full_text
[docs]def get_phrases(all_words: list, duration: Optional[int] = 10) -> list: """Split transcribed text into segments of a fixed length. Args: all_words (List): All stamps with words from the transcribed text. duration (int, optional): Length of intervals for extracting phrases from speech. Defaults to 10. """ phrases = [] assert len(all_words) > 1, "Not enough words in text." while all_words: init_elem = all_words.pop(0) phrase = init_elem["text"] time_left = duration - (init_elem["end"] - init_elem["start"]) end_time = init_elem["end"] if time_left < 0: phrases.append( {"time": [init_elem["start"], init_elem["end"]], "text": phrase} ) time_left -= init_elem["end"] - end_time continue while time_left > 0 and all_words: elem = all_words.pop(0) phrase = phrase + " " + elem["text"] time_left -= elem["end"] - end_time end_time = elem["end"] else: phrases.append( {"time": [init_elem["start"], elem["end"]], "text": phrase} ) return phrases
def get_sentences(all_words: Optional[List]): pattern = re.compile("[\.!?]") sentences = [] current_sentence = [] for elem in all_words: if pattern.match(elem["text"][-1]) and len(current_sentence) > 0: current_sentence.append(elem["text"]) sentences.append( { "time_start": current_sentence[0], "text": " ".join(current_sentence[1:]), "time_end": elem["end"], } ) current_sentence = [] elif not pattern.match(elem["text"][-1]) and len(current_sentence) == 0: current_sentence.append(elem["start"]) current_sentence.append(elem["text"]) else: if len(current_sentence) == 0: current_sentence.append(elem["start"]) current_sentence.append(elem["text"]) return sentences
[docs]def between_timestamps(all_words: List, start: float, end: float) -> str: """Get phrase between specific timestamps (start, finish) in seconds. Find closest left index for start stamp and closest right index for end. Args: all_words (List): All stamps with words from the transcribed text. start (float): Start timestamp of the interval (in seconds). end (float): End timestamp of the interval (in seconds). Returns: str: Phrase between timestamps. """ def _binary_search(stamps: List, val: float): """Inner function to obtain clossest indexes.""" lowIdx, highIdx = 0, len(stamps) - 1 while highIdx > lowIdx: idx = (highIdx + lowIdx) // 2 elem = stamps[idx] if stamps[lowIdx] == val: return [lowIdx, lowIdx] elif elem == val: return [idx, idx] elif elem > val: if highIdx == idx: return [lowIdx, highIdx] highIdx = idx else: if lowIdx == idx: return [lowIdx, highIdx] lowIdx = idx return [lowIdx, highIdx] assert start >= 0, "Innapropriate start stamp (negative value)" starts = [elem["start"] for elem in all_words] ends = [elem["end"] for elem in all_words] start_idx = min(_binary_search(starts, start)) end_idx = max(_binary_search(ends, end)) # to get the last word if end > all_words[-1]["end"]: end_idx += 1 words = [elem["text"] for elem in all_words[start_idx:end_idx]] return " ".join(words)