From 05bb688fced53156d14a86b12259a8bc7949d70c Mon Sep 17 00:00:00 2001 From: Landon James Date: Fri, 22 Nov 2024 20:03:35 -0800 Subject: [PATCH] Refactor get/set_telemetry_provider functions to return Result --- .../aws-smithy-observability/src/error.rs | 21 ++++++++- .../aws-smithy-observability/src/global.rs | 46 +++++++++++-------- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/rust-runtime/aws-smithy-observability/src/error.rs b/rust-runtime/aws-smithy-observability/src/error.rs index 34baec8831..84ec342060 100644 --- a/rust-runtime/aws-smithy-observability/src/error.rs +++ b/rust-runtime/aws-smithy-observability/src/error.rs @@ -21,8 +21,10 @@ pub struct ObservabilityError { #[non_exhaustive] #[derive(Debug)] pub enum ErrorKind { - /// An error setting the `GlobalTelemetryProvider`` + /// An error setting the global [crate::TelemetryProvider] SettingGlobalProvider, + /// An error getting the global [crate::TelemetryProvider] + GettingGlobalProvider, /// Error flushing metrics pipeline MetricsFlush, /// Error gracefully shutting down Metrics Provider @@ -54,7 +56,10 @@ impl fmt::Display for ObservabilityError { match &self.kind { ErrorKind::Other => write!(f, "unclassified error"), ErrorKind::SettingGlobalProvider => { - write!(f, "failed to set global telemetry provider") + write!(f, "failed to set global TelemetryProvider") + } + ErrorKind::GettingGlobalProvider => { + write!(f, "failed to get global TelemetryProvider") } ErrorKind::MetricsFlush => write!(f, "failed to flush metrics pipeline"), ErrorKind::MetricsShutdown => write!(f, "failed to shutdown metrics provider"), @@ -67,3 +72,15 @@ impl std::error::Error for ObservabilityError { Some(self.source.as_ref()) } } + +/// An simple error to represent issues with the global [crate::TelemetryProvider]. +#[derive(Debug)] +pub struct GlobalTelemetryProviderError; + +impl std::error::Error for GlobalTelemetryProviderError {} + +impl fmt::Display for GlobalTelemetryProviderError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "GlobalTelemetryProviderError") + } +} diff --git a/rust-runtime/aws-smithy-observability/src/global.rs b/rust-runtime/aws-smithy-observability/src/global.rs index 9bee722830..d3d9698cd7 100644 --- a/rust-runtime/aws-smithy-observability/src/global.rs +++ b/rust-runtime/aws-smithy-observability/src/global.rs @@ -11,7 +11,11 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::provider::{GlobalTelemetryProvider, TelemetryProvider}; +use crate::{ + error::{ErrorKind, GlobalTelemetryProviderError}, + provider::{GlobalTelemetryProvider, TelemetryProvider}, + ObservabilityError, +}; // Statically store the global provider static GLOBAL_TELEMETRY_PROVIDER: Lazy> = @@ -19,29 +23,33 @@ static GLOBAL_TELEMETRY_PROVIDER: Lazy> = /// Set the current global [TelemetryProvider]. /// -/// This is meant to be run once at the beginning of an application. It will panic if two threads -/// attempt to call it at the same time. -pub fn set_telemetry_provider(new_provider: TelemetryProvider) { - // TODO(smithyObservability): would probably be nicer to return a Result here, but the Guard held by the error from - // .try_write is not Send so I struggled to build an ObservabilityError from it - let mut old_provider = GLOBAL_TELEMETRY_PROVIDER - .try_write() - .expect("GLOBAL_TELEMETRY_PROVIDER RwLock Poisoned"); +/// This is meant to be run once at the beginning of an application. Will return an [Err] if the +/// [RwLock] holding the global [TelemetryProvider] is locked or poisoned. +pub fn set_telemetry_provider(new_provider: TelemetryProvider) -> Result<(), ObservabilityError> { + if let Ok(mut old_provider) = GLOBAL_TELEMETRY_PROVIDER.try_write() { + let new_global_provider = GlobalTelemetryProvider::new(new_provider); - let new_global_provider = GlobalTelemetryProvider::new(new_provider); + let _ = mem::replace(&mut *old_provider, new_global_provider); - let _ = mem::replace(&mut *old_provider, new_global_provider); + Ok(()) + } else { + Err(ObservabilityError::new( + ErrorKind::GettingGlobalProvider, + GlobalTelemetryProviderError, + )) + } } -/// Get an [Arc] reference to the current global [TelemetryProvider]. [None] is returned if the [RwLock] containing -/// the global [TelemetryProvider] is poisoned or is currently locked by a writer. -pub fn get_telemetry_provider() -> Option> { - // TODO(smithyObservability): would probably make more sense to return a Result rather than an Option here, but the Guard held by the error from - // .try_read is not Send so I struggled to build an ObservabilityError from it +/// Get an [Arc] reference to the current global [TelemetryProvider]. Will return an [Err] if the +/// [RwLock] holding the global [TelemetryProvider] is locked or poisoned. +pub fn get_telemetry_provider() -> Result, ObservabilityError> { if let Ok(tp) = GLOBAL_TELEMETRY_PROVIDER.try_read() { - Some(tp.telemetry_provider().clone()) + Ok(tp.telemetry_provider().clone()) } else { - None + Err(ObservabilityError::new( + ErrorKind::GettingGlobalProvider, + GlobalTelemetryProviderError, + )) } } @@ -59,7 +67,7 @@ mod tests { let my_provider = TelemetryProvider::default(); // Set the new counter and get a reference to the old one - set_telemetry_provider(my_provider); + set_telemetry_provider(my_provider).unwrap(); } #[test]