-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement a default OPTIONS handler and complement the handler for HTTP 405 in router #743
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -2,7 +2,7 @@ use route_recognizer::{Match, Params, Router as MethodRouter}; | |||
use std::collections::HashMap; | ||||
|
||||
use crate::endpoint::DynEndpoint; | ||||
use crate::{Request, Response, StatusCode}; | ||||
use crate::{http::headers, http::Method, Request, Response, StatusCode}; | ||||
|
||||
/// The routing table used by `Server` | ||||
/// | ||||
|
@@ -71,11 +71,20 @@ impl<State: Clone + Send + Sync + 'static> Router<State> { | |||
.filter(|(k, _)| **k != method) | ||||
.any(|(_, r)| r.recognize(path).is_ok()) | ||||
{ | ||||
// If this `path` can be handled by a callback registered with a different HTTP method | ||||
// should return 405 Method Not Allowed | ||||
// If this `path` can be handled by a callback registered with a different HTTP method, | ||||
// the server should return 405 Method Not Allowed. | ||||
// Or for an OPTIONS request, it should response with a success and supported methods. | ||||
let supported_methods = self.get_supported_methods(path).join(", "); | ||||
let mut params = Params::new(); | ||||
params.insert(String::from(SUPPORTED_METHODS_PARAM_KEY), supported_methods); | ||||
// TODO: How to pass a closure as the endpoint here? | ||||
Selection { | ||||
endpoint: &method_not_allowed, | ||||
params: Params::new(), | ||||
endpoint: if method == Method::Options { | ||||
&http_options_endpoint | ||||
} else { | ||||
&method_not_allowed_endpoint | ||||
}, | ||||
params: params, | ||||
} | ||||
} else { | ||||
Selection { | ||||
|
@@ -84,6 +93,37 @@ impl<State: Clone + Send + Sync + 'static> Router<State> { | |||
} | ||||
} | ||||
} | ||||
|
||||
/// Get supported methods for a target resource path | ||||
fn get_supported_methods<'a>(&'a self, path: &'a str) -> Vec<&str> { | ||||
let basic_methods: &[&str]; // implicitly supported methods not registered in the map | ||||
if !self | ||||
.method_map | ||||
.get(&Method::Head) | ||||
.and_then(|r| r.recognize(path).ok()) | ||||
.is_some() | ||||
&& self | ||||
.method_map | ||||
.get(&Method::Get) | ||||
.and_then(|r| r.recognize(path).ok()) | ||||
.is_some() | ||||
{ | ||||
// If the endpoint has no handler for HEAD, but a handler for GET. | ||||
basic_methods = &["OPTIONS", "HEAD"]; | ||||
} else { | ||||
basic_methods = &["OPTIONS"]; | ||||
} | ||||
let registered_methods = self | ||||
.method_map | ||||
.iter() | ||||
.filter(|(_, r)| r.recognize(path).is_ok()) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is inevitable to query every |
||||
.map(|(m, _)| m.as_ref()); | ||||
basic_methods | ||||
.iter() | ||||
.map(|&s| s) | ||||
.chain(registered_methods) | ||||
.collect::<Vec<&str>>() | ||||
} | ||||
} | ||||
|
||||
async fn not_found_endpoint<State: Clone + Send + Sync + 'static>( | ||||
|
@@ -92,8 +132,129 @@ async fn not_found_endpoint<State: Clone + Send + Sync + 'static>( | |||
Ok(Response::new(StatusCode::NotFound)) | ||||
} | ||||
|
||||
async fn method_not_allowed<State: Clone + Send + Sync + 'static>( | ||||
_req: Request<State>, | ||||
pub(crate) const SUPPORTED_METHODS_PARAM_KEY: &'static str = "_SUPPORTED_METHODS"; | ||||
|
||||
/// The endpoint that responses with HTTP status `405 Method Not Allowed` | ||||
/// | ||||
/// The comma-seperated list of supported methods to be set in the HTTP header `Allow` will be | ||||
/// extracted from the request param named [`SUPPORTED_METHODS_PARAM_KEY`]. | ||||
/// Ref: [Section 6.5.5 of IETC RFC 7231](https://tools.ietf.org/html/rfc7231#section-6.5.5). | ||||
async fn method_not_allowed_endpoint<State: Clone + Send + Sync + 'static>( | ||||
req: Request<State>, | ||||
) -> crate::Result { | ||||
let mut resp = Response::new(StatusCode::MethodNotAllowed); | ||||
if let Some(supported_methods) = req.param(SUPPORTED_METHODS_PARAM_KEY).ok() { | ||||
resp.insert_header(headers::ALLOW, supported_methods); | ||||
} | ||||
Ok(resp) | ||||
} | ||||
|
||||
/// The default handler for the HTTP `OPTIONS` method, only meant for listing supported methods | ||||
/// | ||||
/// The comma-separated list of allowed methods to be set in the HTTP header `Allow` will be | ||||
/// extracted from the request param named [`SUPPORTED_METHODS_PARAM_KEY`]. | ||||
/// For CORS preflight requests (i.e. the HTTP header `Origin` is set), it is expected be overrided | ||||
/// by CORSMiddleware, if the latter is activated. | ||||
async fn http_options_endpoint<State: Clone + Send + Sync + 'static>( | ||||
req: Request<State>, | ||||
) -> crate::Result { | ||||
Ok(Response::new(StatusCode::MethodNotAllowed)) | ||||
let mut resp = Response::new(StatusCode::NoContent); | ||||
if let Some(supported_methods) = req.param(SUPPORTED_METHODS_PARAM_KEY).ok() { | ||||
resp.insert_header(headers::ALLOW, supported_methods); | ||||
} | ||||
Ok(resp) | ||||
} | ||||
|
||||
#[cfg(test)] | ||||
mod test { | ||||
use crate::http::{self, Method, Request, StatusCode, Url}; | ||||
use crate::security::{CorsMiddleware, Origin}; | ||||
use crate::Response; | ||||
use http_types::headers::HeaderValue; | ||||
use std::collections::HashSet; | ||||
|
||||
#[async_std::test] | ||||
async fn default_handler_for_http_options() { | ||||
let mut app = crate::Server::new(); | ||||
app.at("/endpoint") | ||||
.get(|_| async { Ok("Hello, GET.") }) | ||||
.post(|_| async { Ok("Hello, POST.") }); | ||||
app.at("/pendoint").post(|_| async { Ok("Hello, POST.") }); | ||||
|
||||
let response: Response = app | ||||
.respond(Request::new( | ||||
Method::Options, | ||||
Url::parse("http://example.com/endpoint").unwrap(), | ||||
)) | ||||
.await | ||||
.unwrap(); | ||||
assert!(response.status().is_success()); | ||||
ensure_methods_allowed(&response, &["get", "head", "post", "options"], true); | ||||
|
||||
let response: Response = app | ||||
.respond(Request::new( | ||||
Method::Options, | ||||
Url::parse("http://example.com/pendoint").unwrap(), | ||||
)) | ||||
.await | ||||
.unwrap(); | ||||
assert!(response.status().is_success()); | ||||
ensure_methods_allowed(&response, &["options", "post"], true); | ||||
ensure_methods_allowed(&response, &["head"], false); | ||||
} | ||||
|
||||
#[async_std::test] | ||||
async fn return_status_405_if_method_not_allowed() { | ||||
let mut app = crate::Server::new(); | ||||
app.at("/endpoint") | ||||
.get(|_| async { Ok("Hello, GET.") }) | ||||
.post(|_| async { Ok("Hello, POST.") }); | ||||
|
||||
let response: Response = app | ||||
.respond(Request::new( | ||||
Method::Put, | ||||
Url::parse("http://example.com/endpoint").unwrap(), | ||||
)) | ||||
.await | ||||
.unwrap(); | ||||
assert_eq!(response.status(), StatusCode::MethodNotAllowed); | ||||
ensure_methods_allowed(&response, &["get", "post", "options"], true); | ||||
} | ||||
|
||||
#[async_std::test] | ||||
async fn options_overrided_for_cors_preflight() { | ||||
let mut app = crate::Server::new(); | ||||
app.at("/").get(|_| async { Ok("Hello, world.") }); | ||||
app.with( | ||||
CorsMiddleware::new() | ||||
.allow_methods("GET, POST, OPTIONS".parse::<HeaderValue>().unwrap()) | ||||
.allow_origin(Origin::Any), | ||||
); | ||||
|
||||
let self_origin = "example.org"; | ||||
let mut request = Request::new(Method::Options, Url::parse("http://example.com/").unwrap()); | ||||
request.append_header(http::headers::ORIGIN, self_origin); | ||||
let response: Response = app.respond(request).await.unwrap(); | ||||
let allowed_origin = response | ||||
.header(http::headers::ACCESS_CONTROL_ALLOW_ORIGIN) | ||||
.map(|origin| Origin::from(origin.as_str())); | ||||
assert_eq!(allowed_origin.unwrap(), Origin::from(self_origin)); | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation of Line 155 in b0eafba
So the newly added endpoint handler won't interfere with the existing CORS handling flow. |
||||
} | ||||
|
||||
fn ensure_methods_allowed(response: &Response, expected_methods: &[&str], positive: bool) { | ||||
let allowed_methods = response.header(http::headers::ALLOW).map(|methods| { | ||||
methods | ||||
.as_str() | ||||
.split(",") | ||||
.map(|method| method.trim().to_ascii_lowercase()) | ||||
.collect::<HashSet<String>>() | ||||
}); | ||||
let allowed_methods = allowed_methods.unwrap(); | ||||
for method in expected_methods | ||||
.iter() | ||||
.map(|&method| method.to_ascii_lowercase()) | ||||
{ | ||||
assert!(!positive ^ allowed_methods.contains(&method)); | ||||
} | ||||
} | ||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to pass a closure here as the endpoint. But it seems impossible as
Selection.endpoint
takes a reference?To circumvent the limitation, I am trying to pass it in the
params
from which it is then extracted in the endpoint function. I am not sure it is good or not as it is a little tricky and incurring unnecessary overhead. But if the way is acceptable, we can let theCorsMiddleware
also extract the supported methods to fillAccess-Control-Allow-Methods
in CORS preflight requests without refactoring a lot.