From b8e4acd325ccd2f7355ced66ab3d5e1f1c1d9496 Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Thu, 16 Jan 2025 10:09:17 +0100 Subject: [PATCH] Add interceptor to aggregate CCFB reports --- internal/test/mock_stream.go | 9 +- pkg/ccfb/ccfb_receiver.go | 59 ++++++ pkg/ccfb/ccfb_receiver_test.go | 193 +++++++++++++++++ pkg/ccfb/duplicate_ack_filter.go | 29 +++ pkg/ccfb/duplicate_ack_filter_test.go | 106 ++++++++++ pkg/ccfb/history.go | 110 ++++++++++ pkg/ccfb/history_test.go | 114 ++++++++++ pkg/ccfb/interceptor.go | 224 ++++++++++++++++++++ pkg/ccfb/interceptor_test.go | 290 ++++++++++++++++++++++++++ pkg/ccfb/twcc_receiver.go | 88 ++++++++ pkg/ccfb/twcc_receiver_test.go | 125 +++++++++++ 11 files changed, 1343 insertions(+), 4 deletions(-) create mode 100644 pkg/ccfb/ccfb_receiver.go create mode 100644 pkg/ccfb/ccfb_receiver_test.go create mode 100644 pkg/ccfb/duplicate_ack_filter.go create mode 100644 pkg/ccfb/duplicate_ack_filter_test.go create mode 100644 pkg/ccfb/history.go create mode 100644 pkg/ccfb/history_test.go create mode 100644 pkg/ccfb/interceptor.go create mode 100644 pkg/ccfb/interceptor_test.go create mode 100644 pkg/ccfb/twcc_receiver.go create mode 100644 pkg/ccfb/twcc_receiver_test.go diff --git a/internal/test/mock_stream.go b/internal/test/mock_stream.go index bf96e31b..e791ac8a 100644 --- a/internal/test/mock_stream.go +++ b/internal/test/mock_stream.go @@ -41,6 +41,7 @@ type RTPWithError struct { // RTCPWithError is used to send a batch of rtcp packets or an error on a channel type RTCPWithError struct { Packets []rtcp.Packet + Attr interceptor.Attributes Err error } @@ -107,21 +108,21 @@ func NewMockStream(info *interceptor.StreamInfo, i interceptor.Interceptor) *Moc go func() { buf := make([]byte, 1500) for { - i, _, err := s.rtcpReader.Read(buf, interceptor.Attributes{}) + i, attr, err := s.rtcpReader.Read(buf, interceptor.Attributes{}) if err != nil { if !errors.Is(err, io.EOF) { - s.rtcpInModified <- RTCPWithError{Err: err} + s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err} } return } pkts, err := rtcp.Unmarshal(buf[:i]) if err != nil { - s.rtcpInModified <- RTCPWithError{Err: err} + s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err} return } - s.rtcpInModified <- RTCPWithError{Packets: pkts} + s.rtcpInModified <- RTCPWithError{Attr: attr, Packets: pkts} } }() go func() { diff --git a/pkg/ccfb/ccfb_receiver.go b/pkg/ccfb/ccfb_receiver.go new file mode 100644 index 00000000..dd11198c --- /dev/null +++ b/pkg/ccfb/ccfb_receiver.go @@ -0,0 +1,59 @@ +package ccfb + +import ( + "time" + + "github.com/pion/interceptor/internal/ntp" + "github.com/pion/rtcp" +) + +type acknowledgement struct { + seqNr uint16 + arrived bool + arrival time.Time + ecn rtcp.ECN +} + +func convertCCFB(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32][]acknowledgement) { + if feedback == nil { + return time.Time{}, nil + } + result := map[uint32][]acknowledgement{} + referenceTime := ntp.ToTime32(feedback.ReportTimestamp, ts) + for _, rb := range feedback.ReportBlocks { + result[rb.MediaSSRC] = convertMetricBlock(referenceTime, rb.BeginSequence, rb.MetricBlocks) + } + return referenceTime, result +} + +func convertMetricBlock(reference time.Time, seqNrOffset uint16, blocks []rtcp.CCFeedbackMetricBlock) []acknowledgement { + reports := make([]acknowledgement, len(blocks)) + for i, mb := range blocks { + if mb.Received { + arrival := time.Time{} + + // RFC 8888 states: If the measurement is unavailable or if the + // arrival time of the RTP packet is after the time represented by + // the RTS field, then an ATO value of 0x1FFF MUST be reported for + // the packet. In that case, we set a zero time.Time value. + if mb.ArrivalTimeOffset != 0x1FFF { + delta := time.Duration((float64(mb.ArrivalTimeOffset) / 1024.0) * float64(time.Second)) + arrival = reference.Add(-delta) + } + reports[i] = acknowledgement{ + seqNr: seqNrOffset + uint16(i), // nolint:gosec + arrived: true, + arrival: arrival, + ecn: mb.ECN, + } + } else { + reports[i] = acknowledgement{ + seqNr: seqNrOffset + uint16(i), // nolint:gosec + arrived: false, + arrival: time.Time{}, + ecn: 0, + } + } + } + return reports +} diff --git a/pkg/ccfb/ccfb_receiver_test.go b/pkg/ccfb/ccfb_receiver_test.go new file mode 100644 index 00000000..18a248c1 --- /dev/null +++ b/pkg/ccfb/ccfb_receiver_test.go @@ -0,0 +1,193 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/pion/interceptor/internal/ntp" + "github.com/pion/rtcp" + "github.com/stretchr/testify/assert" +) + +func TestConvertCCFB(t *testing.T) { + timeZero := time.Now() + cases := []struct { + ts time.Time + feedback *rtcp.CCFeedbackReport + expect map[uint32][]acknowledgement + expectTS time.Time + }{ + {}, + { + ts: timeZero.Add(2 * time.Second), + feedback: &rtcp.CCFeedbackReport{ + SenderSSRC: 1, + ReportBlocks: []rtcp.CCFeedbackReportBlock{ + { + MediaSSRC: 2, + BeginSequence: 17, + MetricBlocks: []rtcp.CCFeedbackMetricBlock{ + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 512, + }, + }, + }, + }, + ReportTimestamp: ntp.ToNTP32(timeZero.Add(time.Second)), + }, + expect: map[uint32][]acknowledgement{ + 2: { + { + seqNr: 17, + arrived: true, + arrival: timeZero.Add(500 * time.Millisecond), + ecn: 0, + }, + }, + }, + expectTS: timeZero.Add(time.Second), + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + resTS, res := convertCCFB(tc.ts, tc.feedback) + + assert.InDelta(t, tc.expectTS.UnixNano(), resTS.UnixNano(), float64(time.Millisecond.Nanoseconds())) + + // Can't directly check equality since arrival timestamp conversions + // may be slightly off due to ntp conversions. + assert.Equal(t, len(tc.expect), len(res)) + for i, acks := range tc.expect { + for j, ack := range acks { + assert.Equal(t, ack.seqNr, res[i][j].seqNr) + assert.Equal(t, ack.arrived, res[i][j].arrived) + assert.Equal(t, ack.ecn, res[i][j].ecn) + assert.InDelta(t, ack.arrival.UnixNano(), res[i][j].arrival.UnixNano(), float64(time.Millisecond.Nanoseconds())) + } + } + }) + } +} + +func TestConvertMetricBlock(t *testing.T) { + cases := []struct { + ts time.Time + reference time.Time + seqNrOffset uint16 + blocks []rtcp.CCFeedbackMetricBlock + expected []acknowledgement + }{ + { + ts: time.Time{}, + reference: time.Time{}, + seqNrOffset: 0, + blocks: []rtcp.CCFeedbackMetricBlock{}, + expected: []acknowledgement{}, + }, + { + ts: time.Time{}.Add(2 * time.Second), + reference: time.Time{}.Add(time.Second), + seqNrOffset: 3, + blocks: []rtcp.CCFeedbackMetricBlock{ + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 512, + }, + { + Received: false, + ECN: 0, + ArrivalTimeOffset: 0, + }, + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 0, + }, + }, + expected: []acknowledgement{ + { + seqNr: 3, + arrived: true, + arrival: time.Time{}.Add(500 * time.Millisecond), + ecn: 0, + }, + { + seqNr: 4, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }, + { + seqNr: 5, + arrived: true, + arrival: time.Time{}.Add(time.Second), + ecn: 0, + }, + }, + }, + { + ts: time.Time{}.Add(2 * time.Second), + reference: time.Time{}.Add(time.Second), + seqNrOffset: 3, + blocks: []rtcp.CCFeedbackMetricBlock{ + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 512, + }, + { + Received: false, + ECN: 0, + ArrivalTimeOffset: 0, + }, + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 0, + }, + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 0x1FFF, + }, + }, + expected: []acknowledgement{ + { + seqNr: 3, + arrived: true, + arrival: time.Time{}.Add(500 * time.Millisecond), + ecn: 0, + }, + { + seqNr: 4, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }, + { + seqNr: 5, + arrived: true, + arrival: time.Time{}.Add(time.Second), + ecn: 0, + }, + { + seqNr: 6, + arrived: true, + arrival: time.Time{}, + ecn: 0, + }, + }, + }, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := convertMetricBlock(tc.reference, tc.seqNrOffset, tc.blocks) + assert.Equal(t, tc.expected, res) + }) + } +} diff --git a/pkg/ccfb/duplicate_ack_filter.go b/pkg/ccfb/duplicate_ack_filter.go new file mode 100644 index 00000000..79f8f6db --- /dev/null +++ b/pkg/ccfb/duplicate_ack_filter.go @@ -0,0 +1,29 @@ +package ccfb + +// DuplicateAckFilter is a helper to remove duplicate acks from a Report. +type DuplicateAckFilter struct { + highestAckedBySSRC map[uint32]int64 +} + +// NewDuplicateAckFilter creates a new DuplicateAckFilter +func NewDuplicateAckFilter() *DuplicateAckFilter { + return &DuplicateAckFilter{ + highestAckedBySSRC: make(map[uint32]int64), + } +} + +// Filter filters duplicate acks. It filters out all acks for packets with a +// sequence number smaller than the highest seen sequence number for each SSRC. +func (f *DuplicateAckFilter) Filter(reports Report) { + for ssrc, prs := range reports.SSRCToPacketReports { + n := 0 + for _, report := range prs { + if highest, ok := f.highestAckedBySSRC[ssrc]; !ok || report.SeqNr > highest { + f.highestAckedBySSRC[ssrc] = report.SeqNr + prs[n] = report + n++ + } + } + reports.SSRCToPacketReports[ssrc] = prs[:n] + } +} diff --git a/pkg/ccfb/duplicate_ack_filter_test.go b/pkg/ccfb/duplicate_ack_filter_test.go new file mode 100644 index 00000000..20e4d6f8 --- /dev/null +++ b/pkg/ccfb/duplicate_ack_filter_test.go @@ -0,0 +1,106 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDuplicateAckFilter(t *testing.T) { + cases := []struct { + in []Report + expect []Report + }{ + { + in: []Report{}, + expect: []Report{}, + }, + { + in: []Report{ + { + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: {}, + }, + }, + }, + expect: []Report{ + { + Arrival: time.Time{}, + Departure: time.Time{}, + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: {}, + }, + }, + }, + }, + { + in: []Report{ + { + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: { + { + SeqNr: 1, + }, + { + SeqNr: 2, + }, + }, + }, + }, + { + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: { + { + SeqNr: 1, + }, + { + SeqNr: 2, + }, + { + SeqNr: 3, + }, + }, + }, + }, + }, + expect: []Report{ + { + Arrival: time.Time{}, + Departure: time.Time{}, + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: { + { + SeqNr: 1, + }, + { + SeqNr: 2, + }, + }, + }, + }, + { + Arrival: time.Time{}, + Departure: time.Time{}, + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: { + { + SeqNr: 3, + }, + }, + }, + }, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + daf := NewDuplicateAckFilter() + for i, m := range tc.in { + daf.Filter(m) + assert.Equal(t, tc.expect[i], m) + } + }) + } +} diff --git a/pkg/ccfb/history.go b/pkg/ccfb/history.go new file mode 100644 index 00000000..9c144702 --- /dev/null +++ b/pkg/ccfb/history.go @@ -0,0 +1,110 @@ +package ccfb + +import ( + "container/list" + "errors" + "sync" + "time" + + "github.com/pion/interceptor/internal/sequencenumber" + "github.com/pion/rtcp" +) + +var errSequenceNumberWentBackwards = errors.New("sequence number went backwards") + +// PacketReport contains departure and arrival information about an acknowledged +// packet. +type PacketReport struct { + SeqNr int64 + Size int + Departure time.Time + Arrived bool + Arrival time.Time + ECN rtcp.ECN +} + +type sentPacket struct { + seqNr int64 + size int + departure time.Time +} + +type historyList struct { + lock sync.Mutex + size int + evictList *list.List + seqNrToPacket map[int64]*list.Element + sentSeqNr *sequencenumber.Unwrapper + ackedSeqNr *sequencenumber.Unwrapper +} + +func newHistoryList(size int) *historyList { + return &historyList{ + lock: sync.Mutex{}, + size: size, + evictList: list.New(), + seqNrToPacket: make(map[int64]*list.Element), + sentSeqNr: &sequencenumber.Unwrapper{}, + ackedSeqNr: &sequencenumber.Unwrapper{}, + } +} + +func (h *historyList) add(seqNr uint16, size int, departure time.Time) error { + h.lock.Lock() + defer h.lock.Unlock() + + sn := h.sentSeqNr.Unwrap(seqNr) + last := h.evictList.Back() + if last != nil { + if p, ok := last.Value.(sentPacket); ok && sn < p.seqNr { + return errSequenceNumberWentBackwards + } + } + ent := h.evictList.PushBack(sentPacket{ + seqNr: sn, + size: size, + departure: departure, + }) + h.seqNrToPacket[sn] = ent + + if h.evictList.Len() > h.size { + h.removeOldest() + } + return nil +} + +// Must be called while holding the lock +func (h *historyList) removeOldest() { + if ent := h.evictList.Front(); ent != nil { + v := h.evictList.Remove(ent) + if sp, ok := v.(sentPacket); ok { + delete(h.seqNrToPacket, sp.seqNr) + } + } +} + +func (h *historyList) getReportForAck(acks []acknowledgement) []PacketReport { + h.lock.Lock() + defer h.lock.Unlock() + + reports := []PacketReport{} + for _, pr := range acks { + sn := h.ackedSeqNr.Unwrap(pr.seqNr) + ent, ok := h.seqNrToPacket[sn] + // Ignore report for unknown packets (migth have been dropped from + // history) + if ok { + if ack, ok := ent.Value.(sentPacket); ok { + reports = append(reports, PacketReport{ + SeqNr: sn, + Size: ack.size, + Departure: ack.departure, + Arrived: pr.arrived, + Arrival: pr.arrival, + ECN: pr.ecn, + }) + } + } + } + return reports +} diff --git a/pkg/ccfb/history_test.go b/pkg/ccfb/history_test.go new file mode 100644 index 00000000..c500242e --- /dev/null +++ b/pkg/ccfb/history_test.go @@ -0,0 +1,114 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestHistory(t *testing.T) { + t.Run("errorOnDecreasingSeqNr", func(t *testing.T) { + h := newHistoryList(200) + assert.NoError(t, h.add(10, 1200, time.Now())) + assert.NoError(t, h.add(11, 1200, time.Now())) + assert.Error(t, h.add(9, 1200, time.Now())) + }) + + t.Run("getReportForAck", func(t *testing.T) { + cases := []struct { + outgoing []struct { + seqNr uint16 + size int + ts time.Time + } + acks []acknowledgement + expectedReport []PacketReport + expectedHistorySize int + }{ + { + outgoing: []struct { + seqNr uint16 + size int + ts time.Time + }{}, + acks: []acknowledgement{}, + expectedReport: []PacketReport{}, + expectedHistorySize: 0, + }, + { + outgoing: []struct { + seqNr uint16 + size int + ts time.Time + }{ + {0, 1200, time.Time{}.Add(1 * time.Millisecond)}, + {1, 1200, time.Time{}.Add(2 * time.Millisecond)}, + {2, 1200, time.Time{}.Add(3 * time.Millisecond)}, + {3, 1200, time.Time{}.Add(4 * time.Millisecond)}, + }, + acks: []acknowledgement{}, + expectedReport: []PacketReport{}, + expectedHistorySize: 4, + }, + { + outgoing: []struct { + seqNr uint16 + size int + ts time.Time + }{ + {0, 1200, time.Time{}.Add(1 * time.Millisecond)}, + {1, 1200, time.Time{}.Add(2 * time.Millisecond)}, + {2, 1200, time.Time{}.Add(3 * time.Millisecond)}, + {3, 1200, time.Time{}.Add(4 * time.Millisecond)}, + }, + acks: []acknowledgement{ + {1, true, time.Time{}.Add(3 * time.Millisecond), 0}, + {2, false, time.Time{}, 0}, + {3, true, time.Time{}.Add(5 * time.Millisecond), 0}, + }, + expectedReport: []PacketReport{ + {1, 1200, time.Time{}.Add(2 * time.Millisecond), true, time.Time{}.Add(3 * time.Millisecond), 0}, + {2, 1200, time.Time{}.Add(3 * time.Millisecond), false, time.Time{}, 0}, + {3, 1200, time.Time{}.Add(4 * time.Millisecond), true, time.Time{}.Add(5 * time.Millisecond), 0}, + }, + expectedHistorySize: 4, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + h := newHistoryList(200) + for _, op := range tc.outgoing { + assert.NoError(t, h.add(op.seqNr, op.size, op.ts)) + } + prl := h.getReportForAck(tc.acks) + assert.Equal(t, tc.expectedReport, prl) + assert.Equal(t, tc.expectedHistorySize, len(h.seqNrToPacket)) + assert.Equal(t, tc.expectedHistorySize, h.evictList.Len()) + }) + } + }) + + t.Run("garbageCollection", func(t *testing.T) { + h := newHistoryList(200) + + for i := uint16(0); i < 300; i++ { + assert.NoError(t, h.add(i, 1200, time.Time{}.Add(time.Duration(i)*time.Millisecond))) + } + + acks := []acknowledgement{} + for i := uint16(200); i < 290; i++ { + acks = append(acks, acknowledgement{ + seqNr: i, + arrived: true, + arrival: time.Time{}.Add(time.Duration(500+i) * time.Millisecond), + ecn: 0, + }) + } + prl := h.getReportForAck(acks) + assert.Len(t, prl, 90) + assert.Equal(t, 200, len(h.seqNrToPacket)) + assert.Equal(t, 200, h.evictList.Len()) + }) +} diff --git a/pkg/ccfb/interceptor.go b/pkg/ccfb/interceptor.go new file mode 100644 index 00000000..f4304915 --- /dev/null +++ b/pkg/ccfb/interceptor.go @@ -0,0 +1,224 @@ +// Package ccfb implements feedback aggregation for CCFB and TWCC packets. +package ccfb + +import ( + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" + +type ccfbAttributesKeyType uint32 + +// CCFBAttributesKey is the key which can be used to retrieve the Report objects +// from the interceptor.Attributes +const CCFBAttributesKey ccfbAttributesKeyType = iota + +// A Report contains Arrival and Departure (from the remote end) times of a RTCP +// feedback packet (CCFB or TWCC) and a list of PacketReport for all +// acknowledged packets that were still in the history. +type Report struct { + Arrival time.Time + Departure time.Time + SSRCToPacketReports map[uint32][]PacketReport +} + +type history interface { + add(seqNr uint16, size int, departure time.Time) error + getReportForAck([]acknowledgement) []PacketReport +} + +// Option can be used to set initial options on CCFB interceptors +type Option func(*Interceptor) error + +// HistorySize sets the size of the history of outgoing packets. +func HistorySize(size int) Option { + return func(i *Interceptor) error { + i.historySize = size + return nil + } +} + +func timeFactory(f func() time.Time) Option { + return func(i *Interceptor) error { + i.timestamp = f + return nil + } +} + +func historyFactory(f func(int) history) Option { + return func(i *Interceptor) error { + i.historyFactory = f + return nil + } +} + +func ccfbConverterFactory(f func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32][]acknowledgement)) Option { + return func(i *Interceptor) error { + i.convertCCFB = f + return nil + } +} + +func twccConverterFactory(f func(feedback *rtcp.TransportLayerCC) (time.Time, map[uint32][]acknowledgement)) Option { + return func(i *Interceptor) error { + i.convertTWCC = f + return nil + } +} + +// InterceptorFactory is a factory for CCFB interceptors +type InterceptorFactory struct { + opts []Option +} + +// NewInterceptor returns a new CCFB InterceptorFactory +func NewInterceptor(opts ...Option) (*InterceptorFactory, error) { + return &InterceptorFactory{ + opts: opts, + }, nil +} + +// NewInterceptor returns a new ccfb.Interceptor +func (f *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + i := &Interceptor{ + NoOp: interceptor.NoOp{}, + lock: sync.Mutex{}, + timestamp: time.Now, + convertCCFB: convertCCFB, + convertTWCC: convertTWCC, + ssrcToHistory: make(map[uint32]history), + historySize: 200, + historyFactory: func(size int) history { + return newHistoryList(size) + }, + } + for _, opt := range f.opts { + if err := opt(i); err != nil { + return nil, err + } + } + return i, nil +} + +// Interceptor implements a congestion control feedback receiver. It keeps track +// of outgoing packets and reads incoming feedback reports (CCFB or TWCC). For +// each incoming feedback report, it will add an entry to the interceptor +// attributes, which can be read from the `RTCPReader` +// (`webrtc.RTPSender.Read`). For each acknowledgement included in the feedback +// report and for which there still is an entry in the history of outgoing +// packets, a PacketReport will be added to the ccfb.Report map. The map +// contains a list of packets for each outgoing SSRC if CCFB is used. The map +// contains a single entry with SSRC=0 if TWCC is used. +type Interceptor struct { + interceptor.NoOp + lock sync.Mutex + timestamp func() time.Time + convertCCFB func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32][]acknowledgement) + convertTWCC func(feedback *rtcp.TransportLayerCC) (time.Time, map[uint32][]acknowledgement) + ssrcToHistory map[uint32]history + historySize int + historyFactory func(int) history +} + +// BindLocalStream implements interceptor.Interceptor. +func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + var twccHdrExtID uint8 + var useTWCC bool + for _, e := range info.RTPHeaderExtensions { + if e.URI == transportCCURI { + twccHdrExtID = uint8(e.ID) // nolint:gosec + useTWCC = true + break + } + } + + i.lock.Lock() + defer i.lock.Unlock() + + ssrc := info.SSRC + if useTWCC { + ssrc = 0 + } + i.ssrcToHistory[ssrc] = i.historyFactory(i.historySize) + + return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + i.lock.Lock() + defer i.lock.Unlock() + + // If we are using TWCC, we use the sequence number from the TWCC header + // extension and save all TWCC sequence numbers with the same SSRC (0). + // If we are not using TWCC, we save a history per SSRC and use the + // normal RTP sequence numbers. + ssrc := header.SSRC + seqNr := header.SequenceNumber + if useTWCC { + ssrc = 0 + var twccHdrExt rtp.TransportCCExtension + if err := twccHdrExt.Unmarshal(header.GetExtension(twccHdrExtID)); err != nil { + return 0, err + } + seqNr = twccHdrExt.TransportSequence + } + if err := i.ssrcToHistory[ssrc].add(seqNr, header.MarshalSize()+len(payload), i.timestamp()); err != nil { + return 0, err + } + return writer.Write(header, payload, attributes) + }) +} + +// BindRTCPReader implements interceptor.Interceptor. +func (i *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { + return interceptor.RTCPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + n, attr, err := reader.Read(b, a) + if err != nil { + return n, attr, err + } + now := i.timestamp() + + buf := make([]byte, n) + copy(buf, b[:n]) + + if attr == nil { + attr = make(interceptor.Attributes) + } + + res := []Report{} + + pkts, err := attr.GetRTCPPackets(buf) + if err != nil { + return n, attr, err + } + for _, pkt := range pkts { + var reportLists map[uint32][]acknowledgement + var reportDeparture time.Time + switch fb := pkt.(type) { + case *rtcp.CCFeedbackReport: + reportDeparture, reportLists = i.convertCCFB(now, fb) + case *rtcp.TransportLayerCC: + reportDeparture, reportLists = i.convertTWCC(fb) + default: + } + ssrcToPrl := map[uint32][]PacketReport{} + for ssrc, reportList := range reportLists { + prl := i.ssrcToHistory[ssrc].getReportForAck(reportList) + if _, ok := ssrcToPrl[ssrc]; !ok { + ssrcToPrl[ssrc] = prl + } else { + ssrcToPrl[ssrc] = append(ssrcToPrl[ssrc], prl...) + } + } + res = append(res, Report{ + Arrival: now, + Departure: reportDeparture, + SSRCToPacketReports: ssrcToPrl, + }) + } + attr.Set(CCFBAttributesKey, res) + return n, attr, err + }) +} diff --git a/pkg/ccfb/interceptor_test.go b/pkg/ccfb/interceptor_test.go new file mode 100644 index 00000000..a37cd38a --- /dev/null +++ b/pkg/ccfb/interceptor_test.go @@ -0,0 +1,290 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/test" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/stretchr/testify/assert" +) + +type mockHistoryAddEntry struct { + seqNr uint16 + size int + departure time.Time +} + +type mockHistory struct { + added []mockHistoryAddEntry + report []PacketReport +} + +// add implements history. +func (m *mockHistory) add(seqNr uint16, size int, departure time.Time) error { + m.added = append(m.added, mockHistoryAddEntry{ + seqNr: seqNr, + size: size, + departure: departure, + }) + return nil +} + +// getReportForAck implements history. +func (m *mockHistory) getReportForAck(_ []acknowledgement) []PacketReport { + return m.report +} + +func TestInterceptor(t *testing.T) { + mockTimestamp := time.Time{}.Add(17 * time.Second) + t.Run("writeRTP", func(t *testing.T) { + type addPkt struct { + pkt *rtp.Packet + ext *rtp.TransportCCExtension + } + cases := []struct { + add []addPkt + twcc bool + expect *mockHistory + }{ + { + add: []addPkt{}, + expect: &mockHistory{ + added: []mockHistoryAddEntry{}, + }, + }, + { + add: []addPkt{ + { + pkt: &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: 137, + }, + }, + }, + }, + expect: &mockHistory{ + added: []mockHistoryAddEntry{ + {137, 12, mockTimestamp}, + }, + }, + }, + { + add: []addPkt{ + { + pkt: &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: 137, + }, + }, + ext: &rtp.TransportCCExtension{ + TransportSequence: 16, + }, + }, + }, + twcc: true, + expect: &mockHistory{ + added: []mockHistoryAddEntry{ + {16, 20, mockTimestamp}, + }, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + mt := func() time.Time { + return mockTimestamp + } + mh := &mockHistory{ + added: []mockHistoryAddEntry{}, + } + f, err := NewInterceptor( + historyFactory(func(_ int) history { + return mh + }), + timeFactory(mt), + ) + assert.NoError(t, err) + + i, err := f.NewInterceptor("") + assert.NoError(t, err) + + info := &interceptor.StreamInfo{} + if tc.twcc { + info.RTPHeaderExtensions = append(info.RTPHeaderExtensions, interceptor.RTPHeaderExtension{ + URI: transportCCURI, + ID: 2, + }) + } + stream := test.NewMockStream(info, i) + + for _, pkt := range tc.add { + if pkt.ext != nil { + ext, err := pkt.ext.Marshal() + assert.NoError(t, err) + assert.NoError(t, pkt.pkt.SetExtension(2, ext)) + } + assert.NoError(t, stream.WriteRTP(pkt.pkt)) + } + + assert.Equal(t, tc.expect, mh) + }) + } + }) + + t.Run("readRTCP", func(t *testing.T) { + cases := []struct { + mh *mockHistory + rtcp rtcp.Packet + }{ + { + mh: &mockHistory{ + report: []PacketReport{}, + }, + rtcp: &rtcp.CCFeedbackReport{}, + }, + { + mh: &mockHistory{ + report: []PacketReport{ + { + SeqNr: 3, + Size: 12, + Departure: mockTimestamp, + Arrived: true, + Arrival: mockTimestamp, + ECN: 0, + }, + }, + }, + rtcp: &rtcp.CCFeedbackReport{}, + }, + { + mh: &mockHistory{ + report: []PacketReport{}, + }, + rtcp: &rtcp.TransportLayerCC{ + Header: rtcp.Header{ + Padding: false, + Count: rtcp.FormatTCC, + Type: rtcp.TypeTransportSpecificFeedback, + Length: 6, + }, + SenderSSRC: 1, + MediaSSRC: 2, + BaseSequenceNumber: 3, + PacketStatusCount: 0, + ReferenceTime: 5, + FbPktCount: 6, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + Type: rtcp.RunLengthChunkType, + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + }, + }, + }, + { + mh: &mockHistory{ + report: []PacketReport{ + { + SeqNr: 3, + Size: 12, + Departure: mockTimestamp, + Arrived: true, + Arrival: mockTimestamp, + ECN: 0, + }, + }, + }, + rtcp: &rtcp.TransportLayerCC{ + Header: rtcp.Header{ + Padding: false, + Count: rtcp.FormatTCC, + Type: rtcp.TypeTransportSpecificFeedback, + Length: 6, + }, + SenderSSRC: 0, + MediaSSRC: 0, + BaseSequenceNumber: 0, + PacketStatusCount: 0, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + Type: rtcp.RunLengthChunkType, + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + }, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + mt := func() time.Time { + return mockTimestamp + } + mockCCFBConverter := func(_ time.Time, _ *rtcp.CCFeedbackReport) (time.Time, map[uint32][]acknowledgement) { + return mockTimestamp, map[uint32][]acknowledgement{ + 0: {}, + } + } + mockTWCCConverter := func(_ *rtcp.TransportLayerCC) (time.Time, map[uint32][]acknowledgement) { + return mockTimestamp, map[uint32][]acknowledgement{ + 0: {}, + } + } + f, err := NewInterceptor( + historyFactory(func(_ int) history { + return tc.mh + }), + timeFactory(mt), + ccfbConverterFactory(mockCCFBConverter), + twccConverterFactory(mockTWCCConverter), + ) + assert.NoError(t, err) + + i, err := f.NewInterceptor("") + assert.NoError(t, err) + + info := &interceptor.StreamInfo{} + if _, ok := tc.rtcp.(*rtcp.TransportLayerCC); ok { + info.RTPHeaderExtensions = append(info.RTPHeaderExtensions, interceptor.RTPHeaderExtension{ + URI: transportCCURI, + ID: 2, + }) + } + stream := test.NewMockStream(info, i) + + stream.ReceiveRTCP([]rtcp.Packet{tc.rtcp}) + + report := <-stream.ReadRTCP() + + assert.NoError(t, report.Err) + + prlsInterface, ok := report.Attr[CCFBAttributesKey] + assert.True(t, ok) + prls, ok := prlsInterface.([]Report) + assert.True(t, ok) + assert.Len(t, prls, 1) + assert.Equal(t, tc.mh.report, prls[0].SSRCToPacketReports[0]) + }) + } + }) +} diff --git a/pkg/ccfb/twcc_receiver.go b/pkg/ccfb/twcc_receiver.go new file mode 100644 index 00000000..98af8bde --- /dev/null +++ b/pkg/ccfb/twcc_receiver.go @@ -0,0 +1,88 @@ +package ccfb + +import ( + "time" + + "github.com/pion/rtcp" +) + +func convertTWCC(feedback *rtcp.TransportLayerCC) (time.Time, map[uint32][]acknowledgement) { + if feedback == nil { + return time.Time{}, nil + } + var acks []acknowledgement + + nextTimestamp := time.Time{}.Add(time.Duration(feedback.ReferenceTime) * 64 * time.Millisecond) + reportDeparture := nextTimestamp + recvDeltaIndex := 0 + + offset := 0 + for _, pc := range feedback.PacketChunks { + switch chunk := pc.(type) { + case *rtcp.RunLengthChunk: + for i := uint16(0); i < chunk.RunLength; i++ { + seqNr := feedback.BaseSequenceNumber + uint16(offset) // nolint:gosec + offset++ + switch chunk.PacketStatusSymbol { + case rtcp.TypeTCCPacketNotReceived: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta: + delta := feedback.RecvDeltas[recvDeltaIndex] + nextTimestamp = nextTimestamp.Add(time.Duration(delta.Delta) * time.Microsecond) + recvDeltaIndex++ + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: nextTimestamp, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedWithoutDelta: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: time.Time{}, + ecn: 0, + }) + } + } + case *rtcp.StatusVectorChunk: + for _, s := range chunk.SymbolList { + seqNr := feedback.BaseSequenceNumber + uint16(offset) // nolint:gosec + offset++ + switch s { + case rtcp.TypeTCCPacketNotReceived: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta: + delta := feedback.RecvDeltas[recvDeltaIndex] + nextTimestamp = nextTimestamp.Add(time.Duration(delta.Delta) * time.Microsecond) + recvDeltaIndex++ + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: nextTimestamp, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedWithoutDelta: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: time.Time{}, + ecn: 0, + }) + } + } + } + } + + return reportDeparture, map[uint32][]acknowledgement{0: acks} +} diff --git a/pkg/ccfb/twcc_receiver_test.go b/pkg/ccfb/twcc_receiver_test.go new file mode 100644 index 00000000..8c820041 --- /dev/null +++ b/pkg/ccfb/twcc_receiver_test.go @@ -0,0 +1,125 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/pion/rtcp" + "github.com/stretchr/testify/assert" +) + +func TestConvertTWCC(t *testing.T) { + // timeZero := time.Now() + cases := []struct { + feedback *rtcp.TransportLayerCC + expect map[uint32][]acknowledgement + expectTS time.Time + }{ + {}, + { + // ts: timeZero.Add(2 * time.Second), + feedback: &rtcp.TransportLayerCC{ + SenderSSRC: 1, + MediaSSRC: 2, + BaseSequenceNumber: 178, + PacketStatusCount: 0, + ReferenceTime: 3, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{}, + RecvDeltas: []*rtcp.RecvDelta{}, + }, + expect: map[uint32][]acknowledgement{ + 0: nil, + }, + expectTS: time.Time{}.Add(3 * 64 * time.Millisecond), + }, + { + // ts: timeZero.Add(2 * time.Second), + feedback: &rtcp.TransportLayerCC{ + SenderSSRC: 1, + MediaSSRC: 2, + BaseSequenceNumber: 178, + PacketStatusCount: 18, + ReferenceTime: 3, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeOneBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 1000}, + }, + }, + expect: map[uint32][]acknowledgement{ + 0: { + // first run length chunk + {seqNr: 178, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 1*time.Millisecond), ecn: 0}, + {seqNr: 179, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 2*time.Millisecond), ecn: 0}, + {seqNr: 180, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 3*time.Millisecond), ecn: 0}, + + // first status vector chunk + {seqNr: 181, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 4*time.Millisecond), ecn: 0}, + {seqNr: 182, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 5*time.Millisecond), ecn: 0}, + {seqNr: 183, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 6*time.Millisecond), ecn: 0}, + {seqNr: 184, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 185, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 186, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 187, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 188, arrived: false, arrival: time.Time{}, ecn: 0}, + + // second status vector chunk + {seqNr: 189, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 7*time.Millisecond), ecn: 0}, + {seqNr: 190, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 8*time.Millisecond), ecn: 0}, + {seqNr: 191, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 192, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 193, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 194, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 195, arrived: false, arrival: time.Time{}, ecn: 0}, + }, + }, + expectTS: time.Time{}.Add(3 * 64 * time.Millisecond), + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + resTS, res := convertTWCC(tc.feedback) + assert.Equal(t, tc.expect, res) + assert.Equal(t, tc.expectTS, resTS) + }) + } +}