diff --git a/.github/workflows/dockerhub-description.yml b/.github/workflows/dockerhub-description.yml index 0367b21..1301449 100644 --- a/.github/workflows/dockerhub-description.yml +++ b/.github/workflows/dockerhub-description.yml @@ -7,7 +7,7 @@ on: - README.md - .github/workflows/dockerhub-description.yml jobs: - dockerHubDescription: + dockerHubDescriptionKaldi: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 @@ -16,5 +16,16 @@ jobs: with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_PASSWORD }} - repository: lintoai/linto-platform-stt - readme-filepath: ./README.md + repository: lintoai/linto-stt-kaldi + readme-filepath: ./kaldi/README.md + dockerHubDescriptionWhisper: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Docker Hub Description + uses: peter-evans/dockerhub-description@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + repository: lintoai/linto-stt-whisper + readme-filepath: ./whisper/README.md diff --git a/.gitignore b/.gitignore index 0b8d9ad..06b349b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ start_container.sh .env* test/* +tmp* +__pycache__ \ No newline at end of file diff --git a/Jenkinsfile b/Jenkinsfile index 95e42b0..cd1ad07 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,51 +1,72 @@ +def buildDockerfile(main_folder, dockerfilePath, image_name, version, changedFiles) { + if (changedFiles.contains(main_folder) || changedFiles.contains('celery_app') || changedFiles.contains('http_server') || changedFiles.contains('websocket') || changedFiles.contains('document')) { + echo "Building Dockerfile for ${image_name} with version ${version} (using ${dockerfilePath})" + + script { + def image = docker.build(image_name, "-f ${dockerfilePath} .") + + docker.withRegistry('https://registry.hub.docker.com', 'docker-hub-credentials') { + if (version == 'latest-unstable') { + image.push('latest-unstable') + } else { + image.push('latest') + image.push(version) + } + } + } + } +} + pipeline { agent any environment { - DOCKER_HUB_REPO = "lintoai/linto-platform-stt" - DOCKER_HUB_CRED = 'docker-hub-credentials' - - VERSION = '' + DOCKER_HUB_REPO_KALDI = "lintoai/linto-stt-kaldi" + DOCKER_HUB_REPO_WHISPER = "lintoai/linto-stt-whisper" } - - stages{ - stage('Docker build for master branch'){ - when{ + + stages { + stage('Docker build for master branch') { + when { branch 'master' } steps { echo 'Publishing latest' script { - image = docker.build(env.DOCKER_HUB_REPO) - VERSION = sh( + def changedFiles = sh(returnStdout: true, script: 'git diff --name-only HEAD^ HEAD').trim() + echo "My changed files: ${changedFiles}" + + version_kaldi = sh( + returnStdout: true, + script: "awk -v RS='' '/#/ {print; exit}' kaldi/RELEASE.md | head -1 | sed 's/#//' | sed 's/ //'" + ).trim() + + version_whisper = sh( returnStdout: true, - script: "awk -v RS='' '/#/ {print; exit}' RELEASE.md | head -1 | sed 's/#//' | sed 's/ //'" + script: "awk -v RS='' '/#/ {print; exit}' whisper/RELEASE.md | head -1 | sed 's/#//' | sed 's/ //'" ).trim() - docker.withRegistry('https://registry.hub.docker.com', env.DOCKER_HUB_CRED) { - image.push("${VERSION}") - image.push('latest') - } + buildDockerfile('kaldi', 'kaldi/Dockerfile', env.DOCKER_HUB_REPO_KALDI, version_kaldi, changedFiles) + buildDockerfile('whisper', 'whisper/Dockerfile.ctranslate2', env.DOCKER_HUB_REPO_WHISPER, version_whisper, changedFiles) } } } - stage('Docker build for next (unstable) branch'){ - when{ + stage('Docker build for next (unstable) branch') { + when { branch 'next' } steps { echo 'Publishing unstable' script { - image = docker.build(env.DOCKER_HUB_REPO) - VERSION = sh( - returnStdout: true, - script: "awk -v RS='' '/#/ {print; exit}' RELEASE.md | head -1 | sed 's/#//' | sed 's/ //'" - ).trim() - docker.withRegistry('https://registry.hub.docker.com', env.DOCKER_HUB_CRED) { - image.push('latest-unstable') - } + def changedFiles = sh(returnStdout: true, script: 'git diff --name-only HEAD^ HEAD').trim() + echo "My changed files: ${changedFiles}" + + version = 'latest-unstable' + + buildDockerfile('kaldi', 'kaldi/Dockerfile', env.DOCKER_HUB_REPO_KALDI, version, changedFiles) + buildDockerfile('whisper', 'whisper/Dockerfile.ctranslate2', env.DOCKER_HUB_REPO_WHISPER, version, changedFiles) } } } - }// end stages + } } \ No newline at end of file diff --git a/Makefile b/Makefile index 71be1a8..24db387 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .DEFAULT_GOAL := help -target_dirs := stt http_server celery_app +target_dirs := kaldi/stt whisper/stt http_server celery_app help: @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/README.md b/README.md index ec70060..09009fe 100644 --- a/README.md +++ b/README.md @@ -1,224 +1,12 @@ -# LINTO-PLATFORM-STT -LinTO-platform-stt is the transcription service within the [LinTO stack](https://github.com/linto-ai/linto-platform-stack). +# LinTO-STT -LinTO-platform-stt can either be used as a standalone transcription service or deployed within a micro-services infrastructure using a message broker connector. +LinTO-STT is the transcription service within the [LinTO stack](https://github.com/linto-ai/linto-platform-stack), +which can currently work with Speech-To-Text (STT) models. +The following families of STT models are currently supported (please refer to respective documentation for more details): +* [Kaldi models](kaldi/README.md) +* [Whisper models](whisper/README.md) -## Pre-requisites - -### Hardware -To run the transcription models you'll need: -* At least 7Go of disk space to build the docker image. -* Up to 7GB of RAM depending on the model used. -* One CPU per worker. Inference time scales on CPU performances. - -### Model -LinTO-Platform-STT accepts two kinds of models: -* LinTO Acoustic and Languages models. -* Vosk models. - -We provide home-cured models (v2) on [dl.linto.ai](https://doc.linto.ai/docs/developpers/apis/ASR/models). -Or you can also use Vosk models available [here](https://alphacephei.com/vosk/models). - -### Docker -The transcription service requires docker up and running. - -### (micro-service) Service broker and shared folder -The STT only entry point in task mode are tasks posted on a message broker. Supported message broker are RabbitMQ, Redis, Amazon SQS. -On addition, as to prevent large audio from transiting through the message broker, STT-Worker use a shared storage folder (SHARED_FOLDER). - -## Deploy linto-platform-stt - -**1- First step is to build or pull the image:** - -```bash -git clone https://github.com/linto-ai/linto-platform-stt.git -cd linto-platform-stt -docker build . -t linto-platform-stt:latest -``` -or - -```bash -docker pull lintoai/linto-platform-stt -``` - -**2- Download the models** - -Have the acoustic and language model ready at AM_PATH and LM_PATH if you are using LinTO models. If you are using a Vosk model, have it ready at MODEL. - -**3- Fill the .env** - -```bash -cp .envdefault .env -``` - -| PARAMETER | DESCRIPTION | EXEMPLE | -|---|---|---| -| SERVICE_MODE | STT serving mode see [Serving mode](#serving-mode) | http\|task\|websocket | -| MODEL_TYPE | Type of STT model used. | lin\|vosk | -| ENABLE_STREAMING | Using http serving mode, enable the /streaming websocket route | true\|false | -| SERVICE_NAME | Using the task mode, set the queue's name for task processing | my-stt | -| SERVICE_BROKER | Using the task mode, URL of the message broker | redis://my-broker:6379 | -| BROKER_PASS | Using the task mode, broker password | my-password | -| STREAMING_PORT | Using the websocket mode, the listening port for ingoing WS connexions. | 80 | -| CONCURRENCY | Maximum number of parallel requests | >1 | - -### Serving mode -![Serving Modes](https://i.ibb.co/qrtv3Z6/platform-stt.png) - -STT can be used three ways: -* Through an [HTTP API](#http-server) using the **http**'s mode. -* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode. -* Through a [websocket server](#websocket-server) **websocket**'s mode. - -Mode is specified using the .env value or environment variable ```SERVING_MODE```. -```bash -SERVICE_MODE=http -``` -### HTTP Server -The HTTP serving mode deploys a HTTP server and a swagger-ui to allow transcription request on a dedicated route. - -The SERVICE_MODE value in the .env should be set to ```http```. - -```bash -docker run --rm \ --p HOST_SERVING_PORT:80 \ --v AM_PATH:/opt/AM \ --v LM_PATH:/opt/LM \ ---env-file .env \ -linto-platform-stt:latest -``` - -This will run a container providing an [HTTP API](#http-api) binded on the host HOST_SERVING_PORT port. - -**Parameters:** -| Variables | Description | Example | -|:-|:-|:-| -| HOST_SERVING_PORT | Host serving port | 80 | -| AM_PATH | Path to the acoustic model on the host machine mounted to /opt/AM | /my/path/to/models/AM_fr-FR_v2.2.0 | -| LM_PATH | Path to the language model on the host machine mounted to /opt/LM | /my/path/to/models/fr-FR_big-v2.2.0 | -| MODEL_PATH | Path to the model (using MODEL_TYPE=vosk) mounted to /opt/model | /my/path/to/models/vosk-model | - -### Micro-service within LinTO-Platform stack -The HTTP serving mode connect a celery worker to a message broker. - -The SERVICE_MODE value in the .env should be set to ```task```. - ->LinTO-platform-stt can be deployed within the linto-platform-stack through the use of linto-platform-services-manager. Used this way, the container spawn celery worker waiting for transcription task on a message broker. ->LinTO-platform-stt in task mode is not intended to be launch manually. ->However, if you intent to connect it to your custom message's broker here are the parameters: - -You need a message broker up and running at MY_SERVICE_BROKER. - -```bash -docker run --rm \ --v AM_PATH:/opt/AM \ --v LM_PATH:/opt/LM \ --v SHARED_AUDIO_FOLDER:/opt/audio \ ---env-file .env \ -linto-platform-stt:latest -``` - -**Parameters:** -| Variables | Description | Example | -|:-|:-|:-| -| AM_PATH | Path to the acoustic model on the host machine mounted to /opt/AM | /my/path/to/models/AM_fr-FR_v2.2.0 | -| LM_PATH | Path to the language model on the host machine mounted to /opt/LM | /my/path/to/models/fr-FR_big-v2.2.0 | -| MODEL_PATH | Path to the model (using MODEL_TYPE=vosk) mounted to /opt/model | /my/path/to/models/vosk-model | -| SHARED_AUDIO_FOLDER | Shared audio folder mounted to /opt/audio | /my/path/to/models/vosk-model | - - -### Websocket Server -Websocket server's mode deploy a streaming transcription service only. - -The SERVICE_MODE value in the .env should be set to ```websocket```. - -Usage is the same as the [http streaming API](#/streaming) - -## Usages -### HTTP API -#### /healthcheck -Returns the state of the API - -Method: GET - -Returns "1" if healthcheck passes. - -#### /transcribe -Transcription API - -* Method: POST -* Response content: text/plain or application/json -* File: An Wave file 16b 16Khz - -Return the transcripted text using "text/plain" or a json object when using "application/json" structure as followed: -```json -{ - "text" : "This is the transcription", - "words" : [ - {"word":"This", "start": 0.123, "end": 0.453, "conf": 0.9}, - ... - ] - "confidence-score": 0.879 -} -``` - -#### /streaming -The /streaming route is accessible if the ENABLE_STREAMING environment variable is set to true. - -The route accepts websocket connexions. Exchanges are structured as followed: -1. Client send a json {"config": {"sample_rate":16000}}. -2. Client send audio chunk (go to 3- ) or {"eof" : 1} (go to 5-). -3. Server send either a partial result {"partial" : "this is a "} or a final result {"text": "this is a transcription"}. -4. Back to 2- -5. Server send a final result and close the connexion. - -> Connexion will be closed and the worker will be freed if no chunk are received for 10s. - -#### /docs -The /docs route offers a OpenAPI/swagger interface. - -### Through the message broker - -STT-Worker accepts requests with the following arguments: -```file_path: str, with_metadata: bool``` - -* file_path: Is the location of the file within the shared_folder. /.../SHARED_FOLDER/{file_path} -* with_metadata: If True, words timestamps and confidence will be computed and returned. If false, the fields will be empty. - -#### Return format -On a successfull transcription the returned object is a json object structured as follow: -```json -{ - "text" : "this is the transcription as text", - "words": [ - { - "word" : "this", - "start": 0.0, - "end": 0.124, - "conf": 1.0 - }, - ... - ], - "confidence-score": "" -} -``` - -* The text field contains the raw transcription. -* The word field contains each word with their time stamp and individual confidence. (Empty if with_metadata=False) -* The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False) - - -## Test -### Curl -You can test you http API using curl: -```bash -curl -X POST "http://YOUR_SERVICE:YOUR_PORT/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@YOUR_FILE;type=audio/x-wav" -``` +LinTO-STT can either be used as a standalone transcription service or deployed within a micro-services infrastructure using a message broker connector. ## License This project is developped under the AGPLv3 License (see LICENSE). - -## Acknowlegment. - -* [Vosk, speech recognition toolkit](https://alphacephei.com/vosk/). -* [Kaldi Speech Recognition Toolkit](https://github.com/kaldi-asr/kaldi) diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 9966250..0000000 --- a/RELEASE.md +++ /dev/null @@ -1,52 +0,0 @@ -# 3.3.2 -- Fixed use of stereo audio in http serving mode - -# 3.3.1 -- Fixed lin_to_vosk throwing an error on a already existing container. -- Corrected an error on the README regarding mounting model volumes. -- Code styling (PEP 8) - -# 3.3.0 -- Added optional streaming route to the http serving mode -- Added serving mode: websocket -- Added Dynamic model conversion allowing to use either Vosk Models or Linagora AM/LM models -- Changer Vosk dependency to alphacep/vosk -- Updated README.md - -# 3.2.1 -- Repository total rework. The goal being to have a simple transcription service embeddable within a micro-service infrastructure. -- Changed repository name from linto-platform-stt-standalone-worker to linto-platform-stt. -- Added celery connector for microservice integration. -- Added launch option to specify serving mode between task and http. -- Removed diarization functionnality. -- Removed punctuation functionnality. -- Removed Async requests/Job management. -- Updated README to reflect those changes. - -# 3.1.1 -- Change Pykaldi with vosk-API (no python wrapper for decoding function, no extrat packages during installation, c++ implementation based on kaldi functions) -- New feature: Compute a confidence score per transcription -- Fix minor bugs - -# 2.2.1 -- Fix minor bugs -- put SWAGGER_PATH parameter as optional -- Generate the word_boundary file if it does not exist - -# 2.2.0 -- Speaker diarization feature: pyBK package -- Mulithreading feature: Speech decoding and Speaker diarization processes -- Optional parameter: real number of speaker in the audio - -# 2.0.0 -- Reimplement LinTO-Platform-stt-standalone-worker using Pykaldi package - -# 1.1.2 -- New features: - - Word timestamp computing - - Response type: plain/text: simple text output and application/json: the transcription and the words timestamp. - - Swagger: integrate swagger in the service using a python package - - Fix minor bugs - -# 1.0.0 -- First build of LinTO-Platform-stt-standalone-worker \ No newline at end of file diff --git a/celery_app/celeryapp.py b/celery_app/celeryapp.py index e04d73b..d1c4099 100644 --- a/celery_app/celeryapp.py +++ b/celery_app/celeryapp.py @@ -1,7 +1,6 @@ import os from celery import Celery - from stt import logger celery = Celery(__name__, include=["celery_app.tasks"]) @@ -10,9 +9,14 @@ if os.environ.get("BROKER_PASS", False): components = broker_url.split("//") broker_url = f'{components[0]}//:{os.environ.get("BROKER_PASS")}@{components[1]}' + celery.conf.broker_url = f"{broker_url}/0" celery.conf.result_backend = f"{broker_url}/1" -celery.conf.update(result_expires=3600, task_acks_late=True, task_track_started=True) +celery.conf.task_acks_late = False +celery.conf.task_track_started = True +celery.conf.broker_transport_options = {"visibility_timeout": float("inf")} +# celery.conf.result_backend_transport_options = {"visibility_timeout": float("inf")} +# celery.conf.result_expires = 3600 * 24 # Queues celery.conf.update( diff --git a/celery_app/tasks.py b/celery_app/tasks.py index ce2ca4d..114df2a 100644 --- a/celery_app/tasks.py +++ b/celery_app/tasks.py @@ -1,10 +1,11 @@ import asyncio import os -from celery_app.celeryapp import celery from stt import logger -from stt.processing import decode, model -from stt.processing.utils import load_wave +from stt.processing import MODEL, decode +from stt.processing.utils import load_audiofile + +from celery_app.celeryapp import celery @celery.task(name="transcribe_task") @@ -15,16 +16,22 @@ def transcribe_task(file_name: str, with_metadata: bool): # Load wave file_path = os.path.join("/opt/audio", file_name) try: - file_content = load_wave(file_path) + file_content = load_audiofile(file_path) except Exception as err: - logger.error(f"Failed to load ressource: {repr(err)}") - raise Exception(f"Could not open ressource {file_path}") from err + import traceback + + msg = f"{traceback.format_exc()}\nFailed to load ressource {file_path}" + logger.error(msg) + raise Exception(msg) # from err # Decode try: - result = decode(file_content, model, 16000, with_metadata) + result = decode(file_content, MODEL, with_metadata) except Exception as err: - logger.error(f"Failed to decode: {repr(err)}") - raise Exception(f"Failed to decode {file_path}") from err + import traceback + + msg = f"{traceback.format_exc()}\nFailed to decode {file_path}" + logger.error(msg) + raise Exception(msg) # from err return result diff --git a/document/swagger.yml b/document/swagger.yml index 70bc9fc..6da4ed6 100644 --- a/document/swagger.yml +++ b/document/swagger.yml @@ -2,7 +2,7 @@ swagger: "2.0" info: version: "1.0.0" - title: LinTo-Platform-STT + title: LinTo-STT description: Speech To Text API contact: email: "support@linto.ai" diff --git a/http_server/confparser.py b/http_server/confparser.py index 2396d71..d296dbb 100644 --- a/http_server/confparser.py +++ b/http_server/confparser.py @@ -7,24 +7,6 @@ def createParser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - # SERVICE - parser.add_argument( - "--service_name", - type=str, - help="Service Name", - default=os.environ.get("SERVICE_NAME", "stt"), - ) - - # MODELS - parser.add_argument("--am_path", type=str, help="Acoustic Model Path", default="/opt/models/AM") - parser.add_argument("--lm_path", type=str, help="Decoding graph path", default="/opt/models/LM") - parser.add_argument( - "--config_path", - type=str, - help="Configuration files path", - default="/opt/config", - ) - # GUNICORN parser.add_argument("--service_port", type=int, help="Service port", default=80) parser.add_argument( diff --git a/http_server/ingress.py b/http_server/ingress.py index 5a9c661..6c71478 100644 --- a/http_server/ingress.py +++ b/http_server/ingress.py @@ -3,17 +3,15 @@ import json import logging import os -from time import time +import time from confparser import createParser -from flask import Flask, Response, abort, json, request -from flask_sock import Sock -from serving import GunicornServing +from flask import Flask, json, request +from serving import GeventServing, GunicornServing +from stt import logger as stt_logger +from stt.processing import MODEL, USE_GPU, decode, load_wave_buffer from swagger import setupSwaggerUI -from stt.processing import decode, formatAudio, model -from stt.processing.streaming import ws_streaming - app = Flask("__stt-standalone-worker__") app.config["JSON_AS_ASCII"] = False app.config["JSON_SORT_KEYS"] = False @@ -26,13 +24,16 @@ # If websocket streaming route is enabled if os.environ.get("ENABLE_STREAMING", False) in [True, "true", 1]: + from flask_sock import Sock + from stt.processing.streaming import ws_streaming + logger.info("Init websocket serving ...") sock = Sock(app) logger.info("Streaming is enabled") @sock.route("/streaming") def streaming(web_socket): - ws_streaming(web_socket, model) + ws_streaming(web_socket, MODEL) @app.route("/healthcheck", methods=["GET"]) @@ -51,39 +52,38 @@ def transcribe(): logger.info("Transcribe request received") # get response content type - logger.debug(request.headers.get("accept").lower()) + # logger.debug(request.headers.get("accept").lower()) if request.headers.get("accept").lower() == "application/json": join_metadata = True elif request.headers.get("accept").lower() == "text/plain": join_metadata = False else: - raise ValueError("Not accepted header") - logger.debug("Metadata: {}".format(join_metadata)) + raise ValueError( + f"Not accepted header (accept={request.headers.get('accept')} should be either application/json or text/plain)" + ) + # logger.debug("Metadata: {}".format(join_metadata)) # get input file - if "file" in request.files.keys(): - file_buffer = request.files["file"].read() - audio_data, sampling_rate = formatAudio(file_buffer) - start_t = time() + if "file" not in request.files.keys(): + raise ValueError(f"No audio file was uploaded (missing 'file' key)") - # Transcription - transcription = decode(audio_data, model, sampling_rate, join_metadata) - logger.debug("Transcription complete (t={}s)".format(time() - start_t)) + file_buffer = request.files["file"].read() - logger.debug("... Complete") + audio_data = load_wave_buffer(file_buffer) - else: - raise ValueError("No audio file was uploaded") + # Transcription + transcription = decode(audio_data, MODEL, join_metadata) if join_metadata: return json.dumps(transcription, ensure_ascii=False), 200 return transcription["text"], 200 - except ValueError as error: - return str(error), 400 except Exception as error: - logger.error(error) - return "Server Error: {}".format(str(error)), 500 + import traceback + + logger.error(traceback.format_exc()) + logger.error(repr(error)) + return "Server Error: {}".format(str(error)), 400 if isinstance(error, ValueError) else 500 @app.errorhandler(405) @@ -107,7 +107,9 @@ def server_error(error): parser = createParser() args = parser.parse_args() - logger.setLevel(logging.DEBUG if args.debug else logging.INFO) + logger_level = logging.DEBUG if args.debug else logging.INFO + logger.setLevel(logger_level) + stt_logger.setLevel(logger_level) try: # Setup SwaggerUI if args.swagger_path is not None: @@ -116,12 +118,21 @@ def server_error(error): except Exception as err: logger.warning("Could not setup swagger: {}".format(str(err))) - serving = GunicornServing( + logger.info(f"Using {args.workers} workers") + + if USE_GPU: # TODO: get rid of this? + serving_type = GeventServing + logger.debug("Serving with gevent") + else: + serving_type = GunicornServing + logger.debug("Serving with gunicorn") + + serving = serving_type( app, { "bind": f"0.0.0.0:{args.service_port}", "workers": args.workers, - "timeout": 3600, + "timeout": 3600 * 24, }, ) logger.info(args) diff --git a/http_server/serving.py b/http_server/serving.py index d2dd7e8..9230eb4 100644 --- a/http_server/serving.py +++ b/http_server/serving.py @@ -1,5 +1,9 @@ +import gevent.monkey +import gevent.pywsgi import gunicorn.app.base +gevent.monkey.patch_all() + class GunicornServing(gunicorn.app.base.BaseApplication): def __init__(self, app, options=None): @@ -18,3 +22,22 @@ def load_config(self): def load(self): return self.application + + +class GeventServing: + def __init__(self, app, options=None): + self.options = options or {} + self.application = app + + def run(self): + bind = self.options.get("bind", "0.0.0.0:8080") + workers = self.options.get("workers", 1) + listener = bind.split(":") + try: + assert len(listener) == 2 + listener = (listener[0], int(listener[1])) + except: + print(f"Invalid bind address {bind}") + + server = gevent.pywsgi.WSGIServer(listener, self.application, spawn=workers) + server.serve_forever() diff --git a/http_server/swagger.py b/http_server/swagger.py index a9b93d0..31344cd 100644 --- a/http_server/swagger.py +++ b/http_server/swagger.py @@ -11,7 +11,7 @@ def setupSwaggerUI(app, args): args.swagger_prefix + args.swagger_url, args.swagger_path, config={ # Swagger UI config overrides - "app_name": "LinTO Platform STT", + "app_name": "LinTO STT", "spec": swagger_yml, }, ) diff --git a/.envdefault b/kaldi/.envdefault similarity index 100% rename from .envdefault rename to kaldi/.envdefault diff --git a/Dockerfile b/kaldi/Dockerfile similarity index 91% rename from Dockerfile rename to kaldi/Dockerfile index bdf65c0..f062951 100644 --- a/Dockerfile +++ b/kaldi/Dockerfile @@ -45,7 +45,7 @@ RUN git clone -b vosk --single-branch https://github.com/alphacep/kaldi /opt/kal && make -j $(nproc) online2 lm rnnlm # Install python dependencies -COPY requirements.txt ./ +COPY kaldi/requirements.txt ./ RUN pip install --no-cache-dir -r requirements.txt # Install Custom Vosk API @@ -57,13 +57,13 @@ RUN git clone --depth 1 https://github.com/alphacep/vosk-api /opt/vosk-api && cd WORKDIR /usr/src/app -COPY stt /usr/src/app/stt COPY celery_app /usr/src/app/celery_app COPY http_server /usr/src/app/http_server COPY websocket /usr/src/app/websocket COPY document /usr/src/app/document -COPY docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ -COPY lin_to_vosk.py /usr/src/app/lin_to_vosk.py +COPY kaldi/stt /usr/src/app/stt +COPY kaldi/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ +COPY kaldi/lin_to_vosk.py /usr/src/app/lin_to_vosk.py RUN mkdir -p /var/log/supervisor/ diff --git a/kaldi/README.md b/kaldi/README.md new file mode 100644 index 0000000..0e3a31a --- /dev/null +++ b/kaldi/README.md @@ -0,0 +1,222 @@ +# LinTO-STT-Kaldi + +LinTO-STT-Kaldi is the transcription service within the [LinTO stack](https://github.com/linto-ai/linto-platform-stack) +based on Speech-To-Text (STT) models trained with [Kaldi](https://github.com/kaldi-asr/kaldi). + +LinTO-STT-Kaldi can either be used as a standalone transcription service or deployed within a micro-services infrastructure using a message broker connector. + +## Pre-requisites + +### Hardware +To run the transcription models you'll need: +* At least 7Go of disk space to build the docker image. +* Up to 7GB of RAM depending on the model used. +* One CPU per worker. Inference time scales on CPU performances. + +### Model +LinTO-STT-Kaldi accepts two kinds of models: +* LinTO Acoustic and Languages models. +* Vosk models. + +We provide home-cured models (v2) on [dl.linto.ai](https://doc.linto.ai/docs/developpers/apis/ASR/models). +Or you can also use Vosk models available [here](https://alphacephei.com/vosk/models). + +### Docker +The transcription service requires docker up and running. + +### (micro-service) Service broker and shared folder +The STT only entry point in task mode are tasks posted on a message broker. Supported message broker are RabbitMQ, Redis, Amazon SQS. +On addition, as to prevent large audio from transiting through the message broker, STT-Worker use a shared storage folder (SHARED_FOLDER). + +## Deploy LinTO-STT-Kaldi + +**1- First step is to build or pull the image:** + +```bash +git clone https://github.com/linto-ai/linto-stt.git +cd linto-stt +docker build . -f kaldi/Dockerfile -t linto-stt-kaldi:latest +``` +or + +```bash +docker pull lintoai/linto-stt-kaldi +``` + +**2- Download the models** + +Have the acoustic and language model ready at AM_PATH and LM_PATH if you are using LinTO models. If you are using a Vosk model, have it ready at MODEL. + +**3- Fill the .env** + +```bash +cp kaldi/.envdefault kaldi/.env +``` + +| PARAMETER | DESCRIPTION | EXEMPLE | +|---|---|---| +| SERVICE_MODE | STT serving mode see [Serving mode](#serving-mode) | http\|task\|websocket | +| MODEL_TYPE | Type of STT model used. | lin\|vosk | +| ENABLE_STREAMING | Using http serving mode, enable the /streaming websocket route | true\|false | +| SERVICE_NAME | Using the task mode, set the queue's name for task processing | my-stt | +| SERVICE_BROKER | Using the task mode, URL of the message broker | redis://my-broker:6379 | +| BROKER_PASS | Using the task mode, broker password | my-password | +| STREAMING_PORT | Using the websocket mode, the listening port for ingoing WS connexions. | 80 | +| CONCURRENCY | Maximum number of parallel requests | >1 | + +### Serving mode +![Serving Modes](https://i.ibb.co/qrtv3Z6/platform-stt.png) + +STT can be used three ways: +* Through an [HTTP API](#http-server) using the **http**'s mode. +* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode. +* Through a [websocket server](#websocket-server) **websocket**'s mode. + +Mode is specified using the .env value or environment variable ```SERVING_MODE```. +```bash +SERVICE_MODE=http +``` +### HTTP Server +The HTTP serving mode deploys a HTTP server and a swagger-ui to allow transcription request on a dedicated route. + +The SERVICE_MODE value in the .env should be set to ```http```. + +```bash +docker run --rm \ +-p HOST_SERVING_PORT:80 \ +-v AM_PATH:/opt/AM \ +-v LM_PATH:/opt/LM \ +--env-file kaldi/.env \ +linto-stt-kaldi:latest +``` + +This will run a container providing an [HTTP API](#http-api) binded on the host HOST_SERVING_PORT port. + +**Parameters:** +| Variables | Description | Example | +|:-|:-|:-| +| HOST_SERVING_PORT | Host serving port | 80 | +| AM_PATH | Path to the acoustic model on the host machine mounted to /opt/AM | /my/path/to/models/AM_fr-FR_v2.2.0 | +| LM_PATH | Path to the language model on the host machine mounted to /opt/LM | /my/path/to/models/fr-FR_big-v2.2.0 | +| MODEL_PATH | Path to the model (using MODEL_TYPE=vosk) mounted to /opt/model | /my/path/to/models/vosk-model | + +### Micro-service within LinTO-Platform stack +The TASK serving mode connect a celery worker to a message broker. + +The SERVICE_MODE value in the .env should be set to ```task```. + +You need a message broker up and running at MY_SERVICE_BROKER. + +```bash +docker run --rm \ +-v AM_PATH:/opt/AM \ +-v LM_PATH:/opt/LM \ +-v SHARED_AUDIO_FOLDER:/opt/audio \ +--env-file kaldi/.env \ +linto-stt-kaldi:latest +``` + +**Parameters:** +| Variables | Description | Example | +|:-|:-|:-| +| AM_PATH | Path to the acoustic model on the host machine mounted to /opt/AM | /my/path/to/models/AM_fr-FR_v2.2.0 | +| LM_PATH | Path to the language model on the host machine mounted to /opt/LM | /my/path/to/models/fr-FR_big-v2.2.0 | +| MODEL_PATH | Path to the model (using MODEL_TYPE=vosk) mounted to /opt/model | /my/path/to/models/vosk-model | +| SHARED_AUDIO_FOLDER | Shared audio folder mounted to /opt/audio | /my/path/to/models/vosk-model | + + +### Websocket Server +Websocket server's mode deploy a streaming transcription service only. + +The SERVICE_MODE value in the .env should be set to ```websocket```. + +Usage is the same as the [http streaming API](#/streaming) + +## Usages +### HTTP API +#### /healthcheck +Returns the state of the API + +Method: GET + +Returns "1" if healthcheck passes. + +#### /transcribe +Transcription API + +* Method: POST +* Response content: text/plain or application/json +* File: An Wave file 16b 16Khz + +Return the transcripted text using "text/plain" or a json object when using "application/json" structure as followed: +```json +{ + "text" : "This is the transcription", + "words" : [ + {"word":"This", "start": 0.123, "end": 0.453, "conf": 0.9}, + ... + ] + "confidence-score": 0.879 +} +``` + +#### /streaming +The /streaming route is accessible if the ENABLE_STREAMING environment variable is set to true. + +The route accepts websocket connexions. Exchanges are structured as followed: +1. Client send a json {"config": {"sample_rate":16000}}. +2. Client send audio chunk (go to 3- ) or {"eof" : 1} (go to 5-). +3. Server send either a partial result {"partial" : "this is a "} or a final result {"text": "this is a transcription"}. +4. Back to 2- +5. Server send a final result and close the connexion. + +> Connexion will be closed and the worker will be freed if no chunk are received for 10s. + +#### /docs +The /docs route offers a OpenAPI/swagger interface. + +### Through the message broker + +STT-Worker accepts requests with the following arguments: +```file_path: str, with_metadata: bool``` + +* file_path: Is the location of the file within the shared_folder. /.../SHARED_FOLDER/{file_path} +* with_metadata: If True, words timestamps and confidence will be computed and returned. If false, the fields will be empty. + +#### Return format +On a successfull transcription the returned object is a json object structured as follow: +```json +{ + "text" : "this is the transcription as text", + "words": [ + { + "word" : "this", + "start": 0.0, + "end": 0.124, + "conf": 1.0 + }, + ... + ], + "confidence-score": "" +} +``` + +* The text field contains the raw transcription. +* The word field contains each word with their time stamp and individual confidence. (Empty if with_metadata=False) +* The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False) + + +## Test +### Curl +You can test you http API using curl: +```bash +curl -X POST "http://YOUR_SERVICE:YOUR_PORT/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@YOUR_FILE;type=audio/x-wav" +``` + +## License +This project is developped under the AGPLv3 License (see LICENSE). + +## Acknowlegment. + +* [Vosk, speech recognition toolkit](https://alphacephei.com/vosk/). +* [Kaldi Speech Recognition Toolkit](https://github.com/kaldi-asr/kaldi) diff --git a/kaldi/RELEASE.md b/kaldi/RELEASE.md new file mode 100644 index 0000000..e11f89a --- /dev/null +++ b/kaldi/RELEASE.md @@ -0,0 +1,3 @@ +# 1.0.0 +- First build of linto-stt-kaldi +- Based on 3.3.2 of linto-stt (https://github.com/linto-ai/linto-stt/blob/4361300a4463c90cec0bf3fa2975d7cc2ddf8d36/RELEASE.md) diff --git a/docker-entrypoint.sh b/kaldi/docker-entrypoint.sh similarity index 100% rename from docker-entrypoint.sh rename to kaldi/docker-entrypoint.sh diff --git a/lin_to_vosk.py b/kaldi/lin_to_vosk.py similarity index 100% rename from lin_to_vosk.py rename to kaldi/lin_to_vosk.py diff --git a/requirements.txt b/kaldi/requirements.txt similarity index 96% rename from requirements.txt rename to kaldi/requirements.txt index 132bdfc..867a095 100644 --- a/requirements.txt +++ b/kaldi/requirements.txt @@ -4,6 +4,7 @@ flask>=1.1.2 flask-cors>=3.0.10 flask-swagger-ui>=3.36.0 flask-sock +gevent gunicorn pyyaml>=5.4.1 wavio>=0.0.4 diff --git a/stt/__init__.py b/kaldi/stt/__init__.py similarity index 100% rename from stt/__init__.py rename to kaldi/stt/__init__.py diff --git a/stt/processing/__init__.py b/kaldi/stt/processing/__init__.py similarity index 68% rename from stt/processing/__init__.py rename to kaldi/stt/processing/__init__.py index 2a3eca5..9f99406 100644 --- a/stt/processing/__init__.py +++ b/kaldi/stt/processing/__init__.py @@ -2,13 +2,19 @@ import sys from time import time -from vosk import Model - from stt import logger from stt.processing.decoding import decode -from stt.processing.utils import formatAudio, load_wave +from stt.processing.utils import load_audiofile, load_wave_buffer +from vosk import Model -__all__ = ["model", "logger", "decode", "load_wave", "formatAudio"] +__all__ = [ + "logger", + "decode", + "load_audiofile", + "load_wave_buffer", + "MODEL", + "USE_GPU", +] # Model locations (should be mounted) MODEL_PATH = "/opt/model" @@ -17,8 +23,11 @@ logger.info("Loading acoustic model and decoding graph ...") start = time() try: - model = Model(MODEL_PATH) + MODEL = Model(MODEL_PATH) except Exception as err: raise Exception("Failed to load transcription model: {}".format(str(err))) from err sys.exit(-1) logger.info("Acoustic model and decoding graph loaded. (t={}s)".format(time() - start)) + +# Not implemented yet in Kaldi +USE_GPU = False diff --git a/stt/processing/decoding.py b/kaldi/stt/processing/decoding.py similarity index 89% rename from stt/processing/decoding.py rename to kaldi/stt/processing/decoding.py index 2e1fb7c..8c06007 100644 --- a/stt/processing/decoding.py +++ b/kaldi/stt/processing/decoding.py @@ -4,10 +4,12 @@ from vosk import KaldiRecognizer, Model -def decode(audio_data: bytes, model: Model, sampling_rate: int, with_metadata: bool) -> dict: +def decode(audio: tuple[bytes, int], model: Model, with_metadata: bool) -> dict: """Transcribe the audio data using the vosk library with the defined model.""" result = {"text": "", "confidence-score": 0.0, "words": []} + audio_data, sampling_rate = audio + recognizer = KaldiRecognizer(model, sampling_rate) recognizer.SetMaxAlternatives(0) # Set confidence per words recognizer.SetWords(with_metadata) diff --git a/stt/processing/streaming.py b/kaldi/stt/processing/streaming.py similarity index 99% rename from stt/processing/streaming.py rename to kaldi/stt/processing/streaming.py index 28274b8..a33ecfc 100644 --- a/stt/processing/streaming.py +++ b/kaldi/stt/processing/streaming.py @@ -3,11 +3,10 @@ from typing import Union from simple_websocket.ws import Server as WSServer +from stt import logger from vosk import KaldiRecognizer, Model from websockets.legacy.server import WebSocketServerProtocol -from stt import logger - async def wssDecode(ws: WebSocketServerProtocol, model: Model): """Async Decode function endpoint""" diff --git a/stt/processing/utils.py b/kaldi/stt/processing/utils.py similarity index 84% rename from stt/processing/utils.py rename to kaldi/stt/processing/utils.py index b81cc5d..eb3349d 100644 --- a/stt/processing/utils.py +++ b/kaldi/stt/processing/utils.py @@ -1,16 +1,16 @@ import io import wavio -from numpy import int16, squeeze, mean +from numpy import int16, mean, squeeze -def load_wave(file_path): +def load_audiofile(file_path): """Formats audio from a wavFile buffer to a bytebuffer""" audio = squeeze(wavio.read(file_path).data) - return audio.tobytes() + return (audio.tobytes(), 16000) -def formatAudio(file_buffer): +def load_wave_buffer(file_buffer): """Formats audio from a wavFile buffer to a numpy array for processing.""" file_buffer_io = io.BytesIO(file_buffer) file_content = wavio.read(file_buffer_io) diff --git a/whisper/.envdefault b/whisper/.envdefault new file mode 100644 index 0000000..88c27ea --- /dev/null +++ b/whisper/.envdefault @@ -0,0 +1,39 @@ +############################################ +# SERVING PARAMETERS +############################################ +# "http" or "task" +SERVICE_MODE=http + +# Below: used when SERVICE_MODE=task +SERVICE_NAME=stt +SERVICES_BROKER=redis://172.17.0.1:6379 +BROKER_PASS= + +############################################ +# STT MODELING PARAMETERS +############################################ + +# The model can be a path to a model, or a model name ("tiny", "base", "small", "medium", "large-v1", "large-v2" or "large-v3") +MODEL=medium + +# The language can be in different formats: "en", "en-US", "English", ... +# If not set or set to "*", the language will be detected automatically. +LANGUAGE=* + +# An alignment wav2vec model can be used to get word timestamps. +# It can be a path to a model, a language code (fr, en, ...), or "wav2vec" to automatically chose a model for the language +# This option is experimental (and not implemented with ctranslate2). +# ALIGNMENT_MODEL=wav2vec + +############################################ +# EFFICIENCY PARAMETERS +############################################ + +# Device to use. It can be "cuda" to force/check GPU, "cpu" to force computation on CPU, or a specific GPU ("cuda:0", "cuda:1", ...) +# DEVICE=cuda:0 + +# Number of threads per worker when running on CPU +OMP_NUM_THREADS=4 + +# Number of workers +CONCURRENCY=2 diff --git a/whisper/Dockerfile.ctranslate2 b/whisper/Dockerfile.ctranslate2 new file mode 100644 index 0000000..52fbc44 --- /dev/null +++ b/whisper/Dockerfile.ctranslate2 @@ -0,0 +1,23 @@ +FROM ghcr.io/opennmt/ctranslate2:latest-ubuntu20.04-cuda11.2 +LABEL maintainer="jlouradour@linagora.com" + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends ffmpeg git + +# Install python dependencies +COPY whisper/requirements.ctranslate2.txt ./ +RUN pip install --no-cache-dir -r requirements.ctranslate2.txt && rm requirements.ctranslate2.txt + +WORKDIR /usr/src/app + +COPY celery_app /usr/src/app/celery_app +COPY http_server /usr/src/app/http_server +COPY websocket /usr/src/app/websocket +COPY document /usr/src/app/document +COPY whisper/stt /usr/src/app/stt +COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ + +ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" + +HEALTHCHECK CMD ./healthcheck.sh + +ENTRYPOINT ["./docker-entrypoint.sh"] \ No newline at end of file diff --git a/whisper/Dockerfile.ctranslate2.cpu b/whisper/Dockerfile.ctranslate2.cpu new file mode 100644 index 0000000..c8d6972 --- /dev/null +++ b/whisper/Dockerfile.ctranslate2.cpu @@ -0,0 +1,23 @@ +FROM python:3.9 +LABEL maintainer="jlouradour@linagora.com" + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends ffmpeg git + +# Install python dependencies +COPY whisper/requirements.ctranslate2.txt ./ +RUN pip install --no-cache-dir -r requirements.ctranslate2.txt && rm requirements.ctranslate2.txt + +WORKDIR /usr/src/app + +COPY celery_app /usr/src/app/celery_app +COPY http_server /usr/src/app/http_server +COPY websocket /usr/src/app/websocket +COPY document /usr/src/app/document +COPY whisper/stt /usr/src/app/stt +COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ + +ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" + +HEALTHCHECK CMD ./healthcheck.sh + +ENTRYPOINT ["./docker-entrypoint.sh"] \ No newline at end of file diff --git a/whisper/Dockerfile.torch b/whisper/Dockerfile.torch new file mode 100644 index 0000000..2f3a0d0 --- /dev/null +++ b/whisper/Dockerfile.torch @@ -0,0 +1,23 @@ +FROM python:3.9 +LABEL maintainer="jlouradour@linagora.com" + +RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg + +# Install python dependencies +COPY whisper/requirements.torch.txt ./ +RUN pip install --no-cache-dir -r requirements.torch.txt && rm requirements.torch.txt + +WORKDIR /usr/src/app + +COPY celery_app /usr/src/app/celery_app +COPY http_server /usr/src/app/http_server +COPY websocket /usr/src/app/websocket +COPY document /usr/src/app/document +COPY whisper/stt /usr/src/app/stt +COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ + +ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" + +HEALTHCHECK CMD ./healthcheck.sh + +ENTRYPOINT ["./docker-entrypoint.sh"] \ No newline at end of file diff --git a/whisper/Dockerfile.torch.cpu b/whisper/Dockerfile.torch.cpu new file mode 100644 index 0000000..e9198d5 --- /dev/null +++ b/whisper/Dockerfile.torch.cpu @@ -0,0 +1,29 @@ +FROM python:3.9 +LABEL maintainer="jlouradour@linagora.com" + +RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg + +# Force CPU versions of torch +RUN pip3 install \ + torch==1.13.1+cpu \ + torchaudio==0.13.1+cpu \ + -f https://download.pytorch.org/whl/torch_stable.html + +# Install python dependencies +COPY whisper/requirements.torch.txt ./ +RUN pip install --no-cache-dir -r requirements.torch.txt && rm requirements.torch.txt + +WORKDIR /usr/src/app + +COPY celery_app /usr/src/app/celery_app +COPY http_server /usr/src/app/http_server +COPY websocket /usr/src/app/websocket +COPY document /usr/src/app/document +COPY whisper/stt /usr/src/app/stt +COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ + +ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" + +HEALTHCHECK CMD ./healthcheck.sh + +ENTRYPOINT ["./docker-entrypoint.sh"] \ No newline at end of file diff --git a/whisper/README.md b/whisper/README.md new file mode 100644 index 0000000..20a3c7d --- /dev/null +++ b/whisper/README.md @@ -0,0 +1,282 @@ +# LinTO-STT-Whisper + +LinTO-STT-Whisper is the transcription service within the [LinTO stack](https://github.com/linto-ai/linto-platform-stack) +based on Speech-To-Text (STT) [Whisper models](https://openai.com/research/whisper). + +LinTO-STT-Whisper can either be used as a standalone transcription service or deployed within a micro-services infrastructure using a message broker connector. + +## Pre-requisites + +### Hardware +To run the transcription models you'll need: +* At least 8Go of disk space to build the docker image. +* Up to 7GB of RAM depending on the model used. +* One CPU per worker. Inference time scales on CPU performances. + +### Model(s) + +LinTO-STT-Whisper works with a Whisper model to perform Automatic Speech Recognition. +If not downloaded already, the model will be downloaded when calling the first transcription, +and can occupy several GB of disk space. + +#### Optional alignment model (deprecated) + +LinTO-STT-Whisper has also the option to work with a wav2vec model to perform word alignment. +The wav2vec model can be specified either +* (TorchAudio) with a string corresponding to a `torchaudio` pipeline (e.g. "WAV2VEC2_ASR_BASE_960H") or +* (HuggingFace's Transformers) with a string corresponding to a HuggingFace repository of a wav2vec model (e.g. "jonatasgrosman/wav2vec2-large-xlsr-53-english"), or +* (SpeechBrain) with a path corresponding to a folder with a SpeechBrain model + +Default wav2vec models are provided for French (fr), English (en), Spanish (es), German (de), Dutch (nl), Japanese (ja), Chinese (zh). + +But we advise not to use a companion wav2vec alignment model. +This is not needed neither tested anymore. + +### Docker +The transcription service requires docker up and running. + +### (micro-service) Service broker and shared folder +The STT only entry point in task mode are tasks posted on a message broker. Supported message broker are RabbitMQ, Redis, Amazon SQS. +On addition, as to prevent large audio from transiting through the message broker, STT-Worker use a shared storage folder (SHARED_FOLDER). + +## Deploy LinTO-STT-Whisper + +### 1- First step is to build or pull the image + +```bash +git clone https://github.com/linto-ai/linto-stt.git +cd linto-stt +docker build . -f whisper/Dockerfile.ctranslate2 -t linto-stt-whisper:latest +``` +or + +```bash +docker pull lintoai/linto-stt-whisper +``` + +### 2- Fill the .env + +```bash +cp whisper/.envdefault whisper/.env +``` + +| PARAMETER | DESCRIPTION | EXEMPLE | +|---|---|---| +| SERVICE_MODE | STT serving mode see [Serving mode](#serving-mode) | `http` \| `task` | +| MODEL | Path to a Whisper model, type of Whisper model used, or HuggingFace identifier of a Whisper model. | \ \| `large-v3` \| `distil-whisper/distil-large-v2` \| ... | +| LANGUAGE | (Optional) Language to recognize | `*` \| `fr` \| `fr-FR` \| `French` \| `en` \| `en-US` \| `English` \| ... | +| PROMPT | (Optional) Prompt to use for the Whisper model | `some free text to encourage a certain transcription style (disfluencies, no punctuation, ...)` | +| ALIGNMENT_MODEL | (Optional) Path to the wav2vec model for word alignment, or name of HuggingFace repository or torchaudio pipeline | \ \| `WAV2VEC2_ASR_BASE_960H` \| `jonatasgrosman/wav2vec2-large-xlsr-53-english` \| ... | +| CONCURRENCY | Maximum number of parallel requests | `3` | +| SERVICE_NAME | (For the task mode) queue's name for task processing | `my-stt` | +| SERVICE_BROKER | (For the task mode) URL of the message broker | `redis://my-broker:6379` | +| BROKER_PASS | (For the task mode only) broker password | `my-password` | + +#### MODEL environment variable + +**Warning:** +The model will be (downloaded if required and) loaded in memory when calling the first transcription. +When using a Whisper model from Hugging Face (transformers) along with ctranslate2 (faster_whisper), +it will also download torch library to make the conversion from torch to ctranslate2. + +If you want to preload the model (and later specify a path `ASR_PATH` as `MODEL`), +you may want to download one of OpenAI Whisper models: +* Mutli-lingual Whisper models can be downloaded with the following links: + * [tiny](https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt) + * [base](https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt) + * [small](https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt) + * [medium](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt) + * [large-v1](https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt) + * [large-v2](https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt) + * [large-v3](https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt) +* Whisper models specialized for English can also be found here: + * [tiny.en](https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt) + * [base.en](https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt) + * [small.en](https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt) + * [medium.en](https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt) + +If you already used Whisper in the past locally using [OpenAI-Whipser](https://github.com/openai/whisper), models can be found under ~/.cache/whisper. + +The same apply for Whisper models from Hugging Face (transformers), as for instance https://huggingface.co/distil-whisper/distil-large-v2 +(you can either download the model or use the Hugging Face identifier `distil-whisper/distil-large-v2`). + +#### LANGUAGE + +If `*` is used for the `LANGUAGE` environment variable, or if `LANGUAGE` is not defined, +automatic language detection will be performed by Whisper. + +The language can be a code of two or three letters. The list of languages supported by Whisper are: +``` +af(afrikaans), am(amharic), ar(arabic), as(assamese), az(azerbaijani), +ba(bashkir), be(belarusian), bg(bulgarian), bn(bengali), bo(tibetan), br(breton), bs(bosnian), +ca(catalan), cs(czech), cy(welsh), da(danish), de(german), el(greek), en(english), es(spanish), +et(estonian), eu(basque), fa(persian), fi(finnish), fo(faroese), fr(french), gl(galician), +gu(gujarati), ha(hausa), haw(hawaiian), he(hebrew), hi(hindi), hr(croatian), ht(haitian creole), +hu(hungarian), hy(armenian), id(indonesian), is(icelandic), it(italian), ja(japanese), +jw(javanese), ka(georgian), kk(kazakh), km(khmer), kn(kannada), ko(korean), la(latin), +lb(luxembourgish), ln(lingala), lo(lao), lt(lithuanian), lv(latvian), mg(malagasy), mi(maori), +mk(macedonian), ml(malayalam), mn(mongolian), mr(marathi), ms(malay), mt(maltese), my(myanmar), +ne(nepali), nl(dutch), nn(nynorsk), no(norwegian), oc(occitan), pa(punjabi), pl(polish), +ps(pashto), pt(portuguese), ro(romanian), ru(russian), sa(sanskrit), sd(sindhi), si(sinhala), +sk(slovak), sl(slovenian), sn(shona), so(somali), sq(albanian), sr(serbian), su(sundanese), +sv(swedish), sw(swahili), ta(tamil), te(telugu), tg(tajik), th(thai), tk(turkmen), tl(tagalog), +tr(turkish), tt(tatar), uk(ukrainian), ur(urdu), uz(uzbek), vi(vietnamese), yi(yiddish), +yo(yoruba), zh(chinese) +``` +and also `yue(cantonese)` since large-v3. + +### Serving mode +![Serving Modes](https://i.ibb.co/qrtv3Z6/platform-stt.png) + +STT can be used in two ways: +* Through an [HTTP API](#http-server) using the **http**'s mode. +* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode. + +Mode is specified using the .env value or environment variable ```SERVING_MODE```. +```bash +SERVICE_MODE=http +``` +### HTTP Server +The HTTP serving mode deploys a HTTP server and a swagger-ui to allow transcription request on a dedicated route. + +The SERVICE_MODE value in the .env should be set to ```http```. + +```bash +docker run --rm \ +-p HOST_SERVING_PORT:80 \ +-v ASR_PATH:/opt/model.pt \ +--env-file whisper/.env \ +linto-stt-whisper:latest +``` + +This will run a container providing an [HTTP API](#http-api) binded on the host HOST_SERVING_PORT port. + +You may also want to mount your cache folder CACHE_PATH (e.g. "~/.cache") ```-v CACHE_PATH:/root/.cache``` +in order to avoid downloading models each time. + +Also if you want to specifiy a custom alignment model already downloaded in a folder WAV2VEC_PATH, +you can add option ```-v WAV2VEC_PATH:/opt/wav2vec``` and environment variable ```ALIGNMENT_MODEL=/opt/wav2vec```. + +**Parameters:** +| Variables | Description | Example | +|:-|:-|:-| +| HOST_SERVING_PORT | Host serving port | 8080 | +| ASR_PATH | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt | +| CACHE_PATH | (Optional) Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache | +| WAV2VEC_PATH | (Optional) Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | + +### Micro-service within LinTO-Platform stack +The TASK serving mode connect a celery worker to a message broker. + +The SERVICE_MODE value in the .env should be set to ```task```. + +You need a message broker up and running at MY_SERVICE_BROKER. + +```bash +docker run --rm \ +-v ASR_PATH:/opt/model.pt \ +-v SHARED_AUDIO_FOLDER:/opt/audio \ +--env-file whisper/.env \ +linto-stt-whisper:latest +``` + +You may also want to mount your cache folder CACHE_PATH (e.g. "~/.cache") ```-v CACHE_PATH:/root/.cache``` +in order to avoid downloading models each time. + +Also if you want to specifiy a custom alignment model already downloaded in a folder WAV2VEC_PATH, +you can add option ```-v WAV2VEC_PATH:/opt/wav2vec``` and environment variable ```ALIGNMENT_MODEL=/opt/wav2vec```. + +**Parameters:** +| Variables | Description | Example | +|:-|:-|:-| +| SHARED_AUDIO_FOLDER | Shared audio folder mounted to /opt/audio | /my/path/to/models/vosk-model | +| ASR_PATH | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt | +| CACHE_PATH | (Optional) Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache | +| WAV2VEC_PATH | (Optional) Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | + + +## Usages +### HTTP API +#### /healthcheck +Returns the state of the API + +Method: GET + +Returns "1" if healthcheck passes. + +#### /transcribe +Transcription API + +* Method: POST +* Response content: text/plain or application/json +* File: An Wave file 16b 16Khz + +Return the transcripted text using "text/plain" or a json object when using "application/json" structure as followed: +```json +{ + "text" : "This is the transcription as text", + "words": [ + { + "word" : "This", + "start": 0.0, + "end": 0.124, + "conf": 0.82341 + }, + ... + ], + "confidence-score": 0.879 +} +``` + +#### /docs +The /docs route offers a OpenAPI/swagger interface. + +### Through the message broker + +STT-Worker accepts requests with the following arguments: +```file_path: str, with_metadata: bool``` + +* file_path: Is the location of the file within the shared_folder. /.../SHARED_FOLDER/{file_path} +* with_metadata: If True, words timestamps and confidence will be computed and returned. If false, the fields will be empty. + +#### Return format +On a successfull transcription the returned object is a json object structured as follow: +```json +{ + "text" : "This is the transcription as text", + "words": [ + { + "word" : "This", + "start": 0.0, + "end": 0.124, + "conf": 0.82341 + }, + ... + ], + "confidence-score": 0.879 +} +``` + +* The text field contains the raw transcription. +* The word field contains each word with their time stamp and individual confidence. (Empty if with_metadata=False) +* The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False) + + +## Test +### Curl +You can test you http API using curl: +```bash +curl -X POST "http://YOUR_SERVICE:YOUR_PORT/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@YOUR_FILE;type=audio/x-wav" +``` + +## License +This project is developped under the AGPLv3 License (see LICENSE). + +## Acknowlegment. + +* [Faster Whisper](https://github.com/SYSTRAN/faster-whisper) +* [OpenAI Whisper](https://github.com/openai/whisper) +* [Ctranslate2](https://github.com/OpenNMT/CTranslate2) +* [SpeechBrain](https://github.com/speechbrain/speechbrain) +* [TorchAudio](https://github.com/pytorch/audio) +* [HuggingFace Transformers](https://github.com/huggingface/transformers) \ No newline at end of file diff --git a/whisper/RELEASE.md b/whisper/RELEASE.md new file mode 100644 index 0000000..2967139 --- /dev/null +++ b/whisper/RELEASE.md @@ -0,0 +1,3 @@ +# 1.0.0 +- First build of linto-stt-whisper +- Based on 4.0.5 of linto-stt https://github.com/linto-ai/linto-stt/blob/a54b7b7ac2bc491a1795bb6dfb318a39c8b76d63/RELEASE.md diff --git a/whisper/docker-entrypoint.sh b/whisper/docker-entrypoint.sh new file mode 100755 index 0000000..97a3804 --- /dev/null +++ b/whisper/docker-entrypoint.sh @@ -0,0 +1,51 @@ +#!/bin/bash +set -a + +echo "RUNNING STT" + +# Check model +echo "Checking model format ..." +if [ -z "$MODEL" ] +then + echo "Model type not specified, choosing Whisper medium model" + export MODEL=medium +fi + +# Launch parameters, environement variables and dependencies check +if [ -z "$SERVICE_MODE" ] +then + echo "ERROR: Must specify a serving mode: [ http | task | websocket ]" + exit -1 +else + if [ "$SERVICE_MODE" = "http" ] + then + echo "RUNNING STT HTTP SERVER" + python3 http_server/ingress.py --debug + elif [ "$SERVICE_MODE" == "task" ] + then + if [[ -z "$SERVICES_BROKER" ]] + then + echo "ERROR: SERVICES_BROKER variable not specified, cannot start celery worker." + exit -1 + fi + nvidia-smi 2> /dev/null > /dev/null + if [ $? -eq 0 ];then + echo "GPU detected" + GPU=1 + OPT="--pool=solo" + else + echo "No GPU detected" + GPU=0 + OPT="" + fi + /usr/src/app/wait-for-it.sh $(echo $SERVICES_BROKER | cut -d'/' -f 3) --timeout=20 --strict -- echo " $SERVICES_BROKER (Service Broker) is up" || exit 1 + echo "RUNNING STT CELERY WORKER" + celery --app=celery_app.celeryapp worker $OPT -Ofair --queues=${SERVICE_NAME} -c ${CONCURRENCY} -n ${SERVICE_NAME}_worker@%h + + else + echo "ERROR: Wrong serving command: $SERVICE_MODE" + exit -1 + fi +fi + +echo "Service stopped" \ No newline at end of file diff --git a/whisper/requirements.ctranslate2.txt b/whisper/requirements.ctranslate2.txt new file mode 100644 index 0000000..2ddc118 --- /dev/null +++ b/whisper/requirements.ctranslate2.txt @@ -0,0 +1,15 @@ +celery[redis,auth,msgpack]>=4.4.7 +flask>=1.1.2 +flask-cors>=3.0.10 +flask-sock +flask-swagger-ui>=3.36.0 +gevent +gunicorn +lockfile +pyyaml>=5.4.1 +requests>=2.26.0 +wavio>=0.0.4 +websockets +#faster_whisper==0.10.0 +# This is version faster_whisper==0.9.0 + prompt propagation + fix for large-v3 +git+https://github.com/linto-ai/faster-whisper.git@aad9e7508b528e79be2a9975ac79ef8317f02a6d \ No newline at end of file diff --git a/whisper/requirements.torch.txt b/whisper/requirements.torch.txt new file mode 100644 index 0000000..75e747c --- /dev/null +++ b/whisper/requirements.torch.txt @@ -0,0 +1,19 @@ +celery[redis,auth,msgpack]>=4.4.7 +flask>=1.1.2 +flask-cors>=3.0.10 +flask-sock +flask-swagger-ui>=3.36.0 +gevent +gunicorn +lockfile +num2words +pyyaml>=5.4.1 +requests>=2.26.0 +speechbrain +transformers +wavio>=0.0.4 +websockets +# openai-whisper +git+https://github.com/linto-ai/whisper-timestamped.git +onnxruntime +torchaudio \ No newline at end of file diff --git a/whisper/stt/__init__.py b/whisper/stt/__init__.py new file mode 100644 index 0000000..f5551af --- /dev/null +++ b/whisper/stt/__init__.py @@ -0,0 +1,38 @@ +import logging +import os + +logging.basicConfig( + format="[%(asctime)s,%(msecs)03d %(name)s] %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger("__stt__") + +# The following is to have GPU in the right order (as nvidia-smi show them) +# It is important to set that before loading ctranslate2 +# see https://github.com/guillaumekln/faster-whisper/issues/150 +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # GPU in the right order + +try: + import faster_whisper + + USE_CTRANSLATE2 = True +except ImportError as err: + try: + import whisper + except: + raise err + USE_CTRANSLATE2 = False + +try: + import torch + + USE_TORCH = True +except ImportError: + USE_TORCH = False + +try: + import torchaudio + + USE_TORCHAUDIO = True +except ImportError: + USE_TORCHAUDIO = False diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py new file mode 100644 index 0000000..b0e7f6d --- /dev/null +++ b/whisper/stt/processing/__init__.py @@ -0,0 +1,80 @@ +import logging +import os + +from lockfile import FileLock +from stt import USE_CTRANSLATE2, logger + +from .alignment_model import get_alignment_model, load_alignment_model +from .decoding import decode +from .load_model import load_whisper_model +from .utils import get_device, get_language, load_audiofile, load_wave_buffer + +__all__ = [ + "logger", + "decode", + "load_audiofile", + "load_wave_buffer", + "MODEL", + "USE_GPU", +] + + +class LazyLoadedModel: + def __init__(self, model_type, device): + self.model_type = model_type + self.device = device + self._model = None + + def check_loaded(self): + if self._model is None: + lockfile = os.path.basename(self.model_type) + with FileLock(lockfile): + self._model = load_whisper_model(self.model_type, device=self.device) + + def __getattr__(self, name): + self.check_loaded() + return getattr(self._model, name) + + def __call__(self, *args, **kwargs): + self.check_loaded() + return self._model(*args, **kwargs) + + +# Set informative log +logger.setLevel(logging.INFO) + +# Set device +device, USE_GPU = get_device() +logger.info(f"Using device {device}") + +# Check language +language = get_language() +logger.info(f"Using language {language}") + +# Load ASR model +model_type = os.environ.get("MODEL", "medium") +logger.info( + f"Loading Whisper model {model_type} ({'local' if os.path.exists(model_type) else 'remote'})..." +) +try: + model = LazyLoadedModel(model_type, device=device) + # model = load_whisper_model(model_type, device=device) +except Exception as err: + raise Exception("Failed to load transcription model: {}".format(str(err))) from err + +# Load alignment model (if any) +alignment_model = get_alignment_model(os.environ.get("alignment_model"), language) +if alignment_model: + logger.info( + f"Loading alignment model {alignment_model} ({'local' if os.path.exists(alignment_model) else 'remote'})..." + ) + alignment_model = load_alignment_model(alignment_model, device=device, download_root="/opt") +elif alignment_model is None: + logger.info("Alignment will be done using Whisper cross-attention weights") +else: + logger.info( + "No alignment model preloaded. It will be loaded on the fly depending on the detected language." + ) + alignment_model = {} # Alignement model(s) will be loaded on the fly + +MODEL = (model, alignment_model) diff --git a/whisper/stt/processing/alignment_model.py b/whisper/stt/processing/alignment_model.py new file mode 100644 index 0000000..ea958db --- /dev/null +++ b/whisper/stt/processing/alignment_model.py @@ -0,0 +1,409 @@ +import math +import os +import time + +import requests +from stt import USE_TORCH, USE_TORCHAUDIO, logger + +from .utils import LANGUAGES, SAMPLE_RATE + +if USE_TORCH: + import torch + import torch.nn.utils.rnn as rnn_utils + + try: + import huggingface_hub + import speechbrain as sb + except ImportError: + pass + try: + import transformers + except ImportError: + pass + +if USE_TORCHAUDIO: + import torchaudio + +################################################################################ +# Load models + +# Sources: +# * https://github.com/m-bain/whisperX (in whisperx/transcribe.py) +# * https://pytorch.org/audio/stable/pipelines.html +# * https://huggingface.co/jonatasgrosman + +ALIGNMENT_MODELS = { + "en": "WAV2VEC2_ASR_BASE_960H", + # "en": "jonatasgrosman/wav2vec2-large-xlsr-53-english", + "fr": "VOXPOPULI_ASR_BASE_10K_FR", + # "fr": "jonatasgrosman/wav2vec2-large-xlsr-53-french", + "de": "VOXPOPULI_ASR_BASE_10K_DE", + # "de": "jonatasgrosman/wav2vec2-large-xlsr-53-german", + "es": "VOXPOPULI_ASR_BASE_10K_ES", + # "it": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", + "it": "VOXPOPULI_ASR_BASE_10K_IT", + # "it": "jonatasgrosman/wav2vec2-large-xlsr-53-italian", + "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", + "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", + "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", + "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish", + "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", + "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", + "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", + "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", + "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", + "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", + "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", + "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", + "vi": "nguyenvulebinh/wav2vec2-base-vietnamese-250h", +} + + +def get_alignment_model(alignment_model_name, language, force=False): + if alignment_model_name in ["wav2vec", "wav2vec2"]: + if language is None: + # Will load alignment model on the fly depending + # on detected language + return {} + elif language in ALIGNMENT_MODELS: + return ALIGNMENT_MODELS[language] + elif force: + raise ValueError(f"No wav2vec alignment model for language '{language}'.") + else: + logger.warn( + f"No wav2vec alignment model for language '{language}'. Fallback to English." + ) + return ALIGNMENT_MODELS["en"] + elif alignment_model_name in LANGUAGES.keys(): + return get_alignment_model("wav2vec", alignment_model_name, force=True) + return alignment_model_name + + +def load_alignment_model(source, device="cpu", download_root="/opt"): + if not USE_TORCH: + raise NotImplementedError("Alignement model not available without Torch") + + start = time.time() + + if (source in torchaudio.pipelines.__all__) if USE_TORCHAUDIO else False: + model = load_torchaudio_model(source, device=device, download_root=download_root) + else: + try: + model = load_transformers_model(source, device=device, download_root=download_root) + except Exception as err1: + try: + model = load_speechbrain_model(source, device=device, download_root=download_root) + except Exception as err2: + raise Exception( + f"Failed to load alignment model:\n<<< transformers <<<\n{str(err1)}\n<<< speechbrain <<<\n{str(err2)}" + ) from err2 + + logger.info( + f"Alignment Model of type {get_model_type(model)} loaded. (t={time.time() - start}s)" + ) + + return model + + +def load_speechbrain_model(source, device="cpu", download_root="/opt"): + if os.path.isdir(source): + yaml_file = os.path.join(source, "hyperparams.yaml") + assert os.path.isfile(yaml_file), f"Hyperparams file {yaml_file} not found" + else: + try: + yaml_file = huggingface_hub.hf_hub_download( + repo_id=source, + filename="hyperparams.yaml", + cache_dir=os.path.join(download_root, "huggingface/hub"), + ) + except requests.exceptions.HTTPError: + yaml_file = None + overrides = make_yaml_overrides( + yaml_file, {"save_path": os.path.join(download_root, "speechbrain")} + ) + + savedir = os.path.join(download_root, "speechbrain") + try: + model = sb.pretrained.EncoderASR.from_hparams( + source=source, run_opts={"device": device}, savedir=savedir, overrides=overrides + ) + except ValueError: + model = sb.pretrained.EncoderDecoderASR.from_hparams( + source=source, run_opts={"device": device}, savedir=savedir, overrides=overrides + ) + + model.train(False) + model.requires_grad_(False) + return model + + +def load_transformers_model(source, device="cpu", download_root="/opt"): + model = transformers.Wav2Vec2ForCTC.from_pretrained(source).to(device) + processor = transformers.Wav2Vec2Processor.from_pretrained(source) + + model.eval() + model.requires_grad_(False) + return model, processor + + +def load_torchaudio_model(source, device="cpu", download_root="/opt"): + bundle = torchaudio.pipelines.__dict__[source] + model = bundle.get_model().to(device) + labels = bundle.get_labels() + + model.eval() + model.requires_grad_(False) + return model, labels + + +def get_model_type(model): + if not isinstance(model, tuple): + return "speechbrain" + assert len(model) == 2, "Invalid model type" + if isinstance(model[0], transformers.Wav2Vec2ForCTC): + return "transformers" + return "torchaudio" + + +def make_yaml_overrides(yaml_file, key_values): + """ + return a dictionary of overrides to be used with speechbrain (hyperyaml files) + yaml_file: path to yaml file + key_values: dict of key values to override + """ + if yaml_file is None: + return None + + override = {} + with open(yaml_file, "r") as f: + parent = None + for line in f: + if line.strip() == "": + parent = None + elif line == line.lstrip(): + if ":" in line: + parent = line.split(":")[0].strip() + if parent in key_values: + override[parent] = key_values[parent] + elif ":" in line: + child = line.strip().split(":")[0].strip() + if child in key_values: + override[parent] = override.get(parent, {}) | {child: key_values[child]} + return override + + +################################################################################ +# Get list of labels (and blank_id) from model + + +def get_vocab(model): + type = get_model_type(model) + if type == "speechbrain": + labels, blank_id = get_vocab_speechbrain(model) + elif type == "transformers": + labels, blank_id = get_vocab_transformers(model) + else: + labels, blank_id = get_vocab_torchaudio(model) + assert isinstance(labels, list) and min( + [isinstance(l, str) for l in labels] + ), "labels must be a list of strings" + return norm_labels(labels, blank_id), blank_id + + +def get_vocab_speechbrain(model): + tokenizer = model.tokenizer + # Is this general enough? + labels = [ + {"": " ", " ⁇ ": ""}.get(i, i) + for i in tokenizer.decode([[i] for i in range(tokenizer.get_piece_size())]) + ] + blank_id = labels.index("") + return labels, blank_id + + +def get_vocab_torchaudio(model_and_labels): + _, labels = model_and_labels + labels = list(labels) + # WTF : blank_id = labels.index("-") ...? Is it general enough? + blank_id = 0 + return labels, blank_id + + +def get_vocab_transformers(model_and_processor): + _, processor = model_and_processor + labels_dict = dict((v, k) for k, v in processor.tokenizer.get_vocab().items()) + labels = [labels_dict[i] for i in range(len(labels_dict))] + blank_id = labels.index("") + return labels, blank_id + + +def norm_labels(labels, blank_id): + labels[blank_id] = "" + return [l if l != "|" else " " for l in labels] + + +################################################################################ +# Compute log-probabilities from model + + +# The following limit is to handle the corner Case of too long audio segment (which is better to split it to avoid memory overflow). +# But it is 2240400 / 16000 Hz ~ 140 seconds, which should not happen for segments detected by Whisper (usually one sentence). +# Also note that Whisper works with 30 seconds segment, so there is chance that this limit is never reached. +MAX_LEN = 2240400 + + +def compute_logprobas(model, audios, max_len=MAX_LEN): + # Single audio + if not isinstance(audios, list): + audios = [audios] + logits = compute_logprobas(model, audios, max_len=max_len) + return logits[0] + + # Batch of audios (can occur when max_len is reached) + assert len(audios) > 0, "audios must be a non-empty list" + + type = get_model_type(model) + if type == "speechbrain": + logits = compute_logits_speechbrain(model, audios, max_len) + elif type == "transformers": + logits = compute_logits_transformers(model, audios, max_len) + else: + logits = compute_logits_torchaudio(model, audios, max_len) + + return torch.log_softmax(logits, dim=-1) + + +def compute_logits_speechbrain(model, audios, max_len): + if not isinstance(audios[0], torch.Tensor): + audios = [torch.from_numpy(a) for a in audios] + if max([len(a) for a in audios]) > max_len: + # Split audios into chunks of max_len + batch_size = len(audios) + chunks = [] + i_audio = [] + for a in audios: + chunks.extend([a[i : min(i + max_len, len(a))] for i in range(0, len(a), max_len)]) + i_audio.append(len(chunks)) + if len(chunks) > 1: + logger.warning( + "Audio too long, splitting into {} chunks for alignment".format(len(chunks)) + ) + # Decode chunks of audio and concatenate results + log_probas = [[] for i in range(len(audios))] + for i in range(0, len(chunks), batch_size): + chunk = chunks[i : min(i + batch_size, len(chunks))] + log_probas_tmp = compute_logits_speechbrain(model, chunk) + for j in range(i, i + len(chunk)): + k = 0 + while j >= i_audio[k]: + k += 1 + log_probas[k].append(log_probas_tmp[j - i]) + log_probas = [torch.cat(p, dim=0) for p in log_probas] + log_probas, wav_lens = pack_sequences(log_probas, device=model.device) + else: + batch, wav_lens = pack_sequences(audios, device=model.device) + log_probas = model.forward(batch, wav_lens) + + return log_probas.cpu().detach() + + +def pack_sequences(tensors, device="cpu"): + if len(tensors) == 1: + return tensors[0].unsqueeze(0).to(device), torch.Tensor([1.0]).to(device) + tensor = rnn_utils.pad_sequence(tensors, batch_first=True) + wav_lens = [len(x) for x in tensors] + maxwav_lens = max(wav_lens) + wav_lens = torch.Tensor([l / maxwav_lens for l in wav_lens]) + return tensor.to(device), wav_lens.to(device) + + +def compute_logits_transformers(model_and_processor, audios, max_len): + model, processor = model_and_processor + + # can be different from processor.feature_extractor.sampling_rate + sample_rate = SAMPLE_RATE + device = model.device + + audios = [audio.numpy() for audio in audios] + processed_batch = processor(audios, sampling_rate=sample_rate) + + padded_batch = processor.pad( + processed_batch, + padding=True, + max_length=None, + pad_to_multiple_of=None, + return_tensors="pt", + ) + + l = padded_batch.input_values.shape[1] + + use_mask = hasattr(padded_batch, "attention_mask") + + with torch.inference_mode(): + if l > max_len: + # Split batch in smaller chunks + logger.warning( + "Audio too long, splitting into {} chunks for alignment".format( + math.ceil(l / max_len) + ) + ) + logits = [] + for i in range(0, l, max_len): + j = min(i + max_len, l) + if use_mask: + logits.append( + model( + padded_batch.input_values[:, i:j].to(device), + attention_mask=padded_batch.attention_mask[:, i:j].to(device), + ).logits + ) + else: + logits.append(model(padded_batch.input_values[:, i:j].to(device)).logits) + logits = torch.cat(logits, dim=1) + elif use_mask: + logits = model( + padded_batch.input_values.to(device), + attention_mask=padded_batch.attention_mask.to(device), + ).logits + else: + logits = model(padded_batch.input_values.to(device)).logits + + return logits.cpu().detach() + + +def compute_logits_torchaudio(model_and_labels, audios, max_len): + # TODO: factorize with compute_logits_transformers, and add support for batch of audios + + model, _ = model_and_labels + + # Get the device where is running the model + device = "cpu" + for p in model.parameters(): + device = p.device + break + + all_logits = [] + + with torch.inference_mode(): + for audio in audios: + l = len(audio) + if l > max_len: + # Split audio in smaller chunks + logger.warning( + "Audio too long, splitting into {} chunks for alignment".format( + math.ceil(l / max_len) + ) + ) + logits = [] + for i in range(0, l, max_len): + j = min(i + max_len, l) + logits.append(model(audio[i:j].unsqueeze(0).to(device))[0]) + logits = torch.cat(logits, dim=1) + else: + logits, _ = model(audio.unsqueeze(0).to(device)) + + all_logits.append(logits.cpu().detach()) + + assert len(all_logits) == 1 # TODO: support batch of audios + + return all_logits[0] diff --git a/whisper/stt/processing/decoding.py b/whisper/stt/processing/decoding.py new file mode 100644 index 0000000..9f8411f --- /dev/null +++ b/whisper/stt/processing/decoding.py @@ -0,0 +1,366 @@ +import copy +import os +import time +from typing import Tuple, Union + +import numpy as np +from stt import USE_CTRANSLATE2, logger + +from .alignment_model import get_alignment_model, load_alignment_model +from .text_normalize import normalize_text, remove_emoji, remove_punctuation +from .utils import SAMPLE_RATE, get_language +from .word_alignment import compute_alignment + +if not USE_CTRANSLATE2: + import torch + import whisper_timestamped + +USE_ACCURATE = True +USE_VAD = True + +if USE_ACCURATE: + default_beam_size = 5 + default_best_of = 5 + default_temperature = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) +else: + default_beam_size = None + default_best_of = None + default_temperature = 0.0 + +default_initial_prompt = os.environ.get("PROMPT", None) + + +def decode( + audio, + model_and_alignementmodel, # Tuple[model, alignment_model] + with_word_timestamps: bool, + language: str = None, + remove_punctuation_from_words=False, + beam_size: int = default_beam_size, + best_of: int = default_best_of, + temperature: Union[float, Tuple[float, ...]] = default_temperature, + condition_on_previous_text: bool = False, + no_speech_threshold: float = 0.6, + compression_ratio_threshold: float = 2.4, + initial_prompt: str = default_initial_prompt, +) -> dict: + if language is None: + language = get_language() + + kwargs = copy.copy(locals()) + kwargs.pop("model_and_alignementmodel") + kwargs["model"], kwargs["alignment_model"] = model_and_alignementmodel + + logger.info( + "Transcribing audio with " + + (f"language {language}" if language else "automatic language detection") + + "..." + ) + + start_t = time.time() + + if USE_CTRANSLATE2: + kwargs.pop("alignment_model") + res = decode_ct2(**kwargs) + else: + print("OK") + res = decode_torch(**kwargs) + + logger.info("Transcription complete (t={}s)".format(time.time() - start_t)) + + return res + + +def decode_ct2( + audio, model, with_word_timestamps, language, remove_punctuation_from_words, **kwargs +): + kwargs["no_speech_threshold"] = 1 # To avoid empty output + if kwargs.get("beam_size") is None: + kwargs["beam_size"] = 1 + if kwargs.get("best_of") is None: + kwargs["best_of"] = 1 + + segments, info = model.transcribe( + audio, + word_timestamps=with_word_timestamps, + language=language, + # Careful with the following options + max_initial_timestamp=10000.0, + vad_filter=USE_VAD, + **kwargs, + ) + + segments = list(segments) + + return format_faster_whisper_response( + segments, info, remove_punctuation_from_words=remove_punctuation_from_words + ) + + +def decode_torch( + audio, + model, + alignment_model, + with_word_timestamps, + language, + remove_punctuation_from_words, + beam_size, + best_of, + temperature, + condition_on_previous_text, + no_speech_threshold, + compression_ratio_threshold, + normalize_text_as_words=False, + initial_prompt=None, +): + """Transcribe the audio data using Whisper with the defined model.""" + + fp16 = model.device != torch.device("cpu") + + kwargs = dict( + language=language, + fp16=fp16, + temperature=temperature, + beam_size=beam_size, + best_of=best_of, + condition_on_previous_text=condition_on_previous_text, + no_speech_threshold=no_speech_threshold, + compression_ratio_threshold=compression_ratio_threshold, + vad=USE_VAD, + initial_prompt=initial_prompt, + ) + + if alignment_model is None: + # Use Whisper cross-attention weights + whisper_res = whisper_timestamped.transcribe(model, audio, verbose=None, **kwargs) + if language is None: + language = whisper_res["language"] + logger.info(f"Detected language: {language}") + return format_whisper_timestamped_response( + whisper_res, remove_punctuation_from_words=remove_punctuation_from_words + ) + + # Force deterministic results + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + + whisper_res = model.transcribe(audio, verbose=None, **kwargs) + + text = whisper_res["text"] + text = remove_emoji(text).strip() + if normalize_text_as_words: + text = normalize_text(text, language) + if remove_punctuation_from_words: + text = remove_punctuation(text) + segments = whisper_res["segments"] + if language is None: + language = whisper_res["language"] + logger.info(f"Detected language: {language}") + if isinstance(alignment_model, dict): + # Load alignment model on the fly + if language not in alignment_model: + alignment_model_name = get_alignment_model(language) + logger.info( + f"Loading alignment model {alignment_model_name} ({'local' if os.path.exists(alignment_model_name) else 'remote'})..." + ) + alignment_model[language] = load_alignment_model( + alignment_model_name, device=model.device, download_root="/opt" + ) + spec_alignment_model = alignment_model[language] + else: + spec_alignment_model = alignment_model + + result = {} + result["text"] = text + result["language"] = language + result["confidence-score"] = ( + np.exp(np.array([r["avg_logprob"] for r in segments])).mean() if len(segments) else 0.0 + ) + + if not with_word_timestamps: + if not normalize_text_as_words: + text = normalize_text(text, language) + if remove_punctuation_from_words: + text = remove_punctuation(text) + result["words"] = text.split() + else: + # Compute word timestamps + result["words"] = [] + max_t = audio.shape[0] + + # Ensure that the segments start / end time are increasing + # (because there is no guarantee with Whisper) + previous_start = 0.0 + for segment in segments: + if segment["start"] < previous_start: + segment["start"] = previous_start + if segment["end"] <= segment["start"]: + segment["end"] = segment["start"] + 1.0 + previous_start = segment["end"] + + for segment in segments: + offset = segment["start"] + start = min(max_t, round(segment["start"] * SAMPLE_RATE)) + end = min(max_t, round(segment["end"] * SAMPLE_RATE)) + sub_audio = audio[start:end] + sub_text = segment["text"] + logger.debug(f"Aligning text: {sub_text}") + sub_text = remove_emoji(sub_text).strip() + sub_text = normalize_text(sub_text, language) + if remove_punctuation_from_words: + sub_text = remove_punctuation(sub_text) + if not sub_text: + logger.warn(f"Lost text in segment {segment['start']}-{segment['end']}") + continue + labels, emission, trellis, segments, word_segments = compute_alignment( + sub_audio, sub_text, spec_alignment_model + ) + ratio = len(sub_audio) / (trellis.size(0) * SAMPLE_RATE) + sub_words = sub_text.split() + words = [] + use_original_words = True + if len(sub_words) != len(word_segments): + logger.warn( + f"Alignment failed. Some words might be mis-rendered.\nNumber of words: {len(sub_words)} != {len(word_segments)}\n>>>\n{sub_words}\n<<<\n{[segment.label for segment in word_segments]}" + ) + assert len(word_segments) < len(sub_words) + use_original_words = False + for word, seg in zip(sub_words, word_segments): + words.append( + { + "word": word if use_original_words else seg.label, + "start": seg.start * ratio + offset, + "end": seg.end * ratio + offset, + "conf": seg.score, + } + ) + # Glue the words inside a segment + for i, word in enumerate(words): + if i == 0: + word["start"] = segment["start"] + else: + word["start"] = words[i - 1]["end"] + if i == len(words) - 1: + word["end"] = segment["end"] + else: + word["end"] = 0.5 * (words[i + 1]["start"] + word["end"]) + # Accumulate results + result["words"] += words + + return result + + +def format_whisper_timestamped_response(transcription, remove_punctuation_from_words=False): + """Format Whisper response.""" + + for i, seg in enumerate(transcription["segments"][:-1]): + for expected_keys in ["start", "end", "words", "avg_logprob"]: + assert ( + expected_keys in seg + ), f"Missing '{expected_keys}' in segment {i} (that has keys {list(seg.keys())})" + + words = [] + + segments = transcription.get("segments", []) + + for seg in segments: + for word in seg.get("words", []): + text = word["text"] + if remove_punctuation_from_words: + text = remove_punctuation(text) + words.append( + { + "word": text, + "start": word["start"], + "end": word["end"], + "conf": word["confidence"], + } + ) + + return { + "text": transcription["text"].strip(), + "language": transcription["language"], + "confidence-score": round(np.exp(np.array([r["avg_logprob"] for r in segments])).mean(), 2) + if len(segments) + else 0.0, + "words": words, + } + + +def format_faster_whisper_response( + segments, + info, + remove_punctuation_from_words=False, + glue_punctuations="'-&@.,", +): + language = info.language + duration = info.duration + + def checked_timestamps(start, end=None): + if start > duration or (end is not None and end > duration): + print( + "WARNING, timestamp %f is greater than duration %f" + % (max(start, end if end else start), duration) + ) + if end and end <= start: + if end == start: + pass # end = start + 0.01 + else: + print("WARNING, end timestamp %f is smaller than start timestamp %f" % (end, start)) + if end is None: + return start + return (start, end) + + segments_list = [] + for segment in segments: + start, end = checked_timestamps(segment.start, segment.end) + + words = [] + if segment.words: + for word in segment.words: + start, end = checked_timestamps(word.start, word.end) + word_strip = word.word.strip() + if ( + glue_punctuations + and len(words) + and len(word_strip) > 1 + and word_strip[0] in glue_punctuations + ): + words[-1]["text"] += word.word.lstrip() + words[-1]["confidence"].append(word.probability) + words[-1]["end"] = max(words[-1]["end"], end) + continue + words.append( + { + "text": word.word, + "confidence": [word.probability], + "start": start, + "end": end, + } + ) + + for word in words: + word["text"] = word["text"].strip() + word["confidence"] = round(np.mean([c for c in word["confidence"]]), 2) + + segments_list.append( + { + "text": segment.text.strip(), + "start": start, + "end": end, + "avg_logprob": segment.avg_logprob, + "words": words, + } + ) + + transcription = { + "text": " ".join(segment["text"] for segment in segments_list), + "language": language, + "confidence": round( + np.exp(np.mean([segment["avg_logprob"] for segment in segments_list])), 2 + ), + "segments": segments_list, + } + return format_whisper_timestamped_response( + transcription, remove_punctuation_from_words=remove_punctuation_from_words + ) diff --git a/whisper/stt/processing/load_model.py b/whisper/stt/processing/load_model.py new file mode 100644 index 0000000..b87a414 --- /dev/null +++ b/whisper/stt/processing/load_model.py @@ -0,0 +1,369 @@ +import os +import shutil +import subprocess +import sys +import time + +from stt import USE_CTRANSLATE2, logger + +if USE_CTRANSLATE2: + import faster_whisper +else: + import whisper_timestamped as whisper + + +def load_whisper_model(model_type_or_file, device="cpu", download_root=None): + start = time.time() + + logger.info("Loading Whisper model {}...".format(model_type_or_file)) + + default_cache_root = os.path.join(os.path.expanduser("~"), ".cache") + if download_root is None: + download_root = default_cache_root + + if USE_CTRANSLATE2: + if not os.path.isdir(model_type_or_file): + # Note: There is no good way to set the root cache directory + # with the current version of faster_whisper: + # if "download_root" is specified to faster_whisper.WhisperModel + # (or "output_dir" in faster_whisper.utils.download_model), + # then files are downloaded directly in it without symbolic links + # to the cache directory. So it's different from the behavior + # of the huggingface_hub. + # So we try to create a symbolic link to the cache directory that will be used by HuggingFace... + if not os.path.exists(download_root): + if not os.path.exists(default_cache_root): + os.makedirs(download_root) + if default_cache_root != download_root: + os.symlink(download_root, default_cache_root) + else: + os.symlink(default_cache_root, download_root) + elif not os.path.exists(default_cache_root): + os.symlink(download_root, default_cache_root) + + if device == "cpu": + compute_types = ["int8", "float32"] + else: + compute_types = ["int8", "int8_float16", "float16", "float32"] + + device_index = 0 + if device.startswith("cuda:"): + device_index = [int(dev) for dev in device[5:].split(",")] + device = "cuda" + + if not os.path.isfile(os.path.join(model_type_or_file, "model.bin")) and not max( + [ + model_type_or_file.startswith(prefix) + for prefix in ["tiny", "base", "small", "medium", "large"] + ] + ): + # Convert transformer model + + output_dir = os.path.join( + download_root, + f"ctranslate2/converters/transformers--{model_type_or_file.replace('/', '--')}", + ) + logger.info(f"CTranslate2 model in {output_dir}") + if not os.path.isdir(output_dir): + import huggingface_hub + + delete_hf_path = False + if not os.path.isdir(model_type_or_file): + hf_path = huggingface_hub.hf_hub_download( + repo_id=model_type_or_file, filename="pytorch_model.bin" + ) + hf_path = os.path.dirname(os.path.dirname(os.path.dirname(hf_path))) + + delete_hf_path = not os.path.exists(hf_path) + else: + assert os.path.isfile( + os.path.join(model_type_or_file, "pytorch_model.bin") + ), f"Could not find pytorch_model.bin in {model_type_or_file}" + + check_torch_installed() + + # from ctranslate2.converters.transformers import TransformersConverter + # converter = TransformersConverter( + # model_type_or_file, + # activation_scales=None, # Path to the pre-computed activation scales, see https://github.com/mit-han-lab/smoothquant + # copy_files=[], # Note: "tokenizer.json" does not always exist, we will copy it separately + # load_as_float16=False, + # revision=None, + # low_cpu_mem_usage=False, + # trust_remote_code=False, + # ) + + try: + # converter.convert( + # output_dir, + # force=False + # ) + + subprocess.check_call( + [ + "ct2-transformers-converter", + "--model", + model_type_or_file, + "--output_dir", + os.path.realpath(output_dir), + "--quantization", + "float16", + ] + ) + except Exception as err: + shutil.rmtree(output_dir, ignore_errors=True) + raise err + + finally: + if delete_hf_path: + logger.info(f"Deleting {hf_path}") + shutil.rmtree(hf_path, ignore_errors=True) + + assert os.path.isdir(output_dir), f"Failed to build {output_dir}" + + model_type_or_file = output_dir + + model = None + for i, compute_type in enumerate(compute_types): + try: + model = faster_whisper.WhisperModel( + model_type_or_file, + device=device, + device_index=device_index, + compute_type=compute_type, + # cpu_threads=0, # Can be controled with OMP_NUM_THREADS + # num_workers=1, + # download_root=os.path.join(download_root, f"huggingface/hub/models--guillaumekln--faster-whisper-{model_type_or_file}"), + ) + break + except ValueError as err: + logger.info( + "WARNING: failed to load model with compute_type={}".format(compute_type) + ) + # On some old GPU we may have the error + # "ValueError: Requested int8_float16 compute type, + # but the target device or backend do not support efficient int8_float16 computation." + if i == len(compute_types) - 1: + raise err + + else: + extension = ( + os.path.splitext(model_type_or_file)[-1] if os.path.isfile(model_type_or_file) else None + ) + + if model_type_or_file in whisper.available_models() or extension == ".pt": + model = whisper.load_model( + model_type_or_file, + device=device, + download_root=os.path.join(download_root, "whisper"), + ) + + else: + # Convert HuggingFace model + import torch + + peft_folder = None + + if extension in [".ckpt", ".bin"]: + model_path = model_type_or_file + else: + # Search for the cached file (download if necessary) + if os.path.isdir(model_type_or_file): + for root, _, files in os.walk(model_type_or_file): + if "adapter_config.json" in files: + peft_folder = root + break + try: + import transformers + except ImportError: + raise ImportError( + f"If you are trying to download a HuggingFace model with {model_type_or_file}, please install first the transformers library" + ) + from transformers.utils import cached_file + + try: + model_path = cached_file( + model_type_or_file, + "pytorch_model.bin", + cache_dir=download_root, + use_auth_token=None, + revision=None, + ) + except Exception as e: + try: + if isinstance(e, OSError): + model_path = cached_file( + model_type_or_file, + "whisper.ckpt", + cache_dir=download_root, + use_auth_token=None, + revision=None, + ) + else: + raise e + except: + if peft_folder is None: + raise RuntimeError( + f"Original error: {e}\nCould not find model {model_type_or_file} from HuggingFace nor local folders." + ) + + # Load HF Model + if peft_folder is not None: + import transformers + from peft import PeftConfig, PeftModel + + peft_config = PeftConfig.from_pretrained(peft_folder) + base_model = peft_config.base_model_name_or_path + + model = transformers.WhisperForConditionalGeneration.from_pretrained(base_model) + model = PeftModel.from_pretrained(model, peft_folder) + hf_state_dict = model.state_dict() + del model + else: + hf_state_dict = torch.load(model_path, map_location="cpu") + + # Rename layers + for key in list(hf_state_dict.keys()): + new_key = hf_to_whisper_states(key) + if new_key is None: + hf_state_dict.pop(key) + elif new_key != key: + hf_state_dict[new_key] = hf_state_dict.pop(key) + + # Init Whisper Model and replace model weights + dims = whisper.model.ModelDimensions(**states_to_dim(hf_state_dict)) + if "proj_out.weight" in hf_state_dict: + hf_state_dict["decoder.proj_out.weight"] = hf_state_dict.pop("proj_out.weight") + print("WARNING: Using untied projection layer") + whisper_model = WhisperUntied(dims) + else: + whisper_model = whisper.model.Whisper(dims) + whisper_model.load_state_dict(hf_state_dict) + del hf_state_dict + whisper_model = whisper_model.to(device) + return whisper_model + + model.eval() + model.requires_grad_(False) + + logger.info("Whisper model loaded. (t={}s)".format(time.time() - start)) + + return model + + +def check_torch_installed(): + try: + import torch + except ImportError: + # Install transformers with torch + subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers[torch]>=4.23"]) + + # # Re-load ctranslate2 + # import importlib + # import ctranslate2 + # importlib.reload(ctranslate2) + # importlib.reload(ctranslate2.converters.transformers) + + # import torch + + +# Credit: https://github.com/openai/whisper/discussions/830 +def hf_to_whisper_states(text): + import re + + # From Speechbrain + if text == "_mel_filters": + return None + + # From PEFT + if "default" in text: + # print(f"WARNING: Ignoring {text}") + return None + if text.startswith("base_model.model."): + text = text[len("base_model.model.") :] + + text = re.sub(".layers.", ".blocks.", text) + text = re.sub(".self_attn.", ".attn.", text) + text = re.sub(".q_proj.", ".query.", text) + text = re.sub(".k_proj.", ".key.", text) + text = re.sub(".v_proj.", ".value.", text) + text = re.sub(".out_proj.", ".out.", text) + text = re.sub(".fc1.", ".mlp.0.", text) + text = re.sub(".fc2.", ".mlp.2.", text) + text = re.sub(".fc3.", ".mlp.3.", text) + text = re.sub(".fc3.", ".mlp.3.", text) + text = re.sub(".encoder_attn.", ".cross_attn.", text) + text = re.sub(".cross_attn.ln.", ".cross_attn_ln.", text) + text = re.sub(".embed_positions.weight", ".positional_embedding", text) + text = re.sub(".embed_tokens.", ".token_embedding.", text) + text = re.sub("model.", "", text) + text = re.sub("attn.layer_norm.", "attn_ln.", text) + text = re.sub(".final_layer_norm.", ".mlp_ln.", text) + text = re.sub("encoder.layer_norm.", "encoder.ln_post.", text) + text = re.sub("decoder.layer_norm.", "decoder.ln.", text) + return text + + +def states_to_dim(state_dict): + n_audio_state = len(state_dict["encoder.ln_post.bias"]) + n_text_state = len(state_dict["decoder.ln.bias"]) + return { + "n_mels": state_dict["encoder.conv1.weight"].shape[1], # 80 + "n_vocab": state_dict["decoder.token_embedding.weight"].shape[0], # 51864 / 51865 + "n_audio_ctx": state_dict["encoder.positional_embedding"].shape[0], # 1500 + "n_audio_state": n_audio_state, # 384 / 512 / 768 / 1024 / 1280 + "n_audio_head": n_audio_state // 64, # 6 / 8 / 12 / 16 / 20 + "n_audio_layer": len( + set([".".join(k.split(".")[:3]) for k in state_dict.keys() if "encoder.blocks." in k]) + ), # 4 / 6 / 12 / 24 / 32 + "n_text_ctx": state_dict["decoder.positional_embedding"].shape[0], # 448 + "n_text_state": n_text_state, # 384 / 512 / 768 / 1024 / 1280 + "n_text_head": n_text_state // 64, # 6 / 8 / 12 / 16 / 20 + "n_text_layer": len( + set([".".join(k.split(".")[:3]) for k in state_dict.keys() if "decoder.blocks." in k]) + ), # 4 / 6 / 12 / 24 / 32 + } + + +if not USE_CTRANSLATE2: + + class TextDecoderUntied(whisper.model.TextDecoder): + """ + Same as TextDecoder but with untied weights + """ + + def __init__(self, *args, **kwargs): + import torch + + super().__init__(*args, **kwargs) + + n_vocab, n_state = self.token_embedding.weight.shape + + self.proj_out = torch.nn.Linear(n_state, n_vocab, bias=False) + + def forward(self, x, xa, kv_cache=None): + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] + x = x.to(xa.dtype) + + for block in self.blocks: + x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + + x = self.ln(x) + + # logits = self.proj_out(x).float() + # logits = (x @ torch.transpose(self.proj_out.weight.to(x.dtype), 0, 1)).float() + logits = self.proj_out.to(x.dtype)(x).float() + + return logits + + class WhisperUntied(whisper.model.Whisper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.decoder = TextDecoderUntied( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + ) diff --git a/whisper/stt/processing/text_normalize.py b/whisper/stt/processing/text_normalize.py new file mode 100644 index 0000000..cde8f38 --- /dev/null +++ b/whisper/stt/processing/text_normalize.py @@ -0,0 +1,393 @@ +import math +import re +# import string +import unicodedata + +from stt import logger + +from .utils import flatten + +# All punctuations and symbols EXCEPT: +# * apostrophe (') and hyphen (-) +# * underscore (_) +# * currency symbols ($, €, £, ...) -> \p{Sc} +# * math symbols (%, +, ×). ex: C++ +# * misc (#, @). ex: C#, @user +# and the space character (which can separate several series of punctuation marks) +# Example of punctuations that can output models like Whisper: !,.:;?¿،؛؟…、。!,:?>/]:!(~\u200b[ா「«»“”"< ?;…,*」.)' +_punctuation_regex = r"[^\w\p{Sc}" + re.escape("'-_%+×#@&") + "]" +_leading_punctuations_regex = r"^" + _punctuation_regex + r"+" +_trailing_punctuations_regex = _punctuation_regex + r"+$" + +# A list of symbols that can be an isolated words and not in the exclusion list above +# * & +# * candidates not retained: §, <, =, >, ≤, ≥ +_maybe_word_regex = None # r"[" + re.escape("&") + r"]$" + + +def remove_punctuation(text: str, ensure_no_spaces_in_words: bool = False) -> str: + text = text.strip() + # Note: we don't remove dots inside words (e.g. "ab@gmail.com") + new_text = re.sub(_leading_punctuations_regex, "", text) # .lstrip() + new_text = re.sub(_trailing_punctuations_regex, "", new_text) # .rstrip() + # Let punctuation marks that are alone + if not new_text: + if _maybe_word_regex and re.match(_maybe_word_regex, text): + new_text = text + else: + new_text = "" + # Ensure that there is no space in the middle of a word + if ensure_no_spaces_in_words and " " in new_text: + new_text, tail = new_text.split(" ", 1) + # OK if the tail only contains non alphanumeric characters (then we just keep the first part) + assert not re.search(r"[^\W\d\'\-_]", tail), f"Got unexpected word containing space: {text}" + return remove_punctuation(new_text, ensure_no_spaces_in_words=ensure_no_spaces_in_words) + return new_text + + +def transliterate(c): + # Transliterates a character to its closest ASCII equivalent. + # Example: transliterate("à ß œ fl") = "a ss oe fl" + c = re.sub("œ", "oe", c) + c = re.sub("æ", "ae", c) + c = re.sub("Œ", "OE", c) + c = re.sub("Æ", "AE", c) + c = re.sub("ß", "ss", c) + return unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii") + + +def remove_emoji(text): + # Remove emojis + return re.sub( + r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF]+", + "", + text, + ) + + +def normalize_text(text: str, lang: str) -> str: + """Transform digits into characters...""" + + # Reorder currencies (1,20€ -> 1 € 20) + coma = "," if lang in ["fr"] else "\." + for c in _currencies: + if c in text: + text = re.sub(r"\b(\d+)" + coma + r"(\d+)\s*" + c, r"\1 " + c + r" \2", text) + + # Roman digits + if re.search(r"[IVX]", text): + if lang == "en": + digits = re.findall(r"\b(?=[XVI])M*(XX{0,3})(I[XV]|V?I{0,3})(º|st|nd|rd|th)?\b", text) + digits = ["".join(d) for d in digits] + elif lang == "fr": + digits = re.findall( + r"\b(?=[XVI])M*(XX{0,3})(I[XV]|V?I{0,3})(º|ème|eme|e|er|ère)?\b", text + ) + digits = ["".join(d) for d in digits] + else: + digits = re.findall(r"\b(?=[XVI])M*(XX{0,3})(I[XV]|V?I{0,3})\b", text) + digits = ["".join(d) for d in digits] + if digits: + digits = sorted(list(set(digits)), reverse=True, key=lambda x: (len(x), x)) + for s in digits: + filtered = re.sub("[a-zèº]", "", s) + ordinal = filtered != s + digit = roman_to_decimal(filtered) + v = undigit(str(digit), lang=lang, to="ordinal" if ordinal else "cardinal") + text = re.sub(r"\b" + s + r"\b", v, text) + + # Ordinal digits + if lang == "en": + digits = re.findall(r"\b\d*1(?:st)|\d*2(?:nd)|\d*3(?:rd)|\d+(?:º|th)\b", text) + elif lang == "fr": + digits = re.findall(r"\b1(?:ère|ere|er|re|r)|2(?:nd|nde)|\d+(?:º|ème|eme|e)\b", text) + else: + logger.warn( + f"Language {lang} not supported for some normalization. Some words might be mis-localized." + ) + digits = [] + if digits: + digits = sorted(list(set(digits)), reverse=True, key=lambda x: (len(x), x)) + for digit in digits: + word = undigit(re.findall(r"\d+", digit)[0], to="ordinal", lang=lang) + text = re.sub(r"\b" + str(digit) + r"\b", word, text) + + # Cardinal digits + digits = re.findall(r"(?:\-?\b[\d/]*\d+(?: \d\d\d)+\b)|(?:\-?\d[/\d]*)", text) + digits = list(map(lambda s: s.strip(r"[/ ]"), digits)) + digits = list(set(digits)) + digits = digits + flatten([c.split() for c in digits if " " in c]) + digits = digits + flatten([c.split("/") for c in digits if "/" in c]) + digits = sorted(digits, reverse=True, key=lambda x: (len(x), x)) + for digit in digits: + digitf = re.sub("/+", "/", digit) + if not digitf: + continue + numslash = len(re.findall("/", digitf)) + if numslash == 0: + word = undigit(digitf, lang=lang) + elif numslash == 1: # Fraction or date + i = digitf.index("/") + is_date = False + if len(digitf[i + 1 :]) == 2: + try: + first = int(digitf[:i]) + second = int(digitf[i + 1 :]) + is_date = first > 0 and first < 32 and second > 0 and second < 13 + except: + pass + if is_date: + first = digitf[:i].lstrip("0") + use_ordinal = (lang == "fr" and first == "1") or ( + lang != "fr" and first[-1] in ["1", "2", "3"] + ) + first = undigit(first, lang=lang, to="ordinal" if use_ordinal else "cardinal") + second = _int_to_month.get(lang, {}).get(second, digitf[i + 1 :]) + else: + first = undigit(digitf[:i], lang=lang) + second = undigit(digitf[i + 1 :], to="denominator", lang=lang) + if float(digitf[:i]) > 2.0 and second[-1] != "s": + second += "s" + word = first + " " + second + elif numslash == 2: # Maybe a date + i1 = digitf.index("/") + i2 = digitf.index("/", i1 + 1) + is_date = False + if len(digitf[i1 + 1 : i2]) == 2 and len(digitf[i2 + 1 :]) == 4: + try: + first = int(digitf[:i1]) + second = int(digitf[i1 + 1 : i2]) + third = int(digitf[i2 + 1 :]) + is_date = ( + first > 0 and first < 32 and second > 0 and second < 13 and third > 1000 + ) + except: + pass + third = undigit(digitf[i2 + 1 :], lang=lang) + if is_date: + first = digitf[:i1].lstrip("0") + use_ordinal = (lang == "fr" and first == "1") or ( + lang != "fr" and first[-1] in ["1", "2", "3"] + ) + first = undigit(first, lang=lang, to="ordinal" if use_ordinal else "cardinal") + second = _int_to_month.get(lang, {}).get( + int(digitf[i1 + 1 : i2]), digitf[i1 + 1 : i2] + ) + word = " ".join([first, second, third]) + else: + word = " / ".join([undigit(s, lang=lang) for s in digitf.split("/")]) + else: + word = " / ".join([undigit(s, lang=lang) for s in digitf.split("/")]) + text = replace_keeping_word_boundaries(digit, word, text) + + # Symbols (currencies, percent...) + symbol_table = _symbol_to_word.get(lang, {}) + for k, v in symbol_table.items(): + text = replace_keeping_word_boundaries(k, v, text) + + # Remove extra spaces before punctuation + # text = re.sub(r" ([\.,!:;])",r"\1",text) + + return collapse_whitespace(text) + + +def replace_keeping_word_boundaries(orig, dest, text): + if orig in text: + text = re.sub(r"(\W)" + orig + r"(\W)", r"\1" + dest + r"\2", text) + text = re.sub(orig + r"(\W)", " " + dest + r"\1", text) + text = re.sub(r"(\W)" + orig, r"\1" + dest + " ", text) + text = re.sub(orig, " " + dest + " ", text) + return text + + +def undigit(str, lang, to="cardinal"): + str = re.sub(" ", "", str) + if to == "denominator": + if lang == "fr": + if str == "2": + return "demi" + if str == "3": + return "tiers" + if str == "4": + return "quart" + elif lang == "en": + if str == "2": + return "half" + if str == "4": + return "quarter" + elif lang == "es": + if str == "2": + return "mitad" + if str == "3": + return "tercio" + to = "ordinal" + if str.startswith("0") and to == "cardinal": + numZeros = len(re.findall(r"0+", str)[0]) + if numZeros < len(str): + return numZeros * (robust_num2words(0, lang=lang) + " ") + robust_num2words( + float(str), lang=lang, to=to + ) + return robust_num2words(float(str), lang=lang, to=to) + + +def robust_num2words(x, lang, to="cardinal", orig=""): + """ + Bugfix for num2words + """ + from num2words import num2words + + try: + res = num2words(x, lang=lang, to=to) + if lang == "fr" and to == "ordinal": + res = res.replace("vingtsième", "vingtième") + return res + except OverflowError: + if x == math.inf: # ! + return " ".join(robust_num2words(xi, lang=lang, to=to) for xi in orig) + if x == -math.inf: # ! + return "moins " + robust_num2words(-x, lang=lang, to=to, orig=orig.replace("-", "")) + # TODO: print a warning + return robust_num2words(x // 10, lang=lang, to=to) + + +def roman_to_decimal(str): + def value(r): + if r == "I": + return 1 + if r == "V": + return 5 + if r == "X": + return 10 + if r == "L": + return 50 + if r == "C": + return 100 + if r == "D": + return 500 + if r == "M": + return 1000 + return -1 + + res = 0 + i = 0 + while i < len(str): + s1 = value(str[i]) + if i + 1 < len(str): + s2 = value(str[i + 1]) + if s1 >= s2: + # Value of current symbol is greater or equal to the next symbol + res = res + s1 + i = i + 1 + else: + # Value of current symbol is greater or equal to the next symbol + res = res + s2 - s1 + i = i + 2 + else: + res = res + s1 + i = i + 1 + return res + + +_int_to_month = { + "fr": { + 1: "janvier", + 2: "février", + 3: "mars", + 4: "avril", + 5: "mai", + 6: "juin", + 7: "juillet", + 8: "août", + 9: "septembre", + 10: "octobre", + 11: "novembre", + 12: "décembre", + }, + "en": { + 1: "january", + 2: "february", + 3: "march", + 4: "april", + 5: "may", + 6: "june", + 7: "july", + 8: "august", + 9: "september", + 10: "october", + 11: "november", + 12: "december", + }, +} + +_currencies = ["€", "$", "£", "¥"] + +_symbol_to_word = { + "fr": { + "%": "pour cents", + "÷": "divisé par", + "\*": "fois", # ? + "×": "fois", + "±": "plus ou moins", + "\+": "plus", + "&": "et", + "@": "arobase", + "m²": "mètres carrés", + "m³": "mètres cubes", + "²": "au carré", + "³": "au cube", + "¼": "un quart", + "½": "un demi", + "¾": "trois quarts", + "§": "section", + "°C": "degrés Celsius", + "°F": "degrés Fahrenheit", + "°K": "kelvins", + "°": "degrés", + "€": "euros", + "¢": "cents", + "\$": "dollars", + "£": "livres", + "¥": "yens", + # Below: not in Whisper tokens + # "₩": "wons", + # "₽": "roubles", + # "₹": "roupies", + # "₺": "liras", + # "₪": "shekels", + # "₴": "hryvnias", + # "₮": "tugriks", + # "℃": "degrés Celsius", + # "℉": "degrés Fahrenheit", + # "Ω": "ohms", + # "Ω": "ohms", + # "K": "kelvins", + # "ℓ": "litres", + }, + "en": { + "%": "percent", + "÷": "divided by", + "\*": "times", # ? + "×": "times", + "±": "plus or minus", + "\+": "plus", + "&": "and", + "@": "at", + "m²": "square meters", + "m³": "cubic meters", + "²": "squared", + "³": "cubed", + "¼": "one quarter", + "½": "one half", + "¾": "three quarters", + "§": "section", + "°C": "degrees Celsius", + "°F": "degrees Fahrenheit", + "°K": "kelvins", + "°": "degrees", + "€": "euros", + "¢": "cents", + "\$": "dollars", + "£": "pounds", + "¥": "yens", + }, +} diff --git a/whisper/stt/processing/utils.py b/whisper/stt/processing/utils.py new file mode 100644 index 0000000..106167a --- /dev/null +++ b/whisper/stt/processing/utils.py @@ -0,0 +1,225 @@ +import io +import os + +import numpy as np +import wavio +from stt import USE_CTRANSLATE2, USE_TORCH, USE_TORCHAUDIO + +SAMPLE_RATE = 16000 # whisper.audio.SAMPLE_RATE + +if USE_CTRANSLATE2: + import ctranslate2 + import faster_whisper +else: + import torch + + import whisper + +if USE_TORCHAUDIO: + import torchaudio + + +def has_cuda(): + if USE_CTRANSLATE2: + return ctranslate2.get_cuda_device_count() > 0 + else: + return torch.cuda.is_available() + + +def get_device(): + device = os.environ.get("DEVICE", "cuda" if has_cuda() else "cpu") + use_gpu = "cuda" in device + + if USE_CTRANSLATE2: + try: + if device.startswith("cuda:"): + _ = [int(dev) for dev in device[5:].split(",")] + else: + assert device in ["cpu", "cuda"] + except: + raise ValueError( + f"Invalid DEVICE '{device}' (should be 'cpu' or 'cuda' or 'cuda: or 'cuda:,,...')" + ) + else: + try: + device = torch.device(device) + except Exception as err: + raise Exception("Failed to set device: {}".format(str(err))) from err + return device, use_gpu + + +def get_language(): + """ + Get the language from the environment variable LANGUAGE, and format as expected by Whisper. + """ + language = os.environ.get("LANGUAGE", "*") + # "fr-FR" -> "fr" (language-country code to ISO 639-1 code) + if len(language) > 2 and language[2] == "-": + language = language.split("-")[0] + # "*" means "all languages" + if language == "*": + language = None + # Convert French -> fr + if isinstance(language, str) and language not in LANGUAGES: + language = {v: k for k, v in LANGUAGES.items()}.get(language.lower(), language) + # Raise an exception for unknown languages + if language not in LANGUAGES: + available_languages = ( + list(LANGUAGES.keys()) + + [k[0].upper() + k[1:] for k in LANGUAGES.values()] + + ["*", None] + ) + raise ValueError( + f"Language '{language}' is not available. Available languages are: {available_languages}" + ) + return language + + +def conform_audio(audio, sample_rate=16_000): + if sample_rate != SAMPLE_RATE: + if not USE_TORCHAUDIO: + raise NotImplementedError("Resampling not available without torchaudio") + # Down or Up sample to the right sampling rate + audio = torchaudio.transforms.Resample(sample_rate, SAMPLE_RATE)(audio) + if audio.shape[0] > 1: + # Stereo to mono + # audio = torchaudio.transforms.DownmixMono()(audio, channels_first = True) + audio = audio.mean(0) + else: + audio = audio.squeeze(0) + return audio + + +def load_audiofile(path): + if not os.path.isfile(path): + raise RuntimeError("File not found: %s" % path) + elif not os.access(path, os.R_OK): + raise RuntimeError("Missing reading permission for: %s" % path) + if USE_CTRANSLATE2: + return faster_whisper.decode_audio(path, sampling_rate=SAMPLE_RATE) + audio = whisper.load_audio(path) + audio = torch.from_numpy(audio) + return audio + + +def load_wave_buffer(file_buffer): + """Formats audio from a wavFile buffer to a torch array for processing.""" + file_buffer_io = io.BytesIO(file_buffer) + if USE_CTRANSLATE2: + return faster_whisper.decode_audio(file_buffer_io, sampling_rate=SAMPLE_RATE) + file_content = wavio.read(file_buffer_io) + sample_rate = file_content.rate + audio = file_content.data.astype(np.float32) / 32768 + audio = audio.transpose() + audio = torch.from_numpy(audio) + return conform_audio(audio, sample_rate) + + +def flatten(l): + """ + flatten a list of lists + """ + return [item for sublist in l for item in sublist] + + +LANGUAGES = { # whisper.tokenizer.LANGUAGES + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", +} diff --git a/whisper/stt/processing/word_alignment.py b/whisper/stt/processing/word_alignment.py new file mode 100644 index 0000000..e7a9256 --- /dev/null +++ b/whisper/stt/processing/word_alignment.py @@ -0,0 +1,224 @@ +""" +Credits: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html +""" +from dataclasses import dataclass + +from stt import USE_TORCH, logger + +from .alignment_model import compute_logprobas, get_vocab +from .text_normalize import transliterate +from .utils import flatten + +if USE_TORCH: + import torch + +_unknown_chars = [] + + +def compute_alignment(audio, transcript, model): + """Compute the alignment of the audio and a transcript, for a given model that returns log-probabilities on the charset defined the transcript.""" + + emission = compute_logprobas(model, audio) + labels, blank_id = get_vocab(model) + labels = labels[: emission.shape[1]] + dictionary = {c: i for i, c in enumerate(labels)} + + default = labels.index("-") if "-" in labels else None + tokens = [loose_get_char_index(dictionary, c, default) for c in transcript] + tokens = flatten(tokens) + + num_emissions = emission.shape[0] + num_repetitions = count_repetitions(tokens) + if len(tokens) + num_repetitions > num_emissions: + # It will be impossible to find a path... + # It can happen when Whisper is lost in a loop (ex: "Ha ha ha ha ...") + logger.warn(f"Got too many characters from Whisper. Shrinking to the first characters.") + tokens = tokens[:num_emissions] + num_repetitions = count_repetitions(tokens) + while len(tokens) + num_repetitions > num_emissions: + tokens = tokens[:-1] + num_repetitions = count_repetitions(tokens) + + # Make sure transcript has the same length as tokens (it could be different just because of transliteration "œ" -> "oe") + transcript = "".join([labels[i][0] for i in tokens]) + + trellis = get_trellis(emission, tokens, blank_id=blank_id) + + path = backtrack(trellis, emission, tokens, blank_id=blank_id) + + segments = merge_repeats(transcript, path) + + word_segments = merge_words(segments) + + return labels, emission, trellis, segments, word_segments + + +def count_repetitions(tokens): + return sum([a == b for a, b in zip(tokens[1:], tokens[:-1])]) + + +def loose_get_char_index(dictionary, c, default=None): + global _unknown_chars + i = dictionary.get(c, None) + if i is None: + # Try with alternative versions of the character + tc = transliterate(c) + other_char = list(set([c.lower(), c.upper(), tc, tc.lower(), tc.upper()])) + for c2 in other_char: + i = dictionary.get(c2, None) + if i is not None: + i = [i] + break + # Some transliterated versions may correspond to multiple characters + if i is None: + for c2 in other_char: + if len(c2) > 1: + candidate = [dictionary[c3] for c3 in c2 if c3 in dictionary] + if len(candidate) > 0 and (i is None or len(candidate) > len(i)): + i = candidate + # If still not found + if i is None: + if c not in _unknown_chars: + logger.warn( + "Character not correctly handled by alignment model: '" + + "' / '".join(list(set([c] + other_char))) + + "'" + ) + _unknown_chars.append(c) + i = [default] if default is not None else [] + else: + i = [i] + return i + + +def get_trellis(emission, tokens, blank_id=0, use_max=False): + num_frame = emission.size(0) + num_tokens = len(tokens) + + # Trellis has extra diemsions for both time axis and tokens. + # The extra dim for tokens represents (start-of-sentence) + # The extra dim for time axis is for simplification of the code. + trellis = torch.empty((num_frame + 1, num_tokens + 1)).to(emission.device) + trellis[0, 0] = 0 + trellis[1:, 0] = torch.cumsum(emission[:, blank_id], 0) + trellis[0, -num_tokens:] = -float("inf") + trellis[-num_tokens:, 0] = float("inf") + + for t in range(num_frame): + trellis[t + 1, 1:] = ( + torch.maximum( + # Score for staying at the same token + trellis[t, 1:] + emission[t, blank_id], + torch.maximum( + trellis[t, 1:] + emission[t, tokens], + # Score for changing to the next token + trellis[t, :-1] + emission[t, tokens], + ), + ) + if use_max + else torch.logaddexp( + trellis[t, 1:] + emission[t, blank_id], + torch.logaddexp( + trellis[t, 1:] + emission[t, tokens], trellis[t, :-1] + emission[t, tokens] + ), + ) + ) + return trellis + + +@dataclass +class Point: + token_index: int + time_index: int + score: float + + +def backtrack(trellis, emission, tokens, blank_id=0): + # Note: + # j and t are indices for trellis, which has extra dimensions + # for time and tokens at the beginning. + # When referring to time frame index `T` in trellis, + # the corresponding index in emission is `T-1`. + # Similarly, when referring to token index `J` in trellis, + # the corresponding index in transcript is `J-1`. + j = trellis.size(1) - 1 + t_start = torch.argmax(trellis[:, j]).item() + + path = [] + for t in range(t_start, 0, -1): + # 1. Figure out if the current position was stay or change + # Note (again): + # `emission[J-1]` is the emission at time frame `J` of trellis dimension. + # Score for token staying the same from time frame J-1 to T. + stayed = trellis[t - 1, j] + emission[t - 1, blank_id] + # Score for token changing from C-1 at T-1 to J at T. + changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] + + # 2. Store the path with frame-wise probability. + prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() + # Return token index and time index in non-trellis coordinate. + path.append(Point(j - 1, t - 1, prob)) + + # 3. Update the token + if changed > stayed: + j -= 1 + if j == 0: + break + else: + logger.warn(f"Failed to align {len(tokens)} tokens") + return path + return path[::-1] + + +# Merge the labels +@dataclass +class Segment: + label: str + start: int + end: int + score: float + + def __repr__(self): + return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" + + @property + def length(self): + return self.end - self.start + + +def merge_repeats(transcript, path): + i1, i2 = 0, 0 + segments = [] + while i1 < len(path): + while i2 < len(path) and path[i1].token_index == path[i2].token_index: + i2 += 1 + score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) + segments.append( + Segment( + transcript[path[i1].token_index], + path[i1].time_index, + path[i2 - 1].time_index + 1, + score, + ) + ) + i1 = i2 + return segments + + +def merge_words(segments, separator=" "): + words = [] + i1, i2 = 0, 0 + while i1 < len(segments): + if i2 >= len(segments) or segments[i2].label == separator: + if i1 != i2: + segs = segments[i1:i2] + word = "".join([seg.label for seg in segs]) + score = sum(seg.score * seg.length for seg in segs) / sum( + seg.length for seg in segs + ) + words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) + i1 = i2 + 1 + i2 = i1 + else: + i2 += 1 + return words