diff --git a/.github/workflows/dockerhub-description.yml b/.github/workflows/dockerhub-description.yml
index 0367b21..1301449 100644
--- a/.github/workflows/dockerhub-description.yml
+++ b/.github/workflows/dockerhub-description.yml
@@ -7,7 +7,7 @@ on:
- README.md
- .github/workflows/dockerhub-description.yml
jobs:
- dockerHubDescription:
+ dockerHubDescriptionKaldi:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
@@ -16,5 +16,16 @@ jobs:
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- repository: lintoai/linto-platform-stt
- readme-filepath: ./README.md
+ repository: lintoai/linto-stt-kaldi
+ readme-filepath: ./kaldi/README.md
+ dockerHubDescriptionWhisper:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Docker Hub Description
+ uses: peter-evans/dockerhub-description@v3
+ with:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_PASSWORD }}
+ repository: lintoai/linto-stt-whisper
+ readme-filepath: ./whisper/README.md
diff --git a/.gitignore b/.gitignore
index 0b8d9ad..06b349b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
start_container.sh
.env*
test/*
+tmp*
+__pycache__
\ No newline at end of file
diff --git a/Jenkinsfile b/Jenkinsfile
index 95e42b0..cd1ad07 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -1,51 +1,72 @@
+def buildDockerfile(main_folder, dockerfilePath, image_name, version, changedFiles) {
+ if (changedFiles.contains(main_folder) || changedFiles.contains('celery_app') || changedFiles.contains('http_server') || changedFiles.contains('websocket') || changedFiles.contains('document')) {
+ echo "Building Dockerfile for ${image_name} with version ${version} (using ${dockerfilePath})"
+
+ script {
+ def image = docker.build(image_name, "-f ${dockerfilePath} .")
+
+ docker.withRegistry('https://registry.hub.docker.com', 'docker-hub-credentials') {
+ if (version == 'latest-unstable') {
+ image.push('latest-unstable')
+ } else {
+ image.push('latest')
+ image.push(version)
+ }
+ }
+ }
+ }
+}
+
pipeline {
agent any
environment {
- DOCKER_HUB_REPO = "lintoai/linto-platform-stt"
- DOCKER_HUB_CRED = 'docker-hub-credentials'
-
- VERSION = ''
+ DOCKER_HUB_REPO_KALDI = "lintoai/linto-stt-kaldi"
+ DOCKER_HUB_REPO_WHISPER = "lintoai/linto-stt-whisper"
}
-
- stages{
- stage('Docker build for master branch'){
- when{
+
+ stages {
+ stage('Docker build for master branch') {
+ when {
branch 'master'
}
steps {
echo 'Publishing latest'
script {
- image = docker.build(env.DOCKER_HUB_REPO)
- VERSION = sh(
+ def changedFiles = sh(returnStdout: true, script: 'git diff --name-only HEAD^ HEAD').trim()
+ echo "My changed files: ${changedFiles}"
+
+ version_kaldi = sh(
+ returnStdout: true,
+ script: "awk -v RS='' '/#/ {print; exit}' kaldi/RELEASE.md | head -1 | sed 's/#//' | sed 's/ //'"
+ ).trim()
+
+ version_whisper = sh(
returnStdout: true,
- script: "awk -v RS='' '/#/ {print; exit}' RELEASE.md | head -1 | sed 's/#//' | sed 's/ //'"
+ script: "awk -v RS='' '/#/ {print; exit}' whisper/RELEASE.md | head -1 | sed 's/#//' | sed 's/ //'"
).trim()
- docker.withRegistry('https://registry.hub.docker.com', env.DOCKER_HUB_CRED) {
- image.push("${VERSION}")
- image.push('latest')
- }
+ buildDockerfile('kaldi', 'kaldi/Dockerfile', env.DOCKER_HUB_REPO_KALDI, version_kaldi, changedFiles)
+ buildDockerfile('whisper', 'whisper/Dockerfile.ctranslate2', env.DOCKER_HUB_REPO_WHISPER, version_whisper, changedFiles)
}
}
}
- stage('Docker build for next (unstable) branch'){
- when{
+ stage('Docker build for next (unstable) branch') {
+ when {
branch 'next'
}
steps {
echo 'Publishing unstable'
script {
- image = docker.build(env.DOCKER_HUB_REPO)
- VERSION = sh(
- returnStdout: true,
- script: "awk -v RS='' '/#/ {print; exit}' RELEASE.md | head -1 | sed 's/#//' | sed 's/ //'"
- ).trim()
- docker.withRegistry('https://registry.hub.docker.com', env.DOCKER_HUB_CRED) {
- image.push('latest-unstable')
- }
+ def changedFiles = sh(returnStdout: true, script: 'git diff --name-only HEAD^ HEAD').trim()
+ echo "My changed files: ${changedFiles}"
+
+ version = 'latest-unstable'
+
+ buildDockerfile('kaldi', 'kaldi/Dockerfile', env.DOCKER_HUB_REPO_KALDI, version, changedFiles)
+ buildDockerfile('whisper', 'whisper/Dockerfile.ctranslate2', env.DOCKER_HUB_REPO_WHISPER, version, changedFiles)
}
}
}
- }// end stages
+ }
}
\ No newline at end of file
diff --git a/Makefile b/Makefile
index 71be1a8..24db387 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
.DEFAULT_GOAL := help
-target_dirs := stt http_server celery_app
+target_dirs := kaldi/stt whisper/stt http_server celery_app
help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
diff --git a/README.md b/README.md
index ec70060..09009fe 100644
--- a/README.md
+++ b/README.md
@@ -1,224 +1,12 @@
-# LINTO-PLATFORM-STT
-LinTO-platform-stt is the transcription service within the [LinTO stack](https://github.com/linto-ai/linto-platform-stack).
+# LinTO-STT
-LinTO-platform-stt can either be used as a standalone transcription service or deployed within a micro-services infrastructure using a message broker connector.
+LinTO-STT is the transcription service within the [LinTO stack](https://github.com/linto-ai/linto-platform-stack),
+which can currently work with Speech-To-Text (STT) models.
+The following families of STT models are currently supported (please refer to respective documentation for more details):
+* [Kaldi models](kaldi/README.md)
+* [Whisper models](whisper/README.md)
-## Pre-requisites
-
-### Hardware
-To run the transcription models you'll need:
-* At least 7Go of disk space to build the docker image.
-* Up to 7GB of RAM depending on the model used.
-* One CPU per worker. Inference time scales on CPU performances.
-
-### Model
-LinTO-Platform-STT accepts two kinds of models:
-* LinTO Acoustic and Languages models.
-* Vosk models.
-
-We provide home-cured models (v2) on [dl.linto.ai](https://doc.linto.ai/docs/developpers/apis/ASR/models).
-Or you can also use Vosk models available [here](https://alphacephei.com/vosk/models).
-
-### Docker
-The transcription service requires docker up and running.
-
-### (micro-service) Service broker and shared folder
-The STT only entry point in task mode are tasks posted on a message broker. Supported message broker are RabbitMQ, Redis, Amazon SQS.
-On addition, as to prevent large audio from transiting through the message broker, STT-Worker use a shared storage folder (SHARED_FOLDER).
-
-## Deploy linto-platform-stt
-
-**1- First step is to build or pull the image:**
-
-```bash
-git clone https://github.com/linto-ai/linto-platform-stt.git
-cd linto-platform-stt
-docker build . -t linto-platform-stt:latest
-```
-or
-
-```bash
-docker pull lintoai/linto-platform-stt
-```
-
-**2- Download the models**
-
-Have the acoustic and language model ready at AM_PATH and LM_PATH if you are using LinTO models. If you are using a Vosk model, have it ready at MODEL.
-
-**3- Fill the .env**
-
-```bash
-cp .envdefault .env
-```
-
-| PARAMETER | DESCRIPTION | EXEMPLE |
-|---|---|---|
-| SERVICE_MODE | STT serving mode see [Serving mode](#serving-mode) | http\|task\|websocket |
-| MODEL_TYPE | Type of STT model used. | lin\|vosk |
-| ENABLE_STREAMING | Using http serving mode, enable the /streaming websocket route | true\|false |
-| SERVICE_NAME | Using the task mode, set the queue's name for task processing | my-stt |
-| SERVICE_BROKER | Using the task mode, URL of the message broker | redis://my-broker:6379 |
-| BROKER_PASS | Using the task mode, broker password | my-password |
-| STREAMING_PORT | Using the websocket mode, the listening port for ingoing WS connexions. | 80 |
-| CONCURRENCY | Maximum number of parallel requests | >1 |
-
-### Serving mode
-![Serving Modes](https://i.ibb.co/qrtv3Z6/platform-stt.png)
-
-STT can be used three ways:
-* Through an [HTTP API](#http-server) using the **http**'s mode.
-* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode.
-* Through a [websocket server](#websocket-server) **websocket**'s mode.
-
-Mode is specified using the .env value or environment variable ```SERVING_MODE```.
-```bash
-SERVICE_MODE=http
-```
-### HTTP Server
-The HTTP serving mode deploys a HTTP server and a swagger-ui to allow transcription request on a dedicated route.
-
-The SERVICE_MODE value in the .env should be set to ```http```.
-
-```bash
-docker run --rm \
--p HOST_SERVING_PORT:80 \
--v AM_PATH:/opt/AM \
--v LM_PATH:/opt/LM \
---env-file .env \
-linto-platform-stt:latest
-```
-
-This will run a container providing an [HTTP API](#http-api) binded on the host HOST_SERVING_PORT port.
-
-**Parameters:**
-| Variables | Description | Example |
-|:-|:-|:-|
-| HOST_SERVING_PORT | Host serving port | 80 |
-| AM_PATH | Path to the acoustic model on the host machine mounted to /opt/AM | /my/path/to/models/AM_fr-FR_v2.2.0 |
-| LM_PATH | Path to the language model on the host machine mounted to /opt/LM | /my/path/to/models/fr-FR_big-v2.2.0 |
-| MODEL_PATH | Path to the model (using MODEL_TYPE=vosk) mounted to /opt/model | /my/path/to/models/vosk-model |
-
-### Micro-service within LinTO-Platform stack
-The HTTP serving mode connect a celery worker to a message broker.
-
-The SERVICE_MODE value in the .env should be set to ```task```.
-
->LinTO-platform-stt can be deployed within the linto-platform-stack through the use of linto-platform-services-manager. Used this way, the container spawn celery worker waiting for transcription task on a message broker.
->LinTO-platform-stt in task mode is not intended to be launch manually.
->However, if you intent to connect it to your custom message's broker here are the parameters:
-
-You need a message broker up and running at MY_SERVICE_BROKER.
-
-```bash
-docker run --rm \
--v AM_PATH:/opt/AM \
--v LM_PATH:/opt/LM \
--v SHARED_AUDIO_FOLDER:/opt/audio \
---env-file .env \
-linto-platform-stt:latest
-```
-
-**Parameters:**
-| Variables | Description | Example |
-|:-|:-|:-|
-| AM_PATH | Path to the acoustic model on the host machine mounted to /opt/AM | /my/path/to/models/AM_fr-FR_v2.2.0 |
-| LM_PATH | Path to the language model on the host machine mounted to /opt/LM | /my/path/to/models/fr-FR_big-v2.2.0 |
-| MODEL_PATH | Path to the model (using MODEL_TYPE=vosk) mounted to /opt/model | /my/path/to/models/vosk-model |
-| SHARED_AUDIO_FOLDER | Shared audio folder mounted to /opt/audio | /my/path/to/models/vosk-model |
-
-
-### Websocket Server
-Websocket server's mode deploy a streaming transcription service only.
-
-The SERVICE_MODE value in the .env should be set to ```websocket```.
-
-Usage is the same as the [http streaming API](#/streaming)
-
-## Usages
-### HTTP API
-#### /healthcheck
-Returns the state of the API
-
-Method: GET
-
-Returns "1" if healthcheck passes.
-
-#### /transcribe
-Transcription API
-
-* Method: POST
-* Response content: text/plain or application/json
-* File: An Wave file 16b 16Khz
-
-Return the transcripted text using "text/plain" or a json object when using "application/json" structure as followed:
-```json
-{
- "text" : "This is the transcription",
- "words" : [
- {"word":"This", "start": 0.123, "end": 0.453, "conf": 0.9},
- ...
- ]
- "confidence-score": 0.879
-}
-```
-
-#### /streaming
-The /streaming route is accessible if the ENABLE_STREAMING environment variable is set to true.
-
-The route accepts websocket connexions. Exchanges are structured as followed:
-1. Client send a json {"config": {"sample_rate":16000}}.
-2. Client send audio chunk (go to 3- ) or {"eof" : 1} (go to 5-).
-3. Server send either a partial result {"partial" : "this is a "} or a final result {"text": "this is a transcription"}.
-4. Back to 2-
-5. Server send a final result and close the connexion.
-
-> Connexion will be closed and the worker will be freed if no chunk are received for 10s.
-
-#### /docs
-The /docs route offers a OpenAPI/swagger interface.
-
-### Through the message broker
-
-STT-Worker accepts requests with the following arguments:
-```file_path: str, with_metadata: bool```
-
-* file_path: Is the location of the file within the shared_folder. /.../SHARED_FOLDER/{file_path}
-* with_metadata: If True, words timestamps and confidence will be computed and returned. If false, the fields will be empty.
-
-#### Return format
-On a successfull transcription the returned object is a json object structured as follow:
-```json
-{
- "text" : "this is the transcription as text",
- "words": [
- {
- "word" : "this",
- "start": 0.0,
- "end": 0.124,
- "conf": 1.0
- },
- ...
- ],
- "confidence-score": ""
-}
-```
-
-* The text field contains the raw transcription.
-* The word field contains each word with their time stamp and individual confidence. (Empty if with_metadata=False)
-* The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False)
-
-
-## Test
-### Curl
-You can test you http API using curl:
-```bash
-curl -X POST "http://YOUR_SERVICE:YOUR_PORT/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@YOUR_FILE;type=audio/x-wav"
-```
+LinTO-STT can either be used as a standalone transcription service or deployed within a micro-services infrastructure using a message broker connector.
## License
This project is developped under the AGPLv3 License (see LICENSE).
-
-## Acknowlegment.
-
-* [Vosk, speech recognition toolkit](https://alphacephei.com/vosk/).
-* [Kaldi Speech Recognition Toolkit](https://github.com/kaldi-asr/kaldi)
diff --git a/RELEASE.md b/RELEASE.md
deleted file mode 100644
index 9966250..0000000
--- a/RELEASE.md
+++ /dev/null
@@ -1,52 +0,0 @@
-# 3.3.2
-- Fixed use of stereo audio in http serving mode
-
-# 3.3.1
-- Fixed lin_to_vosk throwing an error on a already existing container.
-- Corrected an error on the README regarding mounting model volumes.
-- Code styling (PEP 8)
-
-# 3.3.0
-- Added optional streaming route to the http serving mode
-- Added serving mode: websocket
-- Added Dynamic model conversion allowing to use either Vosk Models or Linagora AM/LM models
-- Changer Vosk dependency to alphacep/vosk
-- Updated README.md
-
-# 3.2.1
-- Repository total rework. The goal being to have a simple transcription service embeddable within a micro-service infrastructure.
-- Changed repository name from linto-platform-stt-standalone-worker to linto-platform-stt.
-- Added celery connector for microservice integration.
-- Added launch option to specify serving mode between task and http.
-- Removed diarization functionnality.
-- Removed punctuation functionnality.
-- Removed Async requests/Job management.
-- Updated README to reflect those changes.
-
-# 3.1.1
-- Change Pykaldi with vosk-API (no python wrapper for decoding function, no extrat packages during installation, c++ implementation based on kaldi functions)
-- New feature: Compute a confidence score per transcription
-- Fix minor bugs
-
-# 2.2.1
-- Fix minor bugs
-- put SWAGGER_PATH parameter as optional
-- Generate the word_boundary file if it does not exist
-
-# 2.2.0
-- Speaker diarization feature: pyBK package
-- Mulithreading feature: Speech decoding and Speaker diarization processes
-- Optional parameter: real number of speaker in the audio
-
-# 2.0.0
-- Reimplement LinTO-Platform-stt-standalone-worker using Pykaldi package
-
-# 1.1.2
-- New features:
- - Word timestamp computing
- - Response type: plain/text: simple text output and application/json: the transcription and the words timestamp.
- - Swagger: integrate swagger in the service using a python package
- - Fix minor bugs
-
-# 1.0.0
-- First build of LinTO-Platform-stt-standalone-worker
\ No newline at end of file
diff --git a/celery_app/celeryapp.py b/celery_app/celeryapp.py
index e04d73b..d1c4099 100644
--- a/celery_app/celeryapp.py
+++ b/celery_app/celeryapp.py
@@ -1,7 +1,6 @@
import os
from celery import Celery
-
from stt import logger
celery = Celery(__name__, include=["celery_app.tasks"])
@@ -10,9 +9,14 @@
if os.environ.get("BROKER_PASS", False):
components = broker_url.split("//")
broker_url = f'{components[0]}//:{os.environ.get("BROKER_PASS")}@{components[1]}'
+
celery.conf.broker_url = f"{broker_url}/0"
celery.conf.result_backend = f"{broker_url}/1"
-celery.conf.update(result_expires=3600, task_acks_late=True, task_track_started=True)
+celery.conf.task_acks_late = False
+celery.conf.task_track_started = True
+celery.conf.broker_transport_options = {"visibility_timeout": float("inf")}
+# celery.conf.result_backend_transport_options = {"visibility_timeout": float("inf")}
+# celery.conf.result_expires = 3600 * 24
# Queues
celery.conf.update(
diff --git a/celery_app/tasks.py b/celery_app/tasks.py
index ce2ca4d..114df2a 100644
--- a/celery_app/tasks.py
+++ b/celery_app/tasks.py
@@ -1,10 +1,11 @@
import asyncio
import os
-from celery_app.celeryapp import celery
from stt import logger
-from stt.processing import decode, model
-from stt.processing.utils import load_wave
+from stt.processing import MODEL, decode
+from stt.processing.utils import load_audiofile
+
+from celery_app.celeryapp import celery
@celery.task(name="transcribe_task")
@@ -15,16 +16,22 @@ def transcribe_task(file_name: str, with_metadata: bool):
# Load wave
file_path = os.path.join("/opt/audio", file_name)
try:
- file_content = load_wave(file_path)
+ file_content = load_audiofile(file_path)
except Exception as err:
- logger.error(f"Failed to load ressource: {repr(err)}")
- raise Exception(f"Could not open ressource {file_path}") from err
+ import traceback
+
+ msg = f"{traceback.format_exc()}\nFailed to load ressource {file_path}"
+ logger.error(msg)
+ raise Exception(msg) # from err
# Decode
try:
- result = decode(file_content, model, 16000, with_metadata)
+ result = decode(file_content, MODEL, with_metadata)
except Exception as err:
- logger.error(f"Failed to decode: {repr(err)}")
- raise Exception(f"Failed to decode {file_path}") from err
+ import traceback
+
+ msg = f"{traceback.format_exc()}\nFailed to decode {file_path}"
+ logger.error(msg)
+ raise Exception(msg) # from err
return result
diff --git a/document/swagger.yml b/document/swagger.yml
index 70bc9fc..6da4ed6 100644
--- a/document/swagger.yml
+++ b/document/swagger.yml
@@ -2,7 +2,7 @@ swagger: "2.0"
info:
version: "1.0.0"
- title: LinTo-Platform-STT
+ title: LinTo-STT
description: Speech To Text API
contact:
email: "support@linto.ai"
diff --git a/http_server/confparser.py b/http_server/confparser.py
index 2396d71..d296dbb 100644
--- a/http_server/confparser.py
+++ b/http_server/confparser.py
@@ -7,24 +7,6 @@
def createParser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
- # SERVICE
- parser.add_argument(
- "--service_name",
- type=str,
- help="Service Name",
- default=os.environ.get("SERVICE_NAME", "stt"),
- )
-
- # MODELS
- parser.add_argument("--am_path", type=str, help="Acoustic Model Path", default="/opt/models/AM")
- parser.add_argument("--lm_path", type=str, help="Decoding graph path", default="/opt/models/LM")
- parser.add_argument(
- "--config_path",
- type=str,
- help="Configuration files path",
- default="/opt/config",
- )
-
# GUNICORN
parser.add_argument("--service_port", type=int, help="Service port", default=80)
parser.add_argument(
diff --git a/http_server/ingress.py b/http_server/ingress.py
index 5a9c661..6c71478 100644
--- a/http_server/ingress.py
+++ b/http_server/ingress.py
@@ -3,17 +3,15 @@
import json
import logging
import os
-from time import time
+import time
from confparser import createParser
-from flask import Flask, Response, abort, json, request
-from flask_sock import Sock
-from serving import GunicornServing
+from flask import Flask, json, request
+from serving import GeventServing, GunicornServing
+from stt import logger as stt_logger
+from stt.processing import MODEL, USE_GPU, decode, load_wave_buffer
from swagger import setupSwaggerUI
-from stt.processing import decode, formatAudio, model
-from stt.processing.streaming import ws_streaming
-
app = Flask("__stt-standalone-worker__")
app.config["JSON_AS_ASCII"] = False
app.config["JSON_SORT_KEYS"] = False
@@ -26,13 +24,16 @@
# If websocket streaming route is enabled
if os.environ.get("ENABLE_STREAMING", False) in [True, "true", 1]:
+ from flask_sock import Sock
+ from stt.processing.streaming import ws_streaming
+
logger.info("Init websocket serving ...")
sock = Sock(app)
logger.info("Streaming is enabled")
@sock.route("/streaming")
def streaming(web_socket):
- ws_streaming(web_socket, model)
+ ws_streaming(web_socket, MODEL)
@app.route("/healthcheck", methods=["GET"])
@@ -51,39 +52,38 @@ def transcribe():
logger.info("Transcribe request received")
# get response content type
- logger.debug(request.headers.get("accept").lower())
+ # logger.debug(request.headers.get("accept").lower())
if request.headers.get("accept").lower() == "application/json":
join_metadata = True
elif request.headers.get("accept").lower() == "text/plain":
join_metadata = False
else:
- raise ValueError("Not accepted header")
- logger.debug("Metadata: {}".format(join_metadata))
+ raise ValueError(
+ f"Not accepted header (accept={request.headers.get('accept')} should be either application/json or text/plain)"
+ )
+ # logger.debug("Metadata: {}".format(join_metadata))
# get input file
- if "file" in request.files.keys():
- file_buffer = request.files["file"].read()
- audio_data, sampling_rate = formatAudio(file_buffer)
- start_t = time()
+ if "file" not in request.files.keys():
+ raise ValueError(f"No audio file was uploaded (missing 'file' key)")
- # Transcription
- transcription = decode(audio_data, model, sampling_rate, join_metadata)
- logger.debug("Transcription complete (t={}s)".format(time() - start_t))
+ file_buffer = request.files["file"].read()
- logger.debug("... Complete")
+ audio_data = load_wave_buffer(file_buffer)
- else:
- raise ValueError("No audio file was uploaded")
+ # Transcription
+ transcription = decode(audio_data, MODEL, join_metadata)
if join_metadata:
return json.dumps(transcription, ensure_ascii=False), 200
return transcription["text"], 200
- except ValueError as error:
- return str(error), 400
except Exception as error:
- logger.error(error)
- return "Server Error: {}".format(str(error)), 500
+ import traceback
+
+ logger.error(traceback.format_exc())
+ logger.error(repr(error))
+ return "Server Error: {}".format(str(error)), 400 if isinstance(error, ValueError) else 500
@app.errorhandler(405)
@@ -107,7 +107,9 @@ def server_error(error):
parser = createParser()
args = parser.parse_args()
- logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
+ logger_level = logging.DEBUG if args.debug else logging.INFO
+ logger.setLevel(logger_level)
+ stt_logger.setLevel(logger_level)
try:
# Setup SwaggerUI
if args.swagger_path is not None:
@@ -116,12 +118,21 @@ def server_error(error):
except Exception as err:
logger.warning("Could not setup swagger: {}".format(str(err)))
- serving = GunicornServing(
+ logger.info(f"Using {args.workers} workers")
+
+ if USE_GPU: # TODO: get rid of this?
+ serving_type = GeventServing
+ logger.debug("Serving with gevent")
+ else:
+ serving_type = GunicornServing
+ logger.debug("Serving with gunicorn")
+
+ serving = serving_type(
app,
{
"bind": f"0.0.0.0:{args.service_port}",
"workers": args.workers,
- "timeout": 3600,
+ "timeout": 3600 * 24,
},
)
logger.info(args)
diff --git a/http_server/serving.py b/http_server/serving.py
index d2dd7e8..9230eb4 100644
--- a/http_server/serving.py
+++ b/http_server/serving.py
@@ -1,5 +1,9 @@
+import gevent.monkey
+import gevent.pywsgi
import gunicorn.app.base
+gevent.monkey.patch_all()
+
class GunicornServing(gunicorn.app.base.BaseApplication):
def __init__(self, app, options=None):
@@ -18,3 +22,22 @@ def load_config(self):
def load(self):
return self.application
+
+
+class GeventServing:
+ def __init__(self, app, options=None):
+ self.options = options or {}
+ self.application = app
+
+ def run(self):
+ bind = self.options.get("bind", "0.0.0.0:8080")
+ workers = self.options.get("workers", 1)
+ listener = bind.split(":")
+ try:
+ assert len(listener) == 2
+ listener = (listener[0], int(listener[1]))
+ except:
+ print(f"Invalid bind address {bind}")
+
+ server = gevent.pywsgi.WSGIServer(listener, self.application, spawn=workers)
+ server.serve_forever()
diff --git a/http_server/swagger.py b/http_server/swagger.py
index a9b93d0..31344cd 100644
--- a/http_server/swagger.py
+++ b/http_server/swagger.py
@@ -11,7 +11,7 @@ def setupSwaggerUI(app, args):
args.swagger_prefix + args.swagger_url,
args.swagger_path,
config={ # Swagger UI config overrides
- "app_name": "LinTO Platform STT",
+ "app_name": "LinTO STT",
"spec": swagger_yml,
},
)
diff --git a/.envdefault b/kaldi/.envdefault
similarity index 100%
rename from .envdefault
rename to kaldi/.envdefault
diff --git a/Dockerfile b/kaldi/Dockerfile
similarity index 91%
rename from Dockerfile
rename to kaldi/Dockerfile
index bdf65c0..f062951 100644
--- a/Dockerfile
+++ b/kaldi/Dockerfile
@@ -45,7 +45,7 @@ RUN git clone -b vosk --single-branch https://github.com/alphacep/kaldi /opt/kal
&& make -j $(nproc) online2 lm rnnlm
# Install python dependencies
-COPY requirements.txt ./
+COPY kaldi/requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt
# Install Custom Vosk API
@@ -57,13 +57,13 @@ RUN git clone --depth 1 https://github.com/alphacep/vosk-api /opt/vosk-api && cd
WORKDIR /usr/src/app
-COPY stt /usr/src/app/stt
COPY celery_app /usr/src/app/celery_app
COPY http_server /usr/src/app/http_server
COPY websocket /usr/src/app/websocket
COPY document /usr/src/app/document
-COPY docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./
-COPY lin_to_vosk.py /usr/src/app/lin_to_vosk.py
+COPY kaldi/stt /usr/src/app/stt
+COPY kaldi/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./
+COPY kaldi/lin_to_vosk.py /usr/src/app/lin_to_vosk.py
RUN mkdir -p /var/log/supervisor/
diff --git a/kaldi/README.md b/kaldi/README.md
new file mode 100644
index 0000000..0e3a31a
--- /dev/null
+++ b/kaldi/README.md
@@ -0,0 +1,222 @@
+# LinTO-STT-Kaldi
+
+LinTO-STT-Kaldi is the transcription service within the [LinTO stack](https://github.com/linto-ai/linto-platform-stack)
+based on Speech-To-Text (STT) models trained with [Kaldi](https://github.com/kaldi-asr/kaldi).
+
+LinTO-STT-Kaldi can either be used as a standalone transcription service or deployed within a micro-services infrastructure using a message broker connector.
+
+## Pre-requisites
+
+### Hardware
+To run the transcription models you'll need:
+* At least 7Go of disk space to build the docker image.
+* Up to 7GB of RAM depending on the model used.
+* One CPU per worker. Inference time scales on CPU performances.
+
+### Model
+LinTO-STT-Kaldi accepts two kinds of models:
+* LinTO Acoustic and Languages models.
+* Vosk models.
+
+We provide home-cured models (v2) on [dl.linto.ai](https://doc.linto.ai/docs/developpers/apis/ASR/models).
+Or you can also use Vosk models available [here](https://alphacephei.com/vosk/models).
+
+### Docker
+The transcription service requires docker up and running.
+
+### (micro-service) Service broker and shared folder
+The STT only entry point in task mode are tasks posted on a message broker. Supported message broker are RabbitMQ, Redis, Amazon SQS.
+On addition, as to prevent large audio from transiting through the message broker, STT-Worker use a shared storage folder (SHARED_FOLDER).
+
+## Deploy LinTO-STT-Kaldi
+
+**1- First step is to build or pull the image:**
+
+```bash
+git clone https://github.com/linto-ai/linto-stt.git
+cd linto-stt
+docker build . -f kaldi/Dockerfile -t linto-stt-kaldi:latest
+```
+or
+
+```bash
+docker pull lintoai/linto-stt-kaldi
+```
+
+**2- Download the models**
+
+Have the acoustic and language model ready at AM_PATH and LM_PATH if you are using LinTO models. If you are using a Vosk model, have it ready at MODEL.
+
+**3- Fill the .env**
+
+```bash
+cp kaldi/.envdefault kaldi/.env
+```
+
+| PARAMETER | DESCRIPTION | EXEMPLE |
+|---|---|---|
+| SERVICE_MODE | STT serving mode see [Serving mode](#serving-mode) | http\|task\|websocket |
+| MODEL_TYPE | Type of STT model used. | lin\|vosk |
+| ENABLE_STREAMING | Using http serving mode, enable the /streaming websocket route | true\|false |
+| SERVICE_NAME | Using the task mode, set the queue's name for task processing | my-stt |
+| SERVICE_BROKER | Using the task mode, URL of the message broker | redis://my-broker:6379 |
+| BROKER_PASS | Using the task mode, broker password | my-password |
+| STREAMING_PORT | Using the websocket mode, the listening port for ingoing WS connexions. | 80 |
+| CONCURRENCY | Maximum number of parallel requests | >1 |
+
+### Serving mode
+![Serving Modes](https://i.ibb.co/qrtv3Z6/platform-stt.png)
+
+STT can be used three ways:
+* Through an [HTTP API](#http-server) using the **http**'s mode.
+* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode.
+* Through a [websocket server](#websocket-server) **websocket**'s mode.
+
+Mode is specified using the .env value or environment variable ```SERVING_MODE```.
+```bash
+SERVICE_MODE=http
+```
+### HTTP Server
+The HTTP serving mode deploys a HTTP server and a swagger-ui to allow transcription request on a dedicated route.
+
+The SERVICE_MODE value in the .env should be set to ```http```.
+
+```bash
+docker run --rm \
+-p HOST_SERVING_PORT:80 \
+-v AM_PATH:/opt/AM \
+-v LM_PATH:/opt/LM \
+--env-file kaldi/.env \
+linto-stt-kaldi:latest
+```
+
+This will run a container providing an [HTTP API](#http-api) binded on the host HOST_SERVING_PORT port.
+
+**Parameters:**
+| Variables | Description | Example |
+|:-|:-|:-|
+| HOST_SERVING_PORT | Host serving port | 80 |
+| AM_PATH | Path to the acoustic model on the host machine mounted to /opt/AM | /my/path/to/models/AM_fr-FR_v2.2.0 |
+| LM_PATH | Path to the language model on the host machine mounted to /opt/LM | /my/path/to/models/fr-FR_big-v2.2.0 |
+| MODEL_PATH | Path to the model (using MODEL_TYPE=vosk) mounted to /opt/model | /my/path/to/models/vosk-model |
+
+### Micro-service within LinTO-Platform stack
+The TASK serving mode connect a celery worker to a message broker.
+
+The SERVICE_MODE value in the .env should be set to ```task```.
+
+You need a message broker up and running at MY_SERVICE_BROKER.
+
+```bash
+docker run --rm \
+-v AM_PATH:/opt/AM \
+-v LM_PATH:/opt/LM \
+-v SHARED_AUDIO_FOLDER:/opt/audio \
+--env-file kaldi/.env \
+linto-stt-kaldi:latest
+```
+
+**Parameters:**
+| Variables | Description | Example |
+|:-|:-|:-|
+| AM_PATH | Path to the acoustic model on the host machine mounted to /opt/AM | /my/path/to/models/AM_fr-FR_v2.2.0 |
+| LM_PATH | Path to the language model on the host machine mounted to /opt/LM | /my/path/to/models/fr-FR_big-v2.2.0 |
+| MODEL_PATH | Path to the model (using MODEL_TYPE=vosk) mounted to /opt/model | /my/path/to/models/vosk-model |
+| SHARED_AUDIO_FOLDER | Shared audio folder mounted to /opt/audio | /my/path/to/models/vosk-model |
+
+
+### Websocket Server
+Websocket server's mode deploy a streaming transcription service only.
+
+The SERVICE_MODE value in the .env should be set to ```websocket```.
+
+Usage is the same as the [http streaming API](#/streaming)
+
+## Usages
+### HTTP API
+#### /healthcheck
+Returns the state of the API
+
+Method: GET
+
+Returns "1" if healthcheck passes.
+
+#### /transcribe
+Transcription API
+
+* Method: POST
+* Response content: text/plain or application/json
+* File: An Wave file 16b 16Khz
+
+Return the transcripted text using "text/plain" or a json object when using "application/json" structure as followed:
+```json
+{
+ "text" : "This is the transcription",
+ "words" : [
+ {"word":"This", "start": 0.123, "end": 0.453, "conf": 0.9},
+ ...
+ ]
+ "confidence-score": 0.879
+}
+```
+
+#### /streaming
+The /streaming route is accessible if the ENABLE_STREAMING environment variable is set to true.
+
+The route accepts websocket connexions. Exchanges are structured as followed:
+1. Client send a json {"config": {"sample_rate":16000}}.
+2. Client send audio chunk (go to 3- ) or {"eof" : 1} (go to 5-).
+3. Server send either a partial result {"partial" : "this is a "} or a final result {"text": "this is a transcription"}.
+4. Back to 2-
+5. Server send a final result and close the connexion.
+
+> Connexion will be closed and the worker will be freed if no chunk are received for 10s.
+
+#### /docs
+The /docs route offers a OpenAPI/swagger interface.
+
+### Through the message broker
+
+STT-Worker accepts requests with the following arguments:
+```file_path: str, with_metadata: bool```
+
+* file_path: Is the location of the file within the shared_folder. /.../SHARED_FOLDER/{file_path}
+* with_metadata: If True, words timestamps and confidence will be computed and returned. If false, the fields will be empty.
+
+#### Return format
+On a successfull transcription the returned object is a json object structured as follow:
+```json
+{
+ "text" : "this is the transcription as text",
+ "words": [
+ {
+ "word" : "this",
+ "start": 0.0,
+ "end": 0.124,
+ "conf": 1.0
+ },
+ ...
+ ],
+ "confidence-score": ""
+}
+```
+
+* The text field contains the raw transcription.
+* The word field contains each word with their time stamp and individual confidence. (Empty if with_metadata=False)
+* The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False)
+
+
+## Test
+### Curl
+You can test you http API using curl:
+```bash
+curl -X POST "http://YOUR_SERVICE:YOUR_PORT/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@YOUR_FILE;type=audio/x-wav"
+```
+
+## License
+This project is developped under the AGPLv3 License (see LICENSE).
+
+## Acknowlegment.
+
+* [Vosk, speech recognition toolkit](https://alphacephei.com/vosk/).
+* [Kaldi Speech Recognition Toolkit](https://github.com/kaldi-asr/kaldi)
diff --git a/kaldi/RELEASE.md b/kaldi/RELEASE.md
new file mode 100644
index 0000000..e11f89a
--- /dev/null
+++ b/kaldi/RELEASE.md
@@ -0,0 +1,3 @@
+# 1.0.0
+- First build of linto-stt-kaldi
+- Based on 3.3.2 of linto-stt (https://github.com/linto-ai/linto-stt/blob/4361300a4463c90cec0bf3fa2975d7cc2ddf8d36/RELEASE.md)
diff --git a/docker-entrypoint.sh b/kaldi/docker-entrypoint.sh
similarity index 100%
rename from docker-entrypoint.sh
rename to kaldi/docker-entrypoint.sh
diff --git a/lin_to_vosk.py b/kaldi/lin_to_vosk.py
similarity index 100%
rename from lin_to_vosk.py
rename to kaldi/lin_to_vosk.py
diff --git a/requirements.txt b/kaldi/requirements.txt
similarity index 96%
rename from requirements.txt
rename to kaldi/requirements.txt
index 132bdfc..867a095 100644
--- a/requirements.txt
+++ b/kaldi/requirements.txt
@@ -4,6 +4,7 @@ flask>=1.1.2
flask-cors>=3.0.10
flask-swagger-ui>=3.36.0
flask-sock
+gevent
gunicorn
pyyaml>=5.4.1
wavio>=0.0.4
diff --git a/stt/__init__.py b/kaldi/stt/__init__.py
similarity index 100%
rename from stt/__init__.py
rename to kaldi/stt/__init__.py
diff --git a/stt/processing/__init__.py b/kaldi/stt/processing/__init__.py
similarity index 68%
rename from stt/processing/__init__.py
rename to kaldi/stt/processing/__init__.py
index 2a3eca5..9f99406 100644
--- a/stt/processing/__init__.py
+++ b/kaldi/stt/processing/__init__.py
@@ -2,13 +2,19 @@
import sys
from time import time
-from vosk import Model
-
from stt import logger
from stt.processing.decoding import decode
-from stt.processing.utils import formatAudio, load_wave
+from stt.processing.utils import load_audiofile, load_wave_buffer
+from vosk import Model
-__all__ = ["model", "logger", "decode", "load_wave", "formatAudio"]
+__all__ = [
+ "logger",
+ "decode",
+ "load_audiofile",
+ "load_wave_buffer",
+ "MODEL",
+ "USE_GPU",
+]
# Model locations (should be mounted)
MODEL_PATH = "/opt/model"
@@ -17,8 +23,11 @@
logger.info("Loading acoustic model and decoding graph ...")
start = time()
try:
- model = Model(MODEL_PATH)
+ MODEL = Model(MODEL_PATH)
except Exception as err:
raise Exception("Failed to load transcription model: {}".format(str(err))) from err
sys.exit(-1)
logger.info("Acoustic model and decoding graph loaded. (t={}s)".format(time() - start))
+
+# Not implemented yet in Kaldi
+USE_GPU = False
diff --git a/stt/processing/decoding.py b/kaldi/stt/processing/decoding.py
similarity index 89%
rename from stt/processing/decoding.py
rename to kaldi/stt/processing/decoding.py
index 2e1fb7c..8c06007 100644
--- a/stt/processing/decoding.py
+++ b/kaldi/stt/processing/decoding.py
@@ -4,10 +4,12 @@
from vosk import KaldiRecognizer, Model
-def decode(audio_data: bytes, model: Model, sampling_rate: int, with_metadata: bool) -> dict:
+def decode(audio: tuple[bytes, int], model: Model, with_metadata: bool) -> dict:
"""Transcribe the audio data using the vosk library with the defined model."""
result = {"text": "", "confidence-score": 0.0, "words": []}
+ audio_data, sampling_rate = audio
+
recognizer = KaldiRecognizer(model, sampling_rate)
recognizer.SetMaxAlternatives(0) # Set confidence per words
recognizer.SetWords(with_metadata)
diff --git a/stt/processing/streaming.py b/kaldi/stt/processing/streaming.py
similarity index 99%
rename from stt/processing/streaming.py
rename to kaldi/stt/processing/streaming.py
index 28274b8..a33ecfc 100644
--- a/stt/processing/streaming.py
+++ b/kaldi/stt/processing/streaming.py
@@ -3,11 +3,10 @@
from typing import Union
from simple_websocket.ws import Server as WSServer
+from stt import logger
from vosk import KaldiRecognizer, Model
from websockets.legacy.server import WebSocketServerProtocol
-from stt import logger
-
async def wssDecode(ws: WebSocketServerProtocol, model: Model):
"""Async Decode function endpoint"""
diff --git a/stt/processing/utils.py b/kaldi/stt/processing/utils.py
similarity index 84%
rename from stt/processing/utils.py
rename to kaldi/stt/processing/utils.py
index b81cc5d..eb3349d 100644
--- a/stt/processing/utils.py
+++ b/kaldi/stt/processing/utils.py
@@ -1,16 +1,16 @@
import io
import wavio
-from numpy import int16, squeeze, mean
+from numpy import int16, mean, squeeze
-def load_wave(file_path):
+def load_audiofile(file_path):
"""Formats audio from a wavFile buffer to a bytebuffer"""
audio = squeeze(wavio.read(file_path).data)
- return audio.tobytes()
+ return (audio.tobytes(), 16000)
-def formatAudio(file_buffer):
+def load_wave_buffer(file_buffer):
"""Formats audio from a wavFile buffer to a numpy array for processing."""
file_buffer_io = io.BytesIO(file_buffer)
file_content = wavio.read(file_buffer_io)
diff --git a/whisper/.envdefault b/whisper/.envdefault
new file mode 100644
index 0000000..88c27ea
--- /dev/null
+++ b/whisper/.envdefault
@@ -0,0 +1,39 @@
+############################################
+# SERVING PARAMETERS
+############################################
+# "http" or "task"
+SERVICE_MODE=http
+
+# Below: used when SERVICE_MODE=task
+SERVICE_NAME=stt
+SERVICES_BROKER=redis://172.17.0.1:6379
+BROKER_PASS=
+
+############################################
+# STT MODELING PARAMETERS
+############################################
+
+# The model can be a path to a model, or a model name ("tiny", "base", "small", "medium", "large-v1", "large-v2" or "large-v3")
+MODEL=medium
+
+# The language can be in different formats: "en", "en-US", "English", ...
+# If not set or set to "*", the language will be detected automatically.
+LANGUAGE=*
+
+# An alignment wav2vec model can be used to get word timestamps.
+# It can be a path to a model, a language code (fr, en, ...), or "wav2vec" to automatically chose a model for the language
+# This option is experimental (and not implemented with ctranslate2).
+# ALIGNMENT_MODEL=wav2vec
+
+############################################
+# EFFICIENCY PARAMETERS
+############################################
+
+# Device to use. It can be "cuda" to force/check GPU, "cpu" to force computation on CPU, or a specific GPU ("cuda:0", "cuda:1", ...)
+# DEVICE=cuda:0
+
+# Number of threads per worker when running on CPU
+OMP_NUM_THREADS=4
+
+# Number of workers
+CONCURRENCY=2
diff --git a/whisper/Dockerfile.ctranslate2 b/whisper/Dockerfile.ctranslate2
new file mode 100644
index 0000000..52fbc44
--- /dev/null
+++ b/whisper/Dockerfile.ctranslate2
@@ -0,0 +1,23 @@
+FROM ghcr.io/opennmt/ctranslate2:latest-ubuntu20.04-cuda11.2
+LABEL maintainer="jlouradour@linagora.com"
+
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends ffmpeg git
+
+# Install python dependencies
+COPY whisper/requirements.ctranslate2.txt ./
+RUN pip install --no-cache-dir -r requirements.ctranslate2.txt && rm requirements.ctranslate2.txt
+
+WORKDIR /usr/src/app
+
+COPY celery_app /usr/src/app/celery_app
+COPY http_server /usr/src/app/http_server
+COPY websocket /usr/src/app/websocket
+COPY document /usr/src/app/document
+COPY whisper/stt /usr/src/app/stt
+COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./
+
+ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt"
+
+HEALTHCHECK CMD ./healthcheck.sh
+
+ENTRYPOINT ["./docker-entrypoint.sh"]
\ No newline at end of file
diff --git a/whisper/Dockerfile.ctranslate2.cpu b/whisper/Dockerfile.ctranslate2.cpu
new file mode 100644
index 0000000..c8d6972
--- /dev/null
+++ b/whisper/Dockerfile.ctranslate2.cpu
@@ -0,0 +1,23 @@
+FROM python:3.9
+LABEL maintainer="jlouradour@linagora.com"
+
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends ffmpeg git
+
+# Install python dependencies
+COPY whisper/requirements.ctranslate2.txt ./
+RUN pip install --no-cache-dir -r requirements.ctranslate2.txt && rm requirements.ctranslate2.txt
+
+WORKDIR /usr/src/app
+
+COPY celery_app /usr/src/app/celery_app
+COPY http_server /usr/src/app/http_server
+COPY websocket /usr/src/app/websocket
+COPY document /usr/src/app/document
+COPY whisper/stt /usr/src/app/stt
+COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./
+
+ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt"
+
+HEALTHCHECK CMD ./healthcheck.sh
+
+ENTRYPOINT ["./docker-entrypoint.sh"]
\ No newline at end of file
diff --git a/whisper/Dockerfile.torch b/whisper/Dockerfile.torch
new file mode 100644
index 0000000..2f3a0d0
--- /dev/null
+++ b/whisper/Dockerfile.torch
@@ -0,0 +1,23 @@
+FROM python:3.9
+LABEL maintainer="jlouradour@linagora.com"
+
+RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg
+
+# Install python dependencies
+COPY whisper/requirements.torch.txt ./
+RUN pip install --no-cache-dir -r requirements.torch.txt && rm requirements.torch.txt
+
+WORKDIR /usr/src/app
+
+COPY celery_app /usr/src/app/celery_app
+COPY http_server /usr/src/app/http_server
+COPY websocket /usr/src/app/websocket
+COPY document /usr/src/app/document
+COPY whisper/stt /usr/src/app/stt
+COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./
+
+ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt"
+
+HEALTHCHECK CMD ./healthcheck.sh
+
+ENTRYPOINT ["./docker-entrypoint.sh"]
\ No newline at end of file
diff --git a/whisper/Dockerfile.torch.cpu b/whisper/Dockerfile.torch.cpu
new file mode 100644
index 0000000..e9198d5
--- /dev/null
+++ b/whisper/Dockerfile.torch.cpu
@@ -0,0 +1,29 @@
+FROM python:3.9
+LABEL maintainer="jlouradour@linagora.com"
+
+RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg
+
+# Force CPU versions of torch
+RUN pip3 install \
+ torch==1.13.1+cpu \
+ torchaudio==0.13.1+cpu \
+ -f https://download.pytorch.org/whl/torch_stable.html
+
+# Install python dependencies
+COPY whisper/requirements.torch.txt ./
+RUN pip install --no-cache-dir -r requirements.torch.txt && rm requirements.torch.txt
+
+WORKDIR /usr/src/app
+
+COPY celery_app /usr/src/app/celery_app
+COPY http_server /usr/src/app/http_server
+COPY websocket /usr/src/app/websocket
+COPY document /usr/src/app/document
+COPY whisper/stt /usr/src/app/stt
+COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./
+
+ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt"
+
+HEALTHCHECK CMD ./healthcheck.sh
+
+ENTRYPOINT ["./docker-entrypoint.sh"]
\ No newline at end of file
diff --git a/whisper/README.md b/whisper/README.md
new file mode 100644
index 0000000..20a3c7d
--- /dev/null
+++ b/whisper/README.md
@@ -0,0 +1,282 @@
+# LinTO-STT-Whisper
+
+LinTO-STT-Whisper is the transcription service within the [LinTO stack](https://github.com/linto-ai/linto-platform-stack)
+based on Speech-To-Text (STT) [Whisper models](https://openai.com/research/whisper).
+
+LinTO-STT-Whisper can either be used as a standalone transcription service or deployed within a micro-services infrastructure using a message broker connector.
+
+## Pre-requisites
+
+### Hardware
+To run the transcription models you'll need:
+* At least 8Go of disk space to build the docker image.
+* Up to 7GB of RAM depending on the model used.
+* One CPU per worker. Inference time scales on CPU performances.
+
+### Model(s)
+
+LinTO-STT-Whisper works with a Whisper model to perform Automatic Speech Recognition.
+If not downloaded already, the model will be downloaded when calling the first transcription,
+and can occupy several GB of disk space.
+
+#### Optional alignment model (deprecated)
+
+LinTO-STT-Whisper has also the option to work with a wav2vec model to perform word alignment.
+The wav2vec model can be specified either
+* (TorchAudio) with a string corresponding to a `torchaudio` pipeline (e.g. "WAV2VEC2_ASR_BASE_960H") or
+* (HuggingFace's Transformers) with a string corresponding to a HuggingFace repository of a wav2vec model (e.g. "jonatasgrosman/wav2vec2-large-xlsr-53-english"), or
+* (SpeechBrain) with a path corresponding to a folder with a SpeechBrain model
+
+Default wav2vec models are provided for French (fr), English (en), Spanish (es), German (de), Dutch (nl), Japanese (ja), Chinese (zh).
+
+But we advise not to use a companion wav2vec alignment model.
+This is not needed neither tested anymore.
+
+### Docker
+The transcription service requires docker up and running.
+
+### (micro-service) Service broker and shared folder
+The STT only entry point in task mode are tasks posted on a message broker. Supported message broker are RabbitMQ, Redis, Amazon SQS.
+On addition, as to prevent large audio from transiting through the message broker, STT-Worker use a shared storage folder (SHARED_FOLDER).
+
+## Deploy LinTO-STT-Whisper
+
+### 1- First step is to build or pull the image
+
+```bash
+git clone https://github.com/linto-ai/linto-stt.git
+cd linto-stt
+docker build . -f whisper/Dockerfile.ctranslate2 -t linto-stt-whisper:latest
+```
+or
+
+```bash
+docker pull lintoai/linto-stt-whisper
+```
+
+### 2- Fill the .env
+
+```bash
+cp whisper/.envdefault whisper/.env
+```
+
+| PARAMETER | DESCRIPTION | EXEMPLE |
+|---|---|---|
+| SERVICE_MODE | STT serving mode see [Serving mode](#serving-mode) | `http` \| `task` |
+| MODEL | Path to a Whisper model, type of Whisper model used, or HuggingFace identifier of a Whisper model. | \ \| `large-v3` \| `distil-whisper/distil-large-v2` \| ... |
+| LANGUAGE | (Optional) Language to recognize | `*` \| `fr` \| `fr-FR` \| `French` \| `en` \| `en-US` \| `English` \| ... |
+| PROMPT | (Optional) Prompt to use for the Whisper model | `some free text to encourage a certain transcription style (disfluencies, no punctuation, ...)` |
+| ALIGNMENT_MODEL | (Optional) Path to the wav2vec model for word alignment, or name of HuggingFace repository or torchaudio pipeline | \ \| `WAV2VEC2_ASR_BASE_960H` \| `jonatasgrosman/wav2vec2-large-xlsr-53-english` \| ... |
+| CONCURRENCY | Maximum number of parallel requests | `3` |
+| SERVICE_NAME | (For the task mode) queue's name for task processing | `my-stt` |
+| SERVICE_BROKER | (For the task mode) URL of the message broker | `redis://my-broker:6379` |
+| BROKER_PASS | (For the task mode only) broker password | `my-password` |
+
+#### MODEL environment variable
+
+**Warning:**
+The model will be (downloaded if required and) loaded in memory when calling the first transcription.
+When using a Whisper model from Hugging Face (transformers) along with ctranslate2 (faster_whisper),
+it will also download torch library to make the conversion from torch to ctranslate2.
+
+If you want to preload the model (and later specify a path `ASR_PATH` as `MODEL`),
+you may want to download one of OpenAI Whisper models:
+* Mutli-lingual Whisper models can be downloaded with the following links:
+ * [tiny](https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt)
+ * [base](https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt)
+ * [small](https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt)
+ * [medium](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt)
+ * [large-v1](https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt)
+ * [large-v2](https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt)
+ * [large-v3](https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt)
+* Whisper models specialized for English can also be found here:
+ * [tiny.en](https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt)
+ * [base.en](https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt)
+ * [small.en](https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt)
+ * [medium.en](https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt)
+
+If you already used Whisper in the past locally using [OpenAI-Whipser](https://github.com/openai/whisper), models can be found under ~/.cache/whisper.
+
+The same apply for Whisper models from Hugging Face (transformers), as for instance https://huggingface.co/distil-whisper/distil-large-v2
+(you can either download the model or use the Hugging Face identifier `distil-whisper/distil-large-v2`).
+
+#### LANGUAGE
+
+If `*` is used for the `LANGUAGE` environment variable, or if `LANGUAGE` is not defined,
+automatic language detection will be performed by Whisper.
+
+The language can be a code of two or three letters. The list of languages supported by Whisper are:
+```
+af(afrikaans), am(amharic), ar(arabic), as(assamese), az(azerbaijani),
+ba(bashkir), be(belarusian), bg(bulgarian), bn(bengali), bo(tibetan), br(breton), bs(bosnian),
+ca(catalan), cs(czech), cy(welsh), da(danish), de(german), el(greek), en(english), es(spanish),
+et(estonian), eu(basque), fa(persian), fi(finnish), fo(faroese), fr(french), gl(galician),
+gu(gujarati), ha(hausa), haw(hawaiian), he(hebrew), hi(hindi), hr(croatian), ht(haitian creole),
+hu(hungarian), hy(armenian), id(indonesian), is(icelandic), it(italian), ja(japanese),
+jw(javanese), ka(georgian), kk(kazakh), km(khmer), kn(kannada), ko(korean), la(latin),
+lb(luxembourgish), ln(lingala), lo(lao), lt(lithuanian), lv(latvian), mg(malagasy), mi(maori),
+mk(macedonian), ml(malayalam), mn(mongolian), mr(marathi), ms(malay), mt(maltese), my(myanmar),
+ne(nepali), nl(dutch), nn(nynorsk), no(norwegian), oc(occitan), pa(punjabi), pl(polish),
+ps(pashto), pt(portuguese), ro(romanian), ru(russian), sa(sanskrit), sd(sindhi), si(sinhala),
+sk(slovak), sl(slovenian), sn(shona), so(somali), sq(albanian), sr(serbian), su(sundanese),
+sv(swedish), sw(swahili), ta(tamil), te(telugu), tg(tajik), th(thai), tk(turkmen), tl(tagalog),
+tr(turkish), tt(tatar), uk(ukrainian), ur(urdu), uz(uzbek), vi(vietnamese), yi(yiddish),
+yo(yoruba), zh(chinese)
+```
+and also `yue(cantonese)` since large-v3.
+
+### Serving mode
+![Serving Modes](https://i.ibb.co/qrtv3Z6/platform-stt.png)
+
+STT can be used in two ways:
+* Through an [HTTP API](#http-server) using the **http**'s mode.
+* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode.
+
+Mode is specified using the .env value or environment variable ```SERVING_MODE```.
+```bash
+SERVICE_MODE=http
+```
+### HTTP Server
+The HTTP serving mode deploys a HTTP server and a swagger-ui to allow transcription request on a dedicated route.
+
+The SERVICE_MODE value in the .env should be set to ```http```.
+
+```bash
+docker run --rm \
+-p HOST_SERVING_PORT:80 \
+-v ASR_PATH:/opt/model.pt \
+--env-file whisper/.env \
+linto-stt-whisper:latest
+```
+
+This will run a container providing an [HTTP API](#http-api) binded on the host HOST_SERVING_PORT port.
+
+You may also want to mount your cache folder CACHE_PATH (e.g. "~/.cache") ```-v CACHE_PATH:/root/.cache```
+in order to avoid downloading models each time.
+
+Also if you want to specifiy a custom alignment model already downloaded in a folder WAV2VEC_PATH,
+you can add option ```-v WAV2VEC_PATH:/opt/wav2vec``` and environment variable ```ALIGNMENT_MODEL=/opt/wav2vec```.
+
+**Parameters:**
+| Variables | Description | Example |
+|:-|:-|:-|
+| HOST_SERVING_PORT | Host serving port | 8080 |
+| ASR_PATH | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt |
+| CACHE_PATH | (Optional) Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache |
+| WAV2VEC_PATH | (Optional) Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec |
+
+### Micro-service within LinTO-Platform stack
+The TASK serving mode connect a celery worker to a message broker.
+
+The SERVICE_MODE value in the .env should be set to ```task```.
+
+You need a message broker up and running at MY_SERVICE_BROKER.
+
+```bash
+docker run --rm \
+-v ASR_PATH:/opt/model.pt \
+-v SHARED_AUDIO_FOLDER:/opt/audio \
+--env-file whisper/.env \
+linto-stt-whisper:latest
+```
+
+You may also want to mount your cache folder CACHE_PATH (e.g. "~/.cache") ```-v CACHE_PATH:/root/.cache```
+in order to avoid downloading models each time.
+
+Also if you want to specifiy a custom alignment model already downloaded in a folder WAV2VEC_PATH,
+you can add option ```-v WAV2VEC_PATH:/opt/wav2vec``` and environment variable ```ALIGNMENT_MODEL=/opt/wav2vec```.
+
+**Parameters:**
+| Variables | Description | Example |
+|:-|:-|:-|
+| SHARED_AUDIO_FOLDER | Shared audio folder mounted to /opt/audio | /my/path/to/models/vosk-model |
+| ASR_PATH | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt |
+| CACHE_PATH | (Optional) Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache |
+| WAV2VEC_PATH | (Optional) Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec |
+
+
+## Usages
+### HTTP API
+#### /healthcheck
+Returns the state of the API
+
+Method: GET
+
+Returns "1" if healthcheck passes.
+
+#### /transcribe
+Transcription API
+
+* Method: POST
+* Response content: text/plain or application/json
+* File: An Wave file 16b 16Khz
+
+Return the transcripted text using "text/plain" or a json object when using "application/json" structure as followed:
+```json
+{
+ "text" : "This is the transcription as text",
+ "words": [
+ {
+ "word" : "This",
+ "start": 0.0,
+ "end": 0.124,
+ "conf": 0.82341
+ },
+ ...
+ ],
+ "confidence-score": 0.879
+}
+```
+
+#### /docs
+The /docs route offers a OpenAPI/swagger interface.
+
+### Through the message broker
+
+STT-Worker accepts requests with the following arguments:
+```file_path: str, with_metadata: bool```
+
+* file_path: Is the location of the file within the shared_folder. /.../SHARED_FOLDER/{file_path}
+* with_metadata: If True, words timestamps and confidence will be computed and returned. If false, the fields will be empty.
+
+#### Return format
+On a successfull transcription the returned object is a json object structured as follow:
+```json
+{
+ "text" : "This is the transcription as text",
+ "words": [
+ {
+ "word" : "This",
+ "start": 0.0,
+ "end": 0.124,
+ "conf": 0.82341
+ },
+ ...
+ ],
+ "confidence-score": 0.879
+}
+```
+
+* The text field contains the raw transcription.
+* The word field contains each word with their time stamp and individual confidence. (Empty if with_metadata=False)
+* The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False)
+
+
+## Test
+### Curl
+You can test you http API using curl:
+```bash
+curl -X POST "http://YOUR_SERVICE:YOUR_PORT/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@YOUR_FILE;type=audio/x-wav"
+```
+
+## License
+This project is developped under the AGPLv3 License (see LICENSE).
+
+## Acknowlegment.
+
+* [Faster Whisper](https://github.com/SYSTRAN/faster-whisper)
+* [OpenAI Whisper](https://github.com/openai/whisper)
+* [Ctranslate2](https://github.com/OpenNMT/CTranslate2)
+* [SpeechBrain](https://github.com/speechbrain/speechbrain)
+* [TorchAudio](https://github.com/pytorch/audio)
+* [HuggingFace Transformers](https://github.com/huggingface/transformers)
\ No newline at end of file
diff --git a/whisper/RELEASE.md b/whisper/RELEASE.md
new file mode 100644
index 0000000..2967139
--- /dev/null
+++ b/whisper/RELEASE.md
@@ -0,0 +1,3 @@
+# 1.0.0
+- First build of linto-stt-whisper
+- Based on 4.0.5 of linto-stt https://github.com/linto-ai/linto-stt/blob/a54b7b7ac2bc491a1795bb6dfb318a39c8b76d63/RELEASE.md
diff --git a/whisper/docker-entrypoint.sh b/whisper/docker-entrypoint.sh
new file mode 100755
index 0000000..97a3804
--- /dev/null
+++ b/whisper/docker-entrypoint.sh
@@ -0,0 +1,51 @@
+#!/bin/bash
+set -a
+
+echo "RUNNING STT"
+
+# Check model
+echo "Checking model format ..."
+if [ -z "$MODEL" ]
+then
+ echo "Model type not specified, choosing Whisper medium model"
+ export MODEL=medium
+fi
+
+# Launch parameters, environement variables and dependencies check
+if [ -z "$SERVICE_MODE" ]
+then
+ echo "ERROR: Must specify a serving mode: [ http | task | websocket ]"
+ exit -1
+else
+ if [ "$SERVICE_MODE" = "http" ]
+ then
+ echo "RUNNING STT HTTP SERVER"
+ python3 http_server/ingress.py --debug
+ elif [ "$SERVICE_MODE" == "task" ]
+ then
+ if [[ -z "$SERVICES_BROKER" ]]
+ then
+ echo "ERROR: SERVICES_BROKER variable not specified, cannot start celery worker."
+ exit -1
+ fi
+ nvidia-smi 2> /dev/null > /dev/null
+ if [ $? -eq 0 ];then
+ echo "GPU detected"
+ GPU=1
+ OPT="--pool=solo"
+ else
+ echo "No GPU detected"
+ GPU=0
+ OPT=""
+ fi
+ /usr/src/app/wait-for-it.sh $(echo $SERVICES_BROKER | cut -d'/' -f 3) --timeout=20 --strict -- echo " $SERVICES_BROKER (Service Broker) is up" || exit 1
+ echo "RUNNING STT CELERY WORKER"
+ celery --app=celery_app.celeryapp worker $OPT -Ofair --queues=${SERVICE_NAME} -c ${CONCURRENCY} -n ${SERVICE_NAME}_worker@%h
+
+ else
+ echo "ERROR: Wrong serving command: $SERVICE_MODE"
+ exit -1
+ fi
+fi
+
+echo "Service stopped"
\ No newline at end of file
diff --git a/whisper/requirements.ctranslate2.txt b/whisper/requirements.ctranslate2.txt
new file mode 100644
index 0000000..2ddc118
--- /dev/null
+++ b/whisper/requirements.ctranslate2.txt
@@ -0,0 +1,15 @@
+celery[redis,auth,msgpack]>=4.4.7
+flask>=1.1.2
+flask-cors>=3.0.10
+flask-sock
+flask-swagger-ui>=3.36.0
+gevent
+gunicorn
+lockfile
+pyyaml>=5.4.1
+requests>=2.26.0
+wavio>=0.0.4
+websockets
+#faster_whisper==0.10.0
+# This is version faster_whisper==0.9.0 + prompt propagation + fix for large-v3
+git+https://github.com/linto-ai/faster-whisper.git@aad9e7508b528e79be2a9975ac79ef8317f02a6d
\ No newline at end of file
diff --git a/whisper/requirements.torch.txt b/whisper/requirements.torch.txt
new file mode 100644
index 0000000..75e747c
--- /dev/null
+++ b/whisper/requirements.torch.txt
@@ -0,0 +1,19 @@
+celery[redis,auth,msgpack]>=4.4.7
+flask>=1.1.2
+flask-cors>=3.0.10
+flask-sock
+flask-swagger-ui>=3.36.0
+gevent
+gunicorn
+lockfile
+num2words
+pyyaml>=5.4.1
+requests>=2.26.0
+speechbrain
+transformers
+wavio>=0.0.4
+websockets
+# openai-whisper
+git+https://github.com/linto-ai/whisper-timestamped.git
+onnxruntime
+torchaudio
\ No newline at end of file
diff --git a/whisper/stt/__init__.py b/whisper/stt/__init__.py
new file mode 100644
index 0000000..f5551af
--- /dev/null
+++ b/whisper/stt/__init__.py
@@ -0,0 +1,38 @@
+import logging
+import os
+
+logging.basicConfig(
+ format="[%(asctime)s,%(msecs)03d %(name)s] %(levelname)s: %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+logger = logging.getLogger("__stt__")
+
+# The following is to have GPU in the right order (as nvidia-smi show them)
+# It is important to set that before loading ctranslate2
+# see https://github.com/guillaumekln/faster-whisper/issues/150
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # GPU in the right order
+
+try:
+ import faster_whisper
+
+ USE_CTRANSLATE2 = True
+except ImportError as err:
+ try:
+ import whisper
+ except:
+ raise err
+ USE_CTRANSLATE2 = False
+
+try:
+ import torch
+
+ USE_TORCH = True
+except ImportError:
+ USE_TORCH = False
+
+try:
+ import torchaudio
+
+ USE_TORCHAUDIO = True
+except ImportError:
+ USE_TORCHAUDIO = False
diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py
new file mode 100644
index 0000000..b0e7f6d
--- /dev/null
+++ b/whisper/stt/processing/__init__.py
@@ -0,0 +1,80 @@
+import logging
+import os
+
+from lockfile import FileLock
+from stt import USE_CTRANSLATE2, logger
+
+from .alignment_model import get_alignment_model, load_alignment_model
+from .decoding import decode
+from .load_model import load_whisper_model
+from .utils import get_device, get_language, load_audiofile, load_wave_buffer
+
+__all__ = [
+ "logger",
+ "decode",
+ "load_audiofile",
+ "load_wave_buffer",
+ "MODEL",
+ "USE_GPU",
+]
+
+
+class LazyLoadedModel:
+ def __init__(self, model_type, device):
+ self.model_type = model_type
+ self.device = device
+ self._model = None
+
+ def check_loaded(self):
+ if self._model is None:
+ lockfile = os.path.basename(self.model_type)
+ with FileLock(lockfile):
+ self._model = load_whisper_model(self.model_type, device=self.device)
+
+ def __getattr__(self, name):
+ self.check_loaded()
+ return getattr(self._model, name)
+
+ def __call__(self, *args, **kwargs):
+ self.check_loaded()
+ return self._model(*args, **kwargs)
+
+
+# Set informative log
+logger.setLevel(logging.INFO)
+
+# Set device
+device, USE_GPU = get_device()
+logger.info(f"Using device {device}")
+
+# Check language
+language = get_language()
+logger.info(f"Using language {language}")
+
+# Load ASR model
+model_type = os.environ.get("MODEL", "medium")
+logger.info(
+ f"Loading Whisper model {model_type} ({'local' if os.path.exists(model_type) else 'remote'})..."
+)
+try:
+ model = LazyLoadedModel(model_type, device=device)
+ # model = load_whisper_model(model_type, device=device)
+except Exception as err:
+ raise Exception("Failed to load transcription model: {}".format(str(err))) from err
+
+# Load alignment model (if any)
+alignment_model = get_alignment_model(os.environ.get("alignment_model"), language)
+if alignment_model:
+ logger.info(
+ f"Loading alignment model {alignment_model} ({'local' if os.path.exists(alignment_model) else 'remote'})..."
+ )
+ alignment_model = load_alignment_model(alignment_model, device=device, download_root="/opt")
+elif alignment_model is None:
+ logger.info("Alignment will be done using Whisper cross-attention weights")
+else:
+ logger.info(
+ "No alignment model preloaded. It will be loaded on the fly depending on the detected language."
+ )
+ alignment_model = {} # Alignement model(s) will be loaded on the fly
+
+MODEL = (model, alignment_model)
diff --git a/whisper/stt/processing/alignment_model.py b/whisper/stt/processing/alignment_model.py
new file mode 100644
index 0000000..ea958db
--- /dev/null
+++ b/whisper/stt/processing/alignment_model.py
@@ -0,0 +1,409 @@
+import math
+import os
+import time
+
+import requests
+from stt import USE_TORCH, USE_TORCHAUDIO, logger
+
+from .utils import LANGUAGES, SAMPLE_RATE
+
+if USE_TORCH:
+ import torch
+ import torch.nn.utils.rnn as rnn_utils
+
+ try:
+ import huggingface_hub
+ import speechbrain as sb
+ except ImportError:
+ pass
+ try:
+ import transformers
+ except ImportError:
+ pass
+
+if USE_TORCHAUDIO:
+ import torchaudio
+
+################################################################################
+# Load models
+
+# Sources:
+# * https://github.com/m-bain/whisperX (in whisperx/transcribe.py)
+# * https://pytorch.org/audio/stable/pipelines.html
+# * https://huggingface.co/jonatasgrosman
+
+ALIGNMENT_MODELS = {
+ "en": "WAV2VEC2_ASR_BASE_960H",
+ # "en": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
+ "fr": "VOXPOPULI_ASR_BASE_10K_FR",
+ # "fr": "jonatasgrosman/wav2vec2-large-xlsr-53-french",
+ "de": "VOXPOPULI_ASR_BASE_10K_DE",
+ # "de": "jonatasgrosman/wav2vec2-large-xlsr-53-german",
+ "es": "VOXPOPULI_ASR_BASE_10K_ES",
+ # "it": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish",
+ "it": "VOXPOPULI_ASR_BASE_10K_IT",
+ # "it": "jonatasgrosman/wav2vec2-large-xlsr-53-italian",
+ "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
+ "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
+ "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
+ "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
+ "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
+ "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
+ "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
+ "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
+ "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
+ "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
+ "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
+ "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
+ "vi": "nguyenvulebinh/wav2vec2-base-vietnamese-250h",
+}
+
+
+def get_alignment_model(alignment_model_name, language, force=False):
+ if alignment_model_name in ["wav2vec", "wav2vec2"]:
+ if language is None:
+ # Will load alignment model on the fly depending
+ # on detected language
+ return {}
+ elif language in ALIGNMENT_MODELS:
+ return ALIGNMENT_MODELS[language]
+ elif force:
+ raise ValueError(f"No wav2vec alignment model for language '{language}'.")
+ else:
+ logger.warn(
+ f"No wav2vec alignment model for language '{language}'. Fallback to English."
+ )
+ return ALIGNMENT_MODELS["en"]
+ elif alignment_model_name in LANGUAGES.keys():
+ return get_alignment_model("wav2vec", alignment_model_name, force=True)
+ return alignment_model_name
+
+
+def load_alignment_model(source, device="cpu", download_root="/opt"):
+ if not USE_TORCH:
+ raise NotImplementedError("Alignement model not available without Torch")
+
+ start = time.time()
+
+ if (source in torchaudio.pipelines.__all__) if USE_TORCHAUDIO else False:
+ model = load_torchaudio_model(source, device=device, download_root=download_root)
+ else:
+ try:
+ model = load_transformers_model(source, device=device, download_root=download_root)
+ except Exception as err1:
+ try:
+ model = load_speechbrain_model(source, device=device, download_root=download_root)
+ except Exception as err2:
+ raise Exception(
+ f"Failed to load alignment model:\n<<< transformers <<<\n{str(err1)}\n<<< speechbrain <<<\n{str(err2)}"
+ ) from err2
+
+ logger.info(
+ f"Alignment Model of type {get_model_type(model)} loaded. (t={time.time() - start}s)"
+ )
+
+ return model
+
+
+def load_speechbrain_model(source, device="cpu", download_root="/opt"):
+ if os.path.isdir(source):
+ yaml_file = os.path.join(source, "hyperparams.yaml")
+ assert os.path.isfile(yaml_file), f"Hyperparams file {yaml_file} not found"
+ else:
+ try:
+ yaml_file = huggingface_hub.hf_hub_download(
+ repo_id=source,
+ filename="hyperparams.yaml",
+ cache_dir=os.path.join(download_root, "huggingface/hub"),
+ )
+ except requests.exceptions.HTTPError:
+ yaml_file = None
+ overrides = make_yaml_overrides(
+ yaml_file, {"save_path": os.path.join(download_root, "speechbrain")}
+ )
+
+ savedir = os.path.join(download_root, "speechbrain")
+ try:
+ model = sb.pretrained.EncoderASR.from_hparams(
+ source=source, run_opts={"device": device}, savedir=savedir, overrides=overrides
+ )
+ except ValueError:
+ model = sb.pretrained.EncoderDecoderASR.from_hparams(
+ source=source, run_opts={"device": device}, savedir=savedir, overrides=overrides
+ )
+
+ model.train(False)
+ model.requires_grad_(False)
+ return model
+
+
+def load_transformers_model(source, device="cpu", download_root="/opt"):
+ model = transformers.Wav2Vec2ForCTC.from_pretrained(source).to(device)
+ processor = transformers.Wav2Vec2Processor.from_pretrained(source)
+
+ model.eval()
+ model.requires_grad_(False)
+ return model, processor
+
+
+def load_torchaudio_model(source, device="cpu", download_root="/opt"):
+ bundle = torchaudio.pipelines.__dict__[source]
+ model = bundle.get_model().to(device)
+ labels = bundle.get_labels()
+
+ model.eval()
+ model.requires_grad_(False)
+ return model, labels
+
+
+def get_model_type(model):
+ if not isinstance(model, tuple):
+ return "speechbrain"
+ assert len(model) == 2, "Invalid model type"
+ if isinstance(model[0], transformers.Wav2Vec2ForCTC):
+ return "transformers"
+ return "torchaudio"
+
+
+def make_yaml_overrides(yaml_file, key_values):
+ """
+ return a dictionary of overrides to be used with speechbrain (hyperyaml files)
+ yaml_file: path to yaml file
+ key_values: dict of key values to override
+ """
+ if yaml_file is None:
+ return None
+
+ override = {}
+ with open(yaml_file, "r") as f:
+ parent = None
+ for line in f:
+ if line.strip() == "":
+ parent = None
+ elif line == line.lstrip():
+ if ":" in line:
+ parent = line.split(":")[0].strip()
+ if parent in key_values:
+ override[parent] = key_values[parent]
+ elif ":" in line:
+ child = line.strip().split(":")[0].strip()
+ if child in key_values:
+ override[parent] = override.get(parent, {}) | {child: key_values[child]}
+ return override
+
+
+################################################################################
+# Get list of labels (and blank_id) from model
+
+
+def get_vocab(model):
+ type = get_model_type(model)
+ if type == "speechbrain":
+ labels, blank_id = get_vocab_speechbrain(model)
+ elif type == "transformers":
+ labels, blank_id = get_vocab_transformers(model)
+ else:
+ labels, blank_id = get_vocab_torchaudio(model)
+ assert isinstance(labels, list) and min(
+ [isinstance(l, str) for l in labels]
+ ), "labels must be a list of strings"
+ return norm_labels(labels, blank_id), blank_id
+
+
+def get_vocab_speechbrain(model):
+ tokenizer = model.tokenizer
+ # Is this general enough?
+ labels = [
+ {"": " ", " ⁇ ": ""}.get(i, i)
+ for i in tokenizer.decode([[i] for i in range(tokenizer.get_piece_size())])
+ ]
+ blank_id = labels.index("")
+ return labels, blank_id
+
+
+def get_vocab_torchaudio(model_and_labels):
+ _, labels = model_and_labels
+ labels = list(labels)
+ # WTF : blank_id = labels.index("-") ...? Is it general enough?
+ blank_id = 0
+ return labels, blank_id
+
+
+def get_vocab_transformers(model_and_processor):
+ _, processor = model_and_processor
+ labels_dict = dict((v, k) for k, v in processor.tokenizer.get_vocab().items())
+ labels = [labels_dict[i] for i in range(len(labels_dict))]
+ blank_id = labels.index("")
+ return labels, blank_id
+
+
+def norm_labels(labels, blank_id):
+ labels[blank_id] = ""
+ return [l if l != "|" else " " for l in labels]
+
+
+################################################################################
+# Compute log-probabilities from model
+
+
+# The following limit is to handle the corner Case of too long audio segment (which is better to split it to avoid memory overflow).
+# But it is 2240400 / 16000 Hz ~ 140 seconds, which should not happen for segments detected by Whisper (usually one sentence).
+# Also note that Whisper works with 30 seconds segment, so there is chance that this limit is never reached.
+MAX_LEN = 2240400
+
+
+def compute_logprobas(model, audios, max_len=MAX_LEN):
+ # Single audio
+ if not isinstance(audios, list):
+ audios = [audios]
+ logits = compute_logprobas(model, audios, max_len=max_len)
+ return logits[0]
+
+ # Batch of audios (can occur when max_len is reached)
+ assert len(audios) > 0, "audios must be a non-empty list"
+
+ type = get_model_type(model)
+ if type == "speechbrain":
+ logits = compute_logits_speechbrain(model, audios, max_len)
+ elif type == "transformers":
+ logits = compute_logits_transformers(model, audios, max_len)
+ else:
+ logits = compute_logits_torchaudio(model, audios, max_len)
+
+ return torch.log_softmax(logits, dim=-1)
+
+
+def compute_logits_speechbrain(model, audios, max_len):
+ if not isinstance(audios[0], torch.Tensor):
+ audios = [torch.from_numpy(a) for a in audios]
+ if max([len(a) for a in audios]) > max_len:
+ # Split audios into chunks of max_len
+ batch_size = len(audios)
+ chunks = []
+ i_audio = []
+ for a in audios:
+ chunks.extend([a[i : min(i + max_len, len(a))] for i in range(0, len(a), max_len)])
+ i_audio.append(len(chunks))
+ if len(chunks) > 1:
+ logger.warning(
+ "Audio too long, splitting into {} chunks for alignment".format(len(chunks))
+ )
+ # Decode chunks of audio and concatenate results
+ log_probas = [[] for i in range(len(audios))]
+ for i in range(0, len(chunks), batch_size):
+ chunk = chunks[i : min(i + batch_size, len(chunks))]
+ log_probas_tmp = compute_logits_speechbrain(model, chunk)
+ for j in range(i, i + len(chunk)):
+ k = 0
+ while j >= i_audio[k]:
+ k += 1
+ log_probas[k].append(log_probas_tmp[j - i])
+ log_probas = [torch.cat(p, dim=0) for p in log_probas]
+ log_probas, wav_lens = pack_sequences(log_probas, device=model.device)
+ else:
+ batch, wav_lens = pack_sequences(audios, device=model.device)
+ log_probas = model.forward(batch, wav_lens)
+
+ return log_probas.cpu().detach()
+
+
+def pack_sequences(tensors, device="cpu"):
+ if len(tensors) == 1:
+ return tensors[0].unsqueeze(0).to(device), torch.Tensor([1.0]).to(device)
+ tensor = rnn_utils.pad_sequence(tensors, batch_first=True)
+ wav_lens = [len(x) for x in tensors]
+ maxwav_lens = max(wav_lens)
+ wav_lens = torch.Tensor([l / maxwav_lens for l in wav_lens])
+ return tensor.to(device), wav_lens.to(device)
+
+
+def compute_logits_transformers(model_and_processor, audios, max_len):
+ model, processor = model_and_processor
+
+ # can be different from processor.feature_extractor.sampling_rate
+ sample_rate = SAMPLE_RATE
+ device = model.device
+
+ audios = [audio.numpy() for audio in audios]
+ processed_batch = processor(audios, sampling_rate=sample_rate)
+
+ padded_batch = processor.pad(
+ processed_batch,
+ padding=True,
+ max_length=None,
+ pad_to_multiple_of=None,
+ return_tensors="pt",
+ )
+
+ l = padded_batch.input_values.shape[1]
+
+ use_mask = hasattr(padded_batch, "attention_mask")
+
+ with torch.inference_mode():
+ if l > max_len:
+ # Split batch in smaller chunks
+ logger.warning(
+ "Audio too long, splitting into {} chunks for alignment".format(
+ math.ceil(l / max_len)
+ )
+ )
+ logits = []
+ for i in range(0, l, max_len):
+ j = min(i + max_len, l)
+ if use_mask:
+ logits.append(
+ model(
+ padded_batch.input_values[:, i:j].to(device),
+ attention_mask=padded_batch.attention_mask[:, i:j].to(device),
+ ).logits
+ )
+ else:
+ logits.append(model(padded_batch.input_values[:, i:j].to(device)).logits)
+ logits = torch.cat(logits, dim=1)
+ elif use_mask:
+ logits = model(
+ padded_batch.input_values.to(device),
+ attention_mask=padded_batch.attention_mask.to(device),
+ ).logits
+ else:
+ logits = model(padded_batch.input_values.to(device)).logits
+
+ return logits.cpu().detach()
+
+
+def compute_logits_torchaudio(model_and_labels, audios, max_len):
+ # TODO: factorize with compute_logits_transformers, and add support for batch of audios
+
+ model, _ = model_and_labels
+
+ # Get the device where is running the model
+ device = "cpu"
+ for p in model.parameters():
+ device = p.device
+ break
+
+ all_logits = []
+
+ with torch.inference_mode():
+ for audio in audios:
+ l = len(audio)
+ if l > max_len:
+ # Split audio in smaller chunks
+ logger.warning(
+ "Audio too long, splitting into {} chunks for alignment".format(
+ math.ceil(l / max_len)
+ )
+ )
+ logits = []
+ for i in range(0, l, max_len):
+ j = min(i + max_len, l)
+ logits.append(model(audio[i:j].unsqueeze(0).to(device))[0])
+ logits = torch.cat(logits, dim=1)
+ else:
+ logits, _ = model(audio.unsqueeze(0).to(device))
+
+ all_logits.append(logits.cpu().detach())
+
+ assert len(all_logits) == 1 # TODO: support batch of audios
+
+ return all_logits[0]
diff --git a/whisper/stt/processing/decoding.py b/whisper/stt/processing/decoding.py
new file mode 100644
index 0000000..9f8411f
--- /dev/null
+++ b/whisper/stt/processing/decoding.py
@@ -0,0 +1,366 @@
+import copy
+import os
+import time
+from typing import Tuple, Union
+
+import numpy as np
+from stt import USE_CTRANSLATE2, logger
+
+from .alignment_model import get_alignment_model, load_alignment_model
+from .text_normalize import normalize_text, remove_emoji, remove_punctuation
+from .utils import SAMPLE_RATE, get_language
+from .word_alignment import compute_alignment
+
+if not USE_CTRANSLATE2:
+ import torch
+ import whisper_timestamped
+
+USE_ACCURATE = True
+USE_VAD = True
+
+if USE_ACCURATE:
+ default_beam_size = 5
+ default_best_of = 5
+ default_temperature = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
+else:
+ default_beam_size = None
+ default_best_of = None
+ default_temperature = 0.0
+
+default_initial_prompt = os.environ.get("PROMPT", None)
+
+
+def decode(
+ audio,
+ model_and_alignementmodel, # Tuple[model, alignment_model]
+ with_word_timestamps: bool,
+ language: str = None,
+ remove_punctuation_from_words=False,
+ beam_size: int = default_beam_size,
+ best_of: int = default_best_of,
+ temperature: Union[float, Tuple[float, ...]] = default_temperature,
+ condition_on_previous_text: bool = False,
+ no_speech_threshold: float = 0.6,
+ compression_ratio_threshold: float = 2.4,
+ initial_prompt: str = default_initial_prompt,
+) -> dict:
+ if language is None:
+ language = get_language()
+
+ kwargs = copy.copy(locals())
+ kwargs.pop("model_and_alignementmodel")
+ kwargs["model"], kwargs["alignment_model"] = model_and_alignementmodel
+
+ logger.info(
+ "Transcribing audio with "
+ + (f"language {language}" if language else "automatic language detection")
+ + "..."
+ )
+
+ start_t = time.time()
+
+ if USE_CTRANSLATE2:
+ kwargs.pop("alignment_model")
+ res = decode_ct2(**kwargs)
+ else:
+ print("OK")
+ res = decode_torch(**kwargs)
+
+ logger.info("Transcription complete (t={}s)".format(time.time() - start_t))
+
+ return res
+
+
+def decode_ct2(
+ audio, model, with_word_timestamps, language, remove_punctuation_from_words, **kwargs
+):
+ kwargs["no_speech_threshold"] = 1 # To avoid empty output
+ if kwargs.get("beam_size") is None:
+ kwargs["beam_size"] = 1
+ if kwargs.get("best_of") is None:
+ kwargs["best_of"] = 1
+
+ segments, info = model.transcribe(
+ audio,
+ word_timestamps=with_word_timestamps,
+ language=language,
+ # Careful with the following options
+ max_initial_timestamp=10000.0,
+ vad_filter=USE_VAD,
+ **kwargs,
+ )
+
+ segments = list(segments)
+
+ return format_faster_whisper_response(
+ segments, info, remove_punctuation_from_words=remove_punctuation_from_words
+ )
+
+
+def decode_torch(
+ audio,
+ model,
+ alignment_model,
+ with_word_timestamps,
+ language,
+ remove_punctuation_from_words,
+ beam_size,
+ best_of,
+ temperature,
+ condition_on_previous_text,
+ no_speech_threshold,
+ compression_ratio_threshold,
+ normalize_text_as_words=False,
+ initial_prompt=None,
+):
+ """Transcribe the audio data using Whisper with the defined model."""
+
+ fp16 = model.device != torch.device("cpu")
+
+ kwargs = dict(
+ language=language,
+ fp16=fp16,
+ temperature=temperature,
+ beam_size=beam_size,
+ best_of=best_of,
+ condition_on_previous_text=condition_on_previous_text,
+ no_speech_threshold=no_speech_threshold,
+ compression_ratio_threshold=compression_ratio_threshold,
+ vad=USE_VAD,
+ initial_prompt=initial_prompt,
+ )
+
+ if alignment_model is None:
+ # Use Whisper cross-attention weights
+ whisper_res = whisper_timestamped.transcribe(model, audio, verbose=None, **kwargs)
+ if language is None:
+ language = whisper_res["language"]
+ logger.info(f"Detected language: {language}")
+ return format_whisper_timestamped_response(
+ whisper_res, remove_punctuation_from_words=remove_punctuation_from_words
+ )
+
+ # Force deterministic results
+ torch.manual_seed(1234)
+ torch.cuda.manual_seed_all(1234)
+
+ whisper_res = model.transcribe(audio, verbose=None, **kwargs)
+
+ text = whisper_res["text"]
+ text = remove_emoji(text).strip()
+ if normalize_text_as_words:
+ text = normalize_text(text, language)
+ if remove_punctuation_from_words:
+ text = remove_punctuation(text)
+ segments = whisper_res["segments"]
+ if language is None:
+ language = whisper_res["language"]
+ logger.info(f"Detected language: {language}")
+ if isinstance(alignment_model, dict):
+ # Load alignment model on the fly
+ if language not in alignment_model:
+ alignment_model_name = get_alignment_model(language)
+ logger.info(
+ f"Loading alignment model {alignment_model_name} ({'local' if os.path.exists(alignment_model_name) else 'remote'})..."
+ )
+ alignment_model[language] = load_alignment_model(
+ alignment_model_name, device=model.device, download_root="/opt"
+ )
+ spec_alignment_model = alignment_model[language]
+ else:
+ spec_alignment_model = alignment_model
+
+ result = {}
+ result["text"] = text
+ result["language"] = language
+ result["confidence-score"] = (
+ np.exp(np.array([r["avg_logprob"] for r in segments])).mean() if len(segments) else 0.0
+ )
+
+ if not with_word_timestamps:
+ if not normalize_text_as_words:
+ text = normalize_text(text, language)
+ if remove_punctuation_from_words:
+ text = remove_punctuation(text)
+ result["words"] = text.split()
+ else:
+ # Compute word timestamps
+ result["words"] = []
+ max_t = audio.shape[0]
+
+ # Ensure that the segments start / end time are increasing
+ # (because there is no guarantee with Whisper)
+ previous_start = 0.0
+ for segment in segments:
+ if segment["start"] < previous_start:
+ segment["start"] = previous_start
+ if segment["end"] <= segment["start"]:
+ segment["end"] = segment["start"] + 1.0
+ previous_start = segment["end"]
+
+ for segment in segments:
+ offset = segment["start"]
+ start = min(max_t, round(segment["start"] * SAMPLE_RATE))
+ end = min(max_t, round(segment["end"] * SAMPLE_RATE))
+ sub_audio = audio[start:end]
+ sub_text = segment["text"]
+ logger.debug(f"Aligning text: {sub_text}")
+ sub_text = remove_emoji(sub_text).strip()
+ sub_text = normalize_text(sub_text, language)
+ if remove_punctuation_from_words:
+ sub_text = remove_punctuation(sub_text)
+ if not sub_text:
+ logger.warn(f"Lost text in segment {segment['start']}-{segment['end']}")
+ continue
+ labels, emission, trellis, segments, word_segments = compute_alignment(
+ sub_audio, sub_text, spec_alignment_model
+ )
+ ratio = len(sub_audio) / (trellis.size(0) * SAMPLE_RATE)
+ sub_words = sub_text.split()
+ words = []
+ use_original_words = True
+ if len(sub_words) != len(word_segments):
+ logger.warn(
+ f"Alignment failed. Some words might be mis-rendered.\nNumber of words: {len(sub_words)} != {len(word_segments)}\n>>>\n{sub_words}\n<<<\n{[segment.label for segment in word_segments]}"
+ )
+ assert len(word_segments) < len(sub_words)
+ use_original_words = False
+ for word, seg in zip(sub_words, word_segments):
+ words.append(
+ {
+ "word": word if use_original_words else seg.label,
+ "start": seg.start * ratio + offset,
+ "end": seg.end * ratio + offset,
+ "conf": seg.score,
+ }
+ )
+ # Glue the words inside a segment
+ for i, word in enumerate(words):
+ if i == 0:
+ word["start"] = segment["start"]
+ else:
+ word["start"] = words[i - 1]["end"]
+ if i == len(words) - 1:
+ word["end"] = segment["end"]
+ else:
+ word["end"] = 0.5 * (words[i + 1]["start"] + word["end"])
+ # Accumulate results
+ result["words"] += words
+
+ return result
+
+
+def format_whisper_timestamped_response(transcription, remove_punctuation_from_words=False):
+ """Format Whisper response."""
+
+ for i, seg in enumerate(transcription["segments"][:-1]):
+ for expected_keys in ["start", "end", "words", "avg_logprob"]:
+ assert (
+ expected_keys in seg
+ ), f"Missing '{expected_keys}' in segment {i} (that has keys {list(seg.keys())})"
+
+ words = []
+
+ segments = transcription.get("segments", [])
+
+ for seg in segments:
+ for word in seg.get("words", []):
+ text = word["text"]
+ if remove_punctuation_from_words:
+ text = remove_punctuation(text)
+ words.append(
+ {
+ "word": text,
+ "start": word["start"],
+ "end": word["end"],
+ "conf": word["confidence"],
+ }
+ )
+
+ return {
+ "text": transcription["text"].strip(),
+ "language": transcription["language"],
+ "confidence-score": round(np.exp(np.array([r["avg_logprob"] for r in segments])).mean(), 2)
+ if len(segments)
+ else 0.0,
+ "words": words,
+ }
+
+
+def format_faster_whisper_response(
+ segments,
+ info,
+ remove_punctuation_from_words=False,
+ glue_punctuations="'-&@.,",
+):
+ language = info.language
+ duration = info.duration
+
+ def checked_timestamps(start, end=None):
+ if start > duration or (end is not None and end > duration):
+ print(
+ "WARNING, timestamp %f is greater than duration %f"
+ % (max(start, end if end else start), duration)
+ )
+ if end and end <= start:
+ if end == start:
+ pass # end = start + 0.01
+ else:
+ print("WARNING, end timestamp %f is smaller than start timestamp %f" % (end, start))
+ if end is None:
+ return start
+ return (start, end)
+
+ segments_list = []
+ for segment in segments:
+ start, end = checked_timestamps(segment.start, segment.end)
+
+ words = []
+ if segment.words:
+ for word in segment.words:
+ start, end = checked_timestamps(word.start, word.end)
+ word_strip = word.word.strip()
+ if (
+ glue_punctuations
+ and len(words)
+ and len(word_strip) > 1
+ and word_strip[0] in glue_punctuations
+ ):
+ words[-1]["text"] += word.word.lstrip()
+ words[-1]["confidence"].append(word.probability)
+ words[-1]["end"] = max(words[-1]["end"], end)
+ continue
+ words.append(
+ {
+ "text": word.word,
+ "confidence": [word.probability],
+ "start": start,
+ "end": end,
+ }
+ )
+
+ for word in words:
+ word["text"] = word["text"].strip()
+ word["confidence"] = round(np.mean([c for c in word["confidence"]]), 2)
+
+ segments_list.append(
+ {
+ "text": segment.text.strip(),
+ "start": start,
+ "end": end,
+ "avg_logprob": segment.avg_logprob,
+ "words": words,
+ }
+ )
+
+ transcription = {
+ "text": " ".join(segment["text"] for segment in segments_list),
+ "language": language,
+ "confidence": round(
+ np.exp(np.mean([segment["avg_logprob"] for segment in segments_list])), 2
+ ),
+ "segments": segments_list,
+ }
+ return format_whisper_timestamped_response(
+ transcription, remove_punctuation_from_words=remove_punctuation_from_words
+ )
diff --git a/whisper/stt/processing/load_model.py b/whisper/stt/processing/load_model.py
new file mode 100644
index 0000000..b87a414
--- /dev/null
+++ b/whisper/stt/processing/load_model.py
@@ -0,0 +1,369 @@
+import os
+import shutil
+import subprocess
+import sys
+import time
+
+from stt import USE_CTRANSLATE2, logger
+
+if USE_CTRANSLATE2:
+ import faster_whisper
+else:
+ import whisper_timestamped as whisper
+
+
+def load_whisper_model(model_type_or_file, device="cpu", download_root=None):
+ start = time.time()
+
+ logger.info("Loading Whisper model {}...".format(model_type_or_file))
+
+ default_cache_root = os.path.join(os.path.expanduser("~"), ".cache")
+ if download_root is None:
+ download_root = default_cache_root
+
+ if USE_CTRANSLATE2:
+ if not os.path.isdir(model_type_or_file):
+ # Note: There is no good way to set the root cache directory
+ # with the current version of faster_whisper:
+ # if "download_root" is specified to faster_whisper.WhisperModel
+ # (or "output_dir" in faster_whisper.utils.download_model),
+ # then files are downloaded directly in it without symbolic links
+ # to the cache directory. So it's different from the behavior
+ # of the huggingface_hub.
+ # So we try to create a symbolic link to the cache directory that will be used by HuggingFace...
+ if not os.path.exists(download_root):
+ if not os.path.exists(default_cache_root):
+ os.makedirs(download_root)
+ if default_cache_root != download_root:
+ os.symlink(download_root, default_cache_root)
+ else:
+ os.symlink(default_cache_root, download_root)
+ elif not os.path.exists(default_cache_root):
+ os.symlink(download_root, default_cache_root)
+
+ if device == "cpu":
+ compute_types = ["int8", "float32"]
+ else:
+ compute_types = ["int8", "int8_float16", "float16", "float32"]
+
+ device_index = 0
+ if device.startswith("cuda:"):
+ device_index = [int(dev) for dev in device[5:].split(",")]
+ device = "cuda"
+
+ if not os.path.isfile(os.path.join(model_type_or_file, "model.bin")) and not max(
+ [
+ model_type_or_file.startswith(prefix)
+ for prefix in ["tiny", "base", "small", "medium", "large"]
+ ]
+ ):
+ # Convert transformer model
+
+ output_dir = os.path.join(
+ download_root,
+ f"ctranslate2/converters/transformers--{model_type_or_file.replace('/', '--')}",
+ )
+ logger.info(f"CTranslate2 model in {output_dir}")
+ if not os.path.isdir(output_dir):
+ import huggingface_hub
+
+ delete_hf_path = False
+ if not os.path.isdir(model_type_or_file):
+ hf_path = huggingface_hub.hf_hub_download(
+ repo_id=model_type_or_file, filename="pytorch_model.bin"
+ )
+ hf_path = os.path.dirname(os.path.dirname(os.path.dirname(hf_path)))
+
+ delete_hf_path = not os.path.exists(hf_path)
+ else:
+ assert os.path.isfile(
+ os.path.join(model_type_or_file, "pytorch_model.bin")
+ ), f"Could not find pytorch_model.bin in {model_type_or_file}"
+
+ check_torch_installed()
+
+ # from ctranslate2.converters.transformers import TransformersConverter
+ # converter = TransformersConverter(
+ # model_type_or_file,
+ # activation_scales=None, # Path to the pre-computed activation scales, see https://github.com/mit-han-lab/smoothquant
+ # copy_files=[], # Note: "tokenizer.json" does not always exist, we will copy it separately
+ # load_as_float16=False,
+ # revision=None,
+ # low_cpu_mem_usage=False,
+ # trust_remote_code=False,
+ # )
+
+ try:
+ # converter.convert(
+ # output_dir,
+ # force=False
+ # )
+
+ subprocess.check_call(
+ [
+ "ct2-transformers-converter",
+ "--model",
+ model_type_or_file,
+ "--output_dir",
+ os.path.realpath(output_dir),
+ "--quantization",
+ "float16",
+ ]
+ )
+ except Exception as err:
+ shutil.rmtree(output_dir, ignore_errors=True)
+ raise err
+
+ finally:
+ if delete_hf_path:
+ logger.info(f"Deleting {hf_path}")
+ shutil.rmtree(hf_path, ignore_errors=True)
+
+ assert os.path.isdir(output_dir), f"Failed to build {output_dir}"
+
+ model_type_or_file = output_dir
+
+ model = None
+ for i, compute_type in enumerate(compute_types):
+ try:
+ model = faster_whisper.WhisperModel(
+ model_type_or_file,
+ device=device,
+ device_index=device_index,
+ compute_type=compute_type,
+ # cpu_threads=0, # Can be controled with OMP_NUM_THREADS
+ # num_workers=1,
+ # download_root=os.path.join(download_root, f"huggingface/hub/models--guillaumekln--faster-whisper-{model_type_or_file}"),
+ )
+ break
+ except ValueError as err:
+ logger.info(
+ "WARNING: failed to load model with compute_type={}".format(compute_type)
+ )
+ # On some old GPU we may have the error
+ # "ValueError: Requested int8_float16 compute type,
+ # but the target device or backend do not support efficient int8_float16 computation."
+ if i == len(compute_types) - 1:
+ raise err
+
+ else:
+ extension = (
+ os.path.splitext(model_type_or_file)[-1] if os.path.isfile(model_type_or_file) else None
+ )
+
+ if model_type_or_file in whisper.available_models() or extension == ".pt":
+ model = whisper.load_model(
+ model_type_or_file,
+ device=device,
+ download_root=os.path.join(download_root, "whisper"),
+ )
+
+ else:
+ # Convert HuggingFace model
+ import torch
+
+ peft_folder = None
+
+ if extension in [".ckpt", ".bin"]:
+ model_path = model_type_or_file
+ else:
+ # Search for the cached file (download if necessary)
+ if os.path.isdir(model_type_or_file):
+ for root, _, files in os.walk(model_type_or_file):
+ if "adapter_config.json" in files:
+ peft_folder = root
+ break
+ try:
+ import transformers
+ except ImportError:
+ raise ImportError(
+ f"If you are trying to download a HuggingFace model with {model_type_or_file}, please install first the transformers library"
+ )
+ from transformers.utils import cached_file
+
+ try:
+ model_path = cached_file(
+ model_type_or_file,
+ "pytorch_model.bin",
+ cache_dir=download_root,
+ use_auth_token=None,
+ revision=None,
+ )
+ except Exception as e:
+ try:
+ if isinstance(e, OSError):
+ model_path = cached_file(
+ model_type_or_file,
+ "whisper.ckpt",
+ cache_dir=download_root,
+ use_auth_token=None,
+ revision=None,
+ )
+ else:
+ raise e
+ except:
+ if peft_folder is None:
+ raise RuntimeError(
+ f"Original error: {e}\nCould not find model {model_type_or_file} from HuggingFace nor local folders."
+ )
+
+ # Load HF Model
+ if peft_folder is not None:
+ import transformers
+ from peft import PeftConfig, PeftModel
+
+ peft_config = PeftConfig.from_pretrained(peft_folder)
+ base_model = peft_config.base_model_name_or_path
+
+ model = transformers.WhisperForConditionalGeneration.from_pretrained(base_model)
+ model = PeftModel.from_pretrained(model, peft_folder)
+ hf_state_dict = model.state_dict()
+ del model
+ else:
+ hf_state_dict = torch.load(model_path, map_location="cpu")
+
+ # Rename layers
+ for key in list(hf_state_dict.keys()):
+ new_key = hf_to_whisper_states(key)
+ if new_key is None:
+ hf_state_dict.pop(key)
+ elif new_key != key:
+ hf_state_dict[new_key] = hf_state_dict.pop(key)
+
+ # Init Whisper Model and replace model weights
+ dims = whisper.model.ModelDimensions(**states_to_dim(hf_state_dict))
+ if "proj_out.weight" in hf_state_dict:
+ hf_state_dict["decoder.proj_out.weight"] = hf_state_dict.pop("proj_out.weight")
+ print("WARNING: Using untied projection layer")
+ whisper_model = WhisperUntied(dims)
+ else:
+ whisper_model = whisper.model.Whisper(dims)
+ whisper_model.load_state_dict(hf_state_dict)
+ del hf_state_dict
+ whisper_model = whisper_model.to(device)
+ return whisper_model
+
+ model.eval()
+ model.requires_grad_(False)
+
+ logger.info("Whisper model loaded. (t={}s)".format(time.time() - start))
+
+ return model
+
+
+def check_torch_installed():
+ try:
+ import torch
+ except ImportError:
+ # Install transformers with torch
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers[torch]>=4.23"])
+
+ # # Re-load ctranslate2
+ # import importlib
+ # import ctranslate2
+ # importlib.reload(ctranslate2)
+ # importlib.reload(ctranslate2.converters.transformers)
+
+ # import torch
+
+
+# Credit: https://github.com/openai/whisper/discussions/830
+def hf_to_whisper_states(text):
+ import re
+
+ # From Speechbrain
+ if text == "_mel_filters":
+ return None
+
+ # From PEFT
+ if "default" in text:
+ # print(f"WARNING: Ignoring {text}")
+ return None
+ if text.startswith("base_model.model."):
+ text = text[len("base_model.model.") :]
+
+ text = re.sub(".layers.", ".blocks.", text)
+ text = re.sub(".self_attn.", ".attn.", text)
+ text = re.sub(".q_proj.", ".query.", text)
+ text = re.sub(".k_proj.", ".key.", text)
+ text = re.sub(".v_proj.", ".value.", text)
+ text = re.sub(".out_proj.", ".out.", text)
+ text = re.sub(".fc1.", ".mlp.0.", text)
+ text = re.sub(".fc2.", ".mlp.2.", text)
+ text = re.sub(".fc3.", ".mlp.3.", text)
+ text = re.sub(".fc3.", ".mlp.3.", text)
+ text = re.sub(".encoder_attn.", ".cross_attn.", text)
+ text = re.sub(".cross_attn.ln.", ".cross_attn_ln.", text)
+ text = re.sub(".embed_positions.weight", ".positional_embedding", text)
+ text = re.sub(".embed_tokens.", ".token_embedding.", text)
+ text = re.sub("model.", "", text)
+ text = re.sub("attn.layer_norm.", "attn_ln.", text)
+ text = re.sub(".final_layer_norm.", ".mlp_ln.", text)
+ text = re.sub("encoder.layer_norm.", "encoder.ln_post.", text)
+ text = re.sub("decoder.layer_norm.", "decoder.ln.", text)
+ return text
+
+
+def states_to_dim(state_dict):
+ n_audio_state = len(state_dict["encoder.ln_post.bias"])
+ n_text_state = len(state_dict["decoder.ln.bias"])
+ return {
+ "n_mels": state_dict["encoder.conv1.weight"].shape[1], # 80
+ "n_vocab": state_dict["decoder.token_embedding.weight"].shape[0], # 51864 / 51865
+ "n_audio_ctx": state_dict["encoder.positional_embedding"].shape[0], # 1500
+ "n_audio_state": n_audio_state, # 384 / 512 / 768 / 1024 / 1280
+ "n_audio_head": n_audio_state // 64, # 6 / 8 / 12 / 16 / 20
+ "n_audio_layer": len(
+ set([".".join(k.split(".")[:3]) for k in state_dict.keys() if "encoder.blocks." in k])
+ ), # 4 / 6 / 12 / 24 / 32
+ "n_text_ctx": state_dict["decoder.positional_embedding"].shape[0], # 448
+ "n_text_state": n_text_state, # 384 / 512 / 768 / 1024 / 1280
+ "n_text_head": n_text_state // 64, # 6 / 8 / 12 / 16 / 20
+ "n_text_layer": len(
+ set([".".join(k.split(".")[:3]) for k in state_dict.keys() if "decoder.blocks." in k])
+ ), # 4 / 6 / 12 / 24 / 32
+ }
+
+
+if not USE_CTRANSLATE2:
+
+ class TextDecoderUntied(whisper.model.TextDecoder):
+ """
+ Same as TextDecoder but with untied weights
+ """
+
+ def __init__(self, *args, **kwargs):
+ import torch
+
+ super().__init__(*args, **kwargs)
+
+ n_vocab, n_state = self.token_embedding.weight.shape
+
+ self.proj_out = torch.nn.Linear(n_state, n_vocab, bias=False)
+
+ def forward(self, x, xa, kv_cache=None):
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+ x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
+ x = x.to(xa.dtype)
+
+ for block in self.blocks:
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
+
+ x = self.ln(x)
+
+ # logits = self.proj_out(x).float()
+ # logits = (x @ torch.transpose(self.proj_out.weight.to(x.dtype), 0, 1)).float()
+ logits = self.proj_out.to(x.dtype)(x).float()
+
+ return logits
+
+ class WhisperUntied(whisper.model.Whisper):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.decoder = TextDecoderUntied(
+ self.dims.n_vocab,
+ self.dims.n_text_ctx,
+ self.dims.n_text_state,
+ self.dims.n_text_head,
+ self.dims.n_text_layer,
+ )
diff --git a/whisper/stt/processing/text_normalize.py b/whisper/stt/processing/text_normalize.py
new file mode 100644
index 0000000..cde8f38
--- /dev/null
+++ b/whisper/stt/processing/text_normalize.py
@@ -0,0 +1,393 @@
+import math
+import re
+# import string
+import unicodedata
+
+from stt import logger
+
+from .utils import flatten
+
+# All punctuations and symbols EXCEPT:
+# * apostrophe (') and hyphen (-)
+# * underscore (_)
+# * currency symbols ($, €, £, ...) -> \p{Sc}
+# * math symbols (%, +, ×). ex: C++
+# * misc (#, @). ex: C#, @user
+# and the space character (which can separate several series of punctuation marks)
+# Example of punctuations that can output models like Whisper: !,.:;?¿،؛؟…、。!,:?>/]:!(~\u200b[ா「«»“”"< ?;…,*」.)'
+_punctuation_regex = r"[^\w\p{Sc}" + re.escape("'-_%+×#@&") + "]"
+_leading_punctuations_regex = r"^" + _punctuation_regex + r"+"
+_trailing_punctuations_regex = _punctuation_regex + r"+$"
+
+# A list of symbols that can be an isolated words and not in the exclusion list above
+# * &
+# * candidates not retained: §, <, =, >, ≤, ≥
+_maybe_word_regex = None # r"[" + re.escape("&") + r"]$"
+
+
+def remove_punctuation(text: str, ensure_no_spaces_in_words: bool = False) -> str:
+ text = text.strip()
+ # Note: we don't remove dots inside words (e.g. "ab@gmail.com")
+ new_text = re.sub(_leading_punctuations_regex, "", text) # .lstrip()
+ new_text = re.sub(_trailing_punctuations_regex, "", new_text) # .rstrip()
+ # Let punctuation marks that are alone
+ if not new_text:
+ if _maybe_word_regex and re.match(_maybe_word_regex, text):
+ new_text = text
+ else:
+ new_text = ""
+ # Ensure that there is no space in the middle of a word
+ if ensure_no_spaces_in_words and " " in new_text:
+ new_text, tail = new_text.split(" ", 1)
+ # OK if the tail only contains non alphanumeric characters (then we just keep the first part)
+ assert not re.search(r"[^\W\d\'\-_]", tail), f"Got unexpected word containing space: {text}"
+ return remove_punctuation(new_text, ensure_no_spaces_in_words=ensure_no_spaces_in_words)
+ return new_text
+
+
+def transliterate(c):
+ # Transliterates a character to its closest ASCII equivalent.
+ # Example: transliterate("à ß œ fl") = "a ss oe fl"
+ c = re.sub("œ", "oe", c)
+ c = re.sub("æ", "ae", c)
+ c = re.sub("Œ", "OE", c)
+ c = re.sub("Æ", "AE", c)
+ c = re.sub("ß", "ss", c)
+ return unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii")
+
+
+def remove_emoji(text):
+ # Remove emojis
+ return re.sub(
+ r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF]+",
+ "",
+ text,
+ )
+
+
+def normalize_text(text: str, lang: str) -> str:
+ """Transform digits into characters..."""
+
+ # Reorder currencies (1,20€ -> 1 € 20)
+ coma = "," if lang in ["fr"] else "\."
+ for c in _currencies:
+ if c in text:
+ text = re.sub(r"\b(\d+)" + coma + r"(\d+)\s*" + c, r"\1 " + c + r" \2", text)
+
+ # Roman digits
+ if re.search(r"[IVX]", text):
+ if lang == "en":
+ digits = re.findall(r"\b(?=[XVI])M*(XX{0,3})(I[XV]|V?I{0,3})(º|st|nd|rd|th)?\b", text)
+ digits = ["".join(d) for d in digits]
+ elif lang == "fr":
+ digits = re.findall(
+ r"\b(?=[XVI])M*(XX{0,3})(I[XV]|V?I{0,3})(º|ème|eme|e|er|ère)?\b", text
+ )
+ digits = ["".join(d) for d in digits]
+ else:
+ digits = re.findall(r"\b(?=[XVI])M*(XX{0,3})(I[XV]|V?I{0,3})\b", text)
+ digits = ["".join(d) for d in digits]
+ if digits:
+ digits = sorted(list(set(digits)), reverse=True, key=lambda x: (len(x), x))
+ for s in digits:
+ filtered = re.sub("[a-zèº]", "", s)
+ ordinal = filtered != s
+ digit = roman_to_decimal(filtered)
+ v = undigit(str(digit), lang=lang, to="ordinal" if ordinal else "cardinal")
+ text = re.sub(r"\b" + s + r"\b", v, text)
+
+ # Ordinal digits
+ if lang == "en":
+ digits = re.findall(r"\b\d*1(?:st)|\d*2(?:nd)|\d*3(?:rd)|\d+(?:º|th)\b", text)
+ elif lang == "fr":
+ digits = re.findall(r"\b1(?:ère|ere|er|re|r)|2(?:nd|nde)|\d+(?:º|ème|eme|e)\b", text)
+ else:
+ logger.warn(
+ f"Language {lang} not supported for some normalization. Some words might be mis-localized."
+ )
+ digits = []
+ if digits:
+ digits = sorted(list(set(digits)), reverse=True, key=lambda x: (len(x), x))
+ for digit in digits:
+ word = undigit(re.findall(r"\d+", digit)[0], to="ordinal", lang=lang)
+ text = re.sub(r"\b" + str(digit) + r"\b", word, text)
+
+ # Cardinal digits
+ digits = re.findall(r"(?:\-?\b[\d/]*\d+(?: \d\d\d)+\b)|(?:\-?\d[/\d]*)", text)
+ digits = list(map(lambda s: s.strip(r"[/ ]"), digits))
+ digits = list(set(digits))
+ digits = digits + flatten([c.split() for c in digits if " " in c])
+ digits = digits + flatten([c.split("/") for c in digits if "/" in c])
+ digits = sorted(digits, reverse=True, key=lambda x: (len(x), x))
+ for digit in digits:
+ digitf = re.sub("/+", "/", digit)
+ if not digitf:
+ continue
+ numslash = len(re.findall("/", digitf))
+ if numslash == 0:
+ word = undigit(digitf, lang=lang)
+ elif numslash == 1: # Fraction or date
+ i = digitf.index("/")
+ is_date = False
+ if len(digitf[i + 1 :]) == 2:
+ try:
+ first = int(digitf[:i])
+ second = int(digitf[i + 1 :])
+ is_date = first > 0 and first < 32 and second > 0 and second < 13
+ except:
+ pass
+ if is_date:
+ first = digitf[:i].lstrip("0")
+ use_ordinal = (lang == "fr" and first == "1") or (
+ lang != "fr" and first[-1] in ["1", "2", "3"]
+ )
+ first = undigit(first, lang=lang, to="ordinal" if use_ordinal else "cardinal")
+ second = _int_to_month.get(lang, {}).get(second, digitf[i + 1 :])
+ else:
+ first = undigit(digitf[:i], lang=lang)
+ second = undigit(digitf[i + 1 :], to="denominator", lang=lang)
+ if float(digitf[:i]) > 2.0 and second[-1] != "s":
+ second += "s"
+ word = first + " " + second
+ elif numslash == 2: # Maybe a date
+ i1 = digitf.index("/")
+ i2 = digitf.index("/", i1 + 1)
+ is_date = False
+ if len(digitf[i1 + 1 : i2]) == 2 and len(digitf[i2 + 1 :]) == 4:
+ try:
+ first = int(digitf[:i1])
+ second = int(digitf[i1 + 1 : i2])
+ third = int(digitf[i2 + 1 :])
+ is_date = (
+ first > 0 and first < 32 and second > 0 and second < 13 and third > 1000
+ )
+ except:
+ pass
+ third = undigit(digitf[i2 + 1 :], lang=lang)
+ if is_date:
+ first = digitf[:i1].lstrip("0")
+ use_ordinal = (lang == "fr" and first == "1") or (
+ lang != "fr" and first[-1] in ["1", "2", "3"]
+ )
+ first = undigit(first, lang=lang, to="ordinal" if use_ordinal else "cardinal")
+ second = _int_to_month.get(lang, {}).get(
+ int(digitf[i1 + 1 : i2]), digitf[i1 + 1 : i2]
+ )
+ word = " ".join([first, second, third])
+ else:
+ word = " / ".join([undigit(s, lang=lang) for s in digitf.split("/")])
+ else:
+ word = " / ".join([undigit(s, lang=lang) for s in digitf.split("/")])
+ text = replace_keeping_word_boundaries(digit, word, text)
+
+ # Symbols (currencies, percent...)
+ symbol_table = _symbol_to_word.get(lang, {})
+ for k, v in symbol_table.items():
+ text = replace_keeping_word_boundaries(k, v, text)
+
+ # Remove extra spaces before punctuation
+ # text = re.sub(r" ([\.,!:;])",r"\1",text)
+
+ return collapse_whitespace(text)
+
+
+def replace_keeping_word_boundaries(orig, dest, text):
+ if orig in text:
+ text = re.sub(r"(\W)" + orig + r"(\W)", r"\1" + dest + r"\2", text)
+ text = re.sub(orig + r"(\W)", " " + dest + r"\1", text)
+ text = re.sub(r"(\W)" + orig, r"\1" + dest + " ", text)
+ text = re.sub(orig, " " + dest + " ", text)
+ return text
+
+
+def undigit(str, lang, to="cardinal"):
+ str = re.sub(" ", "", str)
+ if to == "denominator":
+ if lang == "fr":
+ if str == "2":
+ return "demi"
+ if str == "3":
+ return "tiers"
+ if str == "4":
+ return "quart"
+ elif lang == "en":
+ if str == "2":
+ return "half"
+ if str == "4":
+ return "quarter"
+ elif lang == "es":
+ if str == "2":
+ return "mitad"
+ if str == "3":
+ return "tercio"
+ to = "ordinal"
+ if str.startswith("0") and to == "cardinal":
+ numZeros = len(re.findall(r"0+", str)[0])
+ if numZeros < len(str):
+ return numZeros * (robust_num2words(0, lang=lang) + " ") + robust_num2words(
+ float(str), lang=lang, to=to
+ )
+ return robust_num2words(float(str), lang=lang, to=to)
+
+
+def robust_num2words(x, lang, to="cardinal", orig=""):
+ """
+ Bugfix for num2words
+ """
+ from num2words import num2words
+
+ try:
+ res = num2words(x, lang=lang, to=to)
+ if lang == "fr" and to == "ordinal":
+ res = res.replace("vingtsième", "vingtième")
+ return res
+ except OverflowError:
+ if x == math.inf: # !
+ return " ".join(robust_num2words(xi, lang=lang, to=to) for xi in orig)
+ if x == -math.inf: # !
+ return "moins " + robust_num2words(-x, lang=lang, to=to, orig=orig.replace("-", ""))
+ # TODO: print a warning
+ return robust_num2words(x // 10, lang=lang, to=to)
+
+
+def roman_to_decimal(str):
+ def value(r):
+ if r == "I":
+ return 1
+ if r == "V":
+ return 5
+ if r == "X":
+ return 10
+ if r == "L":
+ return 50
+ if r == "C":
+ return 100
+ if r == "D":
+ return 500
+ if r == "M":
+ return 1000
+ return -1
+
+ res = 0
+ i = 0
+ while i < len(str):
+ s1 = value(str[i])
+ if i + 1 < len(str):
+ s2 = value(str[i + 1])
+ if s1 >= s2:
+ # Value of current symbol is greater or equal to the next symbol
+ res = res + s1
+ i = i + 1
+ else:
+ # Value of current symbol is greater or equal to the next symbol
+ res = res + s2 - s1
+ i = i + 2
+ else:
+ res = res + s1
+ i = i + 1
+ return res
+
+
+_int_to_month = {
+ "fr": {
+ 1: "janvier",
+ 2: "février",
+ 3: "mars",
+ 4: "avril",
+ 5: "mai",
+ 6: "juin",
+ 7: "juillet",
+ 8: "août",
+ 9: "septembre",
+ 10: "octobre",
+ 11: "novembre",
+ 12: "décembre",
+ },
+ "en": {
+ 1: "january",
+ 2: "february",
+ 3: "march",
+ 4: "april",
+ 5: "may",
+ 6: "june",
+ 7: "july",
+ 8: "august",
+ 9: "september",
+ 10: "october",
+ 11: "november",
+ 12: "december",
+ },
+}
+
+_currencies = ["€", "$", "£", "¥"]
+
+_symbol_to_word = {
+ "fr": {
+ "%": "pour cents",
+ "÷": "divisé par",
+ "\*": "fois", # ?
+ "×": "fois",
+ "±": "plus ou moins",
+ "\+": "plus",
+ "&": "et",
+ "@": "arobase",
+ "m²": "mètres carrés",
+ "m³": "mètres cubes",
+ "²": "au carré",
+ "³": "au cube",
+ "¼": "un quart",
+ "½": "un demi",
+ "¾": "trois quarts",
+ "§": "section",
+ "°C": "degrés Celsius",
+ "°F": "degrés Fahrenheit",
+ "°K": "kelvins",
+ "°": "degrés",
+ "€": "euros",
+ "¢": "cents",
+ "\$": "dollars",
+ "£": "livres",
+ "¥": "yens",
+ # Below: not in Whisper tokens
+ # "₩": "wons",
+ # "₽": "roubles",
+ # "₹": "roupies",
+ # "₺": "liras",
+ # "₪": "shekels",
+ # "₴": "hryvnias",
+ # "₮": "tugriks",
+ # "℃": "degrés Celsius",
+ # "℉": "degrés Fahrenheit",
+ # "Ω": "ohms",
+ # "Ω": "ohms",
+ # "K": "kelvins",
+ # "ℓ": "litres",
+ },
+ "en": {
+ "%": "percent",
+ "÷": "divided by",
+ "\*": "times", # ?
+ "×": "times",
+ "±": "plus or minus",
+ "\+": "plus",
+ "&": "and",
+ "@": "at",
+ "m²": "square meters",
+ "m³": "cubic meters",
+ "²": "squared",
+ "³": "cubed",
+ "¼": "one quarter",
+ "½": "one half",
+ "¾": "three quarters",
+ "§": "section",
+ "°C": "degrees Celsius",
+ "°F": "degrees Fahrenheit",
+ "°K": "kelvins",
+ "°": "degrees",
+ "€": "euros",
+ "¢": "cents",
+ "\$": "dollars",
+ "£": "pounds",
+ "¥": "yens",
+ },
+}
diff --git a/whisper/stt/processing/utils.py b/whisper/stt/processing/utils.py
new file mode 100644
index 0000000..106167a
--- /dev/null
+++ b/whisper/stt/processing/utils.py
@@ -0,0 +1,225 @@
+import io
+import os
+
+import numpy as np
+import wavio
+from stt import USE_CTRANSLATE2, USE_TORCH, USE_TORCHAUDIO
+
+SAMPLE_RATE = 16000 # whisper.audio.SAMPLE_RATE
+
+if USE_CTRANSLATE2:
+ import ctranslate2
+ import faster_whisper
+else:
+ import torch
+
+ import whisper
+
+if USE_TORCHAUDIO:
+ import torchaudio
+
+
+def has_cuda():
+ if USE_CTRANSLATE2:
+ return ctranslate2.get_cuda_device_count() > 0
+ else:
+ return torch.cuda.is_available()
+
+
+def get_device():
+ device = os.environ.get("DEVICE", "cuda" if has_cuda() else "cpu")
+ use_gpu = "cuda" in device
+
+ if USE_CTRANSLATE2:
+ try:
+ if device.startswith("cuda:"):
+ _ = [int(dev) for dev in device[5:].split(",")]
+ else:
+ assert device in ["cpu", "cuda"]
+ except:
+ raise ValueError(
+ f"Invalid DEVICE '{device}' (should be 'cpu' or 'cuda' or 'cuda: or 'cuda:,,...')"
+ )
+ else:
+ try:
+ device = torch.device(device)
+ except Exception as err:
+ raise Exception("Failed to set device: {}".format(str(err))) from err
+ return device, use_gpu
+
+
+def get_language():
+ """
+ Get the language from the environment variable LANGUAGE, and format as expected by Whisper.
+ """
+ language = os.environ.get("LANGUAGE", "*")
+ # "fr-FR" -> "fr" (language-country code to ISO 639-1 code)
+ if len(language) > 2 and language[2] == "-":
+ language = language.split("-")[0]
+ # "*" means "all languages"
+ if language == "*":
+ language = None
+ # Convert French -> fr
+ if isinstance(language, str) and language not in LANGUAGES:
+ language = {v: k for k, v in LANGUAGES.items()}.get(language.lower(), language)
+ # Raise an exception for unknown languages
+ if language not in LANGUAGES:
+ available_languages = (
+ list(LANGUAGES.keys())
+ + [k[0].upper() + k[1:] for k in LANGUAGES.values()]
+ + ["*", None]
+ )
+ raise ValueError(
+ f"Language '{language}' is not available. Available languages are: {available_languages}"
+ )
+ return language
+
+
+def conform_audio(audio, sample_rate=16_000):
+ if sample_rate != SAMPLE_RATE:
+ if not USE_TORCHAUDIO:
+ raise NotImplementedError("Resampling not available without torchaudio")
+ # Down or Up sample to the right sampling rate
+ audio = torchaudio.transforms.Resample(sample_rate, SAMPLE_RATE)(audio)
+ if audio.shape[0] > 1:
+ # Stereo to mono
+ # audio = torchaudio.transforms.DownmixMono()(audio, channels_first = True)
+ audio = audio.mean(0)
+ else:
+ audio = audio.squeeze(0)
+ return audio
+
+
+def load_audiofile(path):
+ if not os.path.isfile(path):
+ raise RuntimeError("File not found: %s" % path)
+ elif not os.access(path, os.R_OK):
+ raise RuntimeError("Missing reading permission for: %s" % path)
+ if USE_CTRANSLATE2:
+ return faster_whisper.decode_audio(path, sampling_rate=SAMPLE_RATE)
+ audio = whisper.load_audio(path)
+ audio = torch.from_numpy(audio)
+ return audio
+
+
+def load_wave_buffer(file_buffer):
+ """Formats audio from a wavFile buffer to a torch array for processing."""
+ file_buffer_io = io.BytesIO(file_buffer)
+ if USE_CTRANSLATE2:
+ return faster_whisper.decode_audio(file_buffer_io, sampling_rate=SAMPLE_RATE)
+ file_content = wavio.read(file_buffer_io)
+ sample_rate = file_content.rate
+ audio = file_content.data.astype(np.float32) / 32768
+ audio = audio.transpose()
+ audio = torch.from_numpy(audio)
+ return conform_audio(audio, sample_rate)
+
+
+def flatten(l):
+ """
+ flatten a list of lists
+ """
+ return [item for sublist in l for item in sublist]
+
+
+LANGUAGES = { # whisper.tokenizer.LANGUAGES
+ "en": "english",
+ "zh": "chinese",
+ "de": "german",
+ "es": "spanish",
+ "ru": "russian",
+ "ko": "korean",
+ "fr": "french",
+ "ja": "japanese",
+ "pt": "portuguese",
+ "tr": "turkish",
+ "pl": "polish",
+ "ca": "catalan",
+ "nl": "dutch",
+ "ar": "arabic",
+ "sv": "swedish",
+ "it": "italian",
+ "id": "indonesian",
+ "hi": "hindi",
+ "fi": "finnish",
+ "vi": "vietnamese",
+ "he": "hebrew",
+ "uk": "ukrainian",
+ "el": "greek",
+ "ms": "malay",
+ "cs": "czech",
+ "ro": "romanian",
+ "da": "danish",
+ "hu": "hungarian",
+ "ta": "tamil",
+ "no": "norwegian",
+ "th": "thai",
+ "ur": "urdu",
+ "hr": "croatian",
+ "bg": "bulgarian",
+ "lt": "lithuanian",
+ "la": "latin",
+ "mi": "maori",
+ "ml": "malayalam",
+ "cy": "welsh",
+ "sk": "slovak",
+ "te": "telugu",
+ "fa": "persian",
+ "lv": "latvian",
+ "bn": "bengali",
+ "sr": "serbian",
+ "az": "azerbaijani",
+ "sl": "slovenian",
+ "kn": "kannada",
+ "et": "estonian",
+ "mk": "macedonian",
+ "br": "breton",
+ "eu": "basque",
+ "is": "icelandic",
+ "hy": "armenian",
+ "ne": "nepali",
+ "mn": "mongolian",
+ "bs": "bosnian",
+ "kk": "kazakh",
+ "sq": "albanian",
+ "sw": "swahili",
+ "gl": "galician",
+ "mr": "marathi",
+ "pa": "punjabi",
+ "si": "sinhala",
+ "km": "khmer",
+ "sn": "shona",
+ "yo": "yoruba",
+ "so": "somali",
+ "af": "afrikaans",
+ "oc": "occitan",
+ "ka": "georgian",
+ "be": "belarusian",
+ "tg": "tajik",
+ "sd": "sindhi",
+ "gu": "gujarati",
+ "am": "amharic",
+ "yi": "yiddish",
+ "lo": "lao",
+ "uz": "uzbek",
+ "fo": "faroese",
+ "ht": "haitian creole",
+ "ps": "pashto",
+ "tk": "turkmen",
+ "nn": "nynorsk",
+ "mt": "maltese",
+ "sa": "sanskrit",
+ "lb": "luxembourgish",
+ "my": "myanmar",
+ "bo": "tibetan",
+ "tl": "tagalog",
+ "mg": "malagasy",
+ "as": "assamese",
+ "tt": "tatar",
+ "haw": "hawaiian",
+ "ln": "lingala",
+ "ha": "hausa",
+ "ba": "bashkir",
+ "jw": "javanese",
+ "su": "sundanese",
+}
diff --git a/whisper/stt/processing/word_alignment.py b/whisper/stt/processing/word_alignment.py
new file mode 100644
index 0000000..e7a9256
--- /dev/null
+++ b/whisper/stt/processing/word_alignment.py
@@ -0,0 +1,224 @@
+"""
+Credits: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
+"""
+from dataclasses import dataclass
+
+from stt import USE_TORCH, logger
+
+from .alignment_model import compute_logprobas, get_vocab
+from .text_normalize import transliterate
+from .utils import flatten
+
+if USE_TORCH:
+ import torch
+
+_unknown_chars = []
+
+
+def compute_alignment(audio, transcript, model):
+ """Compute the alignment of the audio and a transcript, for a given model that returns log-probabilities on the charset defined the transcript."""
+
+ emission = compute_logprobas(model, audio)
+ labels, blank_id = get_vocab(model)
+ labels = labels[: emission.shape[1]]
+ dictionary = {c: i for i, c in enumerate(labels)}
+
+ default = labels.index("-") if "-" in labels else None
+ tokens = [loose_get_char_index(dictionary, c, default) for c in transcript]
+ tokens = flatten(tokens)
+
+ num_emissions = emission.shape[0]
+ num_repetitions = count_repetitions(tokens)
+ if len(tokens) + num_repetitions > num_emissions:
+ # It will be impossible to find a path...
+ # It can happen when Whisper is lost in a loop (ex: "Ha ha ha ha ...")
+ logger.warn(f"Got too many characters from Whisper. Shrinking to the first characters.")
+ tokens = tokens[:num_emissions]
+ num_repetitions = count_repetitions(tokens)
+ while len(tokens) + num_repetitions > num_emissions:
+ tokens = tokens[:-1]
+ num_repetitions = count_repetitions(tokens)
+
+ # Make sure transcript has the same length as tokens (it could be different just because of transliteration "œ" -> "oe")
+ transcript = "".join([labels[i][0] for i in tokens])
+
+ trellis = get_trellis(emission, tokens, blank_id=blank_id)
+
+ path = backtrack(trellis, emission, tokens, blank_id=blank_id)
+
+ segments = merge_repeats(transcript, path)
+
+ word_segments = merge_words(segments)
+
+ return labels, emission, trellis, segments, word_segments
+
+
+def count_repetitions(tokens):
+ return sum([a == b for a, b in zip(tokens[1:], tokens[:-1])])
+
+
+def loose_get_char_index(dictionary, c, default=None):
+ global _unknown_chars
+ i = dictionary.get(c, None)
+ if i is None:
+ # Try with alternative versions of the character
+ tc = transliterate(c)
+ other_char = list(set([c.lower(), c.upper(), tc, tc.lower(), tc.upper()]))
+ for c2 in other_char:
+ i = dictionary.get(c2, None)
+ if i is not None:
+ i = [i]
+ break
+ # Some transliterated versions may correspond to multiple characters
+ if i is None:
+ for c2 in other_char:
+ if len(c2) > 1:
+ candidate = [dictionary[c3] for c3 in c2 if c3 in dictionary]
+ if len(candidate) > 0 and (i is None or len(candidate) > len(i)):
+ i = candidate
+ # If still not found
+ if i is None:
+ if c not in _unknown_chars:
+ logger.warn(
+ "Character not correctly handled by alignment model: '"
+ + "' / '".join(list(set([c] + other_char)))
+ + "'"
+ )
+ _unknown_chars.append(c)
+ i = [default] if default is not None else []
+ else:
+ i = [i]
+ return i
+
+
+def get_trellis(emission, tokens, blank_id=0, use_max=False):
+ num_frame = emission.size(0)
+ num_tokens = len(tokens)
+
+ # Trellis has extra diemsions for both time axis and tokens.
+ # The extra dim for tokens represents (start-of-sentence)
+ # The extra dim for time axis is for simplification of the code.
+ trellis = torch.empty((num_frame + 1, num_tokens + 1)).to(emission.device)
+ trellis[0, 0] = 0
+ trellis[1:, 0] = torch.cumsum(emission[:, blank_id], 0)
+ trellis[0, -num_tokens:] = -float("inf")
+ trellis[-num_tokens:, 0] = float("inf")
+
+ for t in range(num_frame):
+ trellis[t + 1, 1:] = (
+ torch.maximum(
+ # Score for staying at the same token
+ trellis[t, 1:] + emission[t, blank_id],
+ torch.maximum(
+ trellis[t, 1:] + emission[t, tokens],
+ # Score for changing to the next token
+ trellis[t, :-1] + emission[t, tokens],
+ ),
+ )
+ if use_max
+ else torch.logaddexp(
+ trellis[t, 1:] + emission[t, blank_id],
+ torch.logaddexp(
+ trellis[t, 1:] + emission[t, tokens], trellis[t, :-1] + emission[t, tokens]
+ ),
+ )
+ )
+ return trellis
+
+
+@dataclass
+class Point:
+ token_index: int
+ time_index: int
+ score: float
+
+
+def backtrack(trellis, emission, tokens, blank_id=0):
+ # Note:
+ # j and t are indices for trellis, which has extra dimensions
+ # for time and tokens at the beginning.
+ # When referring to time frame index `T` in trellis,
+ # the corresponding index in emission is `T-1`.
+ # Similarly, when referring to token index `J` in trellis,
+ # the corresponding index in transcript is `J-1`.
+ j = trellis.size(1) - 1
+ t_start = torch.argmax(trellis[:, j]).item()
+
+ path = []
+ for t in range(t_start, 0, -1):
+ # 1. Figure out if the current position was stay or change
+ # Note (again):
+ # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
+ # Score for token staying the same from time frame J-1 to T.
+ stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
+ # Score for token changing from C-1 at T-1 to J at T.
+ changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
+
+ # 2. Store the path with frame-wise probability.
+ prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
+ # Return token index and time index in non-trellis coordinate.
+ path.append(Point(j - 1, t - 1, prob))
+
+ # 3. Update the token
+ if changed > stayed:
+ j -= 1
+ if j == 0:
+ break
+ else:
+ logger.warn(f"Failed to align {len(tokens)} tokens")
+ return path
+ return path[::-1]
+
+
+# Merge the labels
+@dataclass
+class Segment:
+ label: str
+ start: int
+ end: int
+ score: float
+
+ def __repr__(self):
+ return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
+
+ @property
+ def length(self):
+ return self.end - self.start
+
+
+def merge_repeats(transcript, path):
+ i1, i2 = 0, 0
+ segments = []
+ while i1 < len(path):
+ while i2 < len(path) and path[i1].token_index == path[i2].token_index:
+ i2 += 1
+ score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
+ segments.append(
+ Segment(
+ transcript[path[i1].token_index],
+ path[i1].time_index,
+ path[i2 - 1].time_index + 1,
+ score,
+ )
+ )
+ i1 = i2
+ return segments
+
+
+def merge_words(segments, separator=" "):
+ words = []
+ i1, i2 = 0, 0
+ while i1 < len(segments):
+ if i2 >= len(segments) or segments[i2].label == separator:
+ if i1 != i2:
+ segs = segments[i1:i2]
+ word = "".join([seg.label for seg in segs])
+ score = sum(seg.score * seg.length for seg in segs) / sum(
+ seg.length for seg in segs
+ )
+ words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
+ i1 = i2 + 1
+ i2 = i1
+ else:
+ i2 += 1
+ return words