diff --git a/pkg/nack/responder_interceptor.go b/pkg/nack/responder_interceptor.go index 8c21b078..044f56af 100644 --- a/pkg/nack/responder_interceptor.go +++ b/pkg/nack/responder_interceptor.go @@ -105,11 +105,12 @@ func (n *ResponderInterceptor) BindLocalStream(info *interceptor.StreamInfo, wri // error is already checked in NewGeneratorInterceptor rtpBuffer, _ := rtpbuffer.NewRTPBuffer(n.size) - n.streamsMu.Lock() - n.streams[info.SSRC] = &localStream{ + stream := &localStream{ rtpBuffer: rtpBuffer, rtpWriter: writer, } + n.streamsMu.Lock() + n.streams[info.SSRC] = stream n.streamsMu.Unlock() return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { @@ -122,8 +123,8 @@ func (n *ResponderInterceptor) BindLocalStream(info *interceptor.StreamInfo, wri if err != nil { return 0, err } - n.streams[info.SSRC].rtpBufferMutex.Lock() - defer n.streams[info.SSRC].rtpBufferMutex.Unlock() + stream.rtpBufferMutex.Lock() + defer stream.rtpBufferMutex.Unlock() rtpBuffer.Add(pkt) diff --git a/pkg/nack/responder_interceptor_test.go b/pkg/nack/responder_interceptor_test.go index c48d4fec..036a237b 100644 --- a/pkg/nack/responder_interceptor_test.go +++ b/pkg/nack/responder_interceptor_test.go @@ -5,6 +5,7 @@ package nack import ( "encoding/binary" + "sync" "testing" "time" @@ -153,6 +154,37 @@ func TestResponderInterceptor_Race(t *testing.T) { } } +// this test is only useful when being run with the race detector, it won't fail otherwise: +// +// go test -race ./pkg/nack/ +func TestResponderInterceptor_RaceConcurrentStreams(t *testing.T) { + f, err := NewResponderInterceptor( + ResponderSize(32768), + ResponderLog(logging.NewDefaultLoggerFactory().NewLogger("test")), + ) + require.NoError(t, err) + + i, err := f.NewInterceptor("") + require.NoError(t, err) + + var wg sync.WaitGroup + for j := 0; j < 5; j++ { + stream := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 1, + RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, + }, i) + wg.Add(1) + go func() { + for seqNum := uint16(0); seqNum < 500; seqNum++ { + require.NoError(t, stream.WriteRTP(&rtp.Packet{Header: rtp.Header{SequenceNumber: seqNum}})) + } + wg.Done() + }() + } + + wg.Wait() +} + func TestResponderInterceptor_StreamFilter(t *testing.T) { f, err := NewResponderInterceptor( ResponderSize(8),