diff --git a/internal/allocation/allocation_manager.go b/internal/allocation/allocation_manager.go index 2b765921..ef1e2348 100644 --- a/internal/allocation/allocation_manager.go +++ b/internal/allocation/allocation_manager.go @@ -15,8 +15,8 @@ import ( // ManagerConfig a bag of config params for Manager. type ManagerConfig struct { LeveledLogger logging.LeveledLogger - AllocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) - AllocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error) + AllocatePacketConn func(network string, requestedPort int, username string) (net.PacketConn, net.Addr, error) + AllocateConn func(network string, requestedPort int, username string) (net.Conn, net.Addr, error) PermissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool } @@ -33,8 +33,8 @@ type Manager struct { allocations map[FiveTupleFingerprint]*Allocation reservations []*reservation - allocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) - allocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error) + allocatePacketConn func(network string, requestedPort int, username string) (net.PacketConn, net.Addr, error) + allocateConn func(network string, requestedPort int, username string) (net.Conn, net.Addr, error) permissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool } @@ -86,7 +86,7 @@ func (m *Manager) Close() error { } // CreateAllocation creates a new allocation and starts relaying -func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration) (*Allocation, error) { +func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration, username string) (*Allocation, error) { switch { case fiveTuple == nil: return nil, errNilFiveTuple @@ -105,7 +105,7 @@ func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketCo } a := NewAllocation(turnSocket, fiveTuple, m.log) - conn, relayAddr, err := m.allocatePacketConn("udp4", requestedPort) + conn, relayAddr, err := m.allocatePacketConn("udp4", requestedPort, username) if err != nil { return nil, err } @@ -180,9 +180,13 @@ func (m *Manager) GetReservation(reservationToken string) (int, bool) { } // GetRandomEvenPort returns a random un-allocated udp4 port -func (m *Manager) GetRandomEvenPort() (int, error) { +func (m *Manager) GetRandomEvenPort(username string) (int, error) { for i := 0; i < 128; i++ { - conn, addr, err := m.allocatePacketConn("udp4", 0) + conn, addr, err := m.allocatePacketConn("udp4", 0, username) + if err != nil { + return 0, err + } + if err != nil { return 0, err } diff --git a/internal/allocation/allocation_manager_test.go b/internal/allocation/allocation_manager_test.go index 014d85d3..77c4be11 100644 --- a/internal/allocation/allocation_manager_test.go +++ b/internal/allocation/allocation_manager_test.go @@ -7,6 +7,7 @@ package allocation import ( + "errors" "io" "math/rand" "net" @@ -19,10 +20,12 @@ import ( "github.com/stretchr/testify/assert" ) +var errUnexpectedTestUsername = errors.New("unexpected user name") + func TestManager(t *testing.T) { tt := []struct { name string - f func(*testing.T, net.PacketConn) + f func(*testing.T, net.PacketConn, string) }{ {"CreateInvalidAllocation", subTestCreateInvalidAllocation}, {"CreateAllocation", subTestCreateAllocation}, @@ -42,34 +45,34 @@ func TestManager(t *testing.T) { for _, tc := range tt { f := tc.f t.Run(tc.name, func(t *testing.T) { - f(t, turnSocket) + f(t, turnSocket, "test_user_1") }) } } // Test invalid Allocation creations -func subTestCreateInvalidAllocation(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestCreateInvalidAllocation(t *testing.T, turnSocket net.PacketConn, username string) { + m, err := newTestManager(username) assert.NoError(t, err) - if a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil { + if a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime, username); a != nil || err == nil { t.Errorf("Illegally created allocation with nil FiveTuple") } - if a, err := m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime); a != nil || err == nil { + if a, err := m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime, username); a != nil || err == nil { t.Errorf("Illegally created allocation with nil turnSocket") } - if a, err := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0); a != nil || err == nil { + if a, err := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0, username); a != nil || err == nil { t.Errorf("Illegally created allocation with 0 lifetime") } } // Test valid Allocation creations -func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn, username string) { + m, err := newTestManager(username) assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, username); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } @@ -79,26 +82,26 @@ func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn) { } // Test that two allocations can't be created with the same FiveTuple -func subTestCreateAllocationDuplicateFiveTuple(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestCreateAllocationDuplicateFiveTuple(t *testing.T, turnSocket net.PacketConn, username string) { + m, err := newTestManager(username) assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, username); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, username); a != nil || err == nil { t.Errorf("Was able to create allocation with same FiveTuple twice") } } -func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn, username string) { + m, err := newTestManager(username) assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, username); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } @@ -113,8 +116,8 @@ func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn) { } // Test that allocation should be closed if timeout -func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn, username string) { + m, err := newTestManager(username) assert.NoError(t, err) allocations := make([]*Allocation, 5) @@ -123,7 +126,7 @@ func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) { for index := range allocations { fiveTuple := randomFiveTuple() - a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, lifetime) + a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, lifetime, username) if err != nil { t.Errorf("Failed to create allocation with %v", fiveTuple) } @@ -141,15 +144,15 @@ func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) { } // Test for manager close -func subTestManagerClose(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestManagerClose(t *testing.T, turnSocket net.PacketConn, username string) { + m, err := newTestManager(username) assert.NoError(t, err) allocations := make([]*Allocation, 2) - a1, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second) + a1, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second, username) allocations[0] = a1 - a2, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute) + a2, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute, username) allocations[1] = a2 // Make a1 timeout @@ -174,12 +177,15 @@ func randomFiveTuple() *FiveTuple { } } -func newTestManager() (*Manager, error) { +func newTestManager(expectedUsername string) (*Manager, error) { loggerFactory := logging.NewDefaultLoggerFactory() config := ManagerConfig{ LeveledLogger: loggerFactory.NewLogger("test"), - AllocatePacketConn: func(string, int) (net.PacketConn, net.Addr, error) { + AllocatePacketConn: func(_ string, _ int, username string) (net.PacketConn, net.Addr, error) { + if username != expectedUsername { + return nil, nil, errUnexpectedTestUsername + } conn, err := net.ListenPacket("udp4", "0.0.0.0:0") if err != nil { return nil, nil, err @@ -187,8 +193,9 @@ func newTestManager() (*Manager, error) { return conn, conn.LocalAddr(), nil }, - AllocateConn: func(string, int) (net.Conn, net.Addr, error) { return nil, nil, nil }, + AllocateConn: func(string, int, string) (net.Conn, net.Addr, error) { return nil, nil, nil }, } + return NewManager(config) } @@ -197,11 +204,11 @@ func isClose(conn io.Closer) bool { return closeErr != nil && strings.Contains(closeErr.Error(), "use of closed network connection") } -func subTestGetRandomEvenPort(t *testing.T, _ net.PacketConn) { - m, err := newTestManager() +func subTestGetRandomEvenPort(t *testing.T, _ net.PacketConn, username string) { + m, err := newTestManager(username) assert.NoError(t, err) - port, err := m.GetRandomEvenPort() + port, err := m.GetRandomEvenPort(username) assert.NoError(t, err) assert.True(t, port > 0) assert.True(t, port%2 == 0) diff --git a/internal/allocation/allocation_test.go b/internal/allocation/allocation_test.go index 49269d68..f1d9e15b 100644 --- a/internal/allocation/allocation_test.go +++ b/internal/allocation/allocation_test.go @@ -259,9 +259,12 @@ func subTestAllocationClose(t *testing.T) { } func subTestPacketHandler(t *testing.T) { - network := "udp" + const ( + network = "udp" + testUsername = "test_user_2" + ) - m, _ := newTestManager() + m, _ := newTestManager(testUsername) // TURN server initialization turnSocket, err := net.ListenPacket(network, "127.0.0.1:0") @@ -292,7 +295,7 @@ func subTestPacketHandler(t *testing.T) { a, err := m.CreateAllocation(&FiveTuple{ SrcAddr: clientListener.LocalAddr(), DstAddr: turnSocket.LocalAddr(), - }, turnSocket, 0, proto.DefaultLifetime) + }, turnSocket, 0, proto.DefaultLifetime, testUsername) assert.Nil(t, err, "should succeed") diff --git a/internal/server/turn.go b/internal/server/turn.go index 46e45ecb..a56d3623 100644 --- a/internal/server/turn.go +++ b/internal/server/turn.go @@ -25,8 +25,8 @@ func handleAllocateRequest(r Request, m *stun.Message) error { // mechanism of [https://tools.ietf.org/html/rfc5389#section-10.2.2] // unless the client and server agree to use another mechanism through // some procedure outside the scope of this document. - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodAllocate) - if !hasAuth { + authResult, err := authenticateRequest(r, m, stun.MethodAllocate) + if !authResult.hasAuth { return err } @@ -51,7 +51,7 @@ func handleAllocateRequest(r Request, m *stun.Message) error { return buildAndSendErr(r.Conn, r.SrcAddr, errRelayAlreadyAllocatedForFiveTuple, msg...) } // A retry allocation - msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(attrs, messageIntegrity)...) + msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(attrs, authResult.messageIntegrity)...) return buildAndSend(r.Conn, r.SrcAddr, msg...) } @@ -104,7 +104,7 @@ func handleAllocateRequest(r Request, m *stun.Message) error { var evenPort proto.EvenPort if err = evenPort.GetFrom(m); err == nil { var randomPort int - randomPort, err = r.AllocationManager.GetRandomEvenPort() + randomPort, err = r.AllocationManager.GetRandomEvenPort(authResult.username) if err != nil { return buildAndSendErr(r.Conn, r.SrcAddr, err, insufficientCapacityMsg...) } @@ -131,7 +131,8 @@ func handleAllocateRequest(r Request, m *stun.Message) error { fiveTuple, r.Conn, requestedPort, - lifetimeDuration) + lifetimeDuration, + authResult.username) if err != nil { return buildAndSendErr(r.Conn, r.SrcAddr, err, insufficientCapacityMsg...) } @@ -177,7 +178,7 @@ func handleAllocateRequest(r Request, m *stun.Message) error { responseAttrs = append(responseAttrs, proto.ReservationToken([]byte(reservationToken))) } - msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(responseAttrs, messageIntegrity)...) + msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(responseAttrs, authResult.messageIntegrity)...) a.SetResponseCache(m.TransactionID, responseAttrs) return buildAndSend(r.Conn, r.SrcAddr, msg...) } @@ -185,8 +186,8 @@ func handleAllocateRequest(r Request, m *stun.Message) error { func handleRefreshRequest(r Request, m *stun.Message) error { r.Log.Debugf("Received RefreshRequest from %s", r.SrcAddr) - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodRefresh) - if !hasAuth { + authResult, err := authenticateRequest(r, m, stun.MethodRefresh) + if !authResult.hasAuth { return err } @@ -212,7 +213,7 @@ func handleRefreshRequest(r Request, m *stun.Message) error { &proto.Lifetime{ Duration: lifetimeDuration, }, - messageIntegrity, + authResult.messageIntegrity, }...)...) } @@ -228,8 +229,8 @@ func handleCreatePermissionRequest(r Request, m *stun.Message) error { return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr()) } - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodCreatePermission) - if !hasAuth { + authResult, err := authenticateRequest(r, m, stun.MethodCreatePermission) + if !authResult.hasAuth { return err } @@ -267,7 +268,7 @@ func handleCreatePermissionRequest(r Request, m *stun.Message) error { respClass = stun.ClassErrorResponse } - return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodCreatePermission, respClass), []stun.Setter{messageIntegrity}...)...) + return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodCreatePermission, respClass), []stun.Setter{authResult.messageIntegrity}...)...) } func handleSendIndication(r Request, m *stun.Message) error { @@ -317,8 +318,8 @@ func handleChannelBindRequest(r Request, m *stun.Message) error { badRequestMsg := buildMsg(m.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}) - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodChannelBind) - if !hasAuth { + authResult, err := authenticateRequest(r, m, stun.MethodChannelBind) + if !authResult.hasAuth { return err } @@ -351,7 +352,7 @@ func handleChannelBindRequest(r Request, m *stun.Message) error { return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } - return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse), []stun.Setter{messageIntegrity}...)...) + return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse), []stun.Setter{authResult.messageIntegrity}...)...) } func handleChannelData(r Request, c *proto.ChannelData) error { diff --git a/internal/server/turn_test.go b/internal/server/turn_test.go index e4a3b947..da181436 100644 --- a/internal/server/turn_test.go +++ b/internal/server/turn_test.go @@ -64,7 +64,7 @@ func TestAllocationLifeTime(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("turn") allocationManager, err := allocation.NewManager(allocation.ManagerConfig{ - AllocatePacketConn: func(network string, _ int) (net.PacketConn, net.Addr, error) { + AllocatePacketConn: func(network string, _ int, _ string) (net.PacketConn, net.Addr, error) { conn, listenErr := net.ListenPacket(network, "0.0.0.0:0") if err != nil { return nil, nil, listenErr @@ -72,7 +72,7 @@ func TestAllocationLifeTime(t *testing.T) { return conn, conn.LocalAddr(), nil }, - AllocateConn: func(string, int) (net.Conn, net.Addr, error) { + AllocateConn: func(string, int, string) (net.Conn, net.Addr, error) { return nil, nil, nil }, LeveledLogger: logger, @@ -97,7 +97,7 @@ func TestAllocationLifeTime(t *testing.T) { fiveTuple := &allocation.FiveTuple{SrcAddr: r.SrcAddr, DstAddr: r.Conn.LocalAddr(), Protocol: allocation.UDP} - _, err = r.AllocationManager.CreateAllocation(fiveTuple, r.Conn, 0, time.Hour) + _, err = r.AllocationManager.CreateAllocation(fiveTuple, r.Conn, 0, time.Hour, "") assert.NoError(t, err) assert.NotNil(t, r.AllocationManager.GetAllocation(fiveTuple)) diff --git a/internal/server/util.go b/internal/server/util.go index 7c01d329..89186572 100644 --- a/internal/server/util.go +++ b/internal/server/util.go @@ -42,14 +42,21 @@ func buildMsg(transactionID [stun.TransactionIDSize]byte, msgType stun.MessageTy return append([]stun.Setter{&stun.Message{TransactionID: transactionID}, msgType}, additional...) } -func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) (stun.MessageIntegrity, bool, error) { - respondWithNonce := func(responseCode stun.ErrorCode) (stun.MessageIntegrity, bool, error) { +type authenticationResult struct { + messageIntegrity stun.MessageIntegrity + username string + realm string + hasAuth bool +} + +func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) (authenticationResult, error) { + respondWithNonce := func(responseCode stun.ErrorCode) (authenticationResult, error) { nonce, err := r.NonceHash.Generate() if err != nil { - return nil, false, err + return authenticationResult{nil, "", "", false}, err } - return nil, false, buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, + return authenticationResult{nil, "", "", false}, buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(callingMethod, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: responseCode}, stun.NewNonce(nonce), @@ -70,11 +77,11 @@ func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) // Respond with 400 so clients don't retry if r.AuthHandler == nil { sendErr := buildAndSend(r.Conn, r.SrcAddr, badRequestMsg...) - return nil, false, sendErr + return authenticationResult{}, sendErr } if err := nonceAttr.GetFrom(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + return authenticationResult{}, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } // Assert Nonce is signed and is not expired @@ -83,21 +90,26 @@ func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) } if err := realmAttr.GetFrom(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + return authenticationResult{}, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } else if err := usernameAttr.GetFrom(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + return authenticationResult{}, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } ourKey, ok := r.AuthHandler(usernameAttr.String(), realmAttr.String(), r.SrcAddr) if !ok { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, fmt.Errorf("%w %s", errNoSuchUser, usernameAttr.String()), badRequestMsg...) + return authenticationResult{}, buildAndSendErr(r.Conn, r.SrcAddr, fmt.Errorf("%w %s", errNoSuchUser, usernameAttr.String()), badRequestMsg...) } if err := stun.MessageIntegrity(ourKey).Check(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + return authenticationResult{}, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } - return stun.MessageIntegrity(ourKey), true, nil + return authenticationResult{ + messageIntegrity: stun.MessageIntegrity(ourKey), + username: usernameAttr.String(), + realm: realmAttr.String(), + hasAuth: true, + }, nil } func allocationLifeTime(m *stun.Message) time.Duration { diff --git a/internal/server/util_test.go b/internal/server/util_test.go new file mode 100644 index 00000000..32aa6886 --- /dev/null +++ b/internal/server/util_test.go @@ -0,0 +1,354 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package server + +import ( + "net" + "testing" + + "github.com/pion/stun/v3" + "github.com/pion/turn/v4/internal/proto" + "github.com/stretchr/testify/require" +) + +func TestAuthenticateRequest(t *testing.T) { + const ( + testUsername = "test-user" + testRealm = "test-realm" + ) + testMsgIntegrity := stun.NewLongTermIntegrity(testUsername, testRealm, "pass") + + var conn net.PacketConn + var nonce string + var r *Request + + type options struct { + noAuthHandler bool + } + + setUp := func(t *testing.T, opts options) func() { + var err error + conn, err = net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + srcAddr := conn.LocalAddr() + + nonceHash, err := NewNonceHash() + require.NoError(t, err) + nonce, err = nonceHash.Generate() + require.NoError(t, err) + + r = &Request{ + Conn: conn, + SrcAddr: srcAddr, + AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { + return testMsgIntegrity, username == testUsername && realm == testRealm + }, + NonceHash: nonceHash, + } + if opts.noAuthHandler { + r.AuthHandler = nil + } + + return func() { + err = conn.Close() + if err != nil { + t.Errorf("failed to close connection: %v", err) + } + } + } + + t.Run("auth success", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + testMsgIntegrity, + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.NoError(t, err) + require.True(t, authResult.hasAuth) + require.Equal(t, testRealm, authResult.realm, "Realm value should be present in the result") + require.Equal(t, testUsername, authResult.username, "Username value should be present in the result") + }) + + t.Run("no message integrity", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + // Message integrity attribute is missing + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.NoError(t, err) + require.False(t, authResult.hasAuth) + + // Check the error response + buf := make([]byte, 1024) + n, _, err := conn.ReadFrom(buf) + require.NoError(t, err) + + var resp stun.Message + err = resp.UnmarshalBinary(buf[:n]) + require.NoError(t, err) + require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) + var attrErrorCode stun.ErrorCodeAttribute + err = attrErrorCode.GetFrom(&resp) + require.NoError(t, err) + require.Equal(t, stun.CodeUnauthorised, attrErrorCode.Code) + }) + + t.Run("no auth handler", func(t *testing.T) { + tearDown := setUp(t, options{noAuthHandler: true}) + defer tearDown() + + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + testMsgIntegrity, + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.NoError(t, err) + require.False(t, authResult.hasAuth) + + // Check the error response + buf := make([]byte, 1024) + n, _, err := conn.ReadFrom(buf) + require.NoError(t, err) + + var resp stun.Message + err = resp.UnmarshalBinary(buf[:n]) + require.NoError(t, err) + require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) + var attrErrorCode stun.ErrorCodeAttribute + err = attrErrorCode.GetFrom(&resp) + require.NoError(t, err) + require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + }) + + t.Run("no nonce", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + // Nonce attribute is missing + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + testMsgIntegrity, + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.ErrorIs(t, err, stun.ErrAttributeNotFound) + require.False(t, authResult.hasAuth) + + // Check the error response + buf := make([]byte, 1024) + n, _, err := conn.ReadFrom(buf) + require.NoError(t, err) + + var resp stun.Message + err = resp.UnmarshalBinary(buf[:n]) + require.NoError(t, err) + require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) + var attrErrorCode stun.ErrorCodeAttribute + err = attrErrorCode.GetFrom(&resp) + require.NoError(t, err) + require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + }) + + t.Run("invalid nonce", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + testMsgIntegrity, + stun.NewNonce("bad nonce"), // <- bad nonce + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.NoError(t, err) + require.False(t, authResult.hasAuth) + + // Check the error response + buf := make([]byte, 1024) + n, _, err := conn.ReadFrom(buf) + require.NoError(t, err) + + var resp stun.Message + err = resp.UnmarshalBinary(buf[:n]) + require.NoError(t, err) + require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) + var attrErrorCode stun.ErrorCodeAttribute + err = attrErrorCode.GetFrom(&resp) + require.NoError(t, err) + require.Equal(t, stun.CodeStaleNonce, attrErrorCode.Code) + }) + + t.Run("no realm", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + // Realm attribute is missing + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + testMsgIntegrity, + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.ErrorIs(t, err, stun.ErrAttributeNotFound) + require.False(t, authResult.hasAuth) + + // Check the error response + buf := make([]byte, 1024) + n, _, err := conn.ReadFrom(buf) + require.NoError(t, err) + + var resp stun.Message + err = resp.UnmarshalBinary(buf[:n]) + require.NoError(t, err) + require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) + var attrErrorCode stun.ErrorCodeAttribute + err = attrErrorCode.GetFrom(&resp) + require.NoError(t, err) + require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + }) + + t.Run("no username", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + // Username attribute is missing + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewRealm(testRealm), + testMsgIntegrity, + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.ErrorIs(t, err, stun.ErrAttributeNotFound) + require.False(t, authResult.hasAuth) + + // Check the error response + buf := make([]byte, 1024) + n, _, err := conn.ReadFrom(buf) + require.NoError(t, err) + + var resp stun.Message + err = resp.UnmarshalBinary(buf[:n]) + require.NoError(t, err) + require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) + var attrErrorCode stun.ErrorCodeAttribute + err = attrErrorCode.GetFrom(&resp) + require.NoError(t, err) + require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + }) + + t.Run("unknown username", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername("bad user"), // <- user name that does not exist + stun.NewRealm(testRealm), + testMsgIntegrity, + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.ErrorContains(t, err, "no such user") + require.False(t, authResult.hasAuth) + + // Check the error response + buf := make([]byte, 1024) + n, _, err := conn.ReadFrom(buf) + require.NoError(t, err) + + var resp stun.Message + err = resp.UnmarshalBinary(buf[:n]) + require.NoError(t, err) + require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) + var attrErrorCode stun.ErrorCodeAttribute + err = attrErrorCode.GetFrom(&resp) + require.NoError(t, err) + require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + }) + + t.Run("invalid message integrity", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + stun.NewLongTermIntegrity(testUsername, testRealm, "bad"), // <- bad message integrity + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.ErrorIs(t, err, stun.ErrIntegrityMismatch) + require.False(t, authResult.hasAuth) + + // Check the error response + buf := make([]byte, 1024) + n, _, err := conn.ReadFrom(buf) + require.NoError(t, err) + + var resp stun.Message + err = resp.UnmarshalBinary(buf[:n]) + require.NoError(t, err) + require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) + var attrErrorCode stun.ErrorCodeAttribute + err = attrErrorCode.GetFrom(&resp) + require.NoError(t, err) + require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + }) +} diff --git a/lt_cred.go b/lt_cred.go index 42466c38..bd3197f1 100644 --- a/lt_cred.go +++ b/lt_cred.go @@ -79,7 +79,7 @@ func LongTermTURNRESTAuthHandler(sharedSecret string, l logging.LeveledLogger) A l = logging.NewDefaultLoggerFactory().NewLogger("turn") } return func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { - l.Tracef("Authentication username=%q realm=%q srcAddr=%v\n", username, realm, srcAddr) + l.Tracef("Authentication username=%q realm=%q srcAddr=%v", username, realm, srcAddr) timestamp := strings.Split(username, ":")[0] t, err := strconv.Atoi(timestamp) if err != nil { diff --git a/relay_address_generator_none.go b/relay_address_generator_none.go index b0974010..bc7cdbb2 100644 --- a/relay_address_generator_none.go +++ b/relay_address_generator_none.go @@ -39,7 +39,7 @@ func (r *RelayAddressGeneratorNone) Validate() error { } // AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorNone) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { +func (r *RelayAddressGeneratorNone) AllocatePacketConn(network string, requestedPort int, _ string) (net.PacketConn, net.Addr, error) { conn, err := r.Net.ListenPacket(network, r.Address+":"+strconv.Itoa(requestedPort)) if err != nil { return nil, nil, err @@ -49,6 +49,6 @@ func (r *RelayAddressGeneratorNone) AllocatePacketConn(network string, requested } // AllocateConn generates a new Conn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorNone) AllocateConn(string, int) (net.Conn, net.Addr, error) { +func (r *RelayAddressGeneratorNone) AllocateConn(string, int, string) (net.Conn, net.Addr, error) { return nil, nil, errTODO } diff --git a/relay_address_generator_range.go b/relay_address_generator_range.go index d87a57f9..c64d27b5 100644 --- a/relay_address_generator_range.go +++ b/relay_address_generator_range.go @@ -68,7 +68,7 @@ func (r *RelayAddressGeneratorPortRange) Validate() error { } // AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorPortRange) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { +func (r *RelayAddressGeneratorPortRange) AllocatePacketConn(network string, requestedPort int, _ string) (net.PacketConn, net.Addr, error) { if requestedPort != 0 { conn, err := r.Net.ListenPacket(network, fmt.Sprintf("%s:%d", r.Address, requestedPort)) if err != nil { @@ -103,6 +103,6 @@ func (r *RelayAddressGeneratorPortRange) AllocatePacketConn(network string, requ } // AllocateConn generates a new Conn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorPortRange) AllocateConn(string, int) (net.Conn, net.Addr, error) { +func (r *RelayAddressGeneratorPortRange) AllocateConn(string, int, string) (net.Conn, net.Addr, error) { return nil, nil, errTODO } diff --git a/relay_address_generator_static.go b/relay_address_generator_static.go index 39c68777..07832268 100644 --- a/relay_address_generator_static.go +++ b/relay_address_generator_static.go @@ -45,7 +45,7 @@ func (r *RelayAddressGeneratorStatic) Validate() error { } // AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorStatic) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { +func (r *RelayAddressGeneratorStatic) AllocatePacketConn(network string, requestedPort int, _ string) (net.PacketConn, net.Addr, error) { conn, err := r.Net.ListenPacket(network, r.Address+":"+strconv.Itoa(requestedPort)) if err != nil { return nil, nil, err @@ -63,6 +63,6 @@ func (r *RelayAddressGeneratorStatic) AllocatePacketConn(network string, request } // AllocateConn generates a new Conn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorStatic) AllocateConn(string, int) (net.Conn, net.Addr, error) { +func (r *RelayAddressGeneratorStatic) AllocateConn(string, int, string) (net.Conn, net.Addr, error) { return nil, nil, errTODO } diff --git a/server.go b/server.go index 3b58938f..47200883 100644 --- a/server.go +++ b/server.go @@ -171,11 +171,11 @@ type nilAddressGenerator struct{} func (n *nilAddressGenerator) Validate() error { return errRelayAddressGeneratorNil } -func (n *nilAddressGenerator) AllocatePacketConn(string, int) (net.PacketConn, net.Addr, error) { +func (n *nilAddressGenerator) AllocatePacketConn(string, int, string) (net.PacketConn, net.Addr, error) { return nil, nil, errRelayAddressGeneratorNil } -func (n *nilAddressGenerator) AllocateConn(string, int) (net.Conn, net.Addr, error) { +func (n *nilAddressGenerator) AllocateConn(string, int, string) (net.Conn, net.Addr, error) { return nil, nil, errRelayAddressGeneratorNil } diff --git a/server_config.go b/server_config.go index eab2988e..02a9edc6 100644 --- a/server_config.go +++ b/server_config.go @@ -20,10 +20,10 @@ type RelayAddressGenerator interface { Validate() error // Allocate a PacketConn (UDP) RelayAddress - AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) + AllocatePacketConn(network string, requestedPort int, username string) (net.PacketConn, net.Addr, error) // Allocate a Conn (TCP) RelayAddress - AllocateConn(network string, requestedPort int) (net.Conn, net.Addr, error) + AllocateConn(network string, requestedPort int, username string) (net.Conn, net.Addr, error) } // PermissionHandler is a callback to filter incoming CreatePermission and ChannelBindRequest