diff --git a/pkg/ccfb/interceptor.go b/pkg/ccfb/interceptor.go index fba41b71..8d946dbf 100644 --- a/pkg/ccfb/interceptor.go +++ b/pkg/ccfb/interceptor.go @@ -9,6 +9,8 @@ import ( "github.com/pion/rtp" ) +const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" + type ccfbAttributesKeyType uint32 const CCFBAttributesKey ccfbAttributesKeyType = iota @@ -48,6 +50,16 @@ type Interceptor struct { // BindLocalStream implements interceptor.Interceptor. func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + var twccHdrExtID uint8 + var useTWCC bool + for _, e := range info.RTPHeaderExtensions { + if e.URI == transportCCURI { + twccHdrExtID = uint8(e.ID) + useTWCC = true + break + } + } + i.lock.Lock() defer i.lock.Unlock() i.ssrcToHistory[info.SSRC] = newHistory(200) @@ -55,7 +67,20 @@ func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer inter return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { i.lock.Lock() defer i.lock.Unlock() - i.ssrcToHistory[header.SSRC].add(header.SequenceNumber, uint16(header.MarshalSize()+len(payload)), i.timestamp()) + + // If we are using TWCC, we use the sequence number from the TWCC header + // extension and save all TWCC sequence numbers with the same SSRC (0). + // If we are not using TWCC, we save a history per SSRC and use the + // normal RTP sequence numbers. + ssrc := header.SSRC + seqNr := header.SequenceNumber + if useTWCC { + ssrc = 0 + var twccHdrExt rtp.TransportCCExtension + twccHdrExt.Unmarshal(header.GetExtension(twccHdrExtID)) + seqNr = twccHdrExt.TransportSequence + } + i.ssrcToHistory[ssrc].add(seqNr, uint16(header.MarshalSize()+len(payload)), i.timestamp()) return writer.Write(header, payload, attributes) }) } @@ -80,16 +105,19 @@ func (i *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor. pkts, err := attr.GetRTCPPackets(buf) for _, pkt := range pkts { + var reportLists map[uint32]acknowledgementList switch fb := pkt.(type) { case *rtcp.CCFeedbackReport: - reportLists := convertCCFB(now, fb) - for ssrc, reportList := range reportLists { - prl := i.ssrcToHistory[ssrc].getReportForAck(reportList) - if l, ok := pktReportLists[ssrc]; !ok { - pktReportLists[ssrc] = &prl - } else { - l.Reports = append(l.Reports, prl.Reports...) - } + reportLists = convertCCFB(now, fb) + case *rtcp.TransportLayerCC: + reportLists = convertTWCC(now, fb) + } + for ssrc, reportList := range reportLists { + prl := i.ssrcToHistory[ssrc].getReportForAck(reportList) + if l, ok := pktReportLists[ssrc]; !ok { + pktReportLists[ssrc] = &prl + } else { + l.Reports = append(l.Reports, prl.Reports...) } } } diff --git a/pkg/ccfb/twcc_receiver.go b/pkg/ccfb/twcc_receiver.go new file mode 100644 index 00000000..22b58ab1 --- /dev/null +++ b/pkg/ccfb/twcc_receiver.go @@ -0,0 +1,90 @@ +package ccfb + +import ( + "time" + + "github.com/pion/rtcp" +) + +func convertTWCC(ts time.Time, feedback *rtcp.TransportLayerCC) map[uint32]acknowledgementList { + if feedback == nil { + return nil + } + var acks []acknowledgement + + referenceTime := time.Time{}.Add(time.Duration(feedback.ReferenceTime) * 64 * time.Millisecond) + recvDeltaIndex := 0 + + offset := 0 + for _, pc := range feedback.PacketChunks { + switch chunk := pc.(type) { + case *rtcp.RunLengthChunk: + for i := uint16(0); i < chunk.RunLength; i++ { + seqNr := feedback.BaseSequenceNumber + uint16(offset) + offset++ + switch chunk.PacketStatusSymbol { + case rtcp.TypeTCCPacketNotReceived: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta: + delta := feedback.RecvDeltas[recvDeltaIndex] + recvDeltaIndex++ + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: referenceTime.Add(time.Duration(delta.Delta) * time.Millisecond), + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedWithoutDelta: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: time.Time{}, + ecn: 0, + }) + } + } + case *rtcp.StatusVectorChunk: + for _, s := range chunk.SymbolList { + seqNr := feedback.BaseSequenceNumber + uint16(offset) + offset++ + switch s { + case rtcp.TypeTCCPacketNotReceived: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta: + delta := feedback.RecvDeltas[recvDeltaIndex] + recvDeltaIndex++ + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: referenceTime.Add(time.Duration(delta.Delta) * time.Millisecond), + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedWithoutDelta: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: time.Time{}, + ecn: 0, + }) + } + } + } + } + + return map[uint32]acknowledgementList{ + feedback.MediaSSRC: { + ts: ts, + acks: acks, + }, + } +} diff --git a/pkg/ccfb/twcc_receiver_test.go b/pkg/ccfb/twcc_receiver_test.go new file mode 100644 index 00000000..d042e958 --- /dev/null +++ b/pkg/ccfb/twcc_receiver_test.go @@ -0,0 +1,141 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/pion/rtcp" + "github.com/stretchr/testify/assert" +) + +func TestConvertTWCC(t *testing.T) { + timeZero := time.Now() + cases := []struct { + ts time.Time + feedback *rtcp.TransportLayerCC + expect map[uint32]acknowledgementList + }{ + {}, + { + ts: timeZero.Add(2 * time.Second), + feedback: &rtcp.TransportLayerCC{ + SenderSSRC: 1, + MediaSSRC: 2, + BaseSequenceNumber: 178, + PacketStatusCount: 0, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{}, + RecvDeltas: []*rtcp.RecvDelta{}, + }, + expect: map[uint32]acknowledgementList{ + 2: { + ts: timeZero.Add(2 * time.Second), + acks: []acknowledgement{}, + }, + }, + }, + { + ts: timeZero.Add(2 * time.Second), + feedback: &rtcp.TransportLayerCC{ + SenderSSRC: 1, + MediaSSRC: 2, + BaseSequenceNumber: 178, + PacketStatusCount: 3, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeOneBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 0}, + }, + }, + expect: map[uint32]acknowledgementList{ + 2: { + ts: timeZero.Add(2 * time.Second), + acks: []acknowledgement{ + // first run length chunk + {seqNr: 178, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 179, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 180, arrived: true, arrival: time.Time{}, ecn: 0}, + + // first status vector chunk + {seqNr: 181, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 182, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 183, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 184, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 185, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 186, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 187, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 188, arrived: false, arrival: time.Time{}, ecn: 0}, + + // second status vector chunk + {seqNr: 189, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 190, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 191, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 192, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 193, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 194, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 195, arrived: false, arrival: time.Time{}, ecn: 0}, + }, + }, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := convertTWCC(tc.ts, tc.feedback) + + // Can't directly check equality since arrival timestamp conversions + // may be slightly off due to ntp conversions. + assert.Equal(t, len(tc.expect), len(res)) + for i, ee := range tc.expect { + assert.Equal(t, ee.ts, res[i].ts) + for j, ack := range ee.acks { + assert.Equal(t, ack.seqNr, res[i].acks[j].seqNr) + assert.Equal(t, ack.arrived, res[i].acks[j].arrived) + assert.Equal(t, ack.ecn, res[i].acks[j].ecn) + assert.InDelta(t, ack.arrival.UnixNano(), res[i].acks[j].arrival.UnixNano(), float64(time.Millisecond.Nanoseconds())) + } + } + }) + } + +}