Skip to content

Commit

Permalink
Add interceptor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mengelbart committed Jan 27, 2025
1 parent eb37e5d commit f2d1167
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 22 deletions.
9 changes: 5 additions & 4 deletions internal/test/mock_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type RTPWithError struct {
// RTCPWithError is used to send a batch of rtcp packets or an error on a channel
type RTCPWithError struct {
Packets []rtcp.Packet
Attr interceptor.Attributes
Err error
}

Expand Down Expand Up @@ -107,21 +108,21 @@ func NewMockStream(info *interceptor.StreamInfo, i interceptor.Interceptor) *Moc
go func() {
buf := make([]byte, 1500)
for {
i, _, err := s.rtcpReader.Read(buf, interceptor.Attributes{})
i, attr, err := s.rtcpReader.Read(buf, interceptor.Attributes{})
if err != nil {
if !errors.Is(err, io.EOF) {
s.rtcpInModified <- RTCPWithError{Err: err}
s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err}
}
return
}

pkts, err := rtcp.Unmarshal(buf[:i])
if err != nil {
s.rtcpInModified <- RTCPWithError{Err: err}
s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err}
return
}

s.rtcpInModified <- RTCPWithError{Packets: pkts}
s.rtcpInModified <- RTCPWithError{Attr: attr, Packets: pkts}
}
}()
go func() {
Expand Down
12 changes: 6 additions & 6 deletions pkg/ccfb/history.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type sentPacket struct {
departure time.Time
}

type history struct {
type historyList struct {
lock sync.Mutex
size int
evictList *list.List
Expand All @@ -40,8 +40,8 @@ type history struct {
ackedSeqNr *sequencenumber.Unwrapper
}

func newHistory(size int) *history {
return &history{
func newHistoryList(size int) *historyList {
return &historyList{
lock: sync.Mutex{},
size: size,
evictList: list.New(),
Expand All @@ -51,7 +51,7 @@ func newHistory(size int) *history {
}
}

func (h *history) add(seqNr uint16, size uint16, departure time.Time) error {
func (h *historyList) add(seqNr uint16, size uint16, departure time.Time) error {
h.lock.Lock()
defer h.lock.Unlock()

Expand All @@ -76,7 +76,7 @@ func (h *history) add(seqNr uint16, size uint16, departure time.Time) error {
}

// Must be called while holding the lock
func (h *history) removeOldest() {
func (h *historyList) removeOldest() {
if ent := h.evictList.Front(); ent != nil {
v := h.evictList.Remove(ent)
if sp, ok := v.(sentPacket); ok {
Expand All @@ -85,7 +85,7 @@ func (h *history) removeOldest() {
}
}

func (h *history) getReportForAck(al acknowledgementList) PacketReportList {
func (h *historyList) getReportForAck(al acknowledgementList) PacketReportList {
h.lock.Lock()
defer h.lock.Unlock()

Expand Down
6 changes: 3 additions & 3 deletions pkg/ccfb/history_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func TestHistory(t *testing.T) {
t.Run("errorOnDecreasingSeqNr", func(t *testing.T) {
h := newHistory(200)
h := newHistoryList(200)
assert.NoError(t, h.add(10, 1200, time.Now()))
assert.NoError(t, h.add(11, 1200, time.Now()))
assert.Error(t, h.add(9, 1200, time.Now()))
Expand Down Expand Up @@ -84,7 +84,7 @@ func TestHistory(t *testing.T) {
}
for i, tc := range cases {
t.Run(fmt.Sprintf("%v", i), func(t *testing.T) {
h := newHistory(200)
h := newHistoryList(200)
for _, op := range tc.outgoing {
assert.NoError(t, h.add(op.seqNr, op.size, op.ts))
}
Expand All @@ -97,7 +97,7 @@ func TestHistory(t *testing.T) {
})

t.Run("garbageCollection", func(t *testing.T) {
h := newHistory(200)
h := newHistoryList(200)

for i := uint16(0); i < 300; i++ {
assert.NoError(t, h.add(i, 1200, time.Time{}.Add(time.Duration(i)*time.Millisecond)))
Expand Down
69 changes: 62 additions & 7 deletions pkg/ccfb/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,48 @@ type ccfbAttributesKeyType uint32

const CCFBAttributesKey ccfbAttributesKeyType = iota

type history interface {
add(seqNr uint16, size uint16, departure time.Time) error
getReportForAck(al acknowledgementList) PacketReportList
}

type Option func(*Interceptor) error

func HistorySize(size int) Option {
return func(i *Interceptor) error {
i.historySize = size
return nil
}
}

func timeFactory(f func() time.Time) Option {
return func(i *Interceptor) error {
i.timestamp = f
return nil
}
}

func historyFactory(f func(int) history) Option {
return func(i *Interceptor) error {
i.historyFactory = f
return nil
}
}

func ccfbConverterFactory(f func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32]acknowledgementList)) Option {
return func(i *Interceptor) error {
i.convertCCFB = f
return nil
}
}

func twccConverterFactory(f func(ts time.Time, feedback *rtcp.TransportLayerCC) (time.Time, map[uint32]acknowledgementList)) Option {
return func(i *Interceptor) error {
i.convertTWCC = f
return nil
}
}

type InterceptorFactory struct {
opts []Option
}
Expand All @@ -30,8 +70,15 @@ func NewInterceptor(opts ...Option) (*InterceptorFactory, error) {
func (f *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
i := &Interceptor{
NoOp: interceptor.NoOp{},
lock: sync.Mutex{},
timestamp: time.Now,
ssrcToHistory: make(map[uint32]*history),
convertCCFB: convertCCFB,
convertTWCC: convertTWCC,
ssrcToHistory: make(map[uint32]history),
historySize: 200,
historyFactory: func(size int) history {
return newHistoryList(size)
},
}
for _, opt := range f.opts {
if err := opt(i); err != nil {
Expand All @@ -43,9 +90,13 @@ func (f *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor,

type Interceptor struct {
interceptor.NoOp
lock sync.Mutex
timestamp func() time.Time
ssrcToHistory map[uint32]*history
lock sync.Mutex
timestamp func() time.Time
convertCCFB func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32]acknowledgementList)
convertTWCC func(ts time.Time, feedback *rtcp.TransportLayerCC) (time.Time, map[uint32]acknowledgementList)
ssrcToHistory map[uint32]history
historySize int
historyFactory func(int) history
}

// BindLocalStream implements interceptor.Interceptor.
Expand All @@ -67,7 +118,7 @@ func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer inter
if useTWCC {
ssrc = 0
}
i.ssrcToHistory[ssrc] = newHistory(200)
i.ssrcToHistory[ssrc] = i.historyFactory(i.historySize)

return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
i.lock.Lock()
Expand Down Expand Up @@ -109,14 +160,18 @@ func (i *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.
pktReportLists := map[uint32]*PacketReportList{}

pkts, err := attr.GetRTCPPackets(buf)
if err != nil {
return n, attr, err
}
for _, pkt := range pkts {
var reportLists map[uint32]acknowledgementList
var reportDeparture time.Time
switch fb := pkt.(type) {
case *rtcp.CCFeedbackReport:
reportDeparture, reportLists = convertCCFB(now, fb)
reportDeparture, reportLists = i.convertCCFB(now, fb)
case *rtcp.TransportLayerCC:
reportDeparture, reportLists = convertTWCC(now, fb)
reportDeparture, reportLists = i.convertTWCC(now, fb)
default:
}
for ssrc, reportList := range reportLists {
prl := i.ssrcToHistory[ssrc].getReportForAck(reportList)
Expand Down
Loading

0 comments on commit f2d1167

Please sign in to comment.