Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race in a03526d8eb82 #276

Merged
merged 1 commit into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 6 additions & 50 deletions pkg/nack/responder_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package nack

import (
"encoding/binary"
"sync"

"github.com/pion/interceptor"
Expand All @@ -19,7 +18,7 @@
}

type packetFactory interface {
NewPacket(header *rtp.Header, payload []byte) (*retainablePacket, error)
NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*retainablePacket, error)
}

// NewInterceptor constructs a new ResponderInterceptor
Expand Down Expand Up @@ -63,11 +62,6 @@
type localStream struct {
sendBuffer *sendBuffer
rtpWriter interceptor.RTPWriter

// Non-zero if Retransmissions should be sent on a distinct stream
rtxSsrc uint32
rtxPayloadType uint8
rtxSequencer rtp.Sequencer
}

// NewResponderInterceptor returns a new ResponderInterceptorFactor
Expand Down Expand Up @@ -115,16 +109,13 @@
sendBuffer, _ := newSendBuffer(n.size)
n.streamsMu.Lock()
n.streams[info.SSRC] = &localStream{
sendBuffer: sendBuffer,
rtpWriter: writer,
rtxSsrc: info.SSRCRetransmission,
rtxPayloadType: info.PayloadTypeRetransmission,
rtxSequencer: rtp.NewRandomSequencer(),
sendBuffer: sendBuffer,
rtpWriter: writer,
}
n.streamsMu.Unlock()

return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
pkt, err := n.packetFactory.NewPacket(header, payload)
pkt, err := n.packetFactory.NewPacket(header, payload, info.SSRCRetransmission, info.PayloadTypeRetransmission)
if err != nil {
return 0, err
}
Expand All @@ -151,43 +142,8 @@
for i := range nack.Nacks {
nack.Nacks[i].Range(func(seq uint16) bool {
if p := stream.sendBuffer.get(seq); p != nil {
if stream.rtxSsrc != 0 {
// Store the original sequence number and rewrite the sequence number.
originalSequenceNumber := p.Header().SequenceNumber
p.Header().SequenceNumber = stream.rtxSequencer.NextSequenceNumber()

// Rewrite the SSRC.
p.Header().SSRC = stream.rtxSsrc
// Rewrite the payload type.
p.Header().PayloadType = stream.rtxPayloadType

// Remove padding if present.
paddingLength := 0
originPayload := p.Payload()
if p.Header().Padding {
paddingLength = int(originPayload[len(originPayload)-1])
p.Header().Padding = false
}

// Write the original sequence number at the beginning of the payload.
payload := make([]byte, 2)
binary.BigEndian.PutUint16(payload, originalSequenceNumber)
payload = append(payload, originPayload[:len(originPayload)-paddingLength]...)

// Send RTX packet.
if _, err := stream.rtpWriter.Write(p.Header(), payload, interceptor.Attributes{}); err != nil {
n.log.Warnf("failed sending rtx packet: %+v", err)
}

// Resore the Padding and SSRC.
if paddingLength > 0 {
p.Header().Padding = true
}
p.Header().SequenceNumber = originalSequenceNumber
} else {
if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil {
n.log.Warnf("failed resending nacked packet: %+v", err)
}
if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil {
n.log.Warnf("failed resending nacked packet: %+v", err)

Check warning on line 146 in pkg/nack/responder_interceptor.go

View check run for this annotation

Codecov / codecov/patch

pkg/nack/responder_interceptor.go#L146

Added line #L146 was not covered by tests
}
p.Release()
}
Expand Down
48 changes: 39 additions & 9 deletions pkg/nack/retainable_packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package nack

import (
"encoding/binary"
"io"
"sync"

Expand All @@ -13,8 +14,9 @@
const maxPayloadLen = 1460

type packetManager struct {
headerPool *sync.Pool
payloadPool *sync.Pool
headerPool *sync.Pool
payloadPool *sync.Pool
rtxSequencer rtp.Sequencer
}

func newPacketManager() *packetManager {
Expand All @@ -30,16 +32,18 @@
return &buf
},
},
rtxSequencer: rtp.NewRandomSequencer(),
}
}

func (m *packetManager) NewPacket(header *rtp.Header, payload []byte) (*retainablePacket, error) {
func (m *packetManager) NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*retainablePacket, error) {
if len(payload) > maxPayloadLen {
return nil, io.ErrShortBuffer
}

p := &retainablePacket{
onRelease: m.releasePacket,
onRelease: m.releasePacket,
sequenceNumber: header.SequenceNumber,
// new packets have retain count of 1
count: 1,
}
Expand All @@ -62,6 +66,29 @@
p.payload = (*p.buffer)[:size]
}

if rtxSsrc != 0 && rtxPayloadType != 0 {
// Store the original sequence number and rewrite the sequence number.
originalSequenceNumber := p.header.SequenceNumber
p.header.SequenceNumber = m.rtxSequencer.NextSequenceNumber()

// Rewrite the SSRC.
p.header.SSRC = rtxSsrc
// Rewrite the payload type.
p.header.PayloadType = rtxPayloadType

// Remove padding if present.
paddingLength := 0
if p.header.Padding {
paddingLength = int(p.payload[len(p.payload)-1])
p.header.Padding = false

Check warning on line 83 in pkg/nack/retainable_packet.go

View check run for this annotation

Codecov / codecov/patch

pkg/nack/retainable_packet.go#L82-L83

Added lines #L82 - L83 were not covered by tests
}

// Write the original sequence number at the beginning of the payload.
payload := make([]byte, 2)
binary.BigEndian.PutUint16(payload, originalSequenceNumber)
p.payload = append(payload, p.payload[:len(p.payload)-paddingLength]...)
}

return p, nil
}

Expand All @@ -74,12 +101,13 @@

type noOpPacketFactory struct{}

func (f *noOpPacketFactory) NewPacket(header *rtp.Header, payload []byte) (*retainablePacket, error) {
func (f *noOpPacketFactory) NewPacket(header *rtp.Header, payload []byte, _ uint32, _ uint8) (*retainablePacket, error) {
return &retainablePacket{
onRelease: f.releasePacket,
count: 1,
header: header,
payload: payload,
onRelease: f.releasePacket,
count: 1,
header: header,
payload: payload,
sequenceNumber: header.SequenceNumber,
}, nil
}

Expand All @@ -96,6 +124,8 @@
header *rtp.Header
buffer *[]byte
payload []byte

sequenceNumber uint16
}

func (p *retainablePacket) Header() *rtp.Header {
Expand Down
4 changes: 2 additions & 2 deletions pkg/nack/send_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (s *sendBuffer) add(packet *retainablePacket) {
s.m.Lock()
defer s.m.Unlock()

seq := packet.Header().SequenceNumber
seq := packet.sequenceNumber
if !s.started {
s.packets[seq%s.size] = packet
s.lastAdded = seq
Expand Down Expand Up @@ -92,7 +92,7 @@ func (s *sendBuffer) get(seq uint16) *retainablePacket {

pkt := s.packets[seq%s.size]
if pkt != nil {
if pkt.Header().SequenceNumber != seq {
if pkt.sequenceNumber != seq {
return nil
}
// already released
Expand Down
8 changes: 4 additions & 4 deletions pkg/nack/send_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestSendBuffer(t *testing.T) {
add := func(nums ...uint16) {
for _, n := range nums {
seq := start + n
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil)
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil, 0, 0)
require.NoError(t, err)
sb.add(pkt)
}
Expand Down Expand Up @@ -78,7 +78,7 @@ func TestSendBuffer_Overridden(t *testing.T) {
require.Equal(t, uint16(1), sb.size)

originalBytes := []byte("originalContent")
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1}, originalBytes)
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1}, originalBytes, 0, 0)
require.NoError(t, err)
sb.add(pkt)

Expand All @@ -91,7 +91,7 @@ func TestSendBuffer_Overridden(t *testing.T) {
require.Equal(t, 1, retrieved.count)

// ensure original packet is released
pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, originalBytes)
pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, originalBytes, 0, 0)
require.NoError(t, err)
sb.add(pkt)
require.Equal(t, 0, retrieved.count)
Expand All @@ -113,7 +113,7 @@ func TestSendBuffer_Race(t *testing.T) {
add := func(nums ...uint16) {
for _, n := range nums {
seq := start + n
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil)
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil, 0, 0)
require.NoError(t, err)
sb.add(pkt)
}
Expand Down
Loading