Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Apply SO_REUSEPORT to UDP sockets #5725

Merged
merged 14 commits into from
Oct 17, 2023
106 changes: 97 additions & 9 deletions listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,34 @@ func reuseUnixSocket(network, addr string) (any, error) {
return nil, nil
}

func listenTCPOrUnix(ctx context.Context, lnKey string, network, address string, config net.ListenConfig) (net.Listener, error) {
sharedLn, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) {
ln, err := config.Listen(ctx, network, address)
func listenReusable(ctx context.Context, lnKey string, network, address string, config net.ListenConfig) (any, error) {
switch network {
case "udp", "udp4", "udp6", "unixgram":
WeidiDeng marked this conversation as resolved.
Show resolved Hide resolved
sharedPc, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) {
pc, err := config.ListenPacket(ctx, network, address)
if err != nil {
return nil, err
}
return &sharedPacketConn{PacketConn: pc, key: lnKey}, nil
})
if err != nil {
return nil, err
}
return &sharedListener{Listener: ln, key: lnKey}, nil
})
if err != nil {
return nil, err
return &fakeClosePacketConn{sharedPacketConn: sharedPc.(*sharedPacketConn)}, nil

default:
sharedLn, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) {
ln, err := config.Listen(ctx, network, address)
if err != nil {
return nil, err
}
return &sharedListener{Listener: ln, key: lnKey}, nil
})
if err != nil {
return nil, err
}
return &fakeCloseListener{sharedListener: sharedLn.(*sharedListener), keepAlivePeriod: config.KeepAlive}, nil
}
return &fakeCloseListener{sharedListener: sharedLn.(*sharedListener), keepAlivePeriod: config.KeepAlive}, nil
}

// fakeCloseListener is a private wrapper over a listener that
Expand Down Expand Up @@ -98,7 +114,7 @@ func (fcl *fakeCloseListener) Accept() (net.Conn, error) {
// so that it's clear in the code that side-effects are shared with other
// users of this listener, not just our own reference to it; we also don't
// do anything with the error because all we could do is log it, but we
// expliclty assign it to nothing so we don't forget it's there if needed
// explicitly assign it to nothing so we don't forget it's there if needed
_ = fcl.sharedListener.clearDeadline()

if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
Expand Down Expand Up @@ -172,3 +188,75 @@ func (sl *sharedListener) setDeadline() error {
func (sl *sharedListener) Destruct() error {
return sl.Listener.Close()
}

// fakeClosePacketConn is like fakeCloseListener, but for PacketConns,
// or more specifically, *net.UDPConn
type fakeClosePacketConn struct {
closed int32 // accessed atomically; belongs to this struct only
*sharedPacketConn // embedded, so we also become a net.PacketConn; its key is used in Close
}

func (fcpc *fakeClosePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
// if the listener is already "closed", return error
if atomic.LoadInt32(&fcpc.closed) == 1 {
return 0, nil, &net.OpError{
Op: "readfrom",
Net: fcpc.LocalAddr().Network(),
Addr: fcpc.LocalAddr(),
Err: errFakeClosed,
}
}

// call underlying readfrom
n, addr, err = fcpc.sharedPacketConn.ReadFrom(p)
if err != nil {
// this server was stopped, so clear the deadline and let
// any new server continue reading; but we will exit
if atomic.LoadInt32(&fcpc.closed) == 1 {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
if err = fcpc.SetReadDeadline(time.Time{}); err != nil {
return
}
}
}
return
}

return
}

// Close won't close the underlying socket unless there is no more reference, then listenerPool will close it.
func (fcpc *fakeClosePacketConn) Close() error {
if atomic.CompareAndSwapInt32(&fcpc.closed, 0, 1) {
_ = fcpc.SetReadDeadline(time.Now()) // unblock ReadFrom() calls to kick old servers out of their loops
_, _ = listenerPool.Delete(fcpc.sharedPacketConn.key)
}
return nil
}

func (fcpc *fakeClosePacketConn) Unwrap() net.PacketConn {
return fcpc.sharedPacketConn.PacketConn
}

// sharedPacketConn is like sharedListener, but for net.PacketConns.
type sharedPacketConn struct {
net.PacketConn
key string
}

// Destruct closes the underlying socket.
func (spc *sharedPacketConn) Destruct() error {
return spc.PacketConn.Close()
}

// Unwrap returns the underlying socket
func (spc *sharedPacketConn) Unwrap() net.PacketConn {
return spc.PacketConn
}

// Interface guards (see https://github.com/caddyserver/caddy/issues/3998)
var (
_ (interface {
Unwrap() net.PacketConn
}) = (*fakeClosePacketConn)(nil)
)
75 changes: 72 additions & 3 deletions listen_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ package caddy
import (
"context"
"errors"
"io"
"io/fs"
"net"
"os"
"sync/atomic"
"syscall"

Expand Down Expand Up @@ -87,7 +89,7 @@ func reuseUnixSocket(network, addr string) (any, error) {
return nil, nil
}

func listenTCPOrUnix(ctx context.Context, lnKey string, network, address string, config net.ListenConfig) (net.Listener, error) {
func listenReusable(ctx context.Context, lnKey string, network, address string, config net.ListenConfig) (any, error) {
// wrap any Control function set by the user so we can also add our reusePort control without clobbering theirs
oldControl := config.Control
config.Control = func(network, address string, c syscall.RawConn) error {
Expand All @@ -103,7 +105,14 @@ func listenTCPOrUnix(ctx context.Context, lnKey string, network, address string,
// we still put it in the listenerPool so we can count how many
// configs are using this socket; necessary to ensure we can know
// whether to enforce shutdown delays, for example (see #5393).
ln, err := config.Listen(ctx, network, address)
var ln io.Closer
var err error
switch network {
case "udp", "udp4", "udp6", "unixgram":
ln, err = config.ListenPacket(ctx, network, address)
default:
ln, err = config.Listen(ctx, network, address)
}
if err == nil {
listenerPool.LoadOrStore(lnKey, nil)
}
Expand All @@ -117,9 +126,23 @@ func listenTCPOrUnix(ctx context.Context, lnKey string, network, address string,
unixSockets[lnKey] = ln.(*unixListener)
}

// TODO: Not 100% sure this is necessary, but we do this for net.UnixListener in listen_unix.go, so...
if unix, ok := ln.(*net.UnixConn); ok {
ln = &unixConn{unix, address, lnKey, &one}
unixSockets[lnKey] = ln.(*unixConn)
}

// lightly wrap the listener so that when it is closed,
// we can decrement the usage pool counter
return deleteListener{ln, lnKey}, err
switch specificLn := ln.(type) {
case net.Listener:
return deleteListener{specificLn, lnKey}, err
case net.PacketConn:
return deletePacketConn{specificLn, lnKey}, err
}

// other types, I guess we just return them directly
return ln, err
}

// reusePort sets SO_REUSEPORT. Ineffective for unix sockets.
Expand Down Expand Up @@ -158,6 +181,36 @@ func (uln *unixListener) Close() error {
return uln.UnixListener.Close()
}

type unixConn struct {
*net.UnixConn
filename string
mapKey string
count *int32 // accessed atomically
}

func (uc *unixConn) Close() error {
newCount := atomic.AddInt32(uc.count, -1)
if newCount == 0 {
defer func() {
unixSocketsMu.Lock()
delete(unixSockets, uc.mapKey)
unixSocketsMu.Unlock()
_ = syscall.Unlink(uc.filename)
}()
}
return uc.UnixConn.Close()
}

func (uc *unixConn) Unwrap() net.PacketConn {
return uc.UnixConn
}

// unixSockets keeps track of the currently-active unix sockets
// so we can transfer their FDs gracefully during reloads.
var unixSockets = make(map[string]interface {
File() (*os.File, error)
})

// deleteListener is a type that simply deletes itself
// from the listenerPool when it closes. It is used
// solely for the purpose of reference counting (i.e.
Expand All @@ -171,3 +224,19 @@ func (dl deleteListener) Close() error {
_, _ = listenerPool.Delete(dl.lnKey)
return dl.Listener.Close()
}

// deletePacketConn is like deleteListener, but
// for net.PacketConns.
type deletePacketConn struct {
net.PacketConn
lnKey string
}

func (dl deletePacketConn) Close() error {
_, _ = listenerPool.Delete(dl.lnKey)
return dl.PacketConn.Close()
}

func (dl deletePacketConn) Unwrap() net.PacketConn {
return dl.PacketConn
}
Loading
Loading