From b695f44c188efc8df8e2e2c149904bb82d2dc58b Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 29 Mar 2024 02:03:23 +0100 Subject: [PATCH] Support old PTX compression scheme (#188) --- zluda_dark_api/Cargo.toml | 1 + zluda_dark_api/src/lib.rs | 65 +++++++++++++++++++++++---------------- zluda_dump/src/log.rs | 8 ++--- 3 files changed, 43 insertions(+), 31 deletions(-) diff --git a/zluda_dark_api/Cargo.toml b/zluda_dark_api/Cargo.toml index 0aef25ef..8266e36a 100644 --- a/zluda_dark_api/Cargo.toml +++ b/zluda_dark_api/Cargo.toml @@ -14,6 +14,7 @@ either = "1.9" bit-vec = "0.6.3" paste = "1.0" lz4-sys = "1.9" +cloudflare-zlib = "0.2.10" thread-id = "4.1.0" # we don't need elf32, but goblin has a bug where elf64 does not build without elf32 goblin = { version = "0.5.1", default-features = false, features = ["elf64", "elf32"] } diff --git a/zluda_dark_api/src/lib.rs b/zluda_dark_api/src/lib.rs index 6849e0e0..15c60911 100644 --- a/zluda_dark_api/src/lib.rs +++ b/zluda_dark_api/src/lib.rs @@ -687,13 +687,19 @@ pub enum FatbinModule { pub struct FatbinFile { data: *const u8, pub kind: FatbinFileKind, - pub compressed: bool, + pub compression: FatbinCompression, pub sm_version: u32, padded_payload_size: usize, payload_size: usize, uncompressed_payload: usize, } +pub enum FatbinCompression { + None, + Zlib, + Lz4, +} + impl FatbinFile { unsafe fn try_new(fatbin_file: &FatbinFileHeader) -> Result { let fatbin_file_version = fatbin_file.version; @@ -719,22 +725,19 @@ impl FatbinFile { }); } }; - if fatbin_file + let compression = if fatbin_file .flags .contains(FatbinFileHeaderFlags::CompressedOld) { - return Err(UnexpectedFieldError { - name: "FATBIN_FILE_HEADER_FLAGS", - expected: vec![ - AnyUInt::U64(FatbinFileHeaderFlags::empty().bits()), - AnyUInt::U64(FatbinFileHeaderFlags::CompressedNew.bits()), - ], - observed: AnyUInt::U64(fatbin_file.flags.bits()), - }); - } - let compressed = fatbin_file + FatbinCompression::Zlib + } else if fatbin_file .flags - .contains(FatbinFileHeaderFlags::CompressedNew); + .contains(FatbinFileHeaderFlags::CompressedNew) + { + FatbinCompression::Lz4 + } else { + FatbinCompression::None + }; let data = (fatbin_file as *const _ as *const u8).add(fatbin_file.header_size as usize); let padded_payload_size = fatbin_file.padded_payload_size as usize; let payload_size = fatbin_file.payload_size as usize; @@ -743,7 +746,7 @@ impl FatbinFile { Ok(Self { data, kind, - compressed, + compression, padded_payload_size, payload_size, uncompressed_payload, @@ -753,28 +756,36 @@ impl FatbinFile { // Returning static lifetime here because all known uses of this are related to fatbin files that // are constants inside files - pub unsafe fn get_or_decompress(&self) -> Result, Lz4DecompressionFailure> { - if self.compressed { - match self.decompress_kernel_module() { - Some(mut decompressed) => { - if self.kind == FatbinFileKind::Ptx { - decompressed.pop(); // remove trailing zero + pub unsafe fn get_or_decompress(&self) -> Result, DecompressionFailure> { + match self.compression { + FatbinCompression::Lz4 => { + match self.decompress_kernel_module_lz4() { + Some(mut decompressed) => { + if self.kind == FatbinFileKind::Ptx { + decompressed.pop(); // remove trailing zero + } + Ok(Cow::Owned(decompressed)) } - Ok(Cow::Owned(decompressed)) + None => Err(DecompressionFailure), } - None => Err(Lz4DecompressionFailure), } - } else { - Ok(Cow::Borrowed(slice::from_raw_parts( + FatbinCompression::Zlib => { + let compressed = + std::slice::from_raw_parts(self.data.cast(), self.padded_payload_size); + Ok(Cow::Owned( + cloudflare_zlib::inflate(compressed).map_err(|_| DecompressionFailure)?, + )) + } + FatbinCompression::None => Ok(Cow::Borrowed(slice::from_raw_parts( self.data, self.padded_payload_size as usize, - ))) + ))), } } const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024; - unsafe fn decompress_kernel_module(&self) -> Option> { + unsafe fn decompress_kernel_module_lz4(&self) -> Option> { let decompressed_size = usize::max(1024, self.uncompressed_payload as usize); let mut decompressed_vec = vec![0u8; decompressed_size]; loop { @@ -801,7 +812,7 @@ impl FatbinFile { } #[derive(Debug)] -pub struct Lz4DecompressionFailure; +pub struct DecompressionFailure; pub fn anti_zluda_hash AntiZludaHashInputDevice>( return_known_value: bool, diff --git a/zluda_dump/src/log.rs b/zluda_dump/src/log.rs index 2cfbda62..7777a61a 100644 --- a/zluda_dump/src/log.rs +++ b/zluda_dump/src/log.rs @@ -19,7 +19,7 @@ use std::path::PathBuf; use std::str::Utf8Error; use zluda_dark_api::AnyUInt; use zluda_dark_api::FatbinFileKind; -use zluda_dark_api::Lz4DecompressionFailure; +use zluda_dark_api::DecompressionFailure; use zluda_dark_api::UnexpectedFieldError; const LOG_PREFIX: &[u8] = b"[ZLUDA_DUMP] "; @@ -447,7 +447,7 @@ impl Display for LogEntry { file_name ) } - LogEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"), + LogEntry::Lz4DecompressionFailure => write!(f, "Decompression failure"), LogEntry::UnknownExportTableFn => write!(f, "Unknown export table function"), LogEntry::UnexpectedBinaryField { field_name, @@ -591,8 +591,8 @@ impl From for LogEntry { } } -impl From for LogEntry { - fn from(_err: Lz4DecompressionFailure) -> Self { +impl From for LogEntry { + fn from(_err: DecompressionFailure) -> Self { LogEntry::Lz4DecompressionFailure } }