Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rtp header padding #1079

Closed
wants to merge 13 commits into from
5 changes: 3 additions & 2 deletions worker/fuzzer/src/RTC/FuzzerRtpPacket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ void Fuzzer::RTC::RtpPacket::Fuzz(const uint8_t* data, size_t len)
packet->HasExtension(14);
packet->GetExtension(14, extenLen);
packet->ReadTransportWideCc01(wideSeqNumber);
packet->UpdateTransportWideCc01(12345u);
packet->SetExtensionLength(14, 2);
//packet->UpdateTransportWideCc01(12345u);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are those tests commented? We cannot leave commented tests. Do they no longer work or what?

Copy link
Author

@buptlsp buptlsp May 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this two line can be replace by setExtensionValue(). should I just delete the 2 line?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test as much as possible

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test as much as possible

Extremely agree. This is my first time writing C++, and I have been writing PHP before. So the more tests, the better

//packet->SetExtensionLength(14, 2);
packet->SetExtensionValue(packet->transportWideCc01ExtensionId, 2u, "09");

packet->SetSsrcAudioLevelExtensionId(11);
packet->HasExtension(11);
Expand Down
2 changes: 1 addition & 1 deletion worker/include/RTC/RtpPacket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ namespace RTC
}
}

bool SetExtensionLength(uint8_t id, uint8_t len);
bool SetExtensionValue(uint8_t id, uint8_t len, const std::string& value);

uint8_t* GetPayload() const
{
Expand Down
154 changes: 138 additions & 16 deletions worker/src/RTC/RtpPacket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,16 +566,14 @@ namespace RTC
return;
}

std::memcpy(extenValue, mid.c_str(), midLen);

SetExtensionLength(this->midExtensionId, midLen);
SetExtensionValue(this->midExtensionId, midLen, mid);
}

/**
* The caller is responsible of not setting a length higher than the
* available one (taking into account existing padding bytes).
*/
bool RtpPacket::SetExtensionLength(uint8_t id, uint8_t len)
bool RtpPacket::SetExtensionValue(uint8_t id, uint8_t len, const std::string& value)
{
MS_TRACE();

Expand All @@ -600,12 +598,79 @@ namespace RTC

auto currentLen = extension->len + 1;

// Fill with 0's if new length is minor.
if (len < currentLen)
std::memset(extension->value + len, 0, currentLen - len);
MS_DEBUG_DEV("set extension id: %" PRIu8 ", length:%" PRIu8 ", currentLen:%" PRIu8 " value:%s", id, len, currentLen, value.c_str());
if(len != currentLen)
{
// need shift
uint8_t* extensionStart = reinterpret_cast<uint8_t*>(this->headerExtension) + 4;
uint8_t* extensionEnd = extensionStart + GetHeaderExtensionLength();
uint8_t* ptr = extensionStart;
size_t extensionsTotalSize = static_cast<size_t>(len+1);
uint8_t* ptr1 = ptr;
//clear current extension
std::memset((uint8_t*)extension, 0, currentLen+1);
while(ptr < extensionEnd && ptr1<extensionEnd)
{
if(ptr >= extensionEnd)
{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case are we entering this condition?

ptr1++;
*ptr1 = 0u;
continue;
}
const uint8_t tempId = (*ptr & 0xF0) >> 4;
const size_t tempLen = static_cast<size_t>(*ptr & 0x0F) + 1;

// id=15 in One-Byte extensions means "stop parsing here".
if (tempId == 15u)
break;
if(tempId != 0u && tempId != id)
{
if(ptr1< ptr)
{
std::memmove(ptr1, ptr, tempLen+1);
this->oneByteExtensions[tempId - 1] = reinterpret_cast<OneByteExtension*>(ptr1);
}
extensionsTotalSize += tempLen + 1;
// move forward templen+1
ptr += tempLen+1;
ptr1 += tempLen+1;
MS_DEBUG_DEV("tempId: %" PRIu8 " tempLen:%zd, offset:%ld", tempId, tempLen, ptr-extensionStart);
}
else
{
ptr++;
}
}
MS_DEBUG_DEV("extensionsTotalSize: %zd headerLength:%zd", extensionsTotalSize, GetHeaderExtensionLength());
auto paddedExtensionsTotalSize =
static_cast<size_t>(Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(extensionsTotalSize)));
extensionsTotalSize = paddedExtensionsTotalSize;
int16_t shift = static_cast<int16_t>(extensionsTotalSize - GetHeaderExtensionLength());
MS_DEBUG_DEV("shift:%d paddedExtensionsTotalSize: %zd", shift, paddedExtensionsTotalSize);
//move payload
std::memmove(this->payload + shift, this->payload, this->payloadLength + this->payloadPadding);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is payloadPadding needed here? it's only valid to point the padding for parsed packets, but has no relevance otherwise.

// clear the shift place, if shift > currentLen, may create some strange extension
if(shift>0)
{
std::memset(this->payload, 0, shift);
}
//begin to set current extension
ptr = extensionStart + extensionsTotalSize - (len + 1);
this->oneByteExtensions[id - 1] = reinterpret_cast<OneByteExtension*>(ptr);
extension = this->oneByteExtensions[id - 1];
*ptr = (id << 4) | ((len - 1) & 0x0F);
ptr++;
std::memcpy(ptr, value.c_str(), len);

this->payload += shift;
this->size += shift;
this->headerExtension->length = htons(extensionsTotalSize / 4);
}else {
// In One-Byte extensions value length 0 means 1.
extension->len = len - 1;
std::memcpy(extension->value, value.c_str(), len);
}

// In One-Byte extensions value length 0 means 1.
extension->len = len - 1;

return true;
}
Expand All @@ -618,13 +683,70 @@ namespace RTC

auto* extension = it->second;
auto currentLen = extension->len;

// Fill with 0's if new length is minor.
if (len < currentLen)
std::memset(extension->value + len, 0, currentLen - len);

extension->len = len;

if(len != currentLen)
{
// need shift
uint8_t* extensionStart = reinterpret_cast<uint8_t*>(this->headerExtension) + 4;
uint8_t* extensionEnd = extensionStart + GetHeaderExtensionLength();
uint8_t* ptr = extensionStart;
size_t extensionsTotalSize = static_cast<size_t>(len + 2);
uint8_t* ptr1 = ptr;
//clear current extension valueLen + 2 byteheader
std::memset((void *)extension, 0, currentLen + 2);
while(ptr + 1 < extensionEnd && ptr1 < extensionEnd)
{
if(ptr + 1 >= extensionEnd) {
ptr1++;
*ptr1 = 0u;
continue;
}
const uint8_t tempId = *ptr;
const uint8_t tempLen = *(ptr + 1);

if(tempId != 0u && tempId != id)
{
if(ptr1< ptr)
{
std::memmove(ptr1, ptr, tempLen+2);
this->mapTwoBytesExtensions[tempId] = reinterpret_cast<TwoBytesExtension*>(ptr1);
}
extensionsTotalSize += tempLen + 2;
// move forward len+1
ptr += tempLen+2;
ptr1 += tempLen+2;
} else
{
ptr++;
}
}
auto paddedExtensionsTotalSize =
static_cast<size_t>(Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(extensionsTotalSize)));
extensionsTotalSize = paddedExtensionsTotalSize;
int16_t shift = static_cast<int16_t>(extensionsTotalSize - GetHeaderExtensionLength());

//move payload
std::memmove(this->payload + shift, this->payload, this->payloadLength + this->payloadPadding);
if(shift>0)
{
std::memset(this->payload, 0, shift);
}
//begin to set current extension
ptr = extensionStart + extensionsTotalSize - (len + 2);
this->mapTwoBytesExtensions[id] = reinterpret_cast<TwoBytesExtension*>(ptr);
extension = this->mapTwoBytesExtensions[id];
*ptr = id;
ptr++;
*ptr = len;
ptr++;
std::memcpy(ptr, value.c_str(), len);

this->payload += shift;
this->size += shift;
this->headerExtension->length = htons(extensionsTotalSize / 4);
}else {
extension->len = len;
std::memcpy(extension->value, value.c_str(), len);
}
return true;
}
else
Expand Down
53 changes: 34 additions & 19 deletions worker/test/src/RTC/TestRtpPacket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,18 +632,21 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
REQUIRE((extenValue = packet->GetExtension(14, extenLen)));
REQUIRE(packet->HasExtension(14) == true);
REQUIRE(extenLen == 4);
REQUIRE(extenValue[0] == 0x01);
REQUIRE(extenValue[1] == 0x02);
REQUIRE(extenValue[2] == 0x03);
REQUIRE(extenValue[3] == 0x04);
REQUIRE(packet->SetExtensionLength(14, 3) == true);
REQUIRE(packet->SetExtensionValue(14, 4, "ABCD"));
REQUIRE((extenValue = packet->GetExtension(14, extenLen)));
REQUIRE(packet->HasExtension(14) == true);
REQUIRE(extenLen == 3);
REQUIRE(extenValue[0] == 0x01);
REQUIRE(extenValue[1] == 0x02);
REQUIRE(extenValue[2] == 0x03);
REQUIRE(extenValue[3] == 0x00);
REQUIRE(extenLen == 4);
REQUIRE(extenValue[0] == 'A');
REQUIRE(extenValue[1] == 'B');
REQUIRE(extenValue[2] == 'C');
REQUIRE(extenValue[3] == 'D');

// Test all kinds of move
const std::string str = "ABCDEFGHIJK";
for(int i=0; i< 10; i++) {
int tempLen = rand()%7 + 1;
REQUIRE(packet->SetExtensionValue(14, tempLen, str.substr(0, tempLen)) == true);
}

delete packet;
}
Expand Down Expand Up @@ -744,18 +747,23 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
REQUIRE((extenValue = packet->GetExtension(1, extenLen)));
REQUIRE(packet->HasExtension(1) == true);
REQUIRE(extenLen == 4);
REQUIRE(extenValue[0] == 0x01);
REQUIRE(extenValue[1] == 0x02);
REQUIRE(extenValue[2] == 0x03);
REQUIRE(extenValue[3] == 0x04);
REQUIRE(packet->SetExtensionLength(1, 2) == true);
REQUIRE(packet->SetExtensionValue(1, 4, "ABCD") == true);
REQUIRE((extenValue = packet->GetExtension(1, extenLen)));
REQUIRE(packet->HasExtension(1) == true);
REQUIRE(extenLen == 4);
REQUIRE(extenValue[0] == 'A');
REQUIRE(extenValue[1] == 'B');
REQUIRE(extenValue[2] == 'C');
REQUIRE(extenValue[3] == 'D');
REQUIRE(packet->SetExtensionValue(1, 2, "ABCD") == true);
REQUIRE((extenValue = packet->GetExtension(1, extenLen)));
REQUIRE(packet->HasExtension(1) == true);
REQUIRE(extenLen == 2);
REQUIRE(extenValue[0] == 0x01);
REQUIRE(extenValue[1] == 0x02);
REQUIRE(extenValue[2] == 0x00);
REQUIRE(extenValue[3] == 0x00);
REQUIRE(extenValue[0] == 'A');
REQUIRE(extenValue[1] == 'B');
// this may failed
REQUIRE(extenValue[2] != 'C');
REQUIRE(extenValue[3] != 'D');
REQUIRE(packet->GetExtension(22, extenLen));
REQUIRE(packet->HasExtension(22) == true);
REQUIRE(extenLen == 11);
Expand Down Expand Up @@ -790,6 +798,13 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
REQUIRE(packet->HasExtension(24) == true);
REQUIRE(extenLen == 4);

// Test all kinds of move
const std::string str = "ABCDEFGHIJK";
for(int i=0; i< 10; i++) {
int tempLen = rand()%7 + 1;
REQUIRE(packet->SetExtensionValue(24, tempLen, str.substr(0, tempLen)) == true);
}

delete packet;
}

Expand Down