diff --git a/pkg/ccfb/ccfb_receiver.go b/pkg/ccfb/ccfb_receiver.go index 4ec1d704..811c636d 100644 --- a/pkg/ccfb/ccfb_receiver.go +++ b/pkg/ccfb/ccfb_receiver.go @@ -20,23 +20,35 @@ type acknowledgementList struct { } func convertCCFB(ts time.Time, feedback *rtcp.CCFeedbackReport) map[uint32]acknowledgementList { + if feedback == nil { + return nil + } result := map[uint32]acknowledgementList{} - referenceTime := ntp.ToTime(uint64(feedback.ReportTimestamp) << 16) + referenceTime := ntp.ToTime32(feedback.ReportTimestamp, ts) for _, rb := range feedback.ReportBlocks { result[rb.MediaSSRC] = convertMetricBlock(ts, referenceTime, rb.BeginSequence, rb.MetricBlocks) } return result } -func convertMetricBlock(ts time.Time, referenceTime time.Time, seqNrOffset uint16, blocks []rtcp.CCFeedbackMetricBlock) acknowledgementList { +func convertMetricBlock(ts time.Time, reference time.Time, seqNrOffset uint16, blocks []rtcp.CCFeedbackMetricBlock) acknowledgementList { reports := make([]acknowledgement, len(blocks)) for i, mb := range blocks { if mb.Received { - delta := time.Duration((float64(mb.ArrivalTimeOffset) / 1024.0) * float64(time.Second)) + arrival := time.Time{} + + /// RFC 8888 states: If the measurement is unavailable or if the + //arrival time of the RTP packet is after the time represented by + //the RTS field, then an ATO value of 0x1FFF MUST be reported for + //the packet. In that case, we set a zero time.Time value. + if mb.ArrivalTimeOffset != 0x1FFF { + delta := time.Duration((float64(mb.ArrivalTimeOffset) / 1024.0) * float64(time.Second)) + arrival = reference.Add(-delta) + } reports[i] = acknowledgement{ seqNr: seqNrOffset + uint16(i), arrived: true, - arrival: referenceTime.Add(-delta), + arrival: arrival, ecn: mb.ECN, } } else { diff --git a/pkg/ccfb/ccfb_receiver_test.go b/pkg/ccfb/ccfb_receiver_test.go new file mode 100644 index 00000000..1e05b63a --- /dev/null +++ b/pkg/ccfb/ccfb_receiver_test.go @@ -0,0 +1,202 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/pion/interceptor/internal/ntp" + "github.com/pion/rtcp" + "github.com/stretchr/testify/assert" +) + +func TestConvertCCFB(t *testing.T) { + timeZero := time.Now() + cases := []struct { + ts time.Time + feedback *rtcp.CCFeedbackReport + expect map[uint32]acknowledgementList + }{ + {}, + { + ts: timeZero.Add(2 * time.Second), + feedback: &rtcp.CCFeedbackReport{ + SenderSSRC: 1, + ReportBlocks: []rtcp.CCFeedbackReportBlock{ + { + MediaSSRC: 2, + BeginSequence: 17, + MetricBlocks: []rtcp.CCFeedbackMetricBlock{ + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 512, + }, + }, + }, + }, + ReportTimestamp: ntp.ToNTP32(timeZero.Add(time.Second)), + }, + expect: map[uint32]acknowledgementList{ + 2: { + ts: timeZero.Add(2 * time.Second), + acks: []acknowledgement{ + { + seqNr: 17, + arrived: true, + arrival: timeZero.Add(500 * time.Millisecond), + ecn: 0, + }, + }, + }, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := convertCCFB(tc.ts, tc.feedback) + + // Can't directly check equality since arrival timestamp conversions + // may be slightly off due to ntp conversions. + assert.Equal(t, len(tc.expect), len(res)) + for i, ee := range tc.expect { + assert.Equal(t, ee.ts, res[i].ts) + for j, ack := range ee.acks { + assert.Equal(t, ack.seqNr, res[i].acks[j].seqNr) + assert.Equal(t, ack.arrived, res[i].acks[j].arrived) + assert.Equal(t, ack.ecn, res[i].acks[j].ecn) + assert.InDelta(t, ack.arrival.UnixNano(), res[i].acks[j].arrival.UnixNano(), float64(time.Millisecond.Nanoseconds())) + } + } + }) + } +} + +func TestConvertMetricBlock(t *testing.T) { + cases := []struct { + ts time.Time + reference time.Time + seqNrOffset uint16 + blocks []rtcp.CCFeedbackMetricBlock + expected acknowledgementList + }{ + { + ts: time.Time{}, + reference: time.Time{}, + seqNrOffset: 0, + blocks: []rtcp.CCFeedbackMetricBlock{}, + expected: acknowledgementList{ + ts: time.Time{}, + acks: []acknowledgement{}, + }, + }, + { + ts: time.Time{}.Add(2 * time.Second), + reference: time.Time{}.Add(time.Second), + seqNrOffset: 3, + blocks: []rtcp.CCFeedbackMetricBlock{ + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 512, + }, + { + Received: false, + ECN: 0, + ArrivalTimeOffset: 0, + }, + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 0, + }, + }, + expected: acknowledgementList{ + ts: time.Time{}.Add(2 * time.Second), + acks: []acknowledgement{ + { + seqNr: 3, + arrived: true, + arrival: time.Time{}.Add(500 * time.Millisecond), + ecn: 0, + }, + { + seqNr: 4, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }, + { + seqNr: 5, + arrived: true, + arrival: time.Time{}.Add(time.Second), + ecn: 0, + }, + }, + }, + }, + { + ts: time.Time{}.Add(2 * time.Second), + reference: time.Time{}.Add(time.Second), + seqNrOffset: 3, + blocks: []rtcp.CCFeedbackMetricBlock{ + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 512, + }, + { + Received: false, + ECN: 0, + ArrivalTimeOffset: 0, + }, + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 0, + }, + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 0x1FFF, + }, + }, + expected: acknowledgementList{ + ts: time.Time{}.Add(2 * time.Second), + acks: []acknowledgement{ + { + seqNr: 3, + arrived: true, + arrival: time.Time{}.Add(500 * time.Millisecond), + ecn: 0, + }, + { + seqNr: 4, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }, + { + seqNr: 5, + arrived: true, + arrival: time.Time{}.Add(time.Second), + ecn: 0, + }, + { + seqNr: 6, + arrived: true, + arrival: time.Time{}, + ecn: 0, + }, + }, + }, + }, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := convertMetricBlock(tc.ts, tc.reference, tc.seqNrOffset, tc.blocks) + assert.Equal(t, tc.expected, res) + }) + } +} diff --git a/pkg/ccfb/history.go b/pkg/ccfb/history.go index 4bdaee1f..e78e1922 100644 --- a/pkg/ccfb/history.go +++ b/pkg/ccfb/history.go @@ -1,8 +1,8 @@ package ccfb import ( + "container/list" "errors" - "log" "time" "github.com/pion/interceptor/internal/sequencenumber" @@ -30,63 +30,77 @@ type sentPacket struct { } type history struct { - inflight []sentPacket - sentSeqNr *sequencenumber.Unwrapper - ackedSeqNr *sequencenumber.Unwrapper + size int + evictList *list.List + seqNrToPacket map[int64]*list.Element + sentSeqNr *sequencenumber.Unwrapper + ackedSeqNr *sequencenumber.Unwrapper } -func newHistory() *history { +func newHistory(size int) *history { return &history{ - inflight: []sentPacket{}, - sentSeqNr: &sequencenumber.Unwrapper{}, - ackedSeqNr: &sequencenumber.Unwrapper{}, + size: size, + evictList: list.New(), + seqNrToPacket: make(map[int64]*list.Element), + sentSeqNr: &sequencenumber.Unwrapper{}, + ackedSeqNr: &sequencenumber.Unwrapper{}, } } func (h *history) add(seqNr uint16, size uint16, departure time.Time) error { sn := h.sentSeqNr.Unwrap(seqNr) - if len(h.inflight) > 0 && sn < h.inflight[len(h.inflight)-1].seqNr { - return errors.New("sequence number went backwards") + last := h.evictList.Back() + if last != nil { + if p, ok := last.Value.(sentPacket); ok && sn < p.seqNr { + return errors.New("sequence number went backwards") + } } - h.inflight = append(h.inflight, sentPacket{ + ent := h.evictList.PushBack(sentPacket{ seqNr: sn, size: size, departure: departure, }) + h.seqNrToPacket[sn] = ent + + if h.evictList.Len() > h.size { + h.removeOldest() + } + return nil } func (h *history) getReportForAck(al acknowledgementList) PacketReportList { - reports := []PacketReport{} - log.Printf("highest sent: %v", h.inflight[len(h.inflight)-1].seqNr) + var reports []PacketReport for _, pr := range al.acks { sn := h.ackedSeqNr.Unwrap(pr.seqNr) - i := h.index(sn) - if i > -1 { - reports = append(reports, PacketReport{ - SeqNr: sn, - Size: h.inflight[i].size, - Departure: h.inflight[i].departure, - Arrived: pr.arrived, - Arrival: pr.arrival, - ECN: pr.ecn, - }) - } else { - panic("got feedback for unknown packet") + ent, ok := h.seqNrToPacket[sn] + // Ignore report for unknown packets (migth have been dropped from + // history) + if ok { + if ack, ok := ent.Value.(sentPacket); ok { + reports = append(reports, PacketReport{ + SeqNr: sn, + Size: ack.size, + Departure: ack.departure, + Arrived: pr.arrived, + Arrival: pr.arrival, + ECN: pr.ecn, + }) + } } - log.Printf("processed ack for seq nr %v, arrived: %v", sn, pr.arrived) } + return PacketReportList{ Timestamp: al.ts, Reports: reports, } } -func (h *history) index(seqNr int64) int { - for i := range h.inflight { - if h.inflight[i].seqNr == seqNr { - return i +func (h *history) removeOldest() { + if ent := h.evictList.Front(); ent != nil { + v := h.evictList.Remove(ent) + if sp, ok := v.(sentPacket); ok { + delete(h.seqNrToPacket, sp.seqNr) } } - return -1 } diff --git a/pkg/ccfb/history_test.go b/pkg/ccfb/history_test.go new file mode 100644 index 00000000..4a56fd11 --- /dev/null +++ b/pkg/ccfb/history_test.go @@ -0,0 +1,123 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestHistory(t *testing.T) { + t.Run("errorOnDecreasingSeqNr", func(t *testing.T) { + h := newHistory(200) + assert.NoError(t, h.add(10, 1200, time.Now())) + assert.NoError(t, h.add(11, 1200, time.Now())) + assert.Error(t, h.add(9, 1200, time.Now())) + }) + + t.Run("getReportForAck", func(t *testing.T) { + cases := []struct { + outgoing []struct { + seqNr uint16 + size uint16 + ts time.Time + } + acks acknowledgementList + expectedReport PacketReportList + expectedHistorySize int + }{ + { + outgoing: []struct { + seqNr uint16 + size uint16 + ts time.Time + }{}, + acks: acknowledgementList{}, + expectedReport: PacketReportList{}, + expectedHistorySize: 0, + }, + { + outgoing: []struct { + seqNr uint16 + size uint16 + ts time.Time + }{ + {0, 1200, time.Time{}.Add(1 * time.Millisecond)}, + {1, 1200, time.Time{}.Add(2 * time.Millisecond)}, + {2, 1200, time.Time{}.Add(3 * time.Millisecond)}, + {3, 1200, time.Time{}.Add(4 * time.Millisecond)}, + }, + acks: acknowledgementList{}, + expectedReport: PacketReportList{}, + expectedHistorySize: 4, + }, + { + outgoing: []struct { + seqNr uint16 + size uint16 + ts time.Time + }{ + {0, 1200, time.Time{}.Add(1 * time.Millisecond)}, + {1, 1200, time.Time{}.Add(2 * time.Millisecond)}, + {2, 1200, time.Time{}.Add(3 * time.Millisecond)}, + {3, 1200, time.Time{}.Add(4 * time.Millisecond)}, + }, + acks: acknowledgementList{ + ts: time.Time{}.Add(time.Second), + acks: []acknowledgement{ + {1, true, time.Time{}.Add(3 * time.Millisecond), 0}, + {2, false, time.Time{}, 0}, + {3, true, time.Time{}.Add(5 * time.Millisecond), 0}, + }, + }, + expectedReport: PacketReportList{ + Timestamp: time.Time{}.Add(time.Second), + Reports: []PacketReport{ + {1, 1200, time.Time{}.Add(2 * time.Millisecond), true, time.Time{}.Add(3 * time.Millisecond), 0}, + {2, 1200, time.Time{}.Add(3 * time.Millisecond), false, time.Time{}, 0}, + {3, 1200, time.Time{}.Add(4 * time.Millisecond), true, time.Time{}.Add(5 * time.Millisecond), 0}, + }, + }, + expectedHistorySize: 4, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + h := newHistory(200) + for _, op := range tc.outgoing { + assert.NoError(t, h.add(op.seqNr, op.size, op.ts)) + } + prl := h.getReportForAck(tc.acks) + assert.Equal(t, tc.expectedReport, prl) + assert.Equal(t, tc.expectedHistorySize, len(h.seqNrToPacket)) + assert.Equal(t, tc.expectedHistorySize, h.evictList.Len()) + }) + } + }) + + t.Run("garbageCollection", func(t *testing.T) { + h := newHistory(200) + + for i := uint16(0); i < 300; i++ { + assert.NoError(t, h.add(i, 1200, time.Time{}.Add(time.Duration(i)*time.Millisecond))) + } + + acks := acknowledgementList{ + ts: time.Time{}.Add(time.Second), + acks: []acknowledgement{}, + } + for i := uint16(200); i < 290; i++ { + acks.acks = append(acks.acks, acknowledgement{ + seqNr: i, + arrived: true, + arrival: time.Time{}.Add(time.Duration(500+i) * time.Millisecond), + ecn: 0, + }) + } + prl := h.getReportForAck(acks) + assert.Len(t, prl.Reports, 90) + assert.Equal(t, 200, len(h.seqNrToPacket)) + assert.Equal(t, 200, h.evictList.Len()) + }) +} diff --git a/pkg/ccfb/interceptor.go b/pkg/ccfb/interceptor.go index cb5ea23a..fba41b71 100644 --- a/pkg/ccfb/interceptor.go +++ b/pkg/ccfb/interceptor.go @@ -50,7 +50,7 @@ type Interceptor struct { func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { i.lock.Lock() defer i.lock.Unlock() - i.ssrcToHistory[info.SSRC] = newHistory() + i.ssrcToHistory[info.SSRC] = newHistory(200) return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { i.lock.Lock()