diff --git a/go.mod b/go.mod index fe03158c..c5269e02 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/pion/interceptor -go 1.20 +go 1.21 require ( github.com/pion/logging v0.2.2 diff --git a/pkg/bwe/acknowledgment.go b/pkg/bwe/acknowledgment.go new file mode 100644 index 00000000..3fc75d76 --- /dev/null +++ b/pkg/bwe/acknowledgment.go @@ -0,0 +1,21 @@ +package bwe + +import ( + "fmt" + "time" + + "github.com/pion/rtcp" +) + +type acknowledgment struct { + seqNr int64 + size uint16 + departure time.Time + arrived bool + arrival time.Time + ecn rtcp.ECN +} + +func (a acknowledgment) String() string { + return fmt.Sprintf("seq=%v, departure=%v, arrival=%v", a.seqNr, a.departure, a.arrival) +} diff --git a/pkg/bwe/arrival_group_accumulator.go b/pkg/bwe/arrival_group_accumulator.go new file mode 100644 index 00000000..69bd3498 --- /dev/null +++ b/pkg/bwe/arrival_group_accumulator.go @@ -0,0 +1,43 @@ +package bwe + +import "time" + +type arrivalGroup []acknowledgment + +type arrivalGroupAccumulator struct { + next arrivalGroup + burstInterval time.Duration +} + +func newArrivalGroupAccumulator() *arrivalGroupAccumulator { + return &arrivalGroupAccumulator{ + next: make([]acknowledgment, 0), + burstInterval: 5 * time.Millisecond, + } +} + +func (a *arrivalGroupAccumulator) onPacketAcked(ack acknowledgment) arrivalGroup { + if len(a.next) == 0 { + a.next = append(a.next, ack) + return nil + } + + if ack.departure.Sub(a.next[0].departure) < a.burstInterval { + a.next = append(a.next, ack) + return nil + } + + interDepartureTime := ack.departure.Sub(a.next[0].departure) + interArrivalTime := ack.arrival.Sub(a.next[len(a.next)-1].arrival) + interGroupDelay := interArrivalTime - interDepartureTime + + if interArrivalTime < a.burstInterval && interGroupDelay < 0 { + a.next = append(a.next, ack) + return nil + } + + group := make(arrivalGroup, len(a.next)) + copy(group, a.next) + a.next = arrivalGroup{ack} + return group +} diff --git a/pkg/bwe/arrival_group_accumulator_test.go b/pkg/bwe/arrival_group_accumulator_test.go new file mode 100644 index 00000000..31cc2bc4 --- /dev/null +++ b/pkg/bwe/arrival_group_accumulator_test.go @@ -0,0 +1,204 @@ +package bwe + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestArrivalGroupAccumulator(t *testing.T) { + triggerNewGroupElement := acknowledgment{ + departure: time.Time{}.Add(time.Second), + arrival: time.Time{}.Add(time.Second), + } + cases := []struct { + name string + log []acknowledgment + exp []arrivalGroup + }{ + { + name: "emptyCreatesNoGroups", + log: []acknowledgment{}, + exp: []arrivalGroup{}, + }, + { + name: "createsSingleElementGroup", + log: []acknowledgment{ + { + departure: time.Time{}, + arrival: time.Time{}.Add(time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{{ + { + departure: time.Time{}, + arrival: time.Time{}.Add(time.Millisecond), + }, + }, + }, + }, + { + name: "createsTwoElementGroup", + log: []acknowledgment{ + { + arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + departure: time.Time{}.Add(3 * time.Millisecond), + arrival: time.Time{}.Add(20 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{{ + { + departure: time.Time{}, + arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + departure: time.Time{}.Add(3 * time.Millisecond), + arrival: time.Time{}.Add(20 * time.Millisecond), + }, + }}, + }, + { + name: "createsTwoArrivalGroups", + log: []acknowledgment{ + { + departure: time.Time{}, + arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + departure: time.Time{}.Add(3 * time.Millisecond), + arrival: time.Time{}.Add(20 * time.Millisecond), + }, + { + departure: time.Time{}.Add(9 * time.Millisecond), + arrival: time.Time{}.Add(30 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + departure: time.Time{}.Add(3 * time.Millisecond), + arrival: time.Time{}.Add(20 * time.Millisecond), + }, + }, + { + { + departure: time.Time{}.Add(9 * time.Millisecond), + arrival: time.Time{}.Add(30 * time.Millisecond), + }, + }, + }, + }, + { + name: "ignoresOutOfOrderPackets", + log: []acknowledgment{ + { + departure: time.Time{}, + arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + departure: time.Time{}.Add(6 * time.Millisecond), + arrival: time.Time{}.Add(34 * time.Millisecond), + }, + { + departure: time.Time{}.Add(8 * time.Millisecond), + arrival: time.Time{}.Add(30 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + departure: time.Time{}, + arrival: time.Time{}.Add(15 * time.Millisecond), + }, + }, + { + { + departure: time.Time{}.Add(6 * time.Millisecond), + arrival: time.Time{}.Add(34 * time.Millisecond), + }, + { + departure: time.Time{}.Add(8 * time.Millisecond), + arrival: time.Time{}.Add(30 * time.Millisecond), + }, + }, + }, + }, + { + name: "newGroupBecauseOfInterDepartureTime", + log: []acknowledgment{ + { + seqNr: 0, + departure: time.Time{}, + arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + seqNr: 1, + departure: time.Time{}.Add(3 * time.Millisecond), + arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + seqNr: 2, + departure: time.Time{}.Add(6 * time.Millisecond), + arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + seqNr: 3, + departure: time.Time{}.Add(9 * time.Millisecond), + arrival: time.Time{}.Add(10 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + seqNr: 0, + departure: time.Time{}, + arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + seqNr: 1, + departure: time.Time{}.Add(3 * time.Millisecond), + arrival: time.Time{}.Add(4 * time.Millisecond), + }, + }, + { + { + seqNr: 2, + departure: time.Time{}.Add(6 * time.Millisecond), + arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + seqNr: 3, + departure: time.Time{}.Add(9 * time.Millisecond), + arrival: time.Time{}.Add(10 * time.Millisecond), + }, + }, + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + aga := newArrivalGroupAccumulator() + received := []arrivalGroup{} + for _, ack := range tc.log { + next := aga.onPacketAcked(ack) + if next != nil { + received = append(received, next) + } + } + assert.Equal(t, tc.exp, received) + }) + } +} diff --git a/pkg/bwe/delay_rate_controller.go b/pkg/bwe/delay_rate_controller.go new file mode 100644 index 00000000..d54cbdd4 --- /dev/null +++ b/pkg/bwe/delay_rate_controller.go @@ -0,0 +1,51 @@ +package bwe + +import ( + "log" + "time" +) + +type DelayRateController struct { + aga *arrivalGroupAccumulator + last arrivalGroup + kf *kalman + od *overuseDetector + rc *rateController + latest usage +} + +func NewDelayRateController(initialRate int) *DelayRateController { + return &DelayRateController{ + aga: newArrivalGroupAccumulator(), + last: []acknowledgment{}, + kf: newKalman(), + od: newOveruseDetector(), + rc: newRateController(initialRate), + } +} + +func (c *DelayRateController) OnPacketAcked(ack acknowledgment) { + next := c.aga.onPacketAcked(ack) + if next == nil { + return + } + if len(next) == 0 { + // ignore empty groups, should never occur + return + } + if len(c.last) == 0 { + c.last = next + return + } + interArrivalTime := next[len(next)-1].arrival.Sub(c.last[len(c.last)-1].arrival) + interDepartureTime := next[0].departure.Sub(c.last[0].departure) + interGroupDelay := interArrivalTime - interDepartureTime + estimate := c.kf.updateEstimate(interGroupDelay) + c.latest = c.od.update(ack.arrival, estimate) + c.last = next + log.Printf("interArrivalTime=%v, interDepartureTime=%v, interGroupDelay=%v, estimate=%v, threshold=%v", interArrivalTime.Nanoseconds(), interDepartureTime.Nanoseconds(), interGroupDelay.Nanoseconds(), estimate.Nanoseconds(), c.od.delayThreshold.Nanoseconds()) +} + +func (c *DelayRateController) Update(ts time.Time, lastDeliveryRate int, rtt time.Duration) int { + return c.rc.update(ts, c.latest, lastDeliveryRate, rtt) +} diff --git a/pkg/bwe/delivery_rate_estimator.go b/pkg/bwe/delivery_rate_estimator.go new file mode 100644 index 00000000..2150a01f --- /dev/null +++ b/pkg/bwe/delivery_rate_estimator.go @@ -0,0 +1,84 @@ +package bwe + +import ( + "container/heap" + "time" +) + +type deliveryRateHeapItem struct { + arrival time.Time + size int +} + +type deliveryRateHeap []deliveryRateHeapItem + +// Len implements heap.Interface. +func (d deliveryRateHeap) Len() int { + return len(d) +} + +// Less implements heap.Interface. +func (d deliveryRateHeap) Less(i int, j int) bool { + return d[i].arrival.Before(d[j].arrival) +} + +// Pop implements heap.Interface. +func (d *deliveryRateHeap) Pop() any { + old := *d + n := len(old) + x := old[n-1] + *d = old[0 : n-1] + return x +} + +// Push implements heap.Interface. +func (d *deliveryRateHeap) Push(x any) { + *d = append(*d, x.(deliveryRateHeapItem)) +} + +// Swap implements heap.Interface. +func (d deliveryRateHeap) Swap(i int, j int) { + d[i], d[j] = d[j], d[i] +} + +type deliveryRateEstimator struct { + window time.Duration + latestArrival time.Time + history *deliveryRateHeap +} + +func newDeliveryRateEstimator(window time.Duration) *deliveryRateEstimator { + return &deliveryRateEstimator{ + window: window, + latestArrival: time.Time{}, + history: &deliveryRateHeap{}, + } +} + +func (e *deliveryRateEstimator) OnPacketAcked(arrival time.Time, size int) { + if arrival.After(e.latestArrival) { + e.latestArrival = arrival + } + heap.Push(e.history, deliveryRateHeapItem{ + arrival: arrival, + size: size, + }) +} + +func (e *deliveryRateEstimator) GetRate() int { + deadline := e.latestArrival.Add(-e.window) + for len(*e.history) > 0 && (*e.history)[0].arrival.Before(deadline) { + heap.Pop(e.history) + } + earliest := e.latestArrival + sum := 0 + for _, i := range *e.history { + if i.arrival.Before(earliest) { + earliest = i.arrival + } + sum += i.size + } + d := e.latestArrival.Sub(earliest) + rate := 8 * float64(sum) / d.Seconds() + return int(rate) +} diff --git a/pkg/bwe/duplicate_ack_filter.go b/pkg/bwe/duplicate_ack_filter.go new file mode 100644 index 00000000..e6753c28 --- /dev/null +++ b/pkg/bwe/duplicate_ack_filter.go @@ -0,0 +1,27 @@ +package bwe + +import "github.com/pion/interceptor/pkg/ccfb" + +type duplicateAckFilter struct { + highestAckedBySSRC map[uint32]int64 +} + +func newDuplicateAckFilter() *duplicateAckFilter { + return &duplicateAckFilter{ + highestAckedBySSRC: make(map[uint32]int64), + } +} + +func (f *duplicateAckFilter) filter(reports map[uint32]*ccfb.PacketReportList) { + for ssrc, prl := range reports { + n := 0 + for _, report := range prl.Reports { + if highest, ok := f.highestAckedBySSRC[ssrc]; !ok || report.SeqNr > highest { + f.highestAckedBySSRC[ssrc] = report.SeqNr + prl.Reports[n] = report + n++ + } + } + prl.Reports = prl.Reports[:n] + } +} diff --git a/pkg/bwe/exponential_moving_average.go b/pkg/bwe/exponential_moving_average.go new file mode 100644 index 00000000..9448f02d --- /dev/null +++ b/pkg/bwe/exponential_moving_average.go @@ -0,0 +1,17 @@ +package bwe + +type exponentialMovingAverage struct { + alpha float64 + average float64 + variance float64 +} + +func (a *exponentialMovingAverage) update(sample float64) { + if a.average == 0.0 { + a.average = sample + } else { + a.average = a.alpha*sample + (1-a.alpha)*a.average + delta := sample - a.average + a.variance = (1-a.alpha)*a.variance + a.alpha*(1-a.alpha)*(delta*delta) + } +} diff --git a/pkg/bwe/kalman.go b/pkg/bwe/kalman.go new file mode 100644 index 00000000..d6e7dd43 --- /dev/null +++ b/pkg/bwe/kalman.go @@ -0,0 +1,92 @@ +package bwe + +import ( + "math" + "time" +) + +const ( + chi = 0.001 +) + +type kalmanOption func(*kalman) + +type kalman struct { + gain float64 + estimate time.Duration + processUncertainty float64 // Q_i + estimateError float64 + measurementUncertainty float64 + + disableMeasurementUncertaintyUpdates bool +} + +func initEstimate(e time.Duration) kalmanOption { + return func(k *kalman) { + k.estimate = e + } +} + +func initProcessUncertainty(p float64) kalmanOption { + return func(k *kalman) { + k.processUncertainty = p + } +} + +func initEstimateError(e float64) kalmanOption { + return func(k *kalman) { + k.estimateError = e * e // Only need variance from now on + } +} + +func initMeasurementUncertainty(u float64) kalmanOption { + return func(k *kalman) { + k.measurementUncertainty = u + } +} + +func setDisableMeasurementUncertaintyUpdates(b bool) kalmanOption { + return func(k *kalman) { + k.disableMeasurementUncertaintyUpdates = b + } +} + +func newKalman(opts ...kalmanOption) *kalman { + k := &kalman{ + gain: 0, + estimate: 0, + processUncertainty: 1e-3, + estimateError: 0.1, + measurementUncertainty: 0, + disableMeasurementUncertaintyUpdates: false, + } + for _, opt := range opts { + opt(k) + } + return k +} + +func (k *kalman) updateEstimate(measurement time.Duration) time.Duration { + z := measurement - k.estimate + + zms := float64(z.Microseconds()) / 1000.0 + + if !k.disableMeasurementUncertaintyUpdates { + alpha := math.Pow((1 - chi), 30.0/(1000.0*5*float64(time.Millisecond))) + root := math.Sqrt(k.measurementUncertainty) + root3 := 3 * root + if zms > root3 { + k.measurementUncertainty = math.Max(alpha*k.measurementUncertainty+(1-alpha)*root3*root3, 1) + } else { + k.measurementUncertainty = math.Max(alpha*k.measurementUncertainty+(1-alpha)*zms*zms, 1) + } + } + + estimateUncertainty := k.estimateError + k.processUncertainty + k.gain = estimateUncertainty / (estimateUncertainty + k.measurementUncertainty) + + k.estimate += time.Duration(k.gain * zms * float64(time.Millisecond)) + + k.estimateError = (1 - k.gain) * estimateUncertainty + return k.estimate +} diff --git a/pkg/bwe/kalman_test.go b/pkg/bwe/kalman_test.go new file mode 100644 index 00000000..0d7968d3 --- /dev/null +++ b/pkg/bwe/kalman_test.go @@ -0,0 +1,68 @@ +package bwe + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestKalman(t *testing.T) { + cases := []struct { + name string + opts []kalmanOption + measurements []time.Duration + expected []time.Duration + }{ + { + name: "empty", + opts: []kalmanOption{}, + measurements: []time.Duration{}, + expected: []time.Duration{}, + }, + { + name: "kalmanfilter.netExample", + opts: []kalmanOption{ + initEstimate(10 * time.Millisecond), + initEstimateError(100), + initProcessUncertainty(0.15), + initMeasurementUncertainty(0.01), + }, + measurements: []time.Duration{ + time.Duration(50.45 * float64(time.Millisecond)), + time.Duration(50.967 * float64(time.Millisecond)), + time.Duration(51.6 * float64(time.Millisecond)), + time.Duration(52.106 * float64(time.Millisecond)), + time.Duration(52.492 * float64(time.Millisecond)), + time.Duration(52.819 * float64(time.Millisecond)), + time.Duration(53.433 * float64(time.Millisecond)), + time.Duration(54.007 * float64(time.Millisecond)), + time.Duration(54.523 * float64(time.Millisecond)), + time.Duration(54.99 * float64(time.Millisecond)), + }, + expected: []time.Duration{ + time.Duration(50.449959 * float64(time.Millisecond)), + time.Duration(50.936547 * float64(time.Millisecond)), + time.Duration(51.560411 * float64(time.Millisecond)), + time.Duration(52.07324 * float64(time.Millisecond)), + time.Duration(52.466566 * float64(time.Millisecond)), + time.Duration(52.797787 * float64(time.Millisecond)), + time.Duration(53.395303 * float64(time.Millisecond)), + time.Duration(53.970236 * float64(time.Millisecond)), + time.Duration(54.489652 * float64(time.Millisecond)), + time.Duration(54.960137 * float64(time.Millisecond)), + }, + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + k := newKalman(append(tc.opts, setDisableMeasurementUncertaintyUpdates(true))...) + estimates := []time.Duration{} + for _, m := range tc.measurements { + estimates = append(estimates, k.updateEstimate(m)) + } + assert.Equal(t, tc.expected, estimates, "%v != %v", tc.expected, estimates) + }) + } +} diff --git a/pkg/bwe/loss_rate_controller.go b/pkg/bwe/loss_rate_controller.go new file mode 100644 index 00000000..45537a64 --- /dev/null +++ b/pkg/bwe/loss_rate_controller.go @@ -0,0 +1,67 @@ +package bwe + +import ( + "sync" +) + +type LossRateController struct { + lock sync.Mutex + bitrate int + min, max int + + packetsSinceLastUpdate int + arrivedSinceLastUpdate int + lostSinceLastUpdate int +} + +func NewLossRateController(initialRate, minRate, maxRate int) *LossRateController { + return &LossRateController{ + lock: sync.Mutex{}, + bitrate: initialRate, + min: minRate, + max: maxRate, + packetsSinceLastUpdate: 0, + arrivedSinceLastUpdate: 0, + lostSinceLastUpdate: 0, + } +} + +func (l *LossRateController) OnPacketAcked() { + l.lock.Lock() + defer l.lock.Unlock() + l.packetsSinceLastUpdate++ + l.arrivedSinceLastUpdate++ +} + +func (l *LossRateController) OnPacketLost() { + l.lock.Lock() + defer l.lock.Unlock() + l.packetsSinceLastUpdate++ + l.lostSinceLastUpdate++ +} + +func (l *LossRateController) Update(lastDeliveryRate int) int { + l.lock.Lock() + defer l.lock.Unlock() + + lossRate := float64(l.lostSinceLastUpdate) / float64(l.packetsSinceLastUpdate) + canIncrease := float64(lastDeliveryRate) >= 0.95*float64(l.bitrate) + if lossRate > 0.1 { + l.bitrate = int(float64(l.bitrate) * (1 - 0.5*lossRate)) + } else if lossRate < 0.02 && canIncrease { + l.bitrate = int(float64(l.bitrate) * 1.05) + } + l.bitrate = max(min(l.bitrate, l.max), l.min) + + l.packetsSinceLastUpdate = 0 + l.arrivedSinceLastUpdate = 0 + l.lostSinceLastUpdate = 0 + + return l.bitrate +} + +func (l *LossRateController) Bitrate() int { + l.lock.Lock() + defer l.lock.Unlock() + return l.bitrate +} diff --git a/pkg/bwe/overuse_detector.go b/pkg/bwe/overuse_detector.go new file mode 100644 index 00000000..eb8e7f3b --- /dev/null +++ b/pkg/bwe/overuse_detector.go @@ -0,0 +1,79 @@ +package bwe + +import ( + "time" +) + +const ( + kU = 0.01 + kD = 0.00018 +) + +type overuseDetector struct { + adaptiveThreshold bool + overUseTimeThreshold time.Duration + delayThreshold time.Duration + lastEstimate time.Duration + lastUpdate time.Time + firstOverUse time.Time + inOveruse bool +} + +func newOveruseDetector() *overuseDetector { + return &overuseDetector{ + adaptiveThreshold: true, + overUseTimeThreshold: 10 * time.Millisecond, + delayThreshold: 12500 * time.Microsecond, + lastEstimate: 0, + lastUpdate: time.Time{}, + firstOverUse: time.Time{}, + inOveruse: false, + } +} + +func (d *overuseDetector) update(ts time.Time, estimate time.Duration) usage { + if d.adaptiveThreshold { + defer d.adaptThreshold(ts, estimate) + } + if estimate >= d.lastEstimate && estimate > d.delayThreshold { + if d.inOveruse && ts.Sub(d.firstOverUse) > d.overUseTimeThreshold { + return usageOver + } + if !d.inOveruse { + d.firstOverUse = ts + } + d.inOveruse = true + return usageNormal + } + if estimate < -d.delayThreshold { + d.inOveruse = false + return usageUnder + } + d.inOveruse = false + return usageNormal +} + +func (d *overuseDetector) adaptThreshold(ts time.Time, estimate time.Duration) { + delta := ts.Sub(d.lastUpdate) + d.lastUpdate = ts + absEstimate := estimate + if absEstimate < 0 { + absEstimate = -absEstimate + } + if absEstimate-d.delayThreshold > 15 { + return + } + var k float64 + if absEstimate < d.delayThreshold { + k = kD + } else { + k = kU + } + d.delayThreshold = d.delayThreshold + delta*time.Duration(k)*(absEstimate-d.delayThreshold) + if d.delayThreshold < 6*time.Millisecond { + d.delayThreshold = 6 * time.Millisecond + } + if d.delayThreshold > 600*time.Millisecond { + d.delayThreshold = 600 * time.Millisecond + } +} diff --git a/pkg/bwe/overuse_detector_test.go b/pkg/bwe/overuse_detector_test.go new file mode 100644 index 00000000..8d762fe1 --- /dev/null +++ b/pkg/bwe/overuse_detector_test.go @@ -0,0 +1,76 @@ +package bwe + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestOveruseDetectorWithoutDelay(t *testing.T) { + type estimate struct { + ts time.Time + estimate time.Duration + } + cases := []struct { + name string + values []estimate + expected []usage + }{ + { + name: "noEstimateNoUsage", + values: []estimate{}, + expected: []usage{}, + }, + { + name: "overuse", + values: []estimate{ + {time.Time{}, time.Millisecond}, + {time.Time{}.Add(5 * time.Millisecond), 20 * time.Millisecond}, + {time.Time{}.Add(20 * time.Millisecond), 30 * time.Millisecond}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "normaluse", + values: []estimate{{estimate: 0}}, + expected: []usage{usageNormal}, + }, + { + name: "underuse", + values: []estimate{{estimate: -20 * time.Millisecond}}, + expected: []usage{usageUnder}, + }, + { + name: "noOverUseBeforeDelay", + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20 * time.Millisecond}, + {time.Time{}.Add(2 * time.Millisecond), 30 * time.Millisecond}, + {time.Time{}.Add(30 * time.Millisecond), 50 * time.Millisecond}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "noOverUseIfEstimateDecreased", + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20 * time.Millisecond}, + {time.Time{}.Add(10 * time.Millisecond), 40 * time.Millisecond}, + {time.Time{}.Add(20 * time.Millisecond), 50 * time.Millisecond}, + {time.Time{}.Add(30 * time.Millisecond), 3 * time.Millisecond}, + }, + expected: []usage{usageNormal, usageNormal, usageOver, usageNormal}, + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + od := newOveruseDetector() + received := []usage{} + for _, e := range tc.values { + usage := od.update(e.ts, e.estimate) + received = append(received, usage) + } + assert.Equal(t, tc.expected, received) + }) + } +} diff --git a/pkg/bwe/rate_controller.go b/pkg/bwe/rate_controller.go new file mode 100644 index 00000000..a9a6f6c1 --- /dev/null +++ b/pkg/bwe/rate_controller.go @@ -0,0 +1,76 @@ +package bwe + +import ( + "math" + "time" +) + +type rateController struct { + s state + rate int + + decreaseFactor float64 // (beta) + lastUpdate time.Time + lastDecrease *exponentialMovingAverage +} + +func newRateController(initialRate int) *rateController { + return &rateController{ + s: stateIncrease, + rate: initialRate, + decreaseFactor: 0.85, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{}, + } +} + +func (c *rateController) update(ts time.Time, u usage, deliveredRate int, rtt time.Duration) int { + nextState := c.s.transition(u) + c.s = nextState + + if c.s == stateIncrease { + var target float64 + if c.canIncreaseMultiplicatively(float64(deliveredRate)) { + window := ts.Sub(c.lastUpdate) + target = c.multiplicativeIncrease(float64(c.rate), window) + } else { + bitsPerFrame := float64(c.rate) / 30.0 + packetsPerFrame := math.Ceil(bitsPerFrame / (1200 * 8)) + expectedPacketSizeBits := bitsPerFrame / packetsPerFrame + target = c.additiveIncrease(float64(c.rate), int(expectedPacketSizeBits), rtt) + } + c.rate = int(min(target, 1.5*float64(deliveredRate))) + } + + if c.s == stateDecrease { + c.rate = int(c.decreaseFactor * float64(deliveredRate)) + c.lastDecrease.update(float64(c.rate)) + } + + c.lastUpdate = ts + + return c.rate +} + +func (c *rateController) canIncreaseMultiplicatively(deliveredRate float64) bool { + if c.lastDecrease.average == 0 { + return true + } + stdDev := math.Sqrt(c.lastDecrease.variance) + lower := c.lastDecrease.average - 3*stdDev + upper := c.lastDecrease.average + 3*stdDev + return deliveredRate < lower || deliveredRate > upper +} + +func (c *rateController) multiplicativeIncrease(rate float64, window time.Duration) float64 { + exponent := min(window.Seconds(), 1.0) + eta := math.Pow(1.08, exponent) + target := eta * rate + return target +} + +func (c *rateController) additiveIncrease(rate float64, expectedPacketSizeBits int, window time.Duration) float64 { + alpha := 0.5 * min(window.Seconds(), 1.0) + target := rate + max(1000, alpha*float64(expectedPacketSizeBits)) + return target +} diff --git a/pkg/bwe/send_side_bwe.go b/pkg/bwe/send_side_bwe.go new file mode 100644 index 00000000..21580439 --- /dev/null +++ b/pkg/bwe/send_side_bwe.go @@ -0,0 +1,76 @@ +package bwe + +import ( + "log" + "time" + + "github.com/pion/interceptor/pkg/ccfb" +) + +type SendSideController struct { + daf *duplicateAckFilter + dre *deliveryRateEstimator + lbc *LossRateController + drc *DelayRateController + rate int +} + +func NewSendSideController(initialRate, minRate, maxRate int) *SendSideController { + return &SendSideController{ + daf: newDuplicateAckFilter(), + dre: newDeliveryRateEstimator(time.Second), + lbc: NewLossRateController(initialRate, minRate, maxRate), + drc: NewDelayRateController(initialRate), + } +} + +func (c *SendSideController) OnFeedbackReport(reports map[uint32]*ccfb.PacketReportList) int { + c.daf.filter(reports) + + var latestReportArrival time.Time + var latestReportDeparture time.Time + var lastAckedArrival time.Time + var lastAckedPacketDeparture time.Time + + acksCount := 0 + for _, prl := range reports { + if prl.Arrival.After(latestReportArrival) { + latestReportArrival = prl.Arrival + latestReportDeparture = prl.Departure + } + for _, r := range prl.Reports { + if r.Arrived && !r.Arrival.IsZero() { // in some cases we might receive acks without timestamps. Ignore them. + c.dre.OnPacketAcked(r.Arrival, int(r.Size)) + c.lbc.OnPacketAcked() + c.drc.OnPacketAcked(acknowledgment{ + seqNr: r.SeqNr, + size: r.Size, + departure: r.Departure, + arrived: r.Arrived, + arrival: r.Arrival, + ecn: r.ECN, + }) + if r.Departure.After(lastAckedPacketDeparture) { + lastAckedPacketDeparture = r.Departure + lastAckedArrival = r.Arrival + } + acksCount++ + } else { + c.lbc.OnPacketLost() + } + } + } + if acksCount == 0 { + return c.rate + } + + pendingTime := latestReportDeparture.Sub(lastAckedArrival) + rtt := latestReportArrival.Sub(lastAckedPacketDeparture) - pendingTime + + delivered := c.dre.GetRate() + lossTarget := c.lbc.Update(delivered) + delayTarget := c.drc.Update(latestReportArrival, delivered, rtt) + c.rate = min(lossTarget, delayTarget) + log.Printf("rtt=%v, delivered=%v, lossTarget=%v, delayTarget=%v, target=%v", rtt.Microseconds(), delivered, lossTarget, delayTarget, c.rate) + return c.rate +} diff --git a/pkg/bwe/state.go b/pkg/bwe/state.go new file mode 100644 index 00000000..70f90069 --- /dev/null +++ b/pkg/bwe/state.go @@ -0,0 +1,59 @@ +package bwe + +import "fmt" + +type state int + +const ( + stateIncrease state = iota + stateDecrease + stateHold +) + +func (s state) transition(u usage) state { + switch s { + case stateHold: + switch u { + case usageOver: + return stateDecrease + case usageNormal: + return stateIncrease + case usageUnder: + return stateHold + } + + case stateIncrease: + switch u { + case usageOver: + return stateDecrease + case usageNormal: + return stateIncrease + case usageUnder: + return stateHold + } + + case stateDecrease: + switch u { + case usageOver: + return stateDecrease + case usageNormal: + return stateHold + case usageUnder: + return stateHold + } + } + return stateIncrease +} + +func (s state) String() string { + switch s { + case stateIncrease: + return "increase" + case stateDecrease: + return "decrease" + case stateHold: + return "hold" + default: + return fmt.Sprintf("invalid state: %d", s) + } +} diff --git a/pkg/bwe/usage.go b/pkg/bwe/usage.go new file mode 100644 index 00000000..0c520950 --- /dev/null +++ b/pkg/bwe/usage.go @@ -0,0 +1,24 @@ +package bwe + +import "fmt" + +type usage int + +const ( + usageOver usage = iota + usageUnder + usageNormal +) + +func (u usage) String() string { + switch u { + case usageOver: + return "overuse" + case usageUnder: + return "underuse" + case usageNormal: + return "normal" + default: + return fmt.Sprintf("invalid usage: %d", u) + } +}