Skip to content

Commit

Permalink
Support old PTX compression scheme (#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
vosen authored Mar 29, 2024
1 parent 7d4147c commit b695f44
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 31 deletions.
1 change: 1 addition & 0 deletions zluda_dark_api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ either = "1.9"
bit-vec = "0.6.3"
paste = "1.0"
lz4-sys = "1.9"
cloudflare-zlib = "0.2.10"
thread-id = "4.1.0"
# we don't need elf32, but goblin has a bug where elf64 does not build without elf32
goblin = { version = "0.5.1", default-features = false, features = ["elf64", "elf32"] }
65 changes: 38 additions & 27 deletions zluda_dark_api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,13 +687,19 @@ pub enum FatbinModule {
pub struct FatbinFile {
data: *const u8,
pub kind: FatbinFileKind,
pub compressed: bool,
pub compression: FatbinCompression,
pub sm_version: u32,
padded_payload_size: usize,
payload_size: usize,
uncompressed_payload: usize,
}

pub enum FatbinCompression {
None,
Zlib,
Lz4,
}

impl FatbinFile {
unsafe fn try_new(fatbin_file: &FatbinFileHeader) -> Result<Self, UnexpectedFieldError> {
let fatbin_file_version = fatbin_file.version;
Expand All @@ -719,22 +725,19 @@ impl FatbinFile {
});
}
};
if fatbin_file
let compression = if fatbin_file
.flags
.contains(FatbinFileHeaderFlags::CompressedOld)
{
return Err(UnexpectedFieldError {
name: "FATBIN_FILE_HEADER_FLAGS",
expected: vec![
AnyUInt::U64(FatbinFileHeaderFlags::empty().bits()),
AnyUInt::U64(FatbinFileHeaderFlags::CompressedNew.bits()),
],
observed: AnyUInt::U64(fatbin_file.flags.bits()),
});
}
let compressed = fatbin_file
FatbinCompression::Zlib
} else if fatbin_file
.flags
.contains(FatbinFileHeaderFlags::CompressedNew);
.contains(FatbinFileHeaderFlags::CompressedNew)
{
FatbinCompression::Lz4
} else {
FatbinCompression::None
};
let data = (fatbin_file as *const _ as *const u8).add(fatbin_file.header_size as usize);
let padded_payload_size = fatbin_file.padded_payload_size as usize;
let payload_size = fatbin_file.payload_size as usize;
Expand All @@ -743,7 +746,7 @@ impl FatbinFile {
Ok(Self {
data,
kind,
compressed,
compression,
padded_payload_size,
payload_size,
uncompressed_payload,
Expand All @@ -753,28 +756,36 @@ impl FatbinFile {

// Returning static lifetime here because all known uses of this are related to fatbin files that
// are constants inside files
pub unsafe fn get_or_decompress(&self) -> Result<Cow<'static, [u8]>, Lz4DecompressionFailure> {
if self.compressed {
match self.decompress_kernel_module() {
Some(mut decompressed) => {
if self.kind == FatbinFileKind::Ptx {
decompressed.pop(); // remove trailing zero
pub unsafe fn get_or_decompress(&self) -> Result<Cow<'static, [u8]>, DecompressionFailure> {
match self.compression {
FatbinCompression::Lz4 => {
match self.decompress_kernel_module_lz4() {
Some(mut decompressed) => {
if self.kind == FatbinFileKind::Ptx {
decompressed.pop(); // remove trailing zero
}
Ok(Cow::Owned(decompressed))
}
Ok(Cow::Owned(decompressed))
None => Err(DecompressionFailure),
}
None => Err(Lz4DecompressionFailure),
}
} else {
Ok(Cow::Borrowed(slice::from_raw_parts(
FatbinCompression::Zlib => {
let compressed =
std::slice::from_raw_parts(self.data.cast(), self.padded_payload_size);
Ok(Cow::Owned(
cloudflare_zlib::inflate(compressed).map_err(|_| DecompressionFailure)?,
))
}
FatbinCompression::None => Ok(Cow::Borrowed(slice::from_raw_parts(
self.data,
self.padded_payload_size as usize,
)))
))),
}
}

const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;

unsafe fn decompress_kernel_module(&self) -> Option<Vec<u8>> {
unsafe fn decompress_kernel_module_lz4(&self) -> Option<Vec<u8>> {
let decompressed_size = usize::max(1024, self.uncompressed_payload as usize);
let mut decompressed_vec = vec![0u8; decompressed_size];
loop {
Expand All @@ -801,7 +812,7 @@ impl FatbinFile {
}

#[derive(Debug)]
pub struct Lz4DecompressionFailure;
pub struct DecompressionFailure;

pub fn anti_zluda_hash<F: FnMut(u32) -> AntiZludaHashInputDevice>(
return_known_value: bool,
Expand Down
8 changes: 4 additions & 4 deletions zluda_dump/src/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::path::PathBuf;
use std::str::Utf8Error;
use zluda_dark_api::AnyUInt;
use zluda_dark_api::FatbinFileKind;
use zluda_dark_api::Lz4DecompressionFailure;
use zluda_dark_api::DecompressionFailure;
use zluda_dark_api::UnexpectedFieldError;

const LOG_PREFIX: &[u8] = b"[ZLUDA_DUMP] ";
Expand Down Expand Up @@ -447,7 +447,7 @@ impl Display for LogEntry {
file_name
)
}
LogEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"),
LogEntry::Lz4DecompressionFailure => write!(f, "Decompression failure"),
LogEntry::UnknownExportTableFn => write!(f, "Unknown export table function"),
LogEntry::UnexpectedBinaryField {
field_name,
Expand Down Expand Up @@ -591,8 +591,8 @@ impl From<io::Error> for LogEntry {
}
}

impl From<Lz4DecompressionFailure> for LogEntry {
fn from(_err: Lz4DecompressionFailure) -> Self {
impl From<DecompressionFailure> for LogEntry {
fn from(_err: DecompressionFailure) -> Self {
LogEntry::Lz4DecompressionFailure
}
}
Expand Down

0 comments on commit b695f44

Please sign in to comment.