diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/__init__.py b/common/speech/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt b/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt new file mode 100644 index 000000000..204934829 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/CMakeLists.txt @@ -0,0 +1,41 @@ +cmake_minimum_required(VERSION 3.8) +project(lasr_speech_recognition_interfaces) + +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wall -Wextra -Wpedantic) +endif() + +# find dependencies +find_package(ament_cmake REQUIRED) +find_package(rclpy REQUIRED) +find_package(action_msgs REQUIRED) + +# uncomment the following section in order to fill in +# further dependencies manually. +# find_package( REQUIRED) + +# For actions, messages, and services +find_package(rosidl_default_generators REQUIRED) + +rosidl_generate_interfaces(${PROJECT_NAME} + "action/TranscribeSpeech.action" + "msg/Transcription.msg" + "srv/TranscribeAudio.srv" + DEPENDENCIES builtin_interfaces # Add packages that above messages depend on +) + +ament_export_dependencies(rosidl_default_runtime) + +if(BUILD_TESTING) + find_package(ament_lint_auto REQUIRED) + # the following line skips the linter which checks for copyrights + # comment the line when a copyright and license is added to all source files + set(ament_cmake_copyright_FOUND TRUE) + # the following line skips cpplint (only works in a git repo) + # comment the line when this package is in a git repo and when + # a copyright and license is added to all source files + set(ament_cmake_cpplint_FOUND TRUE) + ament_lint_auto_find_test_dependencies() +endif() + +ament_package() diff --git a/common/speech/lasr_speech_recognition_interfaces/LICENSE b/common/speech/lasr_speech_recognition_interfaces/LICENSE new file mode 100644 index 000000000..30e8e2ece --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/LICENSE @@ -0,0 +1,17 @@ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/common/speech/lasr_speech_recognition_interfaces/README.md b/common/speech/lasr_speech_recognition_interfaces/README.md new file mode 100644 index 000000000..8e7aab96f --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/README.md @@ -0,0 +1,51 @@ +# lasr_speech_recognition_interfaces + +Common messages used for speech recognition + +This package is maintained by: + +- [Maayan Armony](mailto:maayan.armony@gmail.com) +- [Paul Makles](mailto:me@insrt.uk) (ROS1) + +## Prerequisites + +This package depends on the following ROS packages: + +- colcon (buildtool) +- message_generation (build) +- message_runtime (exec) + +## Usage + +Ask the package maintainer to write a `doc/USAGE.md` for their package! + +## Example + +Ask the package maintainer to write a `doc/EXAMPLE.md` for their package! + +## Technical Overview + +Ask the package maintainer to write a `doc/TECHNICAL.md` for their package! + +## ROS Definitions + +### Launch Files + +This package has no launch files. + +### Messages + +#### `Transcription` + +| Field | Type | Description | +|:--------:|:------:|-------------| +| phrase | string | | +| finished | bool | | + +### Services + +This package has no services. + +### Actions + +This package has no actions. diff --git a/common/speech/lasr_speech_recognition_interfaces/action/TranscribeSpeech.action b/common/speech/lasr_speech_recognition_interfaces/action/TranscribeSpeech.action new file mode 100644 index 000000000..5cac9317e --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/action/TranscribeSpeech.action @@ -0,0 +1,11 @@ +# Energy threshold +float32 energy_threshold + +# Max phrase duration +float32 max_phrase_limit +--- +#result definition +string sequence +--- +#feedback +string sequence \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_interfaces/msg/Transcription.msg b/common/speech/lasr_speech_recognition_interfaces/msg/Transcription.msg new file mode 100644 index 000000000..9c7483636 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/msg/Transcription.msg @@ -0,0 +1,2 @@ +string phrase +bool finished \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_interfaces/package.xml b/common/speech/lasr_speech_recognition_interfaces/package.xml new file mode 100644 index 000000000..fd72011b7 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/package.xml @@ -0,0 +1,23 @@ + + + + lasr_speech_recognition_interfaces + 0.0.0 + Common messages used for speech recognition + maayan + MIT + + ament_cmake + + rosidl_default_generators + action_msgs + rosidl_default_runtime + rosidl_interface_packages + + ament_lint_auto + ament_lint_common + + + ament_cmake + + diff --git a/common/speech/lasr_speech_recognition_interfaces/srv/TranscribeAudio.srv b/common/speech/lasr_speech_recognition_interfaces/srv/TranscribeAudio.srv new file mode 100644 index 000000000..f416a67c4 --- /dev/null +++ b/common/speech/lasr_speech_recognition_interfaces/srv/TranscribeAudio.srv @@ -0,0 +1,2 @@ +--- +string phrase \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/LICENSE b/common/speech/lasr_speech_recognition_whisper/LICENSE new file mode 100644 index 000000000..30e8e2ece --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/LICENSE @@ -0,0 +1,17 @@ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/common/speech/lasr_speech_recognition_whisper/README.md b/common/speech/lasr_speech_recognition_whisper/README.md new file mode 100644 index 000000000..c9f58557e --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/README.md @@ -0,0 +1,109 @@ +# lasr_speech_recognition_whisper + +Speech recognition implemented using OpenAI Whisper + +This package is maintained by: + +- [Maayan Armony](mailto:maayan.armony@gmail.com) +- [Paul Makles](mailto:me@insrt.uk) (ROS1) + +## Prerequisites + +This package depends on the following ROS packages: + +- colcon (buildtool) +- lasr_speech_recognition_interfaces + +This packages requires Python 3.10 to be present. + +This package has 48 Python dependencies: + +- [SpeechRecognition](https://pypi.org/project/SpeechRecognition)==3.10.0 +- [openai-whisper](https://pypi.org/project/openai-whisper)==20230314 +- [PyAudio](https://pypi.org/project/PyAudio)==0.2.13 +- [PyYaml](https://pypi.org/project/PyYaml)==6.0.1 +- .. and sub dependencies (see [requirements file](requirements.txt)) + +This package requires that [ffmpeg](https://ffmpeg.org/) is available during runtime. + +## Usage + +> **Warning**: this package is not complete, this is subject to change. + +List available microphones: + +```bash +ros2 run lasr_speech_recognition_whisper list_microphones.py +``` + +Start the example script: + +```bash +ros2 run lasr_speech_recognition_whisper transcribe_microphone by-index +ros2 run lasr_speech_recognition_whisper transcribe_microphone by-name +``` + +Then start listening to people: + +```bash +ros2 service call /whisper/start_listening "{}" +``` + +You can now listen on `/transcription` for a live transcription. + +Stop listening whenever: + +```bash +ros2 service call /whisper/stop_listening "{}" +``` + +## Example + +Ask the package maintainer to write a `doc/EXAMPLE.md` for their package! + +## Technical Overview + +This package does speech recognition in three parts: + +- Adjusting for background noise + + We wait for a set period of time monitoring the audio stream to determine what we should ignore when collecting voice + data. + +- Collecting appropriate voice data for phrases + + We use the `SpeechRecognition` package to monitor the input audio stream and determine when a person is actually + speaking with enough energy that we would consider them to be speaking to the robot. + +- Running inference on phrases + + We continuously combine segments of the spoken phrase to form a sample until a certain timeout or threshold after + which the phrase ends. This sample is sent to a local OpenAI Whisper model to transcribe. + +The package can input from the following sources: + +- On-board or external microphone on device +- Audio data from ROS topic (WORK IN PROGRESS) + +The package can output transcriptions to: + +- Standard output +- A ROS topic + +## ROS Definitions + +### Launch Files + +This package has no launch files. + +### Messages + +This package has no messages. + +### Services + +This package has no services. + +### Actions + +This package has no actions. \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py new file mode 100644 index 000000000..7b3b1f8a0 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py @@ -0,0 +1,101 @@ +#!/usr/bin python3 +import os +import torch +import rclpy +from ament_index_python import packages + +import sys +from pathlib import Path +import speech_recognition as sr +import numpy as np + +import sounddevice # needed to remove ALSA error messages +from lasr_speech_recognition_interfaces.srv import TranscribeAudio +from src import ModelCache # type: ignore + +MODEL = "medium.en" # Whisper model +TIMEOUT = 5.0 # Timeout for listening for the start of a phrase +PHRASE_TIME_LIMIT = None # Timeout for listening for the end of a phrase + +WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper") +os.makedirs(WHISPER_CACHE, exist_ok=True) +os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE + +if len(sys.argv) < 3: + print("Usage:") + print( + "ros2 run lasr_speech_recognition transcribe_microphone by-index " + ) + print("ros2 run lasr_speech_recognition transcribe_microphone by-name ") + exit(1) +else: + matcher = sys.argv[1] + device_index = None + if matcher == "by-index": + device_index = int(sys.argv[2]) + elif matcher == "by-name": + import speech_recognition as sr + + microphones = enumerate(sr.Microphone.list_microphone_names()) + + target_name = sys.argv[2] + for index, name in microphones: + if target_name in name: + device_index = index + break + + if device_index is None: + print("Could not find device!") + exit(1) + else: + print("Invalid matcher") + exit(1) + +rclpy.init(args=sys.argv) +node = rclpy.create_node("transcribe_mic") + +device = "cuda" if torch.cuda.is_available() else "cpu" +model_cache = ModelCache() +model = model_cache.load_model("medium.en", device=device) + +# try to run inference on the example file +package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") +package_root = os.path.abspath( + os.path.join( + package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper" + ) +) +example_fp = os.path.join(package_root, "test.m4a") +node.get_logger().info( + "Running transcription on example file to ensure model is loaded..." +) +transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available()) +node.get_logger().info(str(transcription)) + +microphone = sr.Microphone(device_index=device_index, sample_rate=16000) +r = sr.Recognizer() +with microphone as source: + r.adjust_for_ambient_noise(source) + + +def handle_transcribe_audio(_): + with microphone as source: + + wav_data = r.listen( + source, timeout=TIMEOUT, phrase_time_limit=PHRASE_TIME_LIMIT + ).get_wav_data() + float_data = ( + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 + ) + + phrase = model.transcribe(float_data, fp16=device == "cuda")["text"] + return TranscribeAudio.Response(phrase=phrase) + + +node.create_service( + TranscribeAudio, "/whisper/transcribe_audio", handle_transcribe_audio +) + +node.get_logger().info("Whisper service ready") +rclpy.spin(node) diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py new file mode 100644 index 000000000..3225072c3 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py @@ -0,0 +1,111 @@ +#!/usr/bin python3 +import os +import sys +import torch +from pathlib import Path + +import rclpy +from rclpy.node import Node +from ament_index_python import packages +from std_srvs.srv import Empty +from src import SpeechRecognitionToTopic, MicrophonePhraseCollector, ModelCache + +WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper") +os.makedirs(WHISPER_CACHE, exist_ok=True) +os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE + + +class TranscribeMicrophone(Node): + def __init__(self): + Node.__init__(self, "transcribe_microphone") + self.worker = None + self.collector = None + + self.create_service(Empty, "/whisper/adjust_for_noise", self.adjust_for_noise) + self.create_service(Empty, "/whisper/start_listening", self.start_listening) + self.create_service(Empty, "/whisper/stop_listening", self.stop_listening) + + self.get_logger().info("Starting the Whisper worker!") + self.run_transcription() + + def run_transcription(self): + if len(sys.argv) < 3: + print("Usage:") + print( + "rosrun lasr_speech_recognition transcribe_microphone by-index " + ) + print( + "rosrun lasr_speech_recognition transcribe_microphone by-name " + ) + exit(1) + else: + matcher = sys.argv[1] + device_index = None + if matcher == "by-index": + device_index = int(sys.argv[2]) + elif matcher == "by-name": + import speech_recognition as sr + + microphones = enumerate(sr.Microphone.list_microphone_names()) + + target_name = sys.argv[2] + for index, name in microphones: + if target_name in name: + device_index = index + break + + if device_index is None: + print("Could not find device!") + exit(1) + else: + print("Invalid matcher") + exit(1) + + self.collector = MicrophonePhraseCollector(device_index=device_index) + self.collector.adjust_for_noise() + + model_cache = ModelCache() + model = model_cache.load_model("medium.en") + + # try to run inference on the example file + package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") + package_root = os.path.abspath( + os.path.join( + package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper" + ) + ) + example_fp = os.path.join(package_root, "test.m4a") + + self.get_logger().info( + "Running transcription on example file to ensure model is loaded..." + ) + model_transcription = model.transcribe( + example_fp, fp16=torch.cuda.is_available() + ) + self.get_logger().info(str(model_transcription)) + + self.worker = SpeechRecognitionToTopic( + self.collector, model, "transcription", infer_partial=False + ) + + def adjust_for_noise(self, request, response): + self.collector.adjust_for_noise() + return response + + def start_listening(self, request, response): + self.worker.start() + return response + + def stop_listening(self, request, response): + self.worker.stop() + return response + + +def main(args=None): + rclpy.init(args=args) + transcribe_microphone = TranscribeMicrophone() + rclpy.spin(transcribe_microphone) + + +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py new file mode 100644 index 000000000..8adf3bd8c --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py @@ -0,0 +1,387 @@ +#!/usr/bin python3 +import os +import sounddevice # needed to remove ALSA error messages +import argparse +from typing import Optional +from dataclasses import dataclass +from pathlib import Path +from timeit import default_timer as timer + +import numpy as np +import torch + +import rclpy +from rclpy.node import Node +from rclpy.action.server import ActionServer, CancelResponse + +import speech_recognition as sr # type: ignore +from lasr_speech_recognition_interfaces.action import TranscribeSpeech # type: ignore +from rclpy.executors import ExternalShutdownException +from std_msgs.msg import String # type: ignore +from src import ModelCache # type: ignore + +# TODO: argpars -> ROS2 params, test behaviour of preemption + + +@dataclass +class speech_model_params: + """Class for storing speech recognition model parameters. + + Args: + model_name (str, optional): Name of the speech recognition model. Defaults to "medium.en". + Must be a valid Whisper model name. + device (str, optional): Device to run the model on. Defaults to "cuda" if available, otherwise "cpu". + start_timeout (float): Max number of seconds of silence when starting listening before stopping. Defaults to 5.0. + phrase_duration (Optional[float]): Max number of seconds of the phrase. Defaults to 10 seconds. + sample_rate (int): Sample rate of the microphone. Defaults to 16000Hz. + mic_device (Optional[str]): Microphone device index or name. Defaults to None. + timer_duration (Optional[int]): Duration of the timer for adjusting the microphone for ambient noise. Defaults to 20 seconds. + warmup (bool): Whether to warmup the model by running inference on a test file. Defaults to True. + energy_threshold (Optional[int]): Energy threshold for silence detection. Using this disables automatic adjustment. Defaults to None. + pause_threshold (Optional[float]): Seconds of non-speaking audio before a phrase is considered complete. Defaults to 0.8 seconds. + """ + + model_name: str = "medium.en" + device: str = "cuda" if torch.cuda.is_available() else "cpu" + start_timeout: float = 5.0 + phrase_duration: Optional[float] = 10 + sample_rate: int = 16000 + mic_device: Optional[str] = None + timer_duration: Optional[int] = 20 + warmup: bool = True + energy_threshold: Optional[int] = None + pause_threshold: Optional[float] = 2.0 + + +class TranscribeSpeechAction(Node): + # create messages that are used to publish feedback/result + _feedback = TranscribeSpeech.Feedback() + _result = TranscribeSpeech.Result() + + def __init__( + self, + action_name: str, + model_params: speech_model_params, + ) -> None: + """Starts an action server for transcribing speech. + + Args: + action_name (str): Name of the action server. + """ + Node.__init__(self, "transcribe_speech_action") + self._action_name = action_name + self._model_params = model_params + self._transcription_server = self.create_publisher( + String, "/live_speech_transcription", 10 + ) + + model_cache = ModelCache() + self._model = model_cache.load_model( + self._model_params.model_name, + self._model_params.device, + self._model_params.warmup, + ) + # Configure the speech recogniser object and adjust for ambient noise + self.recogniser = self._configure_recogniser() + + # Set up the action server and register execution callback + self._action_server = ActionServer( + self, + TranscribeSpeech, + self._action_name, + execute_callback=self.execute_cb, + cancel_callback=self.cancel_cb, + # auto_start=False, # not required in ROS2 ?? (cb is async) + ) + self._action_server.register_cancel_callback(self.cancel_cb) + self._listening = False + + # self._action_server.start() # not required in ROS2 + self.get_logger().info(f"Speech Action server {self._action_name} started") + + def _configure_microphone(self) -> sr.Microphone: + """Configures the microphone for listening to speech based on the + microphone device index or name. + + Returns: microphone object + """ + + if self._model_params.mic_device is None: + # If no microphone device is specified, use the system default microphone + return sr.Microphone(sample_rate=self._model_params.sample_rate) + elif self._model_params.mic_device.isdigit(): + return sr.Microphone( + device_index=int(self._model_params.mic_device), + sample_rate=self._model_params.sample_rate, + ) + else: + microphones = enumerate(sr.Microphone.list_microphone_names()) + for index, name in microphones: + if self._model_params.mic_device in name: + return sr.Microphone( + device_index=index, + sample_rate=self._model_params.sample_rate, + ) + raise ValueError( + f"Could not find microphone with name: {self._model_params.mic_device}" + ) + + def _configure_recogniser( + self, + energy_threshold: Optional[float] = None, + pause_threshold: Optional[float] = None, + ) -> sr.Recognizer: + """Configures the speech recogniser object. + + Args: + energy_threshold (float): Energy threshold for silence detection. Using this disables automatic adjustment. + pause_threshold (float): Seconds of non-speaking audio before a phrase is considered complete. + + Returns: + sr.Recognizer: speech recogniser object. + """ + self._listening = True + recogniser = sr.Recognizer() + + if pause_threshold: + recogniser.pause_threshold = pause_threshold + + elif self._model_params.pause_threshold: + recogniser.pause_threshold = self._model_params.pause_threshold + + if energy_threshold: + recogniser.dynamic_energy_threshold = False + recogniser.energy_threshold = energy_threshold + return recogniser + + if self._model_params.energy_threshold: + recogniser.dynamic_energy_threshold = False + recogniser.energy_threshold = self._model_params.energy_threshold + return recogniser + + with self._configure_microphone() as source: + recogniser.adjust_for_ambient_noise(source) + self._listening = False + return recogniser + + def cancel_cb(self, goal_handle) -> CancelResponse: + """Callback for cancelling the action server. + Sets server to 'canceled' state. + """ + cancel_str = f"{self._action_name} has been cancelled" + self.get_logger().info(cancel_str) + self._result.sequence = cancel_str + + # self._action_server.set_preempted(result=self._result, text=cancel_str) + goal_handle.canceled() + + return CancelResponse.ACCEPT # TODO decide if always accept cancellation + + async def execute_cb(self, goal_handle) -> None: + """Callback for executing the action server. + + Checks for cancellation before listening and before and after transcribing, returning + if cancellation is requested. + + Args: + :param goal_handle: handles the goal request, and provides access to the goal parameters + """ + + goal = goal_handle.request + + self.get_logger().info("Request Received") + if goal_handle.is_cancel_requested: + return + + if goal.energy_threshold > 0.0 and goal.max_phrase_limit > 0.0: + self.recogniser = self._configure_recogniser( + goal.energy_threshold, goal.max_phrase_limit + ) + elif goal.energy_threshold > 0.0: + self.recogniser = self._configure_recogniser(goal.energy_threshold) + elif goal.max_phrase_limit > 0.0: + self.recogniser = self._configure_recogniser( + pause_threshold=goal.max_phrase_limit + ) + + with self._configure_microphone() as src: + self._listening = True + wav_data = self.recogniser.listen( + src, + timeout=self._model_params.start_timeout, + phrase_time_limit=self._model_params.phrase_duration, + ).get_wav_data() + # Magic number 32768.0 is the maximum value of a 16-bit signed integer + float_data = ( + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 + ) + + if goal_handle.is_cancel_requested(): + self._listening = False + self.get_logger().info("Goal was cancelled during execution.") + goal_handle.canceled() + return self._result + + self.get_logger().info(f"Transcribing phrase with Whisper...") + transcription_start_time = timer() + # Cast to fp16 if using GPU + phrase = self._model.transcribe( + float_data, + fp16=self._model_params.device == "cuda", + )["text"] + transcription_end_time = timer() + self.get_logger().info(f"Transcription finished!") + self.get_logger().info( + f"Time taken: {transcription_end_time - transcription_start_time:.2f}s" + ) + self._transcription_server.publish(phrase) + if goal_handle.is_cancel_requested(): + self._listening = False + return + + self._result.sequence = phrase + self.get_logger().info(f"Transcribed phrase: {phrase}") + self.get_logger().info(f"{self._action_name} has succeeded") + + goal_handle.succeed() + + # Have this at the very end to not disrupt the action server + self._listening = False + + return self._result + + +def parse_args() -> dict: + """Parses the command line arguments into a name: value dictinoary. + + Returns: + dict: Dictionary of name: value pairs of command line arguments. + """ + parser = argparse.ArgumentParser( + description="Starts an action server for transcribing speech." + ) + + # TODO change to ROS2 rosparams: + # port = node.declare_parameter('port', '/dev/ttyUSB0').value + # assert isinstance(port, str), 'port parameter must be a str' + + parser.add_argument( + "--action_name", + type=str, + default="transcribe_speech", + help="Name of the action server.", + ) + parser.add_argument( + "--model_name", + type=str, + default="medium.en", + help="Name of the speech recognition model.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run the model on.", + ) + parser.add_argument( + "--start_timeout", + type=float, + default=5.0, + help="Timeout for listening for the start of a phrase.", + ) + parser.add_argument( + "--phrase_duration", + type=float, + default=10, + help="Maximum phrase duration after starting listening in seconds.", + ) + parser.add_argument( + "--sample_rate", + type=int, + default=16000, + help="Sample rate of the microphone.", + ) + parser.add_argument( + "--mic_device", + type=str, + default=None, + help="Microphone device index or name", + ) + parser.add_argument( + "--no_warmup", + action="store_true", + help="Disable warming up the model by running inference on a test file.", + ) + + parser.add_argument( + "--energy_threshold", + type=Optional[int], + default=None, + help="Energy threshold for silence detection. Using this disables automatic adjustment", + ) + + parser.add_argument( + "--pause_threshold", + type=float, + default=2.0, + help="Seconds of non-speaking audio before a phrase is considered complete.", + ) + + args, unknown = parser.parse_known_args() + return vars(args) + + +def configure_model_params(config: dict) -> speech_model_params: + """Configures the speech model parameters based on the provided + command line parameters. + + Args: + config (dict): Command line parameters parsed in dictionary form. + + Returns: + speech_model_params: dataclass containing the speech model parameters + """ + model_params = speech_model_params() + if config["model_name"]: + model_params.model_name = config["model_name"] + if config["device"]: + model_params.device = config["device"] + if config["start_timeout"]: + model_params.start_timeout = config["start_timeout"] + if config["phrase_duration"]: + model_params.phrase_duration = config["phrase_duration"] + if config["sample_rate"]: + model_params.sample_rate = config["sample_rate"] + if config["mic_device"]: + model_params.mic_device = config["mic_device"] + if config["no_warmup"]: + model_params.warmup = False + # if config["energy_threshold"]: + # model_params.energy_threshold = config["energy_threshold"] + if config["pause_threshold"]: + model_params.pause_threshold = config["pause_threshold"] + + return model_params + + +def configure_whisper_cache() -> None: + """Configures the whisper cache directory.""" + whisper_cache = os.path.join(str(Path.home()), ".cache", "whisper") + os.makedirs(whisper_cache, exist_ok=True) + # Environmental variable required to run whisper locally + os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache + + +def main(args=None): + rclpy.init(args=args) + + configure_whisper_cache() + config = parse_args() + + server = TranscribeSpeechAction("transcribe_speech", configure_model_params(config)) + + try: + rclpy.spin(server) + except (KeyboardInterrupt, ExternalShutdownException): + pass diff --git a/common/speech/lasr_speech_recognition_whisper/log/COLCON_IGNORE b/common/speech/lasr_speech_recognition_whisper/log/COLCON_IGNORE new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/package.xml b/common/speech/lasr_speech_recognition_whisper/package.xml new file mode 100644 index 000000000..1cac47617 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/package.xml @@ -0,0 +1,30 @@ + + + + lasr_speech_recognition_whisper + 0.0.0 + Speech recognition implemented using OpenAI Whisper + maayan + MIT + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + + lasr_speech_recognition_interfaces + actionlib + actionlib_msgs + actionlib + actionlib_msgs + + + ament_python + requirements.txt + + diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.in b/common/speech/lasr_speech_recognition_whisper/requirements.in new file mode 100644 index 000000000..1d515543e --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/requirements.in @@ -0,0 +1,6 @@ +SpeechRecognition==3.10.0 +sounddevice==0.4.6 +openai-whisper==20231117 +PyAudio~=0.2.13 +PyYaml==6.0.1 +setuptools==60.0.1 \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.txt b/common/speech/lasr_speech_recognition_whisper/requirements.txt new file mode 100644 index 000000000..eade8e0a3 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/requirements.txt @@ -0,0 +1,45 @@ +certifi==2024.2.2 # via requests +cffi==1.16.0 # via sounddevice +charset-normalizer==3.3.2 # via requests +filelock==3.14.0 # via torch, triton +fsspec==2024.3.1 # via torch +idna==3.7 # via requests +jinja2==3.1.4 # via torch +llvmlite==0.42.0 # via numba +markupsafe==2.1.5 # via jinja2 +more-itertools==10.2.0 # via openai-whisper +mpmath==1.3.0 # via sympy +networkx==3.3 # via torch +numba==0.59.1 # via openai-whisper +numpy==1.26.4 # via numba, openai-whisper +nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch +nvidia-cuda-cupti-cu12==12.1.105 # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 # via torch +nvidia-cuda-runtime-cu12==12.1.105 # via torch +nvidia-cudnn-cu12==8.9.2.26 # via torch +nvidia-cufft-cu12==11.0.2.54 # via torch +nvidia-curand-cu12==10.3.2.106 # via torch +nvidia-cusolver-cu12==11.4.5.107 # via torch +nvidia-cusparse-cu12==12.1.0.106 # via nvidia-cusolver-cu12, torch +nvidia-nccl-cu12==2.20.5 # via torch +nvidia-nvjitlink-cu12==12.4.127 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 # via torch +openai-whisper==20231117 # via -r requirements.in +pyaudio==0.2.13 # via -r requirements.in +pycparser==2.22 # via cffi +pyyaml==6.0.1 # via -r requirements.in +regex==2024.4.28 # via tiktoken +requests==2.31.0 # via speechrecognition, tiktoken +six==1.16.0 # via python-dateutil +sounddevice==0.4.6 # via -r requirements.in +speechrecognition==3.10.0 # via -r requirements.in +sympy==1.12 # via torch +tiktoken==0.6.0 # via openai-whisper +torch==2.3.0 # via openai-whisper +tqdm==4.66.4 # via openai-whisper +triton==2.3.0 # via openai-whisper, torch +typing-extensions==4.11.0 # via torch +urllib3==2.2.1 # via requests + +# The following packages are considered to be unsafe in a requirements file: +# setuptools == 60.0.1 diff --git a/common/speech/lasr_speech_recognition_whisper/resource/lasr_speech_recognition_whisper b/common/speech/lasr_speech_recognition_whisper/resource/lasr_speech_recognition_whisper new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/__init__.py b/common/speech/lasr_speech_recognition_whisper/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py new file mode 100755 index 000000000..a3ce21904 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py @@ -0,0 +1,19 @@ +#!/usr/bin python3 +import speech_recognition as sr +import sounddevice # needed to remove ALSA error messages + + +def main(): + microphones = enumerate(sr.Microphone.list_microphone_names()) + + print("\nAvailable microphones:") + for index, name in microphones: + print(f"[{index}] {name}") + + # # Uncomment for debugging, to see if sounddevice recongises the microphone as well + # print("Available microphone devices (sounddevice):") + # print(sounddevice.query_devices()) + + +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py new file mode 100755 index 000000000..026ab2875 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py @@ -0,0 +1,76 @@ +#!/usr/bin python3 +import argparse +import os +import torch +import numpy as np +from pathlib import Path +import speech_recognition as sr +from src import ModelCache # type: ignore +import sounddevice # needed to remove ALSA error messages +from typing import Dict +import rclpy + +# TODO argparse -> ROS params + + +def parse_args() -> Dict: + parser = argparse.ArgumentParser() + parser.add_argument( + "--device_index", help="Microphone index", type=int, default=None + ) + return vars(parser.parse_args()) + + +def configure_whisper_cache() -> None: + """Configures the whisper cache directory.""" + whisper_cache = os.path.join(str(Path.home()), ".cache", "whisper") + os.makedirs(whisper_cache, exist_ok=True) + # Environmental variable required to run whisper locally + os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache + + +def main(args=None): + rclpy.init(args=args) # Have to initialise rclpy for the ModelCache + + configure_whisper_cache() + args = parse_args() + + recognizer = sr.Recognizer() + recognizer.pause_threshold = 2 + microphone = sr.Microphone(device_index=args["device_index"], sample_rate=16000) + threshold = 100 + recognizer.dynamic_energy_threshold = False + recognizer.energy_threshold = threshold + model_cache = ModelCache() + transcription_model = model_cache.load_model( + "medium.en", "cuda" if torch.cuda.is_available() else "cpu", True + ) + transcription_result = "The quick brown fox jumps over the lazy dog." + while transcription_result != "": + print(f"Listening...") + with microphone as source: + wav_data = recognizer.listen( + source, phrase_time_limit=10, timeout=5 + ).get_wav_data() + print(f"Processing...") + # Magic number 32768.0 is the maximum value of a 16-bit signed integer + float_data = ( + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 + ) + + # Cast to fp16 if using GPU + transcription_result = transcription_model.transcribe( + float_data, fp16=torch.cuda.is_available() + )["text"] + + print( + f"Transcription: {transcription_result} at energy threshold {recognizer.energy_threshold}" + ) + threshold += 100 + recognizer.energy_threshold = threshold + + +if __name__ == "__main__": + + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py new file mode 100755 index 000000000..d14144e21 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py @@ -0,0 +1,72 @@ +#!/usr/bin python3 + +import os +import argparse +import speech_recognition as sr +import rclpy +import sounddevice # needed to remove ALSA error messages + +# TODO argparse -> ROS params + + +def parse_args() -> dict: + """Parse command line arguments into a dictionary. + + Returns: + dict: name: value pairs of command line arguments + """ + + parser = argparse.ArgumentParser(description="Test microphones") + parser.add_argument( + "-m", "--microphone", type=int, help="Microphone index", default=None + ) + parser.add_argument( + "-o", "--output_dir", type=str, help="Directory to save audio files" + ) + + # return vars(parser.parse_args()) + args, _ = parser.parse_known_args() + return vars(args) + + +def main(args: dict = None) -> None: + """Generate audio files from microphone input. + + Args: + args (dict): dictionary of command line arguments. + """ + + # Adapted from https://github.com/Uberi/speech_recognition/blob/master/examples/write_audio.py + + rclpy.init(args=args) + + parser_args = parse_args() + + mic_index = parser_args["microphone"] + output_dir = parser_args["output_dir"] + + r = sr.Recognizer() + r.pause_threshold = 2 + microphone = sr.Microphone(device_index=mic_index, sample_rate=16000) + with microphone as source: + print("Say something!") + audio = r.listen(source, timeout=5, phrase_time_limit=10) + print("Finished listening") + + with open(os.path.join(output_dir, "microphone.raw"), "wb") as f: + f.write(audio.get_raw_data()) + + with open(os.path.join(output_dir, "microphone.wav"), "wb") as f: + f.write(audio.get_wav_data()) + + with open(os.path.join(output_dir, "microphone.flac"), "wb") as f: + f.write(audio.get_flac_data()) + + with open(os.path.join(output_dir, "microphone.aiff"), "wb") as f: + f.write(audio.get_aiff_data()) + + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py new file mode 100755 index 000000000..2448e73ec --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py @@ -0,0 +1,67 @@ +#!/usr/bin python3 +import rclpy +from rclpy.node import Node +from rclpy.action import ActionClient +from lasr_speech_recognition_interfaces.srv import TranscribeAudio # type: ignore +from lasr_speech_recognition_interfaces.action import TranscribeSpeech + +# https://docs.ros2.org/latest/api/rclpy/api/actions.html + + +class TestSpeechServerClient(Node): + def __init__(self): + Node.__init__(self, "listen_action_client") + + self.client = ActionClient(self, TranscribeSpeech, "transcribe_speech") + self.goal_future = None + self.result_future = None + + def send_goal(self, goal): + self.get_logger().info("Waiting for Whisper server...") + self.client.wait_for_server() + self.get_logger().info("Server activated, sending goal...") + + self.goal_future = self.client.send_goal_async( + goal, feedback_callback=self.feedback_cb + ) # Returns a Future instance when the goal request has been accepted or rejected. + self.goal_future.add_done_callback( + self.response_cb + ) # When received get response + + def feedback_cb(self, msg): + self.get_logger().info(f"Received feedback: {msg.feedback}") + + def response_cb(self, future): + handle = future.result() + if not handle.accepted: + self.get_logger().info("Goal was rejected") + return + + self.get_logger().info("Goal was accepted") + self.result_future = ( + handle.get_result_async() + ) # Not using get_result() in cb, as can cause deadlock according to docs + self.result_future.add_done_callback(self.result_cb) + + def result_cb(self, future): + result = future.result().result + self.get_logger().info(f"Transcribed Speech: {result.sequence}") + + +def main(args=None): + rclpy.init(args=args) + while rclpy.ok(): + goal = TranscribeSpeech.Goal() + client = TestSpeechServerClient() + try: + client.send_goal(goal) + rclpy.spin(client) + except KeyboardInterrupt: + client.get_logger().info("Shutting down...") + finally: + client.destroy_node() + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/setup.cfg b/common/speech/lasr_speech_recognition_whisper/setup.cfg new file mode 100644 index 000000000..1f6a54400 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir = $base/lib/lasr_speech_recognition_whisper +[install] +install_scripts = $base/lib/lasr_speech_recognition_whisper diff --git a/common/speech/lasr_speech_recognition_whisper/setup.py b/common/speech/lasr_speech_recognition_whisper/setup.py new file mode 100644 index 000000000..c6a801483 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/setup.py @@ -0,0 +1,38 @@ +from setuptools import find_packages, setup + +package_name = "lasr_speech_recognition_whisper" + +setup( + name=package_name, + version="0.0.0", + packages=find_packages(exclude=["test"]), + # packages=[package_name, f"{package_name}.lasr_speech_recognition_whisper", f"{package_name}.src"], + # package_dir={ + # '': '.', + # package_name: os.path.join(package_name), + # f"{package_name}.whisper": os.path.join(package_name, 'whisper'), + # f"{package_name}.src": os.path.join(package_name, 'src'), + # }, + data_files=[ + ("share/ament_index/resource_index/packages", ["resource/" + package_name]), + ("share/" + package_name, ["package.xml"]), + ], + install_requires=["setuptools"], + zip_safe=True, + maintainer="maayan", + maintainer_email="maayan.armony@gmail.com", + description="Speech recognition implemented using OpenAI Whisper", + license="MIT", + tests_require=["pytest"], + entry_points={ + "console_scripts": [ + "transcribe_microphone_server = lasr_speech_recognition_whisper.transcribe_microphone_server:main", + "transcribe_microphone = lasr_speech_recognition_whisper.transcribe_microphone:main", + "simple_transcribe_microphone = lasr_speech_recognition_whisper.simple_transcribe_microphone:main", + "list_microphones = scripts.list_microphones:main", + "microphone_tuning_test = scripts.microphone_tuning_test:main", + "test_microphones = scripts.test_microphones:main", + "test_speech_server = scripts.test_speech_server:main", + ], + }, +) diff --git a/common/speech/lasr_speech_recognition_whisper/src/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/__init__.py new file mode 100644 index 000000000..473b206b7 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/__init__.py @@ -0,0 +1,12 @@ +# from .collector import AbstractPhraseCollector, AudioTopicPhraseCollector, MicrophonePhraseCollector, RecognizerPhraseCollector +from .lasr_speech_recognition_whisper.collector import ( + AbstractPhraseCollector, + MicrophonePhraseCollector, + RecognizerPhraseCollector, +) +from .lasr_speech_recognition_whisper.worker import ( + SpeechRecognitionWorker, + SpeechRecognitionToStdout, + SpeechRecognitionToTopic, +) +from .lasr_speech_recognition_whisper.cache import ModelCache diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py new file mode 100644 index 000000000..372e26477 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py @@ -0,0 +1,12 @@ +# from .collector import AbstractPhraseCollector, AudioTopicPhraseCollector, MicrophonePhraseCollector, RecognizerPhraseCollector +from .collector import ( + AbstractPhraseCollector, + MicrophonePhraseCollector, + RecognizerPhraseCollector, +) +from .worker import ( + SpeechRecognitionWorker, + SpeechRecognitionToStdout, + SpeechRecognitionToTopic, +) +from .cache import ModelCache diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/bytesfifo.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/bytesfifo.py new file mode 100644 index 000000000..1f86b7ffc --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/bytesfifo.py @@ -0,0 +1,137 @@ +import io + + +class BytesFIFO(object): + """ + A FIFO that can store a fixed number of bytes. + https://github.com/hbock/byte-fifo/blob/master/fifo.py + """ + + def __init__(self, init_size): + """Create a FIFO of ``init_size`` bytes.""" + self._buffer = io.BytesIO(b"\x00" * init_size) + self._size = init_size + self._filled = 0 + self._read_ptr = 0 + self._write_ptr = 0 + + def read(self, size=-1): + """ + Read at most ``size`` bytes from the FIFO. + + If less than ``size`` bytes are available, or ``size`` is negative, + return all remaining bytes. + """ + if size < 0: + size = self._filled + + # Go to read pointer + self._buffer.seek(self._read_ptr) + + # Figure out how many bytes we can really read + size = min(size, self._filled) + contig = self._size - self._read_ptr + contig_read = min(contig, size) + + ret = self._buffer.read(contig_read) + self._read_ptr += contig_read + if contig_read < size: + leftover_size = size - contig_read + self._buffer.seek(0) + ret += self._buffer.read(leftover_size) + self._read_ptr = leftover_size + + self._filled -= size + + return ret + + def write(self, data): + """ + Write as many bytes of ``data`` as are free in the FIFO. + + If less than ``len(data)`` bytes are free, write as many as can be written. + Returns the number of bytes written. + """ + free = self.free() + write_size = min(len(data), free) + + if write_size: + contig = self._size - self._write_ptr + contig_write = min(contig, write_size) + # TODO: avoid 0 write + # TODO: avoid copy + # TODO: test performance of above + self._buffer.seek(self._write_ptr) + self._buffer.write(data[:contig_write]) + self._write_ptr += contig_write + + if contig < write_size: + self._buffer.seek(0) + self._buffer.write(data[contig_write:write_size]) + # self._buffer.write(buffer(data, contig_write, write_size - contig_write)) + self._write_ptr = write_size - contig_write + + self._filled += write_size + + return write_size + + def flush(self): + """Flush all data from the FIFO.""" + self._filled = 0 + self._read_ptr = 0 + self._write_ptr = 0 + + def empty(self): + """Return ```True``` if FIFO is empty.""" + return self._filled == 0 + + def full(self): + """Return ``True`` if FIFO is full.""" + return self._filled == self._size + + def free(self): + """Return the number of bytes that can be written to the FIFO.""" + return self._size - self._filled + + def capacity(self): + """Return the total space allocated for this FIFO.""" + return self._size + + def __len__(self): + """Return the amount of data filled in FIFO""" + return self._filled + + def __nonzero__(self): + """Return ```True``` if the FIFO is not empty.""" + return self._filled > 0 + + def resize(self, new_size): + """ + Resize FIFO to contain ``new_size`` bytes. If FIFO currently has + more than ``new_size`` bytes filled, :exc:`ValueError` is raised. + If ``new_size`` is less than 1, :exc:`ValueError` is raised. + + If ``new_size`` is smaller than the current size, the internal + buffer is not contracted (yet). + """ + if new_size < 1: + raise ValueError("Cannot resize to zero or less bytes.") + + if new_size < self._filled: + raise ValueError( + "Cannot contract FIFO to less than {} bytes, " + "or data will be lost.".format(self._filled) + ) + + # original data is non-contiguous. we need to copy old data, + # re-write to the beginning of the buffer, and re-sync + # the read and write pointers. + if self._read_ptr >= self._write_ptr: + old_data = self.read(self._filled) + self._buffer.seek(0) + self._buffer.write(old_data) + self._filled = len(old_data) + self._read_ptr = 0 + self._write_ptr = self._filled + + self._size = new_size diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py new file mode 100644 index 000000000..259ffffa5 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py @@ -0,0 +1,57 @@ +import os +import whisper # type: ignore +from ament_index_python import packages +from rclpy.node import Node + +# Keep all loaded models in memory +MODEL_CACHE = {} + + +class ModelCache(Node): + def __init__(self): + super().__init__("lasr_speech_recognition_whisper_cache") + + def load_model( + self, name: str, device: str = "cpu", load_test_file: bool = False + ) -> whisper.Whisper: + """Loads a whisper model from disk, or from cache if it has already been loaded. + + Args: + name (str): Name of the whisper model. Must be the name of an official whisper + model, or the path to a model checkpoint. + device (str, optional): Pytorch device to put the model on. Defaults to 'cpu'. + load_test_file (bool, optional): Whether to run inference on a test audio file + after loading the model (if model is not in cache). Defaults to False. Test file + is assumed to be called "test.m4a" and be in the root of the package directory. + + Returns: + whisper.Whisper: Whisper model instance + """ + global MODEL_CACHE + + if name not in MODEL_CACHE: + self.get_logger().info(f"Loading model {name}") + MODEL_CACHE[name] = whisper.load_model(name, device=device) + self.get_logger().info(f"Sucessfully loaded model {name} on {device}") + if load_test_file: + package_install = packages.get_package_prefix( + "lasr_speech_recognition_whisper" + ) + package_root = os.path.abspath( + os.path.join( + package_install, + os.pardir, + os.pardir, + "lasr_speech_recognition_whisper", + ) + ) + example_fp = os.path.join(package_root, "test.m4a") + self.get_logger().info( + "Running transcription on example file to ensure model is loaded..." + ) + test_result: str = MODEL_CACHE[name].transcribe( + example_fp, fp16=device == "cuda" + ) + self.get_logger().info(f"Transcription test result: {test_result}") + + return MODEL_CACHE[name] diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py new file mode 100644 index 000000000..d8c5fbea4 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py @@ -0,0 +1,131 @@ +import rclpy +from rclpy.node import Node +import speech_recognition as sr + +from queue import Queue +from abc import ABC, abstractmethod + + +class AbstractPhraseCollector(ABC): + """ + Supertype holding a queue of audio data representing a phrase + """ + + data: Queue[bytes] = Queue() + + @abstractmethod + def start(self): + """ + Start collecting phrases + """ + pass + + @abstractmethod + def stop(self): + """ + Stop collecting phrases + """ + pass + + @abstractmethod + def sample_rate(self): + """ + Sample rate of the data + """ + pass + + @abstractmethod + def sample_width(self): + """ + Sample width of the data + """ + pass + + +class RecognizerPhraseCollector(AbstractPhraseCollector, Node): + """ + Collect phrases using a SoundRecognition Recognizer + + This will monitor energy levels on the input and only + capture when a certain threshold of activity is met. + """ + + _recorder: sr.Recognizer + _phrase_time_limit: float + + def _record_callback(self, _, audio: sr.AudioData) -> None: + """ + Collect raw audio data from the microphone + """ + self.data.put(audio.get_raw_data()) + + def __init__( + self, energy_threshold: int = 500, phrase_time_limit: float = 2 + ) -> None: + super().__init__("collector") + # Node.__init__(self, "collector") + + self._recorder = sr.Recognizer() + self._recorder.dynamic_energy_threshold = False + self._recorder.energy_threshold = energy_threshold + self._phrase_time_limit = phrase_time_limit + + @abstractmethod + def adjust_for_noise(self, source: sr.AudioSource): + self.get_logger().info("Adjusting for background noise...") + with source: + self._recorder.adjust_for_ambient_noise(source) + + @abstractmethod + def start(self, source: sr.AudioSource): + self.get_logger().info("Started source listen thread") + self._stopper = self._recorder.listen_in_background( + source, self._record_callback, phrase_time_limit=self._phrase_time_limit + ) + + def stop(self): + self._stopper() + + def sample_rate(self): + return self._source.SAMPLE_RATE + + def sample_width(self): + return self._source.SAMPLE_WIDTH + + +class MicrophonePhraseCollector(RecognizerPhraseCollector): + """ + Collect phrases from the default microphone + """ + + _source: sr.Microphone + + def __init__( + self, + energy_threshold: int = 500, + phrase_time_limit: float = 2, + device_index: int = None, + ) -> None: + self._source = sr.Microphone(device_index=device_index, sample_rate=16000) + super().__init__(energy_threshold, phrase_time_limit) + + def adjust_for_noise(self): + return super().adjust_for_noise(self._source) + + def start(self): + return super().start(self._source) + + +# class AudioTopicPhraseCollector(RecognizerPhraseCollector): +# ''' +# Collect phrases from an audio topic +# ''' + +# _source: AudioTopic + +# def __init__(self, topic: str, energy_threshold: int = 100, phrase_time_limit: float = 2) -> None: +# self._source = AudioTopic(topic) +# super().__init__(energy_threshold, phrase_time_limit) + +# def start(self): +# return super().start(self._source) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py new file mode 100644 index 000000000..e405ca8c0 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py @@ -0,0 +1,74 @@ +import rclpy +from rclpy.node import Node +import pyaudio +import speech_recognition as sr + +from audio_common_msgs.msg import AudioInfo, AudioData + +from .bytesfifo import BytesFIFO + +# TODO rospy.wait_for_message() + + +class AudioTopic(sr.AudioSource, Node): + """ + Use a ROS topic as an AudioSource + """ + + _topic: str + # _sub: node.create_subscription TODO add type if possible + + def __init__(self, topic: str, chunk_size=1024) -> None: + Node.__init__(self, "source") + + self._topic = topic + self.subscription = self.create_subscription( + AudioInfo, f"{topic}/audio_info", self.callback, 10 + ) + # config: AudioInfo = rospy.wait_for_message(f"{topic}/audio_info", AudioInfo) + self.config = None # TODO test that this works + if self.config is not None: + assert self.config.coding_format == "wave", "Expected Wave audio format" + assert self.config.sample_format == "S16LE", "Expected sample format S16LE" + self.get_logger().info(self.config) + + self.SAMPLE_WIDTH = pyaudio.get_sample_size(pyaudio.paInt16) + self.SAMPLE_RATE = self.config.sample_rate + + self.CHUNK = chunk_size + self.stream = None + + def callback(self, msg): + self.get_logger().info("Message received") + self.config = msg + + def __enter__(self): + """ + Start stream when entering with: block + """ + + assert ( + self.stream is None + ), "This audio source is already inside a context manager" + self.stream = BytesFIFO(1024 * 10) # 10 kB buffer + self._sub = self.node.create_subscription( + AudioData, f"{self._topic}/audio", self._read + ) + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ + Close out stream on exit + """ + + self.stream = None + self.destroy_subscription( + self._sub + ) # TODO behaviour, was self._sub.unregister() + + def _read(self, msg: AudioData) -> None: + """ + Forward raw audio data to queue + """ + + self.stream.write(msg.data) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py new file mode 100644 index 000000000..43eac780b --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py @@ -0,0 +1,207 @@ +import torch + +from rclpy.node import Node +from rclpy.publisher import Publisher + +import whisper +import speech_recognition as sr + +from io import BytesIO +from time import sleep +from threading import Thread +from abc import ABC, abstractmethod +from tempfile import NamedTemporaryFile +from datetime import datetime, timedelta + +from .collector import AbstractPhraseCollector + +from lasr_speech_recognition_interfaces.msg import Transcription + + +class SpeechRecognitionWorker(ABC, Node): + """ + Collect and run inference on phrases to produce a transcription + """ + + _collector: AbstractPhraseCollector + _tmp_file: NamedTemporaryFile + _model: whisper.Whisper + _current_sample: bytes + _phrase_start: datetime + _maximum_phrase_length: timedelta | None + _infer_partial: bool + _stopped = True + + def __init__( + self, + collector: AbstractPhraseCollector, + model: whisper.Whisper, + maximum_phrase_length=timedelta(seconds=3), + infer_partial=True, + ) -> None: + Node.__init__(self, "worker") + self._collector = collector + self._tmp_file = NamedTemporaryFile().name + self._model = model + self._current_sample = bytes() + self._phrase_start = None + self._maximum_phrase_length = maximum_phrase_length + self._infer_partial = infer_partial + + @abstractmethod + def on_phrase(self, phrase: str, finished: bool) -> None: + """ + Handle a partial or complete transcription + """ + pass + + def _finish_phrase(self): + """ + Complete the current phrase and clear the sample + """ + + text = self._perform_inference() + if text is not None: + self.on_phrase(text, True) + + self._current_sample = bytes() + self._phrase_start = None + + def _perform_inference(self): + """ + Run inference on the current sample + """ + + self.get_logger().info("Processing sample") + audio_data = sr.AudioData( + self._current_sample, + self._collector.sample_rate(), + self._collector.sample_width(), + ) + wav_data = BytesIO(audio_data.get_wav_data()) + + with open(self._tmp_file, "w+b") as f: + f.write(wav_data.read()) + + self.get_logger().info("Running inference") + try: + result = self._model.transcribe( + self._tmp_file, fp16=torch.cuda.is_available() + ) + except RuntimeError: + return None + text = result["text"].strip() + + # Detect and drop garbage + if len(text) == 0 or text.lower() in [".", "you", "thanks for watching!"]: + self._phrase_start = None + self._current_sample = bytes() + self.get_logger().info("Skipping garbage...") + return None + + return text + + def _worker(self): + """ + Indefinitely perform inference on the given data + """ + + self.get_logger().info("Started inference worker") + + while not self._stopped: + try: + # Check whether the current phrase has timed out + now = datetime.utcnow() + if ( + self._phrase_start + and now - self._phrase_start > self._maximum_phrase_length + ): + self.get_logger().info("Reached timeout for phrase, ending now.") + self._finish_phrase() + + # Start / continue phrase if data is coming in + if not self._collector.data.empty(): + self._phrase_start = datetime.utcnow() + + # Concatenate new data with current sample + while not self._collector.data.empty(): + self._current_sample += self._collector.data.get() + + self.get_logger().info( + "Received and added more data to current audio sample." + ) + + # Run inference on partial sample if enabled + if self._infer_partial: + text = self._perform_inference() + + # Handle partial transcription + if text is not None: + self.on_phrase(text, False) + + sleep(0.2) + except KeyboardInterrupt: + self._stopped = True + + self.get_logger().info("Worker finished") + + def start(self): + """ + Start performing inference on incoming data + """ + + assert self._stopped, "Already running inference" + self._stopped = False + self._collector.start() + worker_thread = Thread(target=self._worker) + worker_thread.start() + + def stop(self): + """ + Stop the worker from running inference + """ + + assert not self._stopped, "Not currently running" + self._collector.stop() + self._stopped = True + + # clear next phrase + self._current_sample = bytes() + while not self._collector.data.empty(): + self._current_sample += self._collector.data.get() + + +class SpeechRecognitionToStdout(SpeechRecognitionWorker): + """ + Recognise speech and pass it through to standard output + """ + + def on_phrase(self, phrase: str, finished: bool) -> None: + self.get_logger().info("[" + ("x" if finished else " ") + "] " + phrase) + + +class SpeechRecognitionToTopic(SpeechRecognitionToStdout): + """ + Recognise speech and publish it to a topic + """ + + _pub: Publisher + + def __init__( + self, + collector: AbstractPhraseCollector, + model: whisper.Whisper, + topic: str, + maximum_phrase_length=timedelta(seconds=1), + infer_partial=True, + ) -> None: + super().__init__(collector, model, maximum_phrase_length, infer_partial) + self.get_logger().info(f"Will be publishing transcription to {topic}") + self._pub = self.create_publisher(Transcription, topic, 5) + + def on_phrase(self, phrase: str, finished: bool) -> None: + super().on_phrase(phrase, finished) + msg = Transcription() + msg.phrase = phrase + msg.finished = finished + self._pub.publish(msg) diff --git a/common/speech/lasr_speech_recognition_whisper/test.m4a b/common/speech/lasr_speech_recognition_whisper/test.m4a new file mode 100644 index 000000000..1fbef3f08 Binary files /dev/null and b/common/speech/lasr_speech_recognition_whisper/test.m4a differ diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py new file mode 100644 index 000000000..ceffe896d --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py @@ -0,0 +1,27 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_copyright.main import main +import pytest + + +# Remove the `skip` decorator once the source file(s) have a copyright header +@pytest.mark.skip( + reason="No copyright header has been placed in the generated source file." +) +@pytest.mark.copyright +@pytest.mark.linter +def test_copyright(): + rc = main(argv=[".", "test"]) + assert rc == 0, "Found errors" diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py new file mode 100644 index 000000000..ee79f31ac --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py @@ -0,0 +1,25 @@ +# Copyright 2017 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_flake8.main import main_with_errors +import pytest + + +@pytest.mark.flake8 +@pytest.mark.linter +def test_flake8(): + rc, errors = main_with_errors(argv=[]) + assert rc == 0, "Found %d code style errors / warnings:\n" % len( + errors + ) + "\n".join(errors) diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py new file mode 100644 index 000000000..a2c3deb8e --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py @@ -0,0 +1,23 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_pep257.main import main +import pytest + + +@pytest.mark.linter +@pytest.mark.pep257 +def test_pep257(): + rc = main(argv=[".", "test"]) + assert rc == 0, "Found code style errors / warnings"