diff --git a/pkg/jitterbuffer/priority_queue.go b/pkg/jitterbuffer/priority_queue.go index f6d7d93b..11a8679c 100644 --- a/pkg/jitterbuffer/priority_queue.go +++ b/pkg/jitterbuffer/priority_queue.go @@ -114,6 +114,7 @@ func (q *PriorityQueue) Pop() (*rtp.Packet, error) { return nil, ErrInvalidOperation } val := q.next.val + q.next.val = nil q.length-- q.next = q.next.next return val, nil @@ -126,6 +127,7 @@ func (q *PriorityQueue) PopAt(sqNum uint16) (*rtp.Packet, error) { } if q.next.priority == sqNum { val := q.next.val + q.next.val = nil q.next = q.next.next q.length-- return val, nil @@ -135,6 +137,7 @@ func (q *PriorityQueue) PopAt(sqNum uint16) (*rtp.Packet, error) { for pos != nil { if pos.priority == sqNum { val := pos.val + pos.val = nil prev.next = pos.next if prev.next != nil { prev.next.prev = prev @@ -156,6 +159,7 @@ func (q *PriorityQueue) PopAtTimestamp(timestamp uint32) (*rtp.Packet, error) { } if q.next.val.Timestamp == timestamp { val := q.next.val + q.next.val = nil q.next = q.next.next q.length-- return val, nil @@ -165,6 +169,7 @@ func (q *PriorityQueue) PopAtTimestamp(timestamp uint32) (*rtp.Packet, error) { for pos != nil { if pos.val.Timestamp == timestamp { val := pos.val + pos.val = nil prev.next = pos.next if prev.next != nil { prev.next.prev = prev diff --git a/pkg/jitterbuffer/priority_queue_test.go b/pkg/jitterbuffer/priority_queue_test.go index 7fb2a7a6..8b8d23e1 100644 --- a/pkg/jitterbuffer/priority_queue_test.go +++ b/pkg/jitterbuffer/priority_queue_test.go @@ -4,7 +4,10 @@ package jitterbuffer import ( + "runtime" + "sync/atomic" "testing" + "time" "github.com/pion/rtp" "github.com/stretchr/testify/assert" @@ -136,3 +139,46 @@ func TestPriorityQueue_Clean(t *testing.T) { assert.EqualValues(t, 1, packets.Length()) packets.Clear() } + +func TestPriorityQueue_Unreference(t *testing.T) { + packets := NewQueue() + + var refs int64 + finalizer := func(*rtp.Packet) { + atomic.AddInt64(&refs, -1) + } + + numPkts := 100 + for i := 0; i < numPkts; i++ { + atomic.AddInt64(&refs, 1) + seq := uint16(i) + p := rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: seq, + Timestamp: uint32(i + 42), + }, + Payload: []byte{byte(i)}, + } + runtime.SetFinalizer(&p, finalizer) + packets.Push(&p, seq) + } + for i := 0; i < numPkts-1; i++ { + switch i % 3 { + case 0: + packets.Pop() //nolint + case 1: + packets.PopAt(uint16(i)) //nolint + case 2: + packets.PopAtTimestamp(uint32(i + 42)) //nolint + } + } + + runtime.GC() + time.Sleep(10 * time.Millisecond) + + remainedRefs := atomic.LoadInt64(&refs) + runtime.KeepAlive(packets) + + // only the last packet should be still referenced + assert.Equal(t, int64(1), remainedRefs) +}