From 16bad5160a7ac4212dd5250257eb94673a9f7494 Mon Sep 17 00:00:00 2001 From: Bikal Lem Date: Sun, 11 Dec 2022 16:50:32 +0000 Subject: [PATCH] dns-client(eio): improve performance --- eio/client/dns_client_eio.ml | 114 ++++++++++++++++------------------- 1 file changed, 52 insertions(+), 62 deletions(-) diff --git a/eio/client/dns_client_eio.ml b/eio/client/dns_client_eio.ml index 2c8c7338..5ad6ed4d 100644 --- a/eio/client/dns_client_eio.ml +++ b/eio/client/dns_client_eio.ml @@ -30,7 +30,7 @@ module Transport : Dns_client.S type nonrec stack = stack type +'a io = 'a - type t = { + type t = { nameservers : Dns.proto * nameservers ; stack : stack ; timeout : Eio.Time.Timeout.t ; @@ -38,11 +38,11 @@ module Transport : Dns_client.S mutable ctx : (Dns.proto * context) option ; } - and context = { + and context = { t : t ; mutable requests : Cstruct.t Eio.Promise.u IM.t ; mutable ns_connection: ; - mutable buf : Cstruct.t ; + mutable recv_buf : Cstruct.t ; } (* DNS nameservers. *) @@ -161,10 +161,7 @@ module Transport : Dns_client.S let he, actions = Happy_eyeballs.event he (clock ()) event in he_handle_actions t he actions end - | Connect_failed _ -> - fun () -> - Log.debug (fun m -> m "[he_handle_actions] connection failed"); - None + | Connect_failed _ -> fun () -> None | Connect_cancelled _ | Resolve_a _ | Resolve_aaaa _ as a -> fun () -> Log.warn (fun m -> m "[he_handle_actions] ignoring action %a" Happy_eyeballs.pp_action a); @@ -185,7 +182,6 @@ module Transport : Dns_client.S | Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m) let rec connect t = - Log.debug (fun m -> m "connect : establishing connection to nameservers"); match t.ctx, t.ns_connection_condition with | Some ctx, _ -> Ok ctx | None, Some condition -> @@ -209,16 +205,17 @@ module Transport : Dns_client.S let config = Tls.Config.(client ~authenticator ()) in (Tls_eio.client_of_flow config conn :> Eio.Flow.two_way) in - let context = + let ctx = { t = t ; requests = IM.empty ; ns_connection = conn - ; buf = Cstruct.empty + ; recv_buf = Cstruct.create 2048 } in - t.ctx <- Some (`Tcp, context); + t.ctx <- Some (`Tcp, ctx); + Eio.Fiber.fork ~sw:ctx.t.stack.sw ( fun () -> recv_dns_packets ctx ); Eio.Condition.broadcast ns_connection_condition; - Ok (`Tcp, context) + Ok (`Tcp, ctx) | None -> t.ns_connection_condition <- None; Eio.Condition.broadcast ns_connection_condition; @@ -231,47 +228,46 @@ module Transport : Dns_client.S Error (`Msg error_msg) end - let recv_data t flow id : unit = - let buf = Cstruct.create 512 in - Log.debug (fun m -> m "recv_data (%X): t.buf.len %d" id (Cstruct.length t.buf)); - let got = Eio.Flow.single_read flow buf in - Log.debug (fun m -> m "recv_data (%X): got %d" id got); - let buf = Cstruct.sub buf 0 got in - t.buf <- if Cstruct.length t.buf = 0 then buf else Cstruct.append t.buf buf; - Log.debug (fun m -> m "recv_data (%X): t.buf.len %d" id (Cstruct.length t.buf)) + and recv_dns_packets ?(recv_data = Cstruct.empty) (ctx : context) = - let rec recv_packet t ns_connection request_id = - Log.debug (fun m -> m "recv_packet (%X)" request_id); - let buf_len = Cstruct.length t.buf in - if buf_len > 2 then ( - let packet_len = Cstruct.BE.get_uint16 t.buf 0 in - Log.debug (fun m -> m "recv_packet (%X): packet_len %d" request_id (Cstruct.length t.buf)); - if buf_len - 2 >= packet_len then - let packet, rest = - if buf_len - 2 = packet_len - then t.buf, Cstruct.empty - else Cstruct.split t.buf (packet_len + 2) - in - t.buf <- rest; - let response_id = Cstruct.BE.get_uint16 packet 2 in - Log.debug (fun m -> m "recv_packet (%X): got response %X" request_id response_id); - if response_id = request_id - then packet - else begin - (match IM.find response_id t.requests with - | r -> Eio.Promise.resolve r packet - | exception Not_found -> ()); - recv_packet t ns_connection request_id - end - else begin - recv_data t ns_connection request_id; - recv_packet t ns_connection request_id - end - ) - else begin - recv_data t ns_connection request_id; - recv_packet t ns_connection request_id - end + let append_recv_buf ctx got recv_data = + let buf = Cstruct.sub ctx.recv_buf 0 got in + if Cstruct.is_empty recv_data + then buf + else Cstruct.append recv_data buf + in + + let rec handle_data recv_data = + let recv_data_len = Cstruct.length recv_data in + if recv_data_len < 2 + then recv_dns_packets ~recv_data ctx + else + match Cstruct.BE.get_uint16 recv_data 0 with + | packet_len when recv_data_len - 2 >= packet_len -> + let packet, recv_data = Cstruct.split recv_data @@ packet_len + 2 in + let response_id = Cstruct.BE.get_uint16 packet 2 in + (match IM.find response_id ctx.requests with + | r -> + ctx.requests <- IM.remove response_id ctx.requests ; + Eio.Promise.resolve r packet + | exception Not_found -> () (* spurious data, ignore *) + ); + if not @@ IM.is_empty ctx.requests then handle_data recv_data else () + | _ -> recv_dns_packets ~recv_data ctx + in + + match Eio.Flow.single_read ctx.ns_connection ctx.recv_buf with + | got -> + let recv_data = append_recv_buf ctx got recv_data in + handle_data recv_data + | exception End_of_file -> + ctx.t.ns_connection_condition <- None ; + ctx.t.ctx <- None ; + if not @@ IM.is_empty ctx.requests then + (match connect ctx.t with + | Ok _ -> recv_dns_packets ~recv_data ctx + | Error _ -> Log.warn (fun m -> m "[recv_dns_packets] connection closed while processing dns requests") ) + else () let validate_query_packet tx = if Cstruct.length tx > 4 then Ok () else @@ -281,22 +277,16 @@ module Transport : Dns_client.S let* () = validate_query_packet packet in try let request_id = Cstruct.BE.get_uint16 packet 2 in + let response_p, response_r = Eio.Promise.create () in + ctx.requests <- IM.add request_id response_r ctx.requests; Eio.Time.Timeout.run_exn ctx.t.timeout (fun () -> Eio.Flow.write ctx.ns_connection [packet]; - Log.debug (fun m -> m "send_recv (%X): wrote request" request_id); - let response_p, response_r = Eio.Promise.create () in - ctx.requests <- IM.add request_id response_r ctx.requests; - let response = - Eio.Fiber.first - (fun () -> recv_packet ctx ctx.ns_connection request_id) - (fun () -> Eio.Promise.await response_p) - in - Log.debug (fun m -> m "send_recv (%X): got response" request_id); + let response = Eio.Promise.await response_p in Ok response ) with | Eio.Time.Timeout -> Error (`Msg "DNS request timeout") - (* | exn -> Error (`Msg (Printexc.to_string exn)) *) + | End_of_file -> Error (`Msg "Nameserver closed connection") let close _ = () let bind a f = f a