Skip to content

Commit

Permalink
Add interceptor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mengelbart committed Jan 20, 2025
1 parent 7ad4eec commit 902265e
Show file tree
Hide file tree
Showing 5 changed files with 385 additions and 20 deletions.
9 changes: 5 additions & 4 deletions internal/test/mock_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type RTPWithError struct {
// RTCPWithError is used to send a batch of rtcp packets or an error on a channel
type RTCPWithError struct {
Packets []rtcp.Packet
Attr interceptor.Attributes
Err error
}

Expand Down Expand Up @@ -107,21 +108,21 @@ func NewMockStream(info *interceptor.StreamInfo, i interceptor.Interceptor) *Moc
go func() {
buf := make([]byte, 1500)
for {
i, _, err := s.rtcpReader.Read(buf, interceptor.Attributes{})
i, attr, err := s.rtcpReader.Read(buf, interceptor.Attributes{})
if err != nil {
if !errors.Is(err, io.EOF) {
s.rtcpInModified <- RTCPWithError{Err: err}
s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err}

Check warning on line 114 in internal/test/mock_stream.go

View check run for this annotation

Codecov / codecov/patch

internal/test/mock_stream.go#L114

Added line #L114 was not covered by tests
}
return
}

pkts, err := rtcp.Unmarshal(buf[:i])
if err != nil {
s.rtcpInModified <- RTCPWithError{Err: err}
s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err}
return
}

s.rtcpInModified <- RTCPWithError{Packets: pkts}
s.rtcpInModified <- RTCPWithError{Attr: attr, Packets: pkts}
}
}()
go func() {
Expand Down
12 changes: 6 additions & 6 deletions pkg/ccfb/history.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type sentPacket struct {
departure time.Time
}

type history struct {
type historyList struct {
lock sync.Mutex
size int
evictList *list.List
Expand All @@ -40,8 +40,8 @@ type history struct {
ackedSeqNr *sequencenumber.Unwrapper
}

func newHistory(size int) *history {
return &history{
func newHistoryList(size int) *historyList {
return &historyList{
lock: sync.Mutex{},
size: size,
evictList: list.New(),
Expand All @@ -51,7 +51,7 @@ func newHistory(size int) *history {
}
}

func (h *history) add(seqNr uint16, size uint16, departure time.Time) error {
func (h *historyList) add(seqNr uint16, size uint16, departure time.Time) error {
h.lock.Lock()
defer h.lock.Unlock()

Expand All @@ -76,7 +76,7 @@ func (h *history) add(seqNr uint16, size uint16, departure time.Time) error {
}

// Must be called while holding the lock
func (h *history) removeOldest() {
func (h *historyList) removeOldest() {
if ent := h.evictList.Front(); ent != nil {
v := h.evictList.Remove(ent)
if sp, ok := v.(sentPacket); ok {
Expand All @@ -85,7 +85,7 @@ func (h *history) removeOldest() {
}
}

func (h *history) getReportForAck(al acknowledgementList) PacketReportList {
func (h *historyList) getReportForAck(al acknowledgementList) PacketReportList {
h.lock.Lock()
defer h.lock.Unlock()

Expand Down
6 changes: 3 additions & 3 deletions pkg/ccfb/history_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func TestHistory(t *testing.T) {
t.Run("errorOnDecreasingSeqNr", func(t *testing.T) {
h := newHistory(200)
h := newHistoryList(200)
assert.NoError(t, h.add(10, 1200, time.Now()))
assert.NoError(t, h.add(11, 1200, time.Now()))
assert.Error(t, h.add(9, 1200, time.Now()))
Expand Down Expand Up @@ -84,7 +84,7 @@ func TestHistory(t *testing.T) {
}
for i, tc := range cases {
t.Run(fmt.Sprintf("%v", i), func(t *testing.T) {
h := newHistory(200)
h := newHistoryList(200)
for _, op := range tc.outgoing {
assert.NoError(t, h.add(op.seqNr, op.size, op.ts))
}
Expand All @@ -97,7 +97,7 @@ func TestHistory(t *testing.T) {
})

t.Run("garbageCollection", func(t *testing.T) {
h := newHistory(200)
h := newHistoryList(200)

for i := uint16(0); i < 300; i++ {
assert.NoError(t, h.add(i, 1200, time.Time{}.Add(time.Duration(i)*time.Millisecond)))
Expand Down
72 changes: 65 additions & 7 deletions pkg/ccfb/interceptor.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ccfb

import (
"log"
"sync"
"time"

Expand All @@ -15,8 +16,48 @@ type ccfbAttributesKeyType uint32

const CCFBAttributesKey ccfbAttributesKeyType = iota

type history interface {
add(seqNr uint16, size uint16, departure time.Time) error
getReportForAck(al acknowledgementList) PacketReportList
}

type Option func(*Interceptor) error

func HistorySize(size int) Option {
return func(i *Interceptor) error {
i.historySize = size
return nil
}

Check warning on line 30 in pkg/ccfb/interceptor.go

View check run for this annotation

Codecov / codecov/patch

pkg/ccfb/interceptor.go#L26-L30

Added lines #L26 - L30 were not covered by tests
}

func timeFactory(f func() time.Time) Option {
return func(i *Interceptor) error {
i.timestamp = f
return nil
}
}

func historyFactory(f func(int) history) Option {
return func(i *Interceptor) error {
i.historyFactory = f
return nil
}
}

func ccfbConverterFactory(f func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32]acknowledgementList)) Option {
return func(i *Interceptor) error {
i.convertCCFB = f
return nil
}
}

func twccConverterFactory(f func(ts time.Time, feedback *rtcp.TransportLayerCC) (time.Time, map[uint32]acknowledgementList)) Option {
return func(i *Interceptor) error {
i.convertTWCC = f
return nil
}
}

type InterceptorFactory struct {
opts []Option
}
Expand All @@ -30,8 +71,15 @@ func NewInterceptor(opts ...Option) (*InterceptorFactory, error) {
func (f *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
i := &Interceptor{
NoOp: interceptor.NoOp{},
lock: sync.Mutex{},
timestamp: time.Now,
ssrcToHistory: make(map[uint32]*history),
convertCCFB: convertCCFB,
convertTWCC: convertTWCC,
ssrcToHistory: make(map[uint32]history),
historySize: 200,
historyFactory: func(size int) history {
return newHistoryList(size)
},

Check warning on line 82 in pkg/ccfb/interceptor.go

View check run for this annotation

Codecov / codecov/patch

pkg/ccfb/interceptor.go#L81-L82

Added lines #L81 - L82 were not covered by tests
}
for _, opt := range f.opts {
if err := opt(i); err != nil {
Expand All @@ -43,9 +91,13 @@ func (f *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor,

type Interceptor struct {
interceptor.NoOp
lock sync.Mutex
timestamp func() time.Time
ssrcToHistory map[uint32]*history
lock sync.Mutex
timestamp func() time.Time
convertCCFB func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32]acknowledgementList)
convertTWCC func(ts time.Time, feedback *rtcp.TransportLayerCC) (time.Time, map[uint32]acknowledgementList)
ssrcToHistory map[uint32]history
historySize int
historyFactory func(int) history
}

// BindLocalStream implements interceptor.Interceptor.
Expand All @@ -67,7 +119,7 @@ func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer inter
if useTWCC {
ssrc = 0
}
i.ssrcToHistory[ssrc] = newHistory(200)
i.ssrcToHistory[ssrc] = i.historyFactory(i.historySize)

return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
i.lock.Lock()
Expand Down Expand Up @@ -106,17 +158,23 @@ func (i *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.
attr = make(interceptor.Attributes)
}

Check warning on line 159 in pkg/ccfb/interceptor.go

View check run for this annotation

Codecov / codecov/patch

pkg/ccfb/interceptor.go#L158-L159

Added lines #L158 - L159 were not covered by tests

log.Printf("%v", buf)

Check failure on line 161 in pkg/ccfb/interceptor.go

View workflow job for this annotation

GitHub Actions / lint / Go

use of `log.Printf` forbidden by pattern `^log.(Panic|Fatal|Print)(f|ln)?$` (forbidigo)
pktReportLists := map[uint32]*PacketReportList{}

pkts, err := attr.GetRTCPPackets(buf)
if err != nil {
return n, attr, err
}

Check warning on line 167 in pkg/ccfb/interceptor.go

View check run for this annotation

Codecov / codecov/patch

pkg/ccfb/interceptor.go#L166-L167

Added lines #L166 - L167 were not covered by tests
log.Printf("got rtcp packets: %v, %v", pkts, err)

Check failure on line 168 in pkg/ccfb/interceptor.go

View workflow job for this annotation

GitHub Actions / lint / Go

use of `log.Printf` forbidden by pattern `^log.(Panic|Fatal|Print)(f|ln)?$` (forbidigo)
for _, pkt := range pkts {
var reportLists map[uint32]acknowledgementList
var reportDeparture time.Time
switch fb := pkt.(type) {
case *rtcp.CCFeedbackReport:
reportDeparture, reportLists = convertCCFB(now, fb)
reportDeparture, reportLists = i.convertCCFB(now, fb)

Check warning on line 174 in pkg/ccfb/interceptor.go

View check run for this annotation

Codecov / codecov/patch

pkg/ccfb/interceptor.go#L173-L174

Added lines #L173 - L174 were not covered by tests
case *rtcp.TransportLayerCC:
reportDeparture, reportLists = convertTWCC(now, fb)
reportDeparture, reportLists = i.convertTWCC(now, fb)
default:

Check warning on line 177 in pkg/ccfb/interceptor.go

View check run for this annotation

Codecov / codecov/patch

pkg/ccfb/interceptor.go#L177

Added line #L177 was not covered by tests
}
for ssrc, reportList := range reportLists {
prl := i.ssrcToHistory[ssrc].getReportForAck(reportList)
Expand Down
Loading

0 comments on commit 902265e

Please sign in to comment.