From 902265e3dc73f6425c9f23baa77f1669c18533b4 Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Mon, 20 Jan 2025 20:17:34 +0100 Subject: [PATCH] Add interceptor tests --- internal/test/mock_stream.go | 9 +- pkg/ccfb/history.go | 12 +- pkg/ccfb/history_test.go | 6 +- pkg/ccfb/interceptor.go | 72 ++++++++- pkg/ccfb/interceptor_test.go | 306 +++++++++++++++++++++++++++++++++++ 5 files changed, 385 insertions(+), 20 deletions(-) create mode 100644 pkg/ccfb/interceptor_test.go diff --git a/internal/test/mock_stream.go b/internal/test/mock_stream.go index bf96e31b..e791ac8a 100644 --- a/internal/test/mock_stream.go +++ b/internal/test/mock_stream.go @@ -41,6 +41,7 @@ type RTPWithError struct { // RTCPWithError is used to send a batch of rtcp packets or an error on a channel type RTCPWithError struct { Packets []rtcp.Packet + Attr interceptor.Attributes Err error } @@ -107,21 +108,21 @@ func NewMockStream(info *interceptor.StreamInfo, i interceptor.Interceptor) *Moc go func() { buf := make([]byte, 1500) for { - i, _, err := s.rtcpReader.Read(buf, interceptor.Attributes{}) + i, attr, err := s.rtcpReader.Read(buf, interceptor.Attributes{}) if err != nil { if !errors.Is(err, io.EOF) { - s.rtcpInModified <- RTCPWithError{Err: err} + s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err} } return } pkts, err := rtcp.Unmarshal(buf[:i]) if err != nil { - s.rtcpInModified <- RTCPWithError{Err: err} + s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err} return } - s.rtcpInModified <- RTCPWithError{Packets: pkts} + s.rtcpInModified <- RTCPWithError{Attr: attr, Packets: pkts} } }() go func() { diff --git a/pkg/ccfb/history.go b/pkg/ccfb/history.go index 6e8aba3c..e95f9215 100644 --- a/pkg/ccfb/history.go +++ b/pkg/ccfb/history.go @@ -31,7 +31,7 @@ type sentPacket struct { departure time.Time } -type history struct { +type historyList struct { lock sync.Mutex size int evictList *list.List @@ -40,8 +40,8 @@ type history struct { ackedSeqNr *sequencenumber.Unwrapper } -func newHistory(size int) *history { - return &history{ +func newHistoryList(size int) *historyList { + return &historyList{ lock: sync.Mutex{}, size: size, evictList: list.New(), @@ -51,7 +51,7 @@ func newHistory(size int) *history { } } -func (h *history) add(seqNr uint16, size uint16, departure time.Time) error { +func (h *historyList) add(seqNr uint16, size uint16, departure time.Time) error { h.lock.Lock() defer h.lock.Unlock() @@ -76,7 +76,7 @@ func (h *history) add(seqNr uint16, size uint16, departure time.Time) error { } // Must be called while holding the lock -func (h *history) removeOldest() { +func (h *historyList) removeOldest() { if ent := h.evictList.Front(); ent != nil { v := h.evictList.Remove(ent) if sp, ok := v.(sentPacket); ok { @@ -85,7 +85,7 @@ func (h *history) removeOldest() { } } -func (h *history) getReportForAck(al acknowledgementList) PacketReportList { +func (h *historyList) getReportForAck(al acknowledgementList) PacketReportList { h.lock.Lock() defer h.lock.Unlock() diff --git a/pkg/ccfb/history_test.go b/pkg/ccfb/history_test.go index 88ca99e4..2762e56f 100644 --- a/pkg/ccfb/history_test.go +++ b/pkg/ccfb/history_test.go @@ -10,7 +10,7 @@ import ( func TestHistory(t *testing.T) { t.Run("errorOnDecreasingSeqNr", func(t *testing.T) { - h := newHistory(200) + h := newHistoryList(200) assert.NoError(t, h.add(10, 1200, time.Now())) assert.NoError(t, h.add(11, 1200, time.Now())) assert.Error(t, h.add(9, 1200, time.Now())) @@ -84,7 +84,7 @@ func TestHistory(t *testing.T) { } for i, tc := range cases { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - h := newHistory(200) + h := newHistoryList(200) for _, op := range tc.outgoing { assert.NoError(t, h.add(op.seqNr, op.size, op.ts)) } @@ -97,7 +97,7 @@ func TestHistory(t *testing.T) { }) t.Run("garbageCollection", func(t *testing.T) { - h := newHistory(200) + h := newHistoryList(200) for i := uint16(0); i < 300; i++ { assert.NoError(t, h.add(i, 1200, time.Time{}.Add(time.Duration(i)*time.Millisecond))) diff --git a/pkg/ccfb/interceptor.go b/pkg/ccfb/interceptor.go index 08de7429..9b52e5c5 100644 --- a/pkg/ccfb/interceptor.go +++ b/pkg/ccfb/interceptor.go @@ -1,6 +1,7 @@ package ccfb import ( + "log" "sync" "time" @@ -15,8 +16,48 @@ type ccfbAttributesKeyType uint32 const CCFBAttributesKey ccfbAttributesKeyType = iota +type history interface { + add(seqNr uint16, size uint16, departure time.Time) error + getReportForAck(al acknowledgementList) PacketReportList +} + type Option func(*Interceptor) error +func HistorySize(size int) Option { + return func(i *Interceptor) error { + i.historySize = size + return nil + } +} + +func timeFactory(f func() time.Time) Option { + return func(i *Interceptor) error { + i.timestamp = f + return nil + } +} + +func historyFactory(f func(int) history) Option { + return func(i *Interceptor) error { + i.historyFactory = f + return nil + } +} + +func ccfbConverterFactory(f func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32]acknowledgementList)) Option { + return func(i *Interceptor) error { + i.convertCCFB = f + return nil + } +} + +func twccConverterFactory(f func(ts time.Time, feedback *rtcp.TransportLayerCC) (time.Time, map[uint32]acknowledgementList)) Option { + return func(i *Interceptor) error { + i.convertTWCC = f + return nil + } +} + type InterceptorFactory struct { opts []Option } @@ -30,8 +71,15 @@ func NewInterceptor(opts ...Option) (*InterceptorFactory, error) { func (f *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { i := &Interceptor{ NoOp: interceptor.NoOp{}, + lock: sync.Mutex{}, timestamp: time.Now, - ssrcToHistory: make(map[uint32]*history), + convertCCFB: convertCCFB, + convertTWCC: convertTWCC, + ssrcToHistory: make(map[uint32]history), + historySize: 200, + historyFactory: func(size int) history { + return newHistoryList(size) + }, } for _, opt := range f.opts { if err := opt(i); err != nil { @@ -43,9 +91,13 @@ func (f *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, type Interceptor struct { interceptor.NoOp - lock sync.Mutex - timestamp func() time.Time - ssrcToHistory map[uint32]*history + lock sync.Mutex + timestamp func() time.Time + convertCCFB func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32]acknowledgementList) + convertTWCC func(ts time.Time, feedback *rtcp.TransportLayerCC) (time.Time, map[uint32]acknowledgementList) + ssrcToHistory map[uint32]history + historySize int + historyFactory func(int) history } // BindLocalStream implements interceptor.Interceptor. @@ -67,7 +119,7 @@ func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer inter if useTWCC { ssrc = 0 } - i.ssrcToHistory[ssrc] = newHistory(200) + i.ssrcToHistory[ssrc] = i.historyFactory(i.historySize) return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { i.lock.Lock() @@ -106,17 +158,23 @@ func (i *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor. attr = make(interceptor.Attributes) } + log.Printf("%v", buf) pktReportLists := map[uint32]*PacketReportList{} pkts, err := attr.GetRTCPPackets(buf) + if err != nil { + return n, attr, err + } + log.Printf("got rtcp packets: %v, %v", pkts, err) for _, pkt := range pkts { var reportLists map[uint32]acknowledgementList var reportDeparture time.Time switch fb := pkt.(type) { case *rtcp.CCFeedbackReport: - reportDeparture, reportLists = convertCCFB(now, fb) + reportDeparture, reportLists = i.convertCCFB(now, fb) case *rtcp.TransportLayerCC: - reportDeparture, reportLists = convertTWCC(now, fb) + reportDeparture, reportLists = i.convertTWCC(now, fb) + default: } for ssrc, reportList := range reportLists { prl := i.ssrcToHistory[ssrc].getReportForAck(reportList) diff --git a/pkg/ccfb/interceptor_test.go b/pkg/ccfb/interceptor_test.go new file mode 100644 index 00000000..98946200 --- /dev/null +++ b/pkg/ccfb/interceptor_test.go @@ -0,0 +1,306 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/test" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/stretchr/testify/assert" +) + +type mockHistoryAddEntry struct { + seqNr uint16 + size uint16 + departure time.Time +} + +type mockHistory struct { + added []mockHistoryAddEntry + report PacketReportList +} + +// add implements history. +func (m *mockHistory) add(seqNr uint16, size uint16, departure time.Time) error { + m.added = append(m.added, mockHistoryAddEntry{ + seqNr: seqNr, + size: size, + departure: departure, + }) + return nil +} + +// getReportForAck implements history. +func (m *mockHistory) getReportForAck(_ acknowledgementList) PacketReportList { + return m.report +} + +func TestInterceptor(t *testing.T) { + mockTimestamp := time.Time{}.Add(17 * time.Second) + t.Run("writeRTP", func(t *testing.T) { + type addPkt struct { + pkt *rtp.Packet + ext *rtp.TransportCCExtension + } + cases := []struct { + add []addPkt + twcc bool + expect *mockHistory + }{ + { + add: []addPkt{}, + expect: &mockHistory{ + added: []mockHistoryAddEntry{}, + }, + }, + { + add: []addPkt{ + { + pkt: &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: 137, + }, + }, + }, + }, + expect: &mockHistory{ + added: []mockHistoryAddEntry{ + {137, 12, mockTimestamp}, + }, + }, + }, + { + add: []addPkt{ + { + pkt: &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: 137, + }, + }, + ext: &rtp.TransportCCExtension{ + TransportSequence: 16, + }, + }, + }, + twcc: true, + expect: &mockHistory{ + added: []mockHistoryAddEntry{ + {16, 20, mockTimestamp}, + }, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + mt := func() time.Time { + return mockTimestamp + } + mh := &mockHistory{ + added: []mockHistoryAddEntry{}, + } + f, err := NewInterceptor( + historyFactory(func(i int) history { + return mh + }), + timeFactory(mt), + ) + assert.NoError(t, err) + + i, err := f.NewInterceptor("") + assert.NoError(t, err) + + info := &interceptor.StreamInfo{} + if tc.twcc { + info.RTPHeaderExtensions = append(info.RTPHeaderExtensions, interceptor.RTPHeaderExtension{ + URI: transportCCURI, + ID: 2, + }) + } + stream := test.NewMockStream(info, i) + + for _, pkt := range tc.add { + if pkt.ext != nil { + ext, err := pkt.ext.Marshal() + assert.NoError(t, err) + pkt.pkt.SetExtension(2, ext) + } + stream.WriteRTP(pkt.pkt) + } + + assert.Equal(t, tc.expect, mh) + }) + } + }) + + t.Run("readRTCP", func(t *testing.T) { + cases := []struct { + mh *mockHistory + rtcp rtcp.Packet + }{ + //{ + // mh: &mockHistory{ + // report: PacketReportList{ + // Arrival: mockTimestamp, + // Departure: mockTimestamp, + // Reports: []PacketReport{}, + // }, + // }, + // rtcp: &rtcp.CCFeedbackReport{}, + //}, + //{ + // mh: &mockHistory{ + // report: PacketReportList{ + // Arrival: mockTimestamp, + // Departure: mockTimestamp, + // Reports: []PacketReport{ + // { + // SeqNr: 3, + // Size: 12, + // Departure: mockTimestamp, + // Arrived: true, + // Arrival: mockTimestamp, + // ECN: 0, + // }, + // }, + // }, + // }, + // rtcp: &rtcp.CCFeedbackReport{}, + //}, + { + mh: &mockHistory{ + report: PacketReportList{ + Arrival: mockTimestamp, + Departure: mockTimestamp, + Reports: []PacketReport{}, + }, + }, + rtcp: &rtcp.TransportLayerCC{ + Header: rtcp.Header{ + Padding: false, + Count: rtcp.FormatTCC, + Type: rtcp.TypeTransportSpecificFeedback, + Length: 6, + }, + SenderSSRC: 1, + MediaSSRC: 2, + BaseSequenceNumber: 3, + PacketStatusCount: 0, + ReferenceTime: 5, + FbPktCount: 6, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + Type: rtcp.RunLengthChunkType, + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + }, + }, + }, + { + mh: &mockHistory{ + report: PacketReportList{ + Arrival: mockTimestamp, + Departure: mockTimestamp, + Reports: []PacketReport{ + { + SeqNr: 3, + Size: 12, + Departure: mockTimestamp, + Arrived: true, + Arrival: mockTimestamp, + ECN: 0, + }, + }, + }, + }, + rtcp: &rtcp.TransportLayerCC{ + Header: rtcp.Header{ + Padding: false, + Count: rtcp.FormatTCC, + Type: rtcp.TypeTransportSpecificFeedback, + Length: 6, + }, + SenderSSRC: 0, + MediaSSRC: 0, + BaseSequenceNumber: 0, + PacketStatusCount: 0, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + Type: rtcp.RunLengthChunkType, + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + }, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + mt := func() time.Time { + return mockTimestamp + } + mockCCFBConverter := func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32]acknowledgementList) { + return mockTimestamp, map[uint32]acknowledgementList{ + 0: {}, + } + } + mockTWCCConverter := func(ts time.Time, feedback *rtcp.TransportLayerCC) (time.Time, map[uint32]acknowledgementList) { + return mockTimestamp, map[uint32]acknowledgementList{ + 0: {}, + } + } + f, err := NewInterceptor( + historyFactory(func(i int) history { + return tc.mh + }), + timeFactory(mt), + ccfbConverterFactory(mockCCFBConverter), + twccConverterFactory(mockTWCCConverter), + ) + assert.NoError(t, err) + + i, err := f.NewInterceptor("") + assert.NoError(t, err) + + info := &interceptor.StreamInfo{} + // if _, ok := tc.rtcp.(*rtcp.TransportLayerCC); ok { + // info.RTPHeaderExtensions = append(info.RTPHeaderExtensions, interceptor.RTPHeaderExtension{ + // URI: transportCCURI, + // ID: 2, + // }) + // } + stream := test.NewMockStream(info, i) + + stream.ReceiveRTCP([]rtcp.Packet{tc.rtcp}) + + report := <-stream.ReadRTCP() + + assert.NoError(t, report.Err) + + prlsInterface, ok := report.Attr[CCFBAttributesKey] + assert.True(t, ok) + prls, ok := prlsInterface.(map[uint32]*PacketReportList) + assert.True(t, ok) + assert.Len(t, prls, 1) + // assert.Equal(t, tc.mh.report, *prls[0]) + }) + } + }) +}