diff --git a/Cargo.lock b/Cargo.lock index 415e38f..dda0df1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -189,6 +189,7 @@ dependencies = [ "serde", "sha2", "sqlx", + "tokio", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 4e48221..e2e360c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ rocket_db_pools = { version = "0.1.0", features = ["sqlx_postgres"] } serde = { version = "1.0.193", features = ["derive"] } sha2 = "0.10.8" sqlx = { version = "0.7.2", features = ["macros", "migrate", "chrono", "postgres", "runtime-tokio"], default-features = false } +tokio = { version = "1.36.0", features = ["sync"] } [profile.dev.package.sqlx-macros] # building the proc macros with optimisations speeds up their execution at compile time diff --git a/src/deezer.rs b/src/deezer.rs index b4cb9b8..e5d770a 100644 --- a/src/deezer.rs +++ b/src/deezer.rs @@ -1,9 +1,11 @@ //! A minimal wrapper for the parts of the Deezer API we care about. //! API documentation: -use eyre::{Context, Result}; +use eyre::{eyre, Context, Result}; use lazy_static::lazy_static; -use serde::{Deserialize, Serialize}; +use reqwest::RequestBuilder; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; // We don't actually use Rocket here, but Reqwest also uses this `Bytes` type and doesn't re-export it. +use crate::ratelimit::{Backoff, Ratelimit}; use rocket::http::hyper::body::Bytes; /// Base URL for the API. @@ -16,6 +18,77 @@ lazy_static! { headers.insert(reqwest::header::ACCEPT_LANGUAGE, "en".parse().unwrap()); reqwest::Client::builder().default_headers(headers).build().unwrap() }; + + /// A rate limiter for the Deezer API. + /// + /// Deezer currently allows 50 requests per 5 seconds, and this configuration + /// should align with that. + static ref RATELIMIT: Ratelimit = Ratelimit::new(50, std::time::Duration::from_millis(100)); +} + +/// Make a request to the Deezer API, respecting the rate limit and retrying +/// if we hit it or the service is busy, with exponential backoff. +async fn send_request(req: RequestBuilder) -> Result> { + let backoff = Backoff::new( + std::time::Duration::from_secs(1), + std::time::Duration::from_secs(300), + 2, + ); + for delay in backoff { + RATELIMIT.wait().await; + let response = req + .try_clone() + .expect("reqwest request cloning should not fail") + .send() + .await + .wrap_err("error sending request to Deezer")? + .error_for_status() + .wrap_err("Deezer API returned an HTTP error")? + .json() + .await + .wrap_err("error deserialising Deezer API response")?; + match response { + Response::Data(data) => return Ok(Some(data)), + Response::Error { error } => match error.code { + ErrorCode::Ratelimited | ErrorCode::ServiceBusy => { + eprintln!( + "Deezer API returned a temporary error, retrying in {}s: {error}", + delay.as_secs() + ); + tokio::time::sleep(delay).await; + } + ErrorCode::NotFound => return Ok(None), + ErrorCode::Unknown(_) => { + return Err(eyre!("Deezer API returned an unknown error: {error}")) + } + }, + } + } + unreachable!("backoff iterator should never end") +} + +/// An extension trait allowing for the use of a custom send method on +/// [`RequestBuilder`] with postfix syntax. +#[rocket::async_trait] +trait RequestBuilderExt { + /// Send a request to the Deezer API, returning `None` if the resource was not found. + async fn try_deezer_fetch(self) -> Result>; + + /// Send a request to the Deezer API, returning an error if the resource was not found. + async fn deezer_fetch(self) -> Result; +} + +#[rocket::async_trait] +impl RequestBuilderExt for RequestBuilder { + async fn try_deezer_fetch(self) -> Result> { + send_request(self).await + } + + async fn deezer_fetch(self) -> Result { + send_request(self) + .await + .and_then(|data| data.ok_or_else(|| eyre!("requested resource not found"))) + } } /// Fetch the "chart" (a list of popular tracks) for a given genre. @@ -24,12 +97,9 @@ pub async fn chart(genre_id: Id) -> Result> { let url = format!("{API_URL}/chart/{genre_id}/tracks"); let data: DataWrap<_> = CLIENT .get(&url) - .send() + .deezer_fetch() .await - .wrap_err("error fetching genre chart")? - .json() - .await - .wrap_err("error deserialising genre chart")?; + .wrap_err("error fetching genre chart")?; Ok(data.data) } @@ -46,12 +116,9 @@ pub async fn genres() -> Result> { let url = format!("{API_URL}/genre"); let genres = CLIENT .get(&url) - .send() + .deezer_fetch::>>() .await .wrap_err("error fetching genre list")? - .json::>>() - .await - .wrap_err("error deserialising genre list")? .data .into_iter() .filter(|genre| !GENRE_BLACKLIST.contains(&genre.id)) @@ -67,15 +134,11 @@ pub async fn genres() -> Result> { /// for another reason. pub async fn album(album_id: Id) -> Result { let url = format!("{API_URL}/album/{album_id}"); - let data = CLIENT + CLIENT .get(&url) - .send() - .await - .wrap_err("error fetching album")? - .json() + .deezer_fetch() .await - .wrap_err("error deserialising album")?; - Ok(data) + .wrap_err("error fetching album") } /// Search for a track by name or artist. @@ -84,12 +147,9 @@ pub async fn track_search(q: &str) -> Result> { let data: DataWrap<_> = CLIENT .get(&url) .query(&[("q", q)]) - .send() - .await - .wrap_err("error searching tracks")? - .json() + .deezer_fetch() .await - .wrap_err("error deserialising track search results")?; + .wrap_err("error searching tracks")?; Ok(data.data) } @@ -98,17 +158,14 @@ pub async fn track(id: Id) -> Result> { let url = format!("{API_URL}/track/{id}"); CLIENT .get(&url) - .send() + .try_deezer_fetch() .await - .wrap_err("error fetching track")? - .json() - .await - // assume that if we failed to deserialise the track, it was a "not found" response - .map_or_else(|_| Ok(None), |track| Ok(Some(track))) + .wrap_err("error fetching track") } /// Download a track from Deezer and save it to the music cache. pub async fn track_preview(preview_url: &str) -> Result { + // This isn't an API request so hopefully should be fine without ratelimiting. CLIENT .get(preview_url) .send() @@ -135,6 +192,62 @@ impl std::ops::Deref for DataWrap { } } +/// A response from Deezer which may be an error. +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum Response { + /// A successful response. + Data(T), + /// An error response. + Error { + /// The error object. + error: Error, + }, +} + +/// An error returned by the API. +#[derive(Debug, Deserialize)] +pub struct Error { + /// The error "type". + #[serde(rename = "type")] + kind: String, + /// The error message. + message: String, + /// The error code. + code: ErrorCode, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{} ({:?}): {}", self.kind, self.code, self.message) + } +} + +/// An error code from an API error. +#[derive(Debug, Deserialize)] +pub enum ErrorCode { + /// We have exceeded the request quota. + Ratelimited, + /// The service is busy and we should try again later. + ServiceBusy, + /// The requested resource was not found. + NotFound, + /// Other error codes exist, but we treat them all as unresolvable issues. + Unknown(u32), +} + +impl From for ErrorCode { + fn from(code: u32) -> Self { + // https://developers.deezer.com/api/errors + match code { + 4 => Self::Ratelimited, + 700 => Self::ServiceBusy, + 800 => Self::NotFound, + code => Self::Unknown(code), + } + } +} + /// An artist object returned by the API. #[derive(Debug, Deserialize)] pub struct Artist { diff --git a/src/main.rs b/src/main.rs index e3b786f..51d05e6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,6 +20,7 @@ mod api_error; mod database; mod deezer; mod game; +mod ratelimit; mod tasks; mod track; mod user; diff --git a/src/ratelimit.rs b/src/ratelimit.rs new file mode 100644 index 0000000..296df47 --- /dev/null +++ b/src/ratelimit.rs @@ -0,0 +1,78 @@ +//! Various ratelimiting utilities. +use std::{sync::Arc, time::Duration}; + +use tokio::sync::Semaphore; + +/// A simple "leaky bucket" rate limiter. This is intended to be used as a +/// long-lived singleton, and will spawn a background task. +/// +/// Based on [the example from the docs][1]. +/// +/// [1]: https://docs.rs/tokio/1.36.0/tokio/sync/struct.Semaphore.html#rate-limiting-using-a-token-bucket +pub struct Ratelimit(Arc); + +impl Ratelimit { + /// Set up the ratelimiter and start the background task. + /// + /// `max_requests` is the maximum number of requests to allow at once. + /// `increment_interval` is how long to wait before allowing another request. + pub fn new(max_requests: usize, increment_interval: Duration) -> Self { + let sem = Arc::new(Semaphore::new(max_requests)); + tokio::spawn({ + let sem = sem.clone(); + let mut interval = tokio::time::interval(increment_interval); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + async move { + loop { + interval.tick().await; + if sem.available_permits() < max_requests { + sem.add_permits(1); + } + } + } + }); + Self(sem) + } + + /// Acquire a permit to make a request. The future will resolve once a permit + /// is available. + pub async fn wait(&self) { + self.0 + .acquire() + .await + .expect("semaphore shouldn't be closed") + .forget(); + } +} + +/// An exponential backoff iterator. +pub struct Backoff { + /// The delay for the next iteration. + delay: Duration, + /// The maximum delay to allow. + max_delay: Duration, + /// The factor to multiply the delay by each time. + factor: u32, +} + +impl Backoff { + /// Create a new backoff iterator. + pub const fn new(initial_delay: Duration, max_delay: Duration, factor: u32) -> Self { + Self { + delay: initial_delay, + max_delay, + factor, + } + } +} + +impl Iterator for Backoff { + type Item = Duration; + + fn next(&mut self) -> Option { + let current = self.delay; + self.delay = (self.delay * self.factor).min(self.max_delay); + Some(current) + } +}