diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f9244ed --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +*__pycache__* +*txt +*cache* +*.DS_Store* +tokenized_datasets/* +exps/* +probs/* +.ipynb_checkpoints/ +tiny-imagenet-200/ \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..01ecbb6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Hui-Po Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..a9ea6cb --- /dev/null +++ b/README.md @@ -0,0 +1,136 @@ + +
+
+ +
+ +Logo + +

+

+The official implementation of "Language Models as Zero-shot Lossless Gradient Compressors: Towards General Neural Parameter Prior Models" publised at NeurIPS 2024. +
+
+[Preprint] +

+
+ +## Overview + +![Product Screenshot](images/teaser.png) + +This project provides the source code of LM-GC, the first LLM-powered gradient compressor. + +Here are take-aways: + +- We demonstrate that large language models (LLMs) hold significant potential as prior models for gradients, a concept that has been widely applied in other modalities but gradients. +- We introduce an novel serialization method that converts IEEE 754 floating points into hexadecimal format, enabling LLMs to comprehend and achieve state-of-the-art lossless gradient compression. +- Our LLM-based prior model could unlock new applications for gradients, similar to those in other modalities, such as super-resolution, denoising, generation, and more. + +
+ +*If you find the project interesting, don't forget to star and cite our work:* + +```bibtex +@article{wang2024language, + title={Language Models as Zero-shot Lossless Gradient Compressors: Towards General Neural Parameter Prior Models}, + author={Wang, Hui-Po and Fritz, Mario}, + journal={Advances in Neural Information Processing Systems}, + year={2024} +} +``` +## Getting Started +### Prerequisites + +- torch ≥ 2.12.0 +- transformers ≥ 4.40.1 +- [torchac](https://github.com/fab-jul/torchac) +- [flash attention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2) ≥ 2.5.8 via ```pip install flash-attn --no-build-isolation``` for NVIDIA GPUs + +or + +- install via ```pip``` + ```sh + pip install -r requirements.txt + ``` +**After set up the huggingface access token, ideally, the codebase will download language models automatically via HuggingFace except for LLAMA2. See [More LLMs](#more-llms) for more information.** + +### Quickstart +We provide a quick demo here. Please refer to [Usage](#usage) for the detailed usage. +```bash +cd scripts +# compress gradients of a ConvNet trained on TinyImageNet using TinyLLAMA +bash pipeline.sh +``` +## Usage +It takes three steps to reproduce the experiments in the paper, including (1) train neural networks to collect gradients; (2) serialize and tokenize raw gradients; (3) run LLMs and arithmetic (LM-GC). + +### 1. Gradient collection +This step trains a network (e.g. a ConvNet on TinyImageNet in the following example) and collect gradients for compression later. See ```scripts/run_collect.sh``` for more details. +```bash +DATASET='tinyimagenet' # cifar10 # mnist +ARCH="convnet" # vgg16 # resnet18 # vit +for i in 0 1 2 +do + python -u train_and_collect_grad.py -cfg settings/gradient_collection/$DATASET-$ARCH.yaml --tag $i --grad-interval 400 --download +done +``` +### 2. Serialization and tokenization +For convenience, we process the data before conducting arithmetic encoding. The data is serialized and tokenized here. We create three preprocessed datasets here. See ```scripts/serialization.sh``` for more details. +```bash +NUM_SUBSAMPLE=10 +DATASET='tinyimagenet' # cifar10 # mnist +ARCH="convnet" # vgg16 # resnet18 # vit +TYPE="grad" +COMPRESSOR="tinyllama" # llama2-7b # openllama3b +SEP="hex-none" # hex-space # hex-comma+space # iso # hex-semicolon +BPG=4 # 8 +for i in 1 2 3 +do + python -u tokenize_dataset.py --cfg settings/compression/cifar10-$SEP.yaml \ + --data-path exps/$DATASET-$ARCH/0/grads/ --bytes-per-group $BPG \ + --compressor $COMPRESSOR --exhaustive-listing --num-subsample $NUM_SUBSAMPLE \ + --output-name $ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i +done +``` +### 3. Run compression +The processed data from the previous step is now divided into several disjoint windows. By default, LLMs see a set of 2048 (including 1 BOS token) tokens every time. The experimented are repeated three times. See ```scripts/compress.sh``` for more details. +```bash +HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +NUM_SUBSAMPLE=10 +DATASET='tinyimagenet' # cifar10 # mnist +ARCH="convnet" # vgg16 # resnet18 # vit +TYPE="grad" +COMPRESSOR="tinyllama" # llama2-7b # openllama3b +SEP="hex-none" # hex-space # hex-comma+space # iso # hex-semicolon +BATCHSIZE=4 # depending on your GPUs +BPG=4 # 8 +for i in 1 2 3 +do + python -u compress.py -cfg settings/compression/cifar10-$SEP.yaml --compressor $COMPRESSOR --dataset tokenized_dataset \ + --data-path ./tokenized_datasets/$ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i.pkl --batch-size $BATCHSIZE +done +``` + +## Options + +### More LLMs + +### More models to compress + +### Ablation study +- Bytes per group +- Context window size + +## TO-DO +- [x] prepare `pipeline.sh` +- [x] sanity check +- [ ] how to add more LLMs +- [ ] provide a runnable encode/decode example +- [ ] Baseline codec +## License + +Distributed under the MIT License. See [MIT License](https://opensource.org/licenses/MIT) for more information. + +## Acknowledgments +This project is partially built up on [Deepmind's work](), and the readme file template comes from [makeread.me](https://github.com/ShaanCoding/ReadME-Generator). diff --git a/arithmetic_coder.py b/arithmetic_coder.py new file mode 100644 index 0000000..b1eccd8 --- /dev/null +++ b/arithmetic_coder.py @@ -0,0 +1,324 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implements an Arithmetic Encoder and Decoder.""" + +from typing import Any, Callable + +import chex +import numpy as np + +InputFn = Callable[[], int] +OutputFn = Callable[[int], None] +IOFn = InputFn | OutputFn + + +def _log_power_of_b(n: int, base: int) -> int: + """Returns k assuming n = base ** k. + + We manually implement this function to be faster than a np.log or math.log, + which doesn't assume n is an integer. + + Args: + n: The integer of which we want the log. + base: The base of the log. + """ + log_n = 0 + while n > 1: + n //= base + log_n += 1 + return log_n + + +def _raise_post_terminate_exception(*args: Any, **kwargs: Any) -> None: + """Dummy function that raises an error to ensure AC termination.""" + del args, kwargs + raise ValueError( + "Arithmetic encoder was terminated. " + "Create a new instance for encoding more data. " + "Do NOT use an output function that writes to the same data sink " + "used by the output function of this instance. " + "This will corrupt the arithmetic code as decoding relies on detecting " + "when the compressed data stream is exhausted." + ) + + +class _CoderBase: + """Arithmetic coder (AC) base class.""" + + def __init__(self, base: int, precision: int, io_fn: IOFn): + """Does initialization shared by AC encoder and decoder. + + Args: + base: The arithmetic coder will output digits in {0, 1, ..., base - 1}. + precision: Precision for internal state; on the average this will waste + code space worth at most 1/log(base) * base ** -(precision - 2) digits + of output per coding step. + io_fn: Function to write digits to compressed stream/read digits from + compressed stream. + """ + chex.assert_scalar_in(base, 2, np.inf) + chex.assert_scalar_in(precision, 2, np.inf) + + self._base: int = base + self._base_to_pm1: int = int(base ** (precision - 1)) + self._base_to_pm2: int = int(base ** (precision - 2)) + self._io_fn = io_fn + + # NOTE: We represent the AC interval [0, 1) as rational numbers: + # [0, 1) + # ~ [self._low / base ** precision, (self._high + 1) / base ** precision) + # = [self._low / base ** precision, self._high / base ** precision], + # where the we represent the upper bound *INCLUSIVE*. This is a subtle + # detail required to make the integer arithmetic work correctly given that + # all involved integers have `precision` digits in base `base`. + self._low: int = 0 + self._high: int = int(base**precision) - 1 + self._num_carry_digits: int = 0 + self._code: int = 0 + + def __str__(self) -> str: + """Returns string describing internal state.""" + if self._base > 16: + raise ValueError("`__str__` with `base` exceeding 16 not implmeneted.") + + p = 1 + _log_power_of_b(self._base_to_pm1, base=self._base) + + def _to_str(x: int) -> str: + """Returns representation of `n` in base `self._base`.""" + digits = [(x // self._base**i) % self._base for i in range(p)] + return f"{digits[-1]:x}" + "".join( + f"{d:x}" for d in digits[-2::-1] + ) + + return ( + f"[{_to_str(self._low)}, {_to_str(self._high)}) {_to_str(self._code)}" + ) + + def _get_intervals(self, pdf: np.ndarray) -> np.ndarray: + """Partition the current AC interval according to the distribution `pdf`.""" + if (pdf < 0).any(): + raise ValueError( + "Some probabilities are negative. Please make sure that pdf[x] > 0." + ) + # Compute CPDF s.t. cpdf[x] = sum_y= max(base ** -(precision - 2), np.dtype(x).eps) for any" + " symbol by either preprocessing `pdf` or by increasing `precision`." + ) + if qcpdf[-1] > width: + raise ValueError( + "Cumulative sum of probabilities exceeds 1 after quantisation. " + "Please make sure that sum(pdf) <= 1.0 - eps, for a small eps > 0." + ) + return self._low + qcpdf + + def _remove_matching_digits(self, low_pre_split: int, encoding: bool) -> None: + """Remove matching most significant digits from AC state [low, high). + + This is the *FIRST* normalization step after encoding a symbol into the AC + state. + + When encoding we write the most significant matching digits of the + integer representation of [low, high) to the output, widen the integer + representation of [low, high) including a (potential) queue of carry digits; + when decoding we drop the matching most significant digits of the integer + representation of [low, high), widen this interval and keep the current + slice of the arithmetic code word `self._code` in sync. + + Args: + low_pre_split: Value of `self._low` before encoding a new symbol into the + AC state when `encoding` is True; abitrary, otherwise. + encoding: Are we encoding (i.e. normalise by writing data) or decoding + (i.e. normalise by reading data)? + """ + + def _shift_left(x: int) -> int: + """Shift `x` one digit left.""" + return (x % self._base_to_pm1) * self._base + + while self._low // self._base_to_pm1 == self._high // self._base_to_pm1: + if encoding: + low_msd = self._low // self._base_to_pm1 + self._io_fn(low_msd) + # Note that carry digits will only be written in the first round of this + # loop. + carry_digit = ( + self._base - 1 + low_msd - low_pre_split // self._base_to_pm1 + ) % self._base + assert carry_digit in {0, self._base - 1} or self._num_carry_digits == 0 + while self._num_carry_digits > 0: + self._io_fn(carry_digit) + self._num_carry_digits -= 1 + else: + self._code = _shift_left(self._code) + self._io_fn() + self._low = _shift_left(self._low) + self._high = _shift_left(self._high) + self._base - 1 + + def _remove_carry_digits(self, encoding: bool) -> None: + """Remove and record 2nd most significant digits from AC state [low, high). + + This is the *SECOND* normalization step after encoding a symbol into the AC + state [low, high). + + If the AC state takes the form + low = x B-1 B-1 ... B-1 u ... + high = x+1 0 0 0 v ... + ^__ prefix __^ + where x, u and v are base-B digits then low and high can get arbitrarily ( + well, by means of infinite precision arithmetics) without matching. Since we + work with finite precision arithmetics, we must make sure that this doesn't + occour and we guarantee sufficient of coding range (`high - low`). To end + this we detect the above situation and cut off the highlighted prefix above + to widen the integer representation of [low, high) and record the number of + prefix digits removed. When decoding we must similarly process the current + slice of the arithmetic code word `self._code` to keep it in sync. + + Args: + encoding: Are we encoding (i.e. normalise by writing data) or decoding + (i.e. normalise by reading data)? + """ + + def _shift_left_keeping_msd(x: int) -> int: + """Shift `x` except MSD, which remains in place, one digit left.""" + return x - (x % self._base_to_pm1) + (x % self._base_to_pm2) * self._base + + while self._low // self._base_to_pm2 + 1 == self._high // self._base_to_pm2: + if encoding: + self._num_carry_digits += 1 + else: + self._code = _shift_left_keeping_msd(self._code) + self._io_fn() + self._low = _shift_left_keeping_msd(self._low) + self._high = _shift_left_keeping_msd(self._high) + self._base - 1 + + def _process(self, pdf: np.ndarray, symbol: int | None) -> int: + """Perform an AC encoding or decoding step and modify AC state in-place. + + Args: + pdf: Probability distribution over input alphabet. + symbol: Letter to encode from {0, 1, ..., pdf.size - 1} when encoding or + `None` when decoding. + + Returns: + y: `symbol` from above when encoding or decoded letter from {0, 1, ..., + pdf.size - 1}. + """ + + encoding = symbol is not None + intervals = self._get_intervals(pdf) + if not encoding: + symbol = np.searchsorted(intervals, self._code, side="right") - 1 + assert 0 <= symbol < pdf.size + low_pre_split = self._low + self._low, self._high = intervals[[symbol, symbol + 1]] + # Due to integer arithmetics the integer representation of [low, high) has + # an inclusive upper bound, so decrease high. + self._high -= 1 + assert 0 <= self._low <= self._high < self._base_to_pm1 * self._base + + # Normalize the AC state. + self._remove_matching_digits(low_pre_split=low_pre_split, encoding=encoding) + assert 0 <= self._low <= self._high < self._base_to_pm1 * self._base + assert encoding or self._low <= self._code <= self._high + assert self._low // self._base_to_pm1 != self._high // self._base_to_pm1 + + self._remove_carry_digits(encoding=encoding) + assert 0 <= self._low <= self._high < self._base_to_pm1 * self._base + assert encoding or self._low <= self._code <= self._high + assert self._high - self._low > self._base_to_pm2 + + return symbol + + @classmethod + def p_min(cls, base: int, precision: int) -> float: + """Get minimum probability supported by AC config.""" + # The leading factor 2 is supposed to account for rounding errors and + # wouldn't be necessary given infinite float precision. + return 2.0 * base ** -(precision - 2) + + +class Encoder(_CoderBase): + """Arithmetic encoder.""" + + def __init__(self, base: int, precision: int, output_fn: OutputFn): + """Constructs arithmetic encoder. + + Args: + base: The arithmetic coder will output digits in {0, 1, ..., base - 1}. + precision: Precision for internal state; on the average this will waste + code space worth at most 1/log(base) * base ** -(precision - 2) digits + of output per coding step. + output_fn: Function that writes a digit from {0, 1, ..., base - 1} to the + compressed output. + """ + super().__init__(base, precision, output_fn) + + def encode(self, pdf: np.ndarray, symbol: int) -> None: + """Encodes symbol `symbol` assuming coding distribution `pdf`.""" + self._process(pdf, symbol) + + def terminate(self) -> None: + """Finalizes arithmetic code.""" + # Write outstanding part of the arithmetic code plus one digit to uniquely + # determine a code within the interval of the final symbol coded. + self._io_fn(self._low // self._base_to_pm1) + for _ in range(self._num_carry_digits): + self._io_fn(self._base - 1) + self.encode = _raise_post_terminate_exception + self.terminate = _raise_post_terminate_exception + + +class Decoder(_CoderBase): + """Arithmetic decoder.""" + + def __init__(self, base: int, precision: int, input_fn: InputFn): + """Constructs arithmetic decoder. + + Args: + base: The arithmetic coder will output digits in {0, 1, ..., base - 1}. + precision: Precision for internal state; on the average this will waste + code space worth at most 1/log(base) * base ** -(precision - 2) digits + of output per coding step. + input_fn: Function that reads a digit from {0, 1, ..., base - 1} from the + compressed input or returns `None` when the input is exhausted. + """ + # Add padding to ensure the AC state is well-defined when decoding the last + # symbol. Note that what exactly we do here depends on how encoder + # termination is implemented (see `Encoder.terminate`). + trailing_digits = (base - 1 for _ in range(precision - 1)) + + def _padded_input_fn() -> int: + """Reads digit from input padding the arithmetic code.""" + digit = input_fn() + if digit is None: + digit = next(trailing_digits) + chex.assert_scalar_in(int(digit), 0, base - 1) + return digit + + super().__init__(base, precision, _padded_input_fn) + for _ in range(precision): + self._code = self._code * base + _padded_input_fn() + + def decode(self, pdf: np.ndarray) -> int: + return self._process(pdf, None) \ No newline at end of file diff --git a/compress.py b/compress.py new file mode 100644 index 0000000..1a343c9 --- /dev/null +++ b/compress.py @@ -0,0 +1,270 @@ +# Copyright 2024 CISPA Helmholtz Center for Information Security Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# import sys +# sys.path.append('../gradient-compressors') +import argparse, os +os.environ['HF_HOME'] = './cache' + +import functools +import time +import logging, yaml + +from collections.abc import Generator +from typing import Callable + +import numpy as np + +import torch + +import tqdm +import constants + +from utils import data_loaders +from utils import deepmind_utils +from utils.utils import BaseParser +from compressors import compressor + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('-cfg', '--cfg', default=None, type=str, required=True) + + parser.add_argument('-data-path', '--data-path', default='./', required=True, type=str) + parser.add_argument('-exhaustive-listing', '--exhaustive-listing', action='store_true', help='If read all of the checkpoints in the data path.') + parser.add_argument('-num-subsample', '--num-subsample', type=int, default=1) + + parser.add_argument('-download', '--download', action='store_true') + + parser.add_argument('-use_mask', '--use_mask', action='store_true', help='Applying mask functions, particularly for decoding images into ASCII.') + parser.add_argument('-use_slow_compression', '--use_slow_compression', action='store_true') + + parser.add_argument('--dataset', '-dataset', default='tokenized_dataset', type=str, help='Indicatge what kind of data to compress.') + parser.add_argument('--compressor', '-compressor', default='gpt2', type=str, help='What kind of compressor to use.') + + parser.add_argument('--bytes-per-group', '-bytes-per-group', default=None, type=int, help='Specify after how many bytes a separator will be added.') + parser.add_argument('-batch-size', '--batch-size', type=int, default=32) + parser.add_argument('-max-tokens', '--max-tokens', type=int, default=2048) + parser.add_argument('-preprocess', '--preprocess', type=str, default=None) + + args = parser.parse_args() + with open(args.cfg, 'r') as stream: + settings = yaml.safe_load(stream) + + args = BaseParser(args, settings) + + # handle bytes_per_group + if args.bytes_per_group is None: + args.bytes_per_group = 4 if args.dataset == 'tokenized_data' or args.dataset == 'gradient' else 1 + + if args.exhaustive_listing: + paths = np.array([os.path.join(args.data_path, f) for f in os.listdir(args.data_path) if os.path.isfile(os.path.join(args.data_path, f))]) + args.sample_index = np.random.choice(len(paths), args.num_subsample, replace=False) + print(args.sample_index) + args.data_path = paths[args.sample_index] + + print(args) + print('chunck size: ', constants.CHUNK_SIZE_BYTES) + print('num_chuncks: ',constants.NUM_CHUNKS) + return args + +def evaluate_compressor_chunked( + args: BaseParser, + compress_fn: compressor.Compressor, + get_data_generator_fn: Callable[[], Generator[bytes, None, None]], + num_chunks: int, + count_header_only_once: bool = True, + mask_fn: Callable[[bytes], tuple[bytes, int]] | None = None, + use_tqdm: bool = True, +) -> tuple[float, float]: + """Evaluates the compressor on the chunked dataset. + + Args: + compress_fn: The function that evaluates data. + get_data_generator_fn: The function that creates a data generator. + num_chunks: The number of chunks to consider + count_header_only_once: Whether to count the header as part of the + compressed output only once for the whole dataset or for every chunk + individually. + mask_fn: The function that masks the data in case the compressor cannot + handle all possible byte values (e.g., language models can only process + ASCII-decodable data). + use_tqdm: Whether to use a progress bar or not. + + Returns: + The compression rate and the total running time. + """ + num_missed_bits = running_time = raw_length = compressed_length = num_samples = 0 + + raw_length, data_generator = get_data_generator_fn() + print(f'Data to compress has size {raw_length} bytes.') + + if args.dataset == 'tokenized_dataset': + for num_bytes, data in tqdm.tqdm(data_generator): + num_samples += len(data) + if mask_fn is not None: + d, missed_bits = mask_fn(data) + num_missed_bits += missed_bits + + t0 = time.perf_counter() + compressed_data = compress_fn(data, use_slow_lossless_compression=args.use_slow_compression) + t1 = time.perf_counter() + + running_time += t1 - t0 + compressed_length += len(compressed_data) + else: + for num_bytes, data in tqdm.tqdm(data_generator): + num_samples += len(data) + for d_size, d in zip(num_bytes, data): + if mask_fn is not None: + d, missed_bits = mask_fn(d) + num_missed_bits += missed_bits + + if isinstance(d, torch.Tensor): + d = d.numpy() + + t0 = time.perf_counter() + compressed_data = compress_fn(d) + t1 = time.perf_counter() + + running_time += t1 - t0 + compressed_length += len(compressed_data) + + # raw_length += constants.CHUNK_SIZE_BYTES * num_chunks + + # Since language models are trained on ASCII strings, they cannot handle all + # byte values. Thus, we mask the data to be ASCII-decodable by zeroing + # `num_missed_bits` of the most significant bits. However, this means that we + # are effectively only compressing `num_bits - num_missed_bits` bits, so we + # rescale the `compressed_length` to account for this. + if mask_fn is not None: + num_bits = 8 * num_samples * constants.CHUNK_SIZE_BYTES + compressed_length *= num_bits / (num_bits - num_missed_bits) + + # We only count the header once for classical compressors. + # if count_header_only_once: + # header_length = len(compress_fn((0).to_bytes(1, 'little'))) + # compressed_length -= header_length * (num_samples - 1) + + return compressed_length / raw_length, running_time + + +def evaluate_compressor_unchunked( + compress_fn: compressor.Compressor, + get_data_generator_fn: Callable[[], Generator[bytes, None, None]], + num_chunks: int, +) -> tuple[float, float]: + """Evaluates the compressor on the unchunked dataset. + + Args: + compress_fn: The function that compresses data. + get_data_generator_fn: The function that creates a data generator. + num_chunks: The number of chunks to consider. + + Returns: + The compression rate and the total running time. + """ + all_data = None + + raw_length, data_generator = get_data_generator_fn() + print(f'Data to compress has size {raw_length} bytes.') + + with tqdm.tqdm(total=constants.NUM_CHUNKS) as pbar: + for num_bytes, data in data_generator: + for d in data: + if all_data is None: + all_data = bytearray() if isinstance(d, bytes) else [] + + if isinstance(d, bytes): + all_data += d + else: + d = d.numpy() + all_data.append(d) + + pbar.update(1) + + all_data = bytes(all_data) if isinstance(all_data[0], bytes) else np.concatenate(all_data) + + t0 = time.perf_counter() + compressed_data = compress_fn(all_data) + t1 = time.perf_counter() + + return len(compressed_data) / raw_length, t1 - t0 + + +def main(args) -> None: + print('start') + compress_fn = compressor.COMPRESS_FN_DICT[args.compressor] + + if args.compressor in compressor.COMPRESSOR_TYPES['classical']: + get_data_generator_fn = functools.partial( + data_loaders.GET_DATA_GENERATOR_FN_DICT[args.dataset], + args=args, + serialization=False, + preprocess=args.preprocess, + return_fp=False if args.compressor != 'fpzip' else True + ) + unchunked_rate, unchunked_time = evaluate_compressor_unchunked( + compress_fn=compress_fn, + get_data_generator_fn=get_data_generator_fn, + num_chunks=constants.NUM_CHUNKS, + ) + chunked_rate, chunked_time = evaluate_compressor_chunked( + args=args, + compress_fn=compress_fn, + get_data_generator_fn=get_data_generator_fn, + num_chunks=constants.NUM_CHUNKS, + count_header_only_once=True, + mask_fn=None, + ) + print( + f'Unchunked: {100 * unchunked_rate:.2f} [{unchunked_time:.1f}s]' + ) + print(f'Chunked: {100 * chunked_rate:.2f} [{chunked_time:.1f}s]') + + elif args.compressor in compressor.COMPRESSOR_TYPES['arithmetic_coding']: + get_data_generator_fn = functools.partial( + data_loaders.GET_DATA_GENERATOR_FN_DICT[args.dataset], + args=args, + serialization=True, + max_tokens=args.max_tokens, + preprocess=args.preprocess, + ) + model = compress_fn(args) + if args.use_mask: + # To compress bytes data, we convert it first to ASCII. + if args.dataset == 'enwik9': + # For Enwik9, some characters are UTF-8 but not ASCII, so we still need + # to do the conversion. + mask_fn = deepmind_utils.zero_most_significant_bit_if_not_ascii_decodable + else: + mask_fn = deepmind_utils.right_shift_bytes_by_one + else: + mask_fn = None + + chunked_rate, chunked_time = evaluate_compressor_chunked( + args = args, + compress_fn=model.compress, + get_data_generator_fn=get_data_generator_fn, + num_chunks=constants.NUM_CHUNKS, + count_header_only_once=False, + mask_fn=mask_fn, + ) + print(f'Chunked: {100 * chunked_rate:.2f} [{chunked_time:.1f}s]') + else: + raise NotImplementedError(f'Unknown codec {args.compressor}.') + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/compressors/compressor.py b/compressors/compressor.py new file mode 100644 index 0000000..f899c0d --- /dev/null +++ b/compressors/compressor.py @@ -0,0 +1,54 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Defines the compressor interface.""" +import os + +import functools +import gzip +import lzma +import fpzip +from typing import Mapping, Protocol + +from compressors import flac +from compressors import language_model +from compressors import png + + +class Compressor(Protocol): + + def __call__(self, data: bytes, *args, **kwargs) -> bytes | tuple[bytes, int]: + """Returns the compressed version of `data`, with optional padded bits.""" + + +COMPRESSOR_TYPES = { + 'classical': ['flac', 'gzip', 'lzma', 'png', 'fpzip'], + 'arithmetic_coding': ['gpt2', 'openllama3b', 'tinyllama3b', 'tinyllama', 'mistral7b', 'llama2-7b', 'llama3-8b'], +} + +COMPRESS_FN_DICT: Mapping[str, Compressor] = { + 'flac': flac.compress, + 'gzip': functools.partial(gzip.compress, compresslevel=9), + 'gpt2': language_model.LanguageModelCompressor, + 'openllama3b': language_model.LanguageModelCompressor, + 'tinyllama3b': language_model.LanguageModelCompressor, + 'tinyllama': language_model.LanguageModelCompressor, + 'mistral7b': language_model.LanguageModelCompressor, + 'llama2-7b': language_model.LanguageModelCompressor, + 'llama3-8b': language_model.LanguageModelCompressor, + 'lzma': lzma.compress, + 'png': png.compress, + 'fpzip': functools.partial(fpzip.compress, precision=0, order='C'), +} \ No newline at end of file diff --git a/compressors/flac.py b/compressors/flac.py new file mode 100644 index 0000000..e79d2dd --- /dev/null +++ b/compressors/flac.py @@ -0,0 +1,56 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implements a lossless compressor with FLAC.""" + +import audioop +import io + +import pydub + +def compress(data: bytes) -> bytes: + """Returns data compressed with the FLAC codec. + + Args: + data: Assumes 1 byte per sample (`sample_width`), meaning 256 possible + values, and 1 channel and a `frame_rate` of 16kHz. + """ + sample = pydub.AudioSegment( + data=data, + channels=1, + sample_width=1, + frame_rate=16000, + ) + return sample.export( + format='flac', + parameters=['-compression_level', '12'], + ).read() + + +def decompress(data: bytes) -> bytes: + """Decompresses `data` losslessly using the FLAC codec. + + Args: + data: The data to be decompressed. Assumes 2 bytes per sample (16 bit). + + Returns: + The decompressed data. Assumes 1 byte per sample (8 bit). + """ + sample = pydub.AudioSegment.from_file(io.BytesIO(data), format='flac') + # FLAC assumes that data is 16 bit. However, since our original data is 8 bit, + # we need to convert the samples from 16 bit to 8 bit (i.e., changing from two + # channels to one channel with `lin2lin`) and add 128 since 16 bit is signed + # (i.e., adding 128 using `bias`). + return audioop.bias(audioop.lin2lin(sample.raw_data, 2, 1), 1, 128) \ No newline at end of file diff --git a/compressors/language_model.py b/compressors/language_model.py new file mode 100644 index 0000000..b3db48c --- /dev/null +++ b/compressors/language_model.py @@ -0,0 +1,213 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implements a lossless compressor with language models (arithmetic coding).""" + +from collections.abc import Iterator +import functools, time +from typing import Callable + +import torch +torch.backends.cuda.matmul.allow_tf32 = True +import torchac +from transformers import AutoTokenizer, AutoModelForCausalLM +import numpy as np + +import arithmetic_coder +import constants +from utils.utils import BaseParser +from utils import deepmind_utils + +from multiprocessing import Pool + +MODEL_NAME_DICT = { + "gpt2": "gpt2", + "openllama3b": "openlm-research/open_llama_3b_v2", + "tinyllama3b": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "tinyllama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "mistral7b": "mistralai/Mistral-7B-v0.1", + "llama2-7b": "meta-llama/Llama-2-7b-hf", + "llama3-8b": "meta-llama/Meta-Llama-3-8B", +} +SKIP_LIST = ["openllama3b", "tinyllama3b", "tinyllama", "llama2-7b", "llama3-8b"] + +def work(args): + pid, cum_prob, input_ids = args + byte_stream = torchac.encode_float_cdf(cum_prob, input_ids, check_input_bounds=False) + return len(byte_stream) + +class LanguageModelCompressor(): + def __init__(self, args: BaseParser): + self.args = args + self.device = "cuda" if torch.cuda.is_available() else "cpu" + print(self.device) + + access_token = None + assert access_token, 'Please enter the huggingface access token here if you want to use models like LLAMA.' + + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_DICT[args.compressor], token=access_token) + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME_DICT[args.compressor], + torch_dtype=torch.bfloat16, + # device_map='auto', + attn_implementation="flash_attention_2", + token=access_token, + ) + self.model.config.pad_token_id = self.model.config.eos_token_id + self.model.to(self.device) + self.model.eval() + + @torch.no_grad() + def infer_prob( + self, + data: bytes, + return_num_padded_bits: bool = False, + use_slow_lossless_compression: bool = False, + ) -> torch.Tensor: + if self.args.dataset == 'tokenized_dataset': + input_ids = data.to(self.device)#.view(1, -1) + tokenized_data = { + 'input_ids': input_ids, + 'attention_mask': torch.ones(data.size()), + } + else: + texified_data = data + if not self.args.compressor in SKIP_LIST: + texified_data = self.tokenizer.bos_token + texified_data + + print('length of texified data after adding eos:', len(texified_data)) + + tokenized_data = self.tokenizer(texified_data, padding=False, return_tensors="pt").to(self.device) + print(tokenized_data.keys()) + input_ids = tokenized_data.input_ids + + print('length of tokens', input_ids.shape) + + if use_slow_lossless_compression: + # Compress the token at the position idx+1 using tokens_{t <= idx} + """Remark! + Due to the stochastic operations in huggingface and precision requirements by arithmetic coding, + one may use this option to compute the probability; however, in theory, they should not affect the result. + + TO-DO: provide a runnable example here. + """ + for idx in range(len(tokenized_data.input_ids[0])-1): + input_to_model = { + 'input_ids': tokenized_data.input_ids[:, :idx+1].view(1, -1), + 'attention_mask': tokenized_data.attention_mask[:, :idx+1].view(1, -1) + } + outputs = self.model(**input_to_model) + + logits = outputs.logits[:, idx, :] # next-word prob. of size # num_sentences x + pdf = torch.softmax(logits, dim=-1).squeeze().detach().cpu().numpy() + symbol = tokenized_data.input_ids[:, idx+1].item() + else: + # pay attention to the output length + outputs = self.model(**tokenized_data) + + # collect the probability of the generated token -- probability at index 0 corresponds to the token at index 1 + probs = outputs.logits[:, :-1, :] # num_sentences x next-word prob. of size (excl. the last word prediction) + input_ids = input_ids[:, 1:] # number_sentences (1?) x max_legnth_chars-1 (excl. bos_token) + + probs = torch.log_softmax(probs, dim=-1).detach() + probs = probs.exp() + + assert (input_ids < 65536).all() + input_ids = input_ids.squeeze().detach().cpu().short() + probs = probs.squeeze().detach().cpu() + + return input_ids, probs + + @torch.no_grad() + def compress( + self, + data: bytes, + return_num_padded_bits: bool = False, + use_slow_lossless_compression: bool = False, + ) -> bytes | tuple[bytes, int]: + input_ids, probs = self.infer_prob(data, return_num_padded_bits=return_num_padded_bits, use_slow_lossless_compression=use_slow_lossless_compression) + + cum_prob = probs / torch.cumsum(probs, dim=-1) + cum_prob = torch.cat([torch.zeros(cum_prob.size()[:-1]).unsqueeze(-1), cum_prob], dim=-1) + + byte_stream = torchac.encode_float_cdf(cum_prob, input_ids, check_input_bounds=False, needs_normalization=False) + + return byte_stream + + @torch.no_grad() + def decompress( + self, + data: bytes, + num_padded_bits: int = 0, + uncompressed_length: int = constants.CHUNK_SIZE_BYTES, + ) -> bytes: + """Decompresses the `data` using arithmetic coding and a pretrained model. + + See https://en.wikipedia.org/wiki/Arithmetic_coding for details. + + Args: + data: The data to be decompressed. + num_padded_bits: The number of zeros added to the encoded bitstream in order + to make it byte-decodeable (i.e., divisble by 8). + uncompressed_length: The length of the original data stream (in bytes). + + Returns: + The decompressed data. + """ + data_iter = iter(deepmind_utils.bytes_to_bits(data, num_padded_bits=num_padded_bits)) + # The decoder requires a function that reads digits from {0, 1, ..., base - 1} + # from the compressed input and returns `None` when the input is exhausted. + def _input_fn(bit_sequence: Iterator[str] = data_iter) -> int | None: + try: + return int(next(bit_sequence)) + except StopIteration: + return None + + decoder = arithmetic_coder.Decoder( + base=constants.ARITHMETIC_CODER_BASE, + precision=constants.ARITHMETIC_CODER_PRECISION, + input_fn=_input_fn, + ) + # We need a dummy token because the language model right-shifts the sequence + # by one when computing the conditional probabilities. Concretely, at every + # step, we need the `pdf` of the next token given all currently decompressed + # tokens, but without a dummy token, the last `pdf` would be that of the last + # already decompressed token. The value of the dummy token is irrelevant. + input_ids = torch.empty((1, uncompressed_length), dtype=torch.long, device=self.device) + input_ids[:, :] = self.tokenizer.pad_token_id + attention_mask = torch.ones(input_ids.shape).to(self.device) + + # In our current implementation, we always begin with a token. + input_ids[0, 0] = self.tokenizer.bos_token_id + + for idx in range(0, uncompressed_length-1): + input_to_model = { + 'input_ids': input_ids[:, :idx+1].view(1, -1), + 'attention_mask': attention_mask[:, :idx+1].view(1, -1) + } + outputs = self.model(**input_to_model) + + pdf = outputs.logits[:, idx, :] # next-word prob. of size # num_sentences x + pdf = torch.softmax(pdf, dim=-1).squeeze().detach().cpu().numpy() + token = decoder.decode( + deepmind_utils.normalize_pdf_for_arithmetic_coding(pdf) + ) + input_ids[0, idx+1] = token + + input_ids = input_ids.cpu().detach().numpy() + # Remove the dummy token and convert to bytes. + return input_ids[0, :uncompressed_length] \ No newline at end of file diff --git a/compressors/png.py b/compressors/png.py new file mode 100644 index 0000000..666b303 --- /dev/null +++ b/compressors/png.py @@ -0,0 +1,72 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implements a lossless compressor with PNG.""" + +import io +import math + +from PIL import Image + + +def _get_the_two_closest_factors(n: int) -> tuple[int, int]: + """Returns the 2 closest factors (square root if `n` is a perfect square).""" + a = round(math.sqrt(n)) + while n % a > 0: + a -= 1 + return a, n // a + + +def compress(data: bytes) -> bytes: + """Compresses `data` losslessly using the PNG format. + + The data, which is a sequence of bytes, is reshaped into a + as-close-to-square-as-possible image before compression with 8-bit pixels + (grayscale). + + Args: + data: The data to be compressed. + + Returns: + The compressed data. + """ + # Compute the height and width of the image. + size = _get_the_two_closest_factors(len(data)) + + # Load the image using 8-bit grayscale pixels. + image = Image.frombytes(mode='L', size=size, data=data) + + with io.BytesIO() as buffer: + image.save( + buffer, + format='PNG', + optimize=True, + ) + return buffer.getvalue() + + +def decompress(data: bytes) -> bytes: + """Decompresses `data` losslessly using the PNG format. + + To apply the PNG format, the `data` is treated as the compressed sequence of + bytes from an image consisting of 8-bit pixels (grayscale). + + Args: + data: The data to be decompressed. + + Returns: + The decompressed data. + """ + return Image.open(io.BytesIO(data)).tobytes() \ No newline at end of file diff --git a/constants.py b/constants.py new file mode 100644 index 0000000..d6a4f47 --- /dev/null +++ b/constants.py @@ -0,0 +1,9 @@ +NUM_CHUNKS = 10000000 +CHUNK_SIZE_BYTES = 512 # always read this amount of data, regardless of data type +CHUNK_SHAPE_2D = (16, 32) +UINT8_LOADING = True # only useful for images + +# Base 2 means that the coder writes bits. +ARITHMETIC_CODER_BASE = 2 +# Precision 32 implies 32 bit arithmetic. +ARITHMETIC_CODER_PRECISION = 32 \ No newline at end of file diff --git a/images/lm-gc.png b/images/lm-gc.png new file mode 100644 index 0000000..0f1f911 Binary files /dev/null and b/images/lm-gc.png differ diff --git a/images/teaser.png b/images/teaser.png new file mode 100644 index 0000000..2e26614 Binary files /dev/null and b/images/teaser.png differ diff --git a/scripts/collect_gradients.sh b/scripts/collect_gradients.sh new file mode 100644 index 0000000..6c8d759 --- /dev/null +++ b/scripts/collect_gradients.sh @@ -0,0 +1,8 @@ +cd .. + +DATASET='tinyimagenet' # cifar10 # mnist +ARCH="convnet" # vgg16 # resnet18 # vit +for i in 0 1 2 +do + python -u train_and_collect_grad.py -cfg settings/gradient_collection/$DATASET-$ARCH.yaml --tag $i --grad-interval 400 --download +done \ No newline at end of file diff --git a/scripts/compress.sh b/scripts/compress.sh new file mode 100644 index 0000000..3a1263c --- /dev/null +++ b/scripts/compress.sh @@ -0,0 +1,16 @@ +cd .. + +HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +NUM_SUBSAMPLE=10 +DATASET='tinyimagenet' # cifar10 # mnist +ARCH="convnet" # vgg16 # resnet18 # vit +TYPE="grad" +COMPRESSOR="tinyllama" # llama2-7b # openllama3b +SEP="hex-none" # hex-space # hex-comma+space # iso # hex-semicolon +BATCHSIZE=4 # depending on your GPUs +BPG=4 # 8 +for i in 1 2 3 +do + python -u compress.py -cfg settings/compression/cifar10-$SEP.yaml --compressor $COMPRESSOR --dataset tokenized_dataset \ + --data-path ./tokenized_datasets/$ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i.pkl --batch-size $BATCHSIZE +done \ No newline at end of file diff --git a/scripts/pipeline.sh b/scripts/pipeline.sh new file mode 100644 index 0000000..5c718c3 --- /dev/null +++ b/scripts/pipeline.sh @@ -0,0 +1,3 @@ +bash collect_gradients.sh +bash serialization.sh +bash compress.sh \ No newline at end of file diff --git a/scripts/serialization.sh b/scripts/serialization.sh new file mode 100644 index 0000000..8e0762f --- /dev/null +++ b/scripts/serialization.sh @@ -0,0 +1,14 @@ +cd .. + +NUM_SUBSAMPLE=10 +DATASET='tinyimagenet' # cifar10 # mnist +ARCH="convnet" # vgg16 # resnet18 # vit +TYPE="grad" +COMPRESSOR="tinyllama" # llama2-7b # openllama3b +SEP="hex-none" # hex-space # hex-comma+space # iso # hex-semicolon +BPG=4 # 8 +for i in 1 2 3 +do + python -u tokenize_dataset.py --cfg settings/compression/cifar10-$SEP.yaml --data-path exps/$DATASET-$ARCH/0/grads/ --bytes-per-group $BPG \ + --compressor $COMPRESSOR --exhaustive-listing --num-subsample $NUM_SUBSAMPLE --output-name $ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i +done \ No newline at end of file diff --git a/scripts/server_collect_gradients.sh b/scripts/server_collect_gradients.sh new file mode 100644 index 0000000..4a89612 --- /dev/null +++ b/scripts/server_collect_gradients.sh @@ -0,0 +1,34 @@ +#!/bin/bash +#SBATCH --nodes=1 # How many nodes? +#SBATCH -A hai_lmgc # Who pays for it? +#SBATCH --partition develbooster +#SBATCH --ntasks=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=2:00:00 +#SBATCH -o output-%x.txt +#SBATCH -e error-%x.txt +# Where does the code run? +# Required for legacy reasons +# How long? + +source /p/home/jusers/wang34/juwels/hai_fedak/huipo/general/activate.sh # path to the environment +cd /p/home/jusers/wang34/juwels/hai_fedak/huipo/lmgc-to-release # path to the upper folder +LOG_DIR="scripts/test_collect" + +if [ ! -d "$LOG_DIR" ]; then + echo "$LOG_DIR does not exist." + mkdir "$LOG_DIR" +else + echo "$LOG_DIR does exist." +fi + +DATASET='tinyimagenet' +ARCH="convnet" +INDEX=1 +for i in 0 1 2 +do + srun --exclusive --gres=gpu:1 --cpus-per-task=6 --ntasks=1 -o "$LOG_DIR/out-$INDEX-$i.txt" -e "$LOG_DIR/err-$INDEX-$i.txt" \ + python -u train_and_collect_grad.py -cfg settings/gradient_collection/$DATASET-$ARCH.yaml --tag $i --grad-interval 400 --download & +done + +wait \ No newline at end of file diff --git a/scripts/server_compress.sh b/scripts/server_compress.sh new file mode 100644 index 0000000..794af76 --- /dev/null +++ b/scripts/server_compress.sh @@ -0,0 +1,43 @@ +#!/bin/bash +#SBATCH --nodes=1 # How many nodes? +#SBATCH -A hai_lmgc # Who pays for it? +#SBATCH --partition booster +#SBATCH --ntasks=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=4:00:00 +#SBATCH -o output-%x.txt +#SBATCH -e error-%x.txt +# Where does the code run? +# Required for legacy reasons +# How long? + +source /p/home/jusers/wang34/juwels/hai_fedak/huipo/general/activate.sh # path to the environment +cd /p/home/jusers/wang34/juwels/hai_fedak/huipo/lmgc-to-release # path to the upper folder +LOG_DIR="scripts/test_compress" + +if [ ! -d "$LOG_DIR" ]; then + echo "$LOG_DIR does not exist." + mkdir "$LOG_DIR" +else + echo "$LOG_DIR does exist." +fi + + +HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +NUM_SUBSAMPLE=10 +DATASET='tinyimagenet' # cifar10 # mnist +ARCH="convnet" # vgg16 # resnet18 # vit +TYPE="grad" +COMPRESSOR="tinyllama" # llama2-7b # openllama3b +SEP="hex-space" # hex-space # hex-comma+space # iso # hex-semicolon +BATCHSIZE=4 # number of windows processed everytime -- depending on your GPUs +BPG=4 # 8 +INDEX=2 +for i in 1 2 3 +do + srun --exclusive --gres=gpu:1 --cpus-per-task=6 --ntasks=1 -o "$LOG_DIR/out-$INDEX-$i.txt" -e "$LOG_DIR/err-$INDEX-$i.txt" \ + python -u compress.py -cfg settings/compression/cifar10-$SEP.yaml --compressor $COMPRESSOR --dataset tokenized_dataset \ + --data-path /p/home/jusers/wang34/juwels/hai_fedak/huipo/gradient_compressors/tokenized_datasets/$ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i.pkl --batch-size $BATCHSIZE & +done + +wait \ No newline at end of file diff --git a/scripts/server_serialization.sh b/scripts/server_serialization.sh new file mode 100644 index 0000000..a43e2bd --- /dev/null +++ b/scripts/server_serialization.sh @@ -0,0 +1,40 @@ +#!/bin/bash +#SBATCH --nodes=1 # How many nodes? +#SBATCH -A hai_lmgc # Who pays for it? +#SBATCH --partition develbooster +#SBATCH --ntasks=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=2:00:00 +#SBATCH -o output-%x.txt +#SBATCH -e error-%x.txt +# Where does the code run? +# Required for legacy reasons +# How long? + +source /p/home/jusers/wang34/juwels/hai_fedak/huipo/general/activate.sh # path to the environment +cd /p/home/jusers/wang34/juwels/hai_fedak/huipo/lmgc-to-release # path to the upper folder +LOG_DIR="scripts/test_serialize" + +if [ ! -d "$LOG_DIR" ]; then + echo "$LOG_DIR does not exist." + mkdir "$LOG_DIR" +else + echo "$LOG_DIR does exist." +fi + +NUM_SUBSAMPLE=10 +DATASET='tinyimagenet' # cifar10 # mnist +ARCH="convnet" # vgg16 # resnet18 # vit +TYPE="grad" +COMPRESSOR="tinyllama" # llama2-7b # openllama3b +SEP="hex-space" # hex-space # hex-comma+space # iso # hex-semicolon +BPG=4 # 8 +INDEX=2 +for i in 1 2 3 +do + srun --exclusive --gres=gpu:1 --cpus-per-task=6 --ntasks=1 -o "$LOG_DIR/out-$INDEX-$i.txt" -e "$LOG_DIR/err-$INDEX-$i.txt" \ + python -u tokenize_dataset.py --cfg settings/compression/cifar10-$SEP.yaml --data-path exps/$DATASET-$ARCH/0/grads/ --bytes-per-group $BPG \ + --compressor $COMPRESSOR --exhaustive-listing --num-subsample $NUM_SUBSAMPLE --output-name $ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i & +done + +wait \ No newline at end of file diff --git a/settings/compression/cifar10-hex-0x.yaml b/settings/compression/cifar10-hex-0x.yaml new file mode 100644 index 0000000..7bee1c8 --- /dev/null +++ b/settings/compression/cifar10-hex-0x.yaml @@ -0,0 +1,2 @@ +dummy: True +decoding: hex-0x \ No newline at end of file diff --git a/settings/compression/cifar10-hex-comma+space.yaml b/settings/compression/cifar10-hex-comma+space.yaml new file mode 100644 index 0000000..b247421 --- /dev/null +++ b/settings/compression/cifar10-hex-comma+space.yaml @@ -0,0 +1,2 @@ +dummy: True +decoding: hex-comma+space \ No newline at end of file diff --git a/settings/compression/cifar10-hex-comma.yaml b/settings/compression/cifar10-hex-comma.yaml new file mode 100644 index 0000000..5236349 --- /dev/null +++ b/settings/compression/cifar10-hex-comma.yaml @@ -0,0 +1,2 @@ +dummy: True +decoding: hex-comma \ No newline at end of file diff --git a/settings/compression/cifar10-hex-none.yaml b/settings/compression/cifar10-hex-none.yaml new file mode 100644 index 0000000..170b3e2 --- /dev/null +++ b/settings/compression/cifar10-hex-none.yaml @@ -0,0 +1,2 @@ +dummy: True +decoding: hex-none \ No newline at end of file diff --git a/settings/compression/cifar10-hex-semicolon.yaml b/settings/compression/cifar10-hex-semicolon.yaml new file mode 100644 index 0000000..d309b6b --- /dev/null +++ b/settings/compression/cifar10-hex-semicolon.yaml @@ -0,0 +1,2 @@ +dummy: True +decoding: hex-semicolon \ No newline at end of file diff --git a/settings/compression/cifar10-hex-space.yaml b/settings/compression/cifar10-hex-space.yaml new file mode 100644 index 0000000..5d563cb --- /dev/null +++ b/settings/compression/cifar10-hex-space.yaml @@ -0,0 +1,2 @@ +dummy: True +decoding: hex-space \ No newline at end of file diff --git a/settings/compression/cifar10-iso-comma+space.yaml b/settings/compression/cifar10-iso-comma+space.yaml new file mode 100644 index 0000000..58e04ba --- /dev/null +++ b/settings/compression/cifar10-iso-comma+space.yaml @@ -0,0 +1,2 @@ +dummy: True +decoding: iso-comma+space \ No newline at end of file diff --git a/settings/compression/cifar10-iso-space.yaml b/settings/compression/cifar10-iso-space.yaml new file mode 100644 index 0000000..9a64274 --- /dev/null +++ b/settings/compression/cifar10-iso-space.yaml @@ -0,0 +1,2 @@ +dummy: True +decoding: iso-space \ No newline at end of file diff --git a/settings/compression/cifar10-iso.yaml b/settings/compression/cifar10-iso.yaml new file mode 100644 index 0000000..185a07b --- /dev/null +++ b/settings/compression/cifar10-iso.yaml @@ -0,0 +1,2 @@ +dummy: True +decoding: iso \ No newline at end of file diff --git a/settings/gradient_collection/cifar10-VGG16.yaml b/settings/gradient_collection/cifar10-VGG16.yaml new file mode 100644 index 0000000..7d616bf --- /dev/null +++ b/settings/gradient_collection/cifar10-VGG16.yaml @@ -0,0 +1,2 @@ +dataset: cifar10 +arch: VGG16 \ No newline at end of file diff --git a/settings/gradient_collection/cifar10-convnet.yaml b/settings/gradient_collection/cifar10-convnet.yaml new file mode 100644 index 0000000..b1dfc44 --- /dev/null +++ b/settings/gradient_collection/cifar10-convnet.yaml @@ -0,0 +1,2 @@ +dataset: cifar10 +arch: ConvNet \ No newline at end of file diff --git a/settings/gradient_collection/cifar10-resnet18.yaml b/settings/gradient_collection/cifar10-resnet18.yaml new file mode 100644 index 0000000..4d1c720 --- /dev/null +++ b/settings/gradient_collection/cifar10-resnet18.yaml @@ -0,0 +1,2 @@ +dataset: cifar10 +arch: resnet18 \ No newline at end of file diff --git a/settings/gradient_collection/cifar10-vit.yaml b/settings/gradient_collection/cifar10-vit.yaml new file mode 100644 index 0000000..ac16c33 --- /dev/null +++ b/settings/gradient_collection/cifar10-vit.yaml @@ -0,0 +1,2 @@ +dataset: cifar10 +arch: ViT \ No newline at end of file diff --git a/settings/gradient_collection/mnist-convnet.yaml b/settings/gradient_collection/mnist-convnet.yaml new file mode 100644 index 0000000..8eea9a6 --- /dev/null +++ b/settings/gradient_collection/mnist-convnet.yaml @@ -0,0 +1,2 @@ +dataset: mnist +arch: ConvNet \ No newline at end of file diff --git a/settings/gradient_collection/tinyimagenet-convnet.yaml b/settings/gradient_collection/tinyimagenet-convnet.yaml new file mode 100644 index 0000000..6d3b0b5 --- /dev/null +++ b/settings/gradient_collection/tinyimagenet-convnet.yaml @@ -0,0 +1,2 @@ +dataset: tinyimagenet +arch: ConvNet \ No newline at end of file diff --git a/tokenize_dataset.py b/tokenize_dataset.py new file mode 100644 index 0000000..1b01ec4 --- /dev/null +++ b/tokenize_dataset.py @@ -0,0 +1,126 @@ +import argparse, yaml +import functools, pickle +import os +os.environ['HF_HOME'] = './cache/' + +import constants +import numpy as np + +from transformers import AutoTokenizer + +from tqdm import tqdm +from utils import data_loaders +from utils.utils import BaseParser +from compressors import language_model + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('-cfg', '--cfg', default=None, type=str, required=True) + + parser.add_argument('-data-path', '--data-path', default='./', required=True, type=str) + parser.add_argument('-exhaustive-listing', '--exhaustive-listing', action='store_true', help='If read all of the checkpoints in the data path.') + parser.add_argument('-num-subsample', '--num-subsample', required=True, type=int) + + parser.add_argument('-download', '--download', action='store_true') + + parser.add_argument('-use_mask', '--use_mask', action='store_true', help='Applying mask functions, particularly for decoding images into ASCII.') + + parser.add_argument('--dataset', '-dataset', default='gradient', type=str, help='Indicatge what kind of data to compress.') + + parser.add_argument('--output-name', '-output-name', required=True, help='Name of the output pre-tokenized dataset.') + + parser.add_argument('--compressor', '-compressor', default='tinyllama3b', type=str, help='What kind of compressor to use.') + parser.add_argument('--verbose', '-verbose', action='store_true', help='Print first few tokens for debugging.') + + parser.add_argument('--bytes-per-group', '-bytes-per-group', default=None, type=int, help='Specify after how many bytes a separator will be added.') + parser.add_argument('-batch-size', '--batch-size', type=int, default=32) + parser.add_argument('-preprocess', '--preprocess', type=str, default=None) + + parser.add_argument('-noise-level', '--noise-level', default=None, type=float, help='Standard deviation of gaussian noise to add. Useful only when preprocess "gaussian" is enabled.') + parser.add_argument('-clipping-bound', '--clipping-bound', default=None, type=float, help='Clipping bound. Useful only when preprocess "clipping" is enabled.') + + args = parser.parse_args() + with open(args.cfg, 'r') as stream: + settings = yaml.safe_load(stream) + + args = BaseParser(args, settings) + + args.output_dir = './tokenized_datasets' + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + # arch-cifar-lm-serializeation-step-[model|grad]-tag.pkl + args.output_name = os.path.join(args.output_dir, args.output_name+'.pkl') + + # handle bytes_per_group + if args.bytes_per_group is None: + args.bytes_per_group = 4 if args.dataset == 'tokenized_data' or args.dataset == 'gradient' else 1 + + print(args) + print('chunck size: ', constants.CHUNK_SIZE_BYTES) + print('num_chuncks: ',constants.NUM_CHUNKS) + + if args.exhaustive_listing: + paths = np.array([os.path.join(args.data_path, f) for f in os.listdir(args.data_path) if os.path.isfile(os.path.join(args.data_path, f))]) + args.sample_index = np.random.choice(len(paths), args.num_subsample, replace=False) + args.data_path = paths[args.sample_index] + + # print(args.data_path) + + return args + +def main(args): + _, data_generator = data_loaders.GET_DATA_GENERATOR_FN_DICT[args.dataset](args=args, serialization=True, preprocess=args.preprocess) + tokenizer = AutoTokenizer.from_pretrained(language_model.MODEL_NAME_DICT[args.compressor]) + tokenizer.pad_token = tokenizer.eos_token + + skip_tokens = { + 'tinyllama3b': [1, 29871], + 'tinyllama': [1, 29871], + 'openllama3b': [1, 29871], + 'gpt2': [], + 'llama3-8b': [128000], + 'llama2-7b': [1, 29871], + } + # tokenize -> remove strat token if using LLAMA; sometimes tokenizer also add a '' token with id 29871 to the data + concat_tokens = [] + + assert args.compressor in skip_tokens.keys() + + print(f'Ready to tokenize {len(data_generator)} * 32 samples.') + print(f'Each sample consists of {constants.CHUNK_SIZE_BYTES} bytes.') + print(f'The preprocess will be hanlded by {args.compressor}\'s tokenizer.') + + num_sample = 0 + for num_bytes, data in tqdm(data_generator): + num_sample += len(data) + for d_size, d in zip(num_bytes, data): + tokenized_data = tokenizer(d, padding=False, return_tensors="pt")#.view(-1) # size: num_tokens + tokenized_data = tokenized_data.input_ids.view(-1) + if args.verbose: + print(tokenized_data[:10]) + print(f'\t{d[:20]}') + print(f'\t{[tokenizer.decode(t) for t in tokenized_data[:10]]}') + + # while loop pop bos and additional space tokens + started_index = 0 + while tokenized_data[started_index] in skip_tokens[args.compressor]: + started_index += 1 + + tokenized_data = tokenized_data[started_index:] + concat_tokens.append(tokenized_data.cpu().detach().numpy()) + + if args.verbose: + print(tokenized_data[:10]) + print(f'\t{[tokenizer.decode(t) for t in tokenized_data[:10]]}') + + concat_tokens = np.concatenate(concat_tokens) + print(f'Processed dataset contains {num_sample} samples with {len(concat_tokens)} tokens, containing {num_sample * constants.CHUNK_SIZE_BYTES} byte' + f' ~= {num_sample * constants.CHUNK_SIZE_BYTES/1024/1024:.2f} MB.') + + with open(args.output_name, 'wb') as handle: + pickle.dump({'tokens': concat_tokens, 'sample_index': args.sample_index, 'total_bytes': num_sample * constants.CHUNK_SIZE_BYTES}, handle, protocol=pickle.HIGHEST_PROTOCOL) + +if __name__ == '__main__': + args = parse_args() + main(args) \ No newline at end of file diff --git a/train_and_collect_grad.py b/train_and_collect_grad.py new file mode 100644 index 0000000..3d2e6ca --- /dev/null +++ b/train_and_collect_grad.py @@ -0,0 +1,160 @@ +import argparse, yaml, os +import torch +import torch.nn as nn +import torch.optim as optim + +from tqdm import tqdm +from collections import OrderedDict + +from utils.misc import get_network, prepare_dataset +from utils.utils import BaseParser + +def parse_args(): + # specify the architecture, the frequency of gradient collection + parser = argparse.ArgumentParser() + + parser.add_argument('-cfg', '--cfg', default=None, type=str, required=True) + parser.add_argument('-tag', '--tag', default=0, type=int) + parser.add_argument('-seed', '--seed', default=None) + parser.add_argument('--lr', default=0.1, type=float, help='learning rate') + + parser.add_argument('-data-path', '--data-path', default='./', type=str) + parser.add_argument('-download', '--download', action='store_true') + + parser.add_argument('-save_path', '--save_path', default='./saves', type=str) + parser.add_argument('-grad-interval', '--grad-interval', default=200, type=int, help="The frequency of saving gradients. Counted by #batch steps.") + parser.add_argument('-model-interval', '--model-interval', default=5, type=int, help="The frequency of saving models. Counted by #epoch.") + + parser.add_argument('-start-epoch', '--start-epoch', default=1, type=int, help="Load checkpoints, if sepcified.") + + args = parser.parse_args() + with open(args.cfg, 'r') as stream: + settings = yaml.safe_load(stream) + + args = BaseParser(args, settings) + + args.name = os.path.basename(args.cfg).split('.')[0] + tag = args.tag + args.exp_dir = os.path.join('exps/', args.name, f'{tag}') + + if not os.path.exists(args.exp_dir): + os.makedirs(args.exp_dir) + + # save checkpoints + args.model_dir = os.path.join(args.exp_dir, 'ckpts') + if not os.path.exists(args.model_dir): + os.makedirs(args.model_dir) + + # save gradients for compression experiments + args.grad_dir = os.path.join(args.exp_dir, 'grads') + if not os.path.exists(args.grad_dir): + os.makedirs(args.grad_dir) + + # logs + args.log_dir = os.path.join(args.exp_dir, 'logs') + if not os.path.exists(args.log_dir): + os.makedirs(args.log_dir) + + return args + +def train(args, net, opt, crit, train_loader): + net.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(train_loader): + args.batch_step += 1 + + inputs, targets = inputs.to(args.device), targets.to(args.device) + + opt.zero_grad() + + outputs = net(inputs) + loss = crit(outputs, targets) + loss.backward() + + if args.batch_step % args.grad_interval == 0: + grad_dict = {k: v.grad.detach().clone().cpu() for k, v in net.named_parameters()} + state = { + 'state_dict': grad_dict, + 'batch_step': args.batch_step, + } + fname = os.path.join(args.grad_dir, f'grad_{args.batch_step:08d}.ckpt') + torch.save(state, fname) + + opt.step() + + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + acc = 100.*correct/total + return train_loss/len(train_loader), acc, correct, total + +def test(args, net, crit, test_loader): + net.eval() + test_loss = 0 + correct = 0 + total = 0 + with torch.no_grad(): + for batch_idx, (inputs, targets) in enumerate(test_loader): + inputs, targets = inputs.to(args.device), targets.to(args.device) + outputs = net(inputs) + loss = crit(outputs, targets) + + test_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + acc = 100.*correct/total + return test_loss/len(test_loader), acc, correct, total + + +def main(args): + args.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + train_loader, test_loader, data_info = prepare_dataset(args) + net = get_network(args.arch, data_info).to(args.device) + + crit = nn.CrossEntropyLoss() + if 'vit' in args.arch: + param_dict = {pn: p for pn, p in net.named_parameters()} + parameters_decay, parameters_no_decay = net.separate_parameters() + + optim_groups = [ + {"params": [param_dict[pn] for pn in parameters_decay], "weight_decay": 1e-1}, + {"params": [param_dict[pn] for pn in parameters_no_decay], "weight_decay": 0.0}, + ] + opt = optim.AdamW(optim_groups, lr=1e-3) + scheduler = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3, + steps_per_epoch=len(train_loader), epochs=200) + else: + opt = optim.SGD(net.parameters(), lr=args.lr, + momentum=0.9, weight_decay=5e-4) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=200) + + args.batch_step = 0 # just to avoid call-by-assign + for epoch in tqdm(range(args.start_epoch, 200+1)): + train_loss, train_acc, correct, total = train(args, net, opt, crit, train_loader) + test_loss, test_acc, correct, total = test(args, net, crit, test_loader) + print(f'At epoch {epoch}: Training accuracy {train_acc:.2f} and testing accuracy {test_acc:.2f}.') + scheduler.step() + + if epoch % args.model_interval == 0: + print(f'Saving models at epoch {epoch} ...') + state = { + 'state_dict': net.state_dict(), + 'train_loss': train_loss, + 'train_acc': train_acc, + 'test_loss': test_loss, + 'test_acc': test_acc, + 'epoch': epoch, + } + fname = os.path.join(args.model_dir, f'model_{epoch:03d}.ckpt') + torch.save(state, fname) + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/data_loaders.py b/utils/data_loaders.py new file mode 100644 index 0000000..9f55aa9 --- /dev/null +++ b/utils/data_loaders.py @@ -0,0 +1,397 @@ +"""Implements data loaders.""" +from collections.abc import Iterator +import sys +sys.path.append('../') +import itertools, pickle, functools +import os, os.path +import random +import torch +import torch.nn as nn +import torchvision + +import numpy as np +import torchvision.transforms as transforms + +from transformers import AutoTokenizer + +import constants +from utils.utils import BaseParser +from utils.misc import linear_quantization +from compressors import language_model, compressor +from typing import List + +# number of bits represented by one symbol in a specific decoing scheme +SPLIT_VALUES = { + 'quater': 2, + 'oct': 3, + 'hex': 4, +} + +SEP_VALUE = { + None: '', + 'none': '', + 'comma': ',', + 'space': ' ', + 'comma+space': ', ', + 'semicolon': ';', + '0x': '0x', +} + +def worker_init_fn(worker_id): + os.sched_setaffinity(0, range(os.cpu_count())) + +def _extract_image_patches(image: np.ndarray) -> Iterator[bytes]: + h, w = constants.CHUNK_SHAPE_2D + height, width = image.shape + + for row, col in itertools.product(range(height // h), range(width // w)): + yield image[row * h : (row + 1) * h, col * w : (col + 1) * w].tobytes() + +def _extract_parameter_patches(sample: bytes) -> Iterator[bytes]: + patches = np.array_split( + np.frombuffer(sample, dtype=np.uint8), + range( + constants.CHUNK_SIZE_BYTES, + len(sample), + constants.CHUNK_SIZE_BYTES, + ), + ) + if len(patches[-1]) != constants.CHUNK_SIZE_BYTES: + # pad the array to have the same size + current_size = len(patches[-1]) + padding_size = constants.CHUNK_SIZE_BYTES - current_size + patches[-1] = np.pad(patches[-1], pad_width=(0, padding_size), mode='constant', constant_values=0) + + return map(lambda x: x.tobytes(), patches) + +def _convert_bin_fp(data: bytes, precision=64) -> List[float]: + converted_fp = np.frombuffer(data, dtype=np.float64 if precision == 64 else np.float32) + return converted_fp + +def _convert_fp_param(model: nn.Module, fp: float) -> None: + # pointer: iterate through the entire dp and parse them to param + pt = 0 + + for _, p in model.named_parameters(): + len_p = p.numel() + + param_from_fp = torch.tensor(fp[pt:pt+len_p]).view(p.size()).to(model.device) + p = param_from_fp + + pt = pt + len_p + +def _serialize(data: bytes, decoding: str, bytes_per_group: int = 1) -> str: + '''A function that convert bytes into a hex string list and then the formats that LLMs can understand. + ''' + codec = decoding.split('-')[0] + sep = decoding.split('-')[1] if '-' in decoding else None + assert sep in SEP_VALUE.keys(), f'Unknown separator {sep}. This is typically controlled by the decoding argument.' + sep = SEP_VALUE[sep] + + split = SPLIT_VALUES[codec] if codec in SPLIT_VALUES.keys() else 8 + + # dump every byte as two hexi-decimal numbers and remove '0x' prefix + hex_str = [hex(n)[2:].zfill(2) for n in data] + + # pre-group bytes according to the hyperparameter "bytes_per_group" + hex_str = [''.join(hex_str[i:i+bytes_per_group]) for i in range(0, len(hex_str), bytes_per_group)] + out = [] + + # read one group, consisting of "bytes_per_group bytes" and represented by hex numbers + for group in hex_str: + concat_str = [] + num = int(group, 16) + if codec == 'quater': + pass + elif codec == 'hex': + num = hex(num)[2:].zfill(len(group)) + elif codec == 'iso': + num = num.to_bytes(bytes_per_group, 'big') # orders should be taken care when loading the data + num = num.decode('iso-8859-1') + else: + raise NotImplementedError(f'Unknown serialization method: {decoding}.') + concat_str.append(num) + num = ''.join(concat_str) + + out.append(num) + + out = sep.join(out) + return out + +def _deserialize(data: str, decoding: str, bytes_per_group: int = 1) -> bytes: + '''A function that convert hex string list back to bytes. + ''' + codec = decoding.split('-')[0] + sep = decoding.split('-')[1] if '-' in decoding else None + assert sep in SEP_VALUE.keys(), f'Unknown separator {sep}. This is typically controlled by the decoding argument.' + sep = SEP_VALUE[sep] + + split = SPLIT_VALUES[codec] if codec in SPLIT_VALUES.keys() else 8 + + def hex_char_to_bin(hex_char): + return bin(int(hex_char, 16))[2:].zfill(4) + + # group data separated by the separators. + grouped_data = data.split(sep) + + binary_string = "" + for g in grouped_data: + for char in g: + # print(f'current string: {binary_string}, precoessing {char}') + binary_string += hex_char_to_bin(char) + + byte_array = int(binary_string, 2).to_bytes((len(binary_string) + 7) // 8, byteorder='big') + return byte_array + + +def get_image_iterator( + args: BaseParser = None, + serialization: bool = False, + max_tokens: int = None, + preprocess: str = None, +): + trainset = IterImagePatchDataset(args, serialization) + loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, num_workers=1) + return loader + +def get_gradient_iterator( + args: BaseParser = None, + serialization: bool = False, + max_tokens: int = None, + preprocess: str = None, + return_fp: bool=False +): + trainset = IterGradientDataset(args, serialization, preprocess=preprocess, return_fp=return_fp) + loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, num_workers=1) + return trainset.total_bytes, loader + +def get_td_iterator( + args: BaseParser = None, + serialization: bool = False, + max_tokens: int = None, + preprocess: str = None, +): + trainset = PreTokenizedDataset(args, max_tokens) + loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, num_workers=6, worker_init_fn=worker_init_fn) + return trainset.total_bytes, loader + +class IterImagePatchDataset(torch.utils.data.IterableDataset): + def __init__(self, args, serialization): + self.args = args + self.serialization = serialization + + transform_train = transforms.Compose([ + transforms.Grayscale(), + # transforms.ToTensor(), + ]) + + transform_test = transforms.Compose([ + transforms.Grayscale(), + # transforms.ToTensor(), + ]) + + if args.dataset == 'cifar10': + self.trainset = torchvision.datasets.CIFAR10( + root=args.data_path, train=True, download=args.download, transform=transform_train) + self.testset = torchvision.datasets.CIFAR10( + root=args.data_path, train=False, download=args.download, transform=transform_test) + self.bytes_per_sample = 32 * 32 # only consider gray-scale images here + elif args.dataset == 'imagenet': + self.trainset = torchvision.datasets.ImageNet( + root=args.data_path, train=True, download=args.download, transform=transform_train) + self.testset = torchvision.datasets.ImageNet( + root=args.data_path, train=False, download=args.download, transform=transform_test) + self.bytes_per_sample = 256 * 256 # only consider gray-scale images here + else: + raise NotImplementedError(f'Unknown dataset: {args.dataset}') + + def __iter__(self): + idx = 0 + + for data in self.trainset: + image, label = data + if constants.UINT8_LOADING: + image = np.array(image) + else: + image = image.squeeze().numpy() + # print(type(image), image.shape) + for patch in _extract_image_patches(image): + num_bytes = len(patch) + if idx == constants.NUM_CHUNKS: + return + if self.serialization: + yield num_bytes, _serialize(patch, self.args.decoding, bytes_per_group=self.args.bytes_per_group) + else: + yield num_bytes, patch + idx += 1 + +class IterGradientDataset(torch.utils.data.IterableDataset): + def __init__(self, args, serialization, preprocess=None, return_fp=False): + self.args = args + self.serialization = serialization + self.return_fp = return_fp + self.ckpt_list = args.data_path + if not isinstance(self.ckpt_list, list) and not isinstance(self.ckpt_list, np.ndarray): + self.ckpt_list = [self.ckpt_list] + + preprocess_fn = [] + if preprocess is not None: + preprocess = preprocess.split('+') + for p in preprocess: + print(f'Runing preprocessing {p} ...') + if 'sparsification' in p: + strength = int(p.split('sparsification')[1]) / 100. # convert to percentage + preprocess_fn.append(functools.partial(self._sparsify, strength=strength)) + elif 'quantization' in p: + n_bits = int(p.split('quantization')[1]) + preprocess_fn.append(functools.partial(self._quantize, n_bits=n_bits)) + else: + raise ValueError(f'Unknown preprocess instruction: {preprocess} specified in the dataloader.') + + self.length, self.total_bytes = 0, 0 + self.dset = [] + for ckpt in self.ckpt_list: + state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['state_dict'] + + sample = [] + for k, v in state_dict.items(): + sample.append(v.numpy().flatten()) + sample = np.concatenate(sample) + + if len(preprocess_fn) > 0: + for fn in preprocess_fn: + sample = fn(sample) + + print('===============> total bytes: ', len(sample.tobytes())) + + sample = sample.tobytes() + self.length += int(np.ceil(len(sample)/constants.CHUNK_SIZE_BYTES)) + self.total_bytes += len(sample) if self.length < constants.NUM_CHUNKS else constants.NUM_CHUNKS * constants.CHUNK_SIZE_BYTES + + self.dset.append(sample) + # sample: List[bytes]; len(sample) = #bytes contained in this ckpt + + + def _add_gaussian_noise(self, data: np.ndarray, std: float=1): + noise = np.random.normal(0, std, size=data.shape) + return data + noise + + def _gradient_clipping(self, data: np.ndarray, strength: float=1): + norm = np.linalg.norm(data) + return data * (strength / norm) if norm > strength else data + + def _sparsify(self, data: np.ndarray, strength: float=0.25): + # two way to implement this: set non-masked bits to zero (standard); only transmit masked bits (used by some existing works) + # we use the latter here + assert strength > 0 and strength <= 1 + mask = (np.random.uniform(0, 1, size=data.shape) < strength) + return data[mask] + + def _is_power_of_two(self, n): + return (n & (n - 1)) == 0 and n != 0 + + def _quantize(self, data: np.ndarray, n_bits: int=8): + assert n_bits in [1, 8, 16] + num_levels = 2 ** n_bits + + length = data.max() - data.min() + if n_bits != 1: + intervals = np.linspace(0, length, num_levels + 1) + inds = np.digitize(data - data.min(), intervals) + else: + inds = data + inds[inds>0] = 1 + inds[inds<=0] = 0 + + if n_bits == 1: + inds = np.packbits(inds.astype(bool)) + elif n_bits == 8: + inds = inds.astype(np.uint8) + elif n_bits == 16: + inds = inds.astype(np.uint16) + else: + raise ValueError(f'Unsupported n_bits: {n_bits} for quantization.') + + return inds + + def __iter__(self): + idx = 0 + for bid, s in enumerate(self.dset): + if self.return_fp: + step_size = constants.CHUNK_SIZE_BYTES // 4 + patches = np.array_split( + s, + range( + step_size, + len(s), + step_size, + ), + ) + + if len(patches[-1]) != step_size: + # pad the array to have the same size + current_size = len(patches[-1]) + padding_size = step_size - current_size + patches[-1] = np.pad(patches[-1], pad_width=(0, padding_size), mode='constant', constant_values=0) + + for patch in patches: + yield sys.getsizeof(patch), patch + else: + for patch in _extract_parameter_patches(s): + num_bytes = len(patch) + if idx == constants.NUM_CHUNKS: + return + if self.serialization: + yield num_bytes, _serialize(patch, self.args.decoding, bytes_per_group=self.args.bytes_per_group) + else: + yield num_bytes, patch + idx += 1 + + def __len__(self): + return self.length + +class PreTokenizedDataset(torch.utils.data.Dataset): + def __init__(self, args, max_tokens=None): + self.args = args + with open(args.data_path, 'rb') as handle: + loaded_data = pickle.load(handle) + self.dset = loaded_data['tokens'].reshape(-1) + self.total_bytes = loaded_data['total_bytes'] + + if args.compressor in compressor.COMPRESSOR_TYPES['arithmetic_coding']: + self.tokenizer = AutoTokenizer.from_pretrained(language_model.MODEL_NAME_DICT[args.compressor]) + + # we save one token quota for the bos_token which will be added later + self.max_tokens = max_tokens - 1 if max_tokens is not None else 2048 - 1 + print(self.max_tokens) + + self.num_samples = int(np.ceil(len(self.dset)/(self.max_tokens + 1))) + self.eos_token_id = self.tokenizer.eos_token_id + else: + self.max_tokens = constants.CHUNK_SIZE_BYTES + self.num_samples = int(np.ceil(len(self.dset)/self.max_tokens)) + + def _add_start_token(self, data: np.ndarray): + # for some models like LLAMA, they append more than one tokens at the begining such as a space token + # However, we ingore it here and only prepend a bos_token + bos_token_id = self.tokenizer.bos_token_id + return np.insert(data, 0, bos_token_id) + + def __getitem__(self, index): + sample = self.dset[index*self.max_tokens:(index+1)*self.max_tokens] + if self.args.compressor in compressor.COMPRESSOR_TYPES['arithmetic_coding']: + if len(sample) < self.max_tokens: + padding = np.full((self.max_tokens, ), self.tokenizer.eos_token_id) + padding[:len(sample)] = sample + sample = padding + sample = self._add_start_token(sample) + return len(sample), sample + + def __len__(self): + return self.num_samples + +GET_DATA_GENERATOR_FN_DICT = { + 'cifar10': get_image_iterator, + 'gradient': get_gradient_iterator, + 'tokenized_dataset': get_td_iterator, +} \ No newline at end of file diff --git a/utils/deepmind_utils.py b/utils/deepmind_utils.py new file mode 100644 index 0000000..69a48a0 --- /dev/null +++ b/utils/deepmind_utils.py @@ -0,0 +1,103 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utility functions.""" + +import numpy as np + + +def bits_to_bytes(bits: str) -> tuple[bytes, int]: + """Returns the bytes representation of bitstream and number of padded bits.""" + # Pad the string with zeros if the length is not a multiple of 8. + padded_bits = bits.zfill((len(bits) + 7) // 8 * 8) + num_padded_bits = len(padded_bits) - len(bits) + + # Split the string into 8-bit chunks. + chunks = [padded_bits[i : i + 8] for i in range(0, len(padded_bits), 8)] + + # Convert each chunk to an integer and then to a byte. + bytes_data = bytes([int(chunk, base=2) for chunk in chunks]) + + return bytes_data, num_padded_bits + + +def bytes_to_bits(data: bytes, num_padded_bits: int = 0) -> str: + """Returns the bitstream of bytes data accounting for padded bits.""" + return ''.join([bin(byte)[2:].zfill(8) for byte in data])[num_padded_bits:] + + +def right_shift_bytes_by_one(data: bytes) -> tuple[bytes, int]: + """Returns right-shifted bytes, i.e., divided by 2, and the number of bytes. + + Our language models were trained on ASCII data. However, not all bytes can be + decoded to ASCII, so we set the most significant bit (MSB) to 0, to ensure + that we can decode the data to ASCII. + + However, for certain data types (e.g., images), masking the MSB and leaving + the rest of the byte unchanged will destroy the structure of the data. Thus, + we instead divide the number by two (i.e., we shift the bits to the right by + one). + + Args: + data: The bytes to be shifted. + """ + return bytes([byte >> 1 for byte in data]), len(data) + + +def zero_most_significant_bit_if_not_ascii_decodable( + data: bytes, +) -> tuple[bytes, int]: + """Returns ascii-decodable data & the number of zeroed most significant bits. + + Our language models were trained on ASCII data. However, not all bytes can be + decoded to ASCII, so we set the most significant bit (MSB) to 0, to ensure + that we can decode the data to ASCII. + + Args: + data: The bytes to be shifted. + """ + masked_bits = 0 + masked_data = list() + + for byte in data: + if chr(byte).isascii(): + masked_data.append(byte) + else: + masked_bits += 1 + masked_data.append(byte & 0x7F) + + return bytes(masked_data), masked_bits + + +def normalize_pdf_for_arithmetic_coding(pdf: np.array) -> np.array: + """Normalizes the probabilities for arithmetic coding. + + Arithmetic coding converts the floating-point pdf to integers to avoid + numerical issues. To that end, all pdf values need to be larger than the + machine epsilon (to yield different integer values) and the sum of the pdf + cannot exceed 1 (minus some precision tolerance). + + Args: + pdf: The probabilities to be normalized. + + Returns: + The normalized probabilities. + """ + machine_epsilon = np.finfo(np.float32).eps + # Normalize the probabilities to avoid floating-point errors. + pdf = pdf / np.cumsum(pdf)[-1] + # Ensure all probabilities are sufficiently large to yield distinct cdfs. + pdf = (1 - 2 * pdf.shape[0] * machine_epsilon) * pdf + machine_epsilon + return pdf \ No newline at end of file diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..52f9057 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,168 @@ +import pickle + +import torch +import torchvision + +import torchvision.transforms as transforms + +import numpy as np + +from utils.models.vgg import VGG +from utils.models.convnet import ConvNet +from utils.models.vit import ViT +from utils.models.resnet import resnet +from utils.tinyimagenet import TinyImageNet + +def get_network(arch, data_info=None): + if 'VGG' in arch: + return VGG(arch) + elif 'ConvNet' in arch: + # default settings + net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling' + return ConvNet(channel=data_info['channel'], num_classes=data_info['num_classes'], net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=data_info['im_size']) + elif 'ViT' in arch: + return ViT(data_info['num_classes'], data_info['im_size'][0]) + elif 'resnet' in arch: + return resnet(arch, data_info['channel'], data_info['num_classes']) + else: + raise NotImplementedError(f'Unknown model architecture {arch}.') + +def prepare_dataset(args): + data_info = {} + if args.dataset == 'cifar10': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + trainset = torchvision.datasets.CIFAR10( + root=args.data_path, train=True, download=args.download, transform=transform_train) + trainloader = torch.utils.data.DataLoader( + trainset, batch_size=128, shuffle=True, num_workers=2) + + testset = torchvision.datasets.CIFAR10( + root=args.data_path, train=False, download=args.download, transform=transform_test) + testloader = torch.utils.data.DataLoader( + testset, batch_size=100, shuffle=False, num_workers=2) + + data_info['classes'] = ('plane', 'car', 'bird', 'cat', 'deer', + 'dog', 'frog', 'horse', 'ship', 'truck') + data_info['num_classes'] = 10 + data_info['channel'] = 3 + data_info['im_size'] = (32, 32) + + elif args.dataset == 'mnist': + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])]) + + trainset = torchvision.datasets.MNIST( + root=args.data_path, train=True, download=args.download, transform=transform) + trainloader = torch.utils.data.DataLoader( + trainset, batch_size=128, shuffle=True, num_workers=2) + + testset = torchvision.datasets.MNIST( + root=args.data_path, train=False, download=args.download, transform=transform) + testloader = torch.utils.data.DataLoader( + testset, batch_size=100, shuffle=False, num_workers=2) + + data_info['classes'] = ('0', '1', '2', '3', '4', '5', + '6', '7', '8', '9') + data_info['num_classes'] = 10 + data_info['channel'] = 1 + data_info['im_size'] = (28, 28) + + elif args.dataset == 'tinyimagenet': + transform_train = transforms.Compose([ + transforms.Resize(32), + transforms.RandomRotation(20), + transforms.RandomHorizontalFlip(0.5), + transforms.ToTensor(), + transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), + ]) + + transform_test = transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), + ]) + + trainset = TinyImageNet( + root=args.data_path, split='train', download=args.download, transform=transform_train) + trainloader = torch.utils.data.DataLoader( + trainset, batch_size=128, shuffle=True, num_workers=6) + + testset = TinyImageNet( + root=args.data_path, split='val', download=args.download, transform=transform_test) + testloader = torch.utils.data.DataLoader( + testset, batch_size=128, shuffle=False, num_workers=6) + + # data_info['classes'] = + data_info['num_classes'] = 200 + data_info['channel'] = 3 + data_info['im_size'] = (32, 32) + else: + raise NotImplementedError(f'Unknown dataset type: {args.dataset}') + + return trainloader, testloader, data_info + +# n_bits = 2, 4, 8 +# unbiased: apply probabilistic unbiased quantization or not +# hadamard: apply random hadamard rotation or not +def linear_quantization(input, n_bits, unbiased=True, hadamard=True): + quanti_level = 2 ** n_bits + rand_diag = [] + + if hadamard: + input , rand_diag = hadamard_rotation(input) + + v_max = input.max() + v_min = input.min() + output = input + output = (output - v_min) / (v_max - v_min) * (quanti_level - 1) + + if unbiased: + output = prob_quantization(output) + else: + output = output.round() + + #output = output.reshape(sz) + + return output, v_min, v_max, rand_diag, quanti_level + +def hadamard_rotation(input): + sz = input.shape + sz1 = sz[0] + sz2 = int(input.size / sz1) + dim = 2 ** np.ceil(np.log2(sz1)) + hadamard_mat = hadamard(dim) + if hadamard_mat.shape[0] != sz1: + hadamard_mat = hadamard_mat[:sz1, :sz1] + + x = input.reshape(sz1, sz2) + diag = np.random.uniform(0, 1, size=x.shape) < 0.5 + diag = diag * 2 - 1 + x = np.matmul(hadamard_mat, x) * diag + x = x.reshape(sz) + return x, diag + +def prob_quantization(input): + x = np.ceil(input) + p = np.random.uniform(0, 1, size=x.shape) + x = x - (p < x - input) + return x + +def hist_tokenized_dataset(data_path: str) -> np.ndarray: + with open(data_path, 'rb') as handle: + loaded_data = pickle.load(handle) + dset = loaded_data['tokens'].reshape(-1) + total_bytes = loaded_data['total_bytes'] + + unique_elements, counts = np.unique(dset, return_counts=True) + + return unique_elements, counts \ No newline at end of file diff --git a/utils/models/convnet.py b/utils/models/convnet.py new file mode 100644 index 0000000..7592f3d --- /dev/null +++ b/utils/models/convnet.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn + +class ConvNet(nn.Module): + def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size=(32, 32)): + super(ConvNet, self).__init__() + self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, + im_size) + num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2] + self.classifier = nn.Linear(num_feat, num_classes) + + def forward(self, x): + out = self.features(x) + out = out.view(out.size(0), -1) + out = self.classifier(out) + return out + + def embed(self, x): + feat = self.features(x) + feat = feat.view(feat.size(0), -1) + out = self.classifier(feat) + return feat, out + + def _get_activation(self, net_act): + if net_act == 'sigmoid': + return nn.Sigmoid() + elif net_act == 'relu': + return nn.ReLU(inplace=True) + elif net_act == 'leakyrelu': + return nn.LeakyReLU(negative_slope=0.01) + elif net_act == 'swish': + return Swish() + else: + exit('unknown activation function: %s' % net_act) + + def _get_pooling(self, net_pooling): + if net_pooling == 'maxpooling': + return nn.MaxPool2d(kernel_size=2, stride=2) + elif net_pooling == 'avgpooling': + return nn.AvgPool2d(kernel_size=2, stride=2) + elif net_pooling == 'none': + return None + else: + exit('unknown net_pooling: %s' % net_pooling) + + def _get_normlayer(self, net_norm, shape_feat): + # shape_feat = (c*h*w) + if net_norm == 'batchnorm': + return nn.BatchNorm2d(shape_feat[0], affine=True) + elif net_norm == 'layernorm': + return nn.LayerNorm(shape_feat, elementwise_affine=True) + elif net_norm == 'instancenorm': + return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) + elif net_norm == 'groupnorm': + return nn.GroupNorm(4, shape_feat[0], affine=True) + elif net_norm == 'none': + return None + else: + exit('unknown net_norm: %s' % net_norm) + + def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): + layers = [] + in_channels = channel + if im_size[0] == 28: + im_size = (32, 32) + shape_feat = [in_channels, im_size[0], im_size[1]] + for d in range(net_depth): + layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] + shape_feat[0] = net_width + if net_norm != 'none': + layers += [self._get_normlayer(net_norm, shape_feat)] + layers += [self._get_activation(net_act)] + in_channels = net_width + if net_pooling != 'none': + layers += [self._get_pooling(net_pooling)] + shape_feat[1] //= 2 + shape_feat[2] //= 2 + + return nn.Sequential(*layers), shape_feat \ No newline at end of file diff --git a/utils/models/resnet.py b/utils/models/resnet.py new file mode 100644 index 0000000..5b0344a --- /dev/null +++ b/utils/models/resnet.py @@ -0,0 +1,144 @@ +# The code is adapted from https://github.com/VICO-UoE/DatasetCondensation for CIFAR-10 + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): + super(BasicBlock, self).__init__() + self.norm = norm + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.GroupNorm(self.expansion * planes, self.expansion * planes, + affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): + super(Bottleneck, self).__init__() + self.norm = norm + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) + self.bn3 = nn.GroupNorm(self.expansion * planes, self.expansion * planes, + affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.GroupNorm(self.expansion * planes, self.expansion * planes, + affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): + super(ResNet, self).__init__() + self.in_planes = 64 + self.norm = norm + + self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.classifier = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride, self.norm)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.classifier(out) + return out + + def embed(self, x): + x = F.relu(self.bn1(self.conv1(x))) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = F.avg_pool2d(x, kernel_size=1, stride=1) # modification + x = x.view(x.size(0), -1) + + out = self.classifier(x) + return x, out + + +def ResNet18BN(channel, num_classes): + return ResNet(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes, norm='batchnorm') + + +def ResNet18(channel, num_classes): + return ResNet(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes) + + +def ResNet34(channel, num_classes): + return ResNet(BasicBlock, [3, 4, 6, 3], channel=channel, num_classes=num_classes) + + +def ResNet50(channel, num_classes): + return ResNet(Bottleneck, [3, 4, 6, 3], channel=channel, num_classes=num_classes) + + +def ResNet101(channel, num_classes): + return ResNet(Bottleneck, [3, 4, 23, 3], channel=channel, num_classes=num_classes) + + +def ResNet152(channel, num_classes): + return ResNet(Bottleneck, [3, 8, 36, 3], channel=channel, num_classes=num_classes) + +def resnet(arch, channel, num_classes): + arch_choices = { + 'resnet18bn': ResNet18BN, + 'resnet18': ResNet18, + 'resnet34': ResNet34, + 'resnet50': ResNet50, + 'resnet101': ResNet101, + 'resnet152': ResNet152, + } + + return arch_choices[arch](channel, num_classes) \ No newline at end of file diff --git a/utils/models/vgg.py b/utils/models/vgg.py new file mode 100644 index 0000000..c38f31e --- /dev/null +++ b/utils/models/vgg.py @@ -0,0 +1,38 @@ +'''VGG11/13/16/19 in Pytorch.''' +import torch +import torch.nn as nn + + +cfg = { + 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +class VGG(nn.Module): + def __init__(self, vgg_name): + super(VGG, self).__init__() + self.features = self._make_layers(cfg[vgg_name]) + self.classifier = nn.Linear(512, 10) + + def forward(self, x): + out = self.features(x) + out = out.view(out.size(0), -1) + out = self.classifier(out) + return out + + def _make_layers(self, cfg): + layers = [] + in_channels = 3 + for x in cfg: + if x == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), + nn.BatchNorm2d(x), + nn.ReLU(inplace=True)] + in_channels = x + layers += [nn.AvgPool2d(kernel_size=1, stride=1)] + return nn.Sequential(*layers) \ No newline at end of file diff --git a/utils/models/vit.py b/utils/models/vit.py new file mode 100644 index 0000000..62a9884 --- /dev/null +++ b/utils/models/vit.py @@ -0,0 +1,196 @@ +# CIFAR-10 ViT model is adapted from https://juliusruseckas.github.io/ml/cifar10-vit.html +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Residual(nn.Module): + def __init__(self, *layers): + super().__init__() + self.residual = nn.Sequential(*layers) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return x + self.gamma * self.residual(x) + +class LayerNormChannels(nn.Module): + def __init__(self, channels): + super().__init__() + self.norm = nn.LayerNorm(channels) + + def forward(self, x): + x = x.transpose(1, -1) + x = self.norm(x) + x = x.transpose(-1, 1) + return x + +class SelfAttention2d(nn.Module): + def __init__(self, in_channels, out_channels, head_channels, shape): + super().__init__() + self.heads = out_channels // head_channels + self.head_channels = head_channels + self.scale = head_channels**-0.5 + + self.to_keys = nn.Conv2d(in_channels, out_channels, 1) + self.to_queries = nn.Conv2d(in_channels, out_channels, 1) + self.to_values = nn.Conv2d(in_channels, out_channels, 1) + self.unifyheads = nn.Conv2d(out_channels, out_channels, 1) + + height, width = shape + self.pos_enc = nn.Parameter(torch.Tensor(self.heads, (2 * height - 1) * (2 * width - 1))) + self.register_buffer("relative_indices", self.get_indices(height, width)) + + def forward(self, x): + b, _, h, w = x.shape + + keys = self.to_keys(x).view(b, self.heads, self.head_channels, -1) + values = self.to_values(x).view(b, self.heads, self.head_channels, -1) + queries = self.to_queries(x).view(b, self.heads, self.head_channels, -1) + + att = keys.transpose(-2, -1) @ queries + + indices = self.relative_indices.expand(self.heads, -1) + rel_pos_enc = self.pos_enc.gather(-1, indices) + rel_pos_enc = rel_pos_enc.unflatten(-1, (h * w, h * w)) + + att = att * self.scale + rel_pos_enc + att = F.softmax(att, dim=-2) + + out = values @ att + out = out.view(b, -1, h, w) + out = self.unifyheads(out) + return out + + @staticmethod + def get_indices(h, w): + y = torch.arange(h, dtype=torch.long) + x = torch.arange(w, dtype=torch.long) + + y1, x1, y2, x2 = torch.meshgrid(y, x, y, x, indexing='ij') + indices = (y1 - y2 + h - 1) * (2 * w - 1) + x1 - x2 + w - 1 + indices = indices.flatten() + + return indices + +class FeedForward(nn.Sequential): + def __init__(self, in_channels, out_channels, mult=4): + hidden_channels = in_channels * mult + super().__init__( + nn.Conv2d(in_channels, hidden_channels, 1), + nn.GELU(), + nn.Conv2d(hidden_channels, out_channels, 1) + ) + +class TransformerBlock(nn.Sequential): + def __init__(self, channels, head_channels, shape, p_drop=0.): + super().__init__( + Residual( + LayerNormChannels(channels), + SelfAttention2d(channels, channels, head_channels, shape), + nn.Dropout(p_drop) + ), + Residual( + LayerNormChannels(channels), + FeedForward(channels, channels), + nn.Dropout(p_drop) + ) + ) + +class TransformerStack(nn.Sequential): + def __init__(self, num_blocks, channels, head_channels, shape, p_drop=0.): + layers = [TransformerBlock(channels, head_channels, shape, p_drop) for _ in range(num_blocks)] + super().__init__(*layers) + +class ToPatches(nn.Sequential): + def __init__(self, in_channels, channels, patch_size, hidden_channels=32): + super().__init__( + nn.Conv2d(in_channels, hidden_channels, 3, padding=1), + nn.GELU(), + nn.Conv2d(hidden_channels, channels, patch_size, stride=patch_size) + ) + +class AddPositionEmbedding(nn.Module): + def __init__(self, channels, shape): + super().__init__() + self.pos_embedding = nn.Parameter(torch.Tensor(channels, *shape)) + + def forward(self, x): + return x + self.pos_embedding + +class ToEmbedding(nn.Sequential): + def __init__(self, in_channels, channels, patch_size, shape, p_drop=0.): + super().__init__( + ToPatches(in_channels, channels, patch_size), + AddPositionEmbedding(channels, shape), + nn.Dropout(p_drop) + ) + +# main model +class Head(nn.Sequential): + def __init__(self, in_channels, classes, p_drop=0.): + super().__init__( + LayerNormChannels(in_channels), + nn.GELU(), + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.Dropout(p_drop), + nn.Linear(in_channels, classes) + ) +class RelViT(nn.Sequential): + def __init__(self, classes, image_size, channels, head_channels, num_blocks, patch_size, + in_channels=3, emb_p_drop=0., trans_p_drop=0., head_p_drop=0.): + reduced_size = image_size // patch_size + shape = (reduced_size, reduced_size) + super().__init__( + ToEmbedding(in_channels, channels, patch_size, shape, emb_p_drop), + TransformerStack(num_blocks, channels, head_channels, shape, trans_p_drop), + Head(channels, classes, head_p_drop) + ) + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.kaiming_normal_(m.weight) + if m.bias is not None: nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.weight, 1.) + nn.init.zeros_(m.bias) + elif isinstance(m, AddPositionEmbedding): + nn.init.normal_(m.pos_embedding, mean=0.0, std=0.02) + elif isinstance(m, SelfAttention2d): + nn.init.normal_(m.pos_enc, mean=0.0, std=0.02) + elif isinstance(m, Residual): + nn.init.zeros_(m.gamma) + + def separate_parameters(self): + parameters_decay = set() + parameters_no_decay = set() + modules_weight_decay = (nn.Linear, nn.Conv2d) + modules_no_weight_decay = (nn.LayerNorm,) + + for m_name, m in self.named_modules(): + for param_name, param in m.named_parameters(): + full_param_name = f"{m_name}.{param_name}" if m_name else param_name + + if isinstance(m, modules_no_weight_decay): + parameters_no_decay.add(full_param_name) + elif param_name.endswith("bias"): + parameters_no_decay.add(full_param_name) + elif isinstance(m, Residual) and param_name.endswith("gamma"): + parameters_no_decay.add(full_param_name) + elif isinstance(m, AddPositionEmbedding) and param_name.endswith("pos_embedding"): + parameters_no_decay.add(full_param_name) + elif isinstance(m, SelfAttention2d) and param_name.endswith("pos_enc"): + parameters_no_decay.add(full_param_name) + elif isinstance(m, modules_weight_decay): + parameters_decay.add(full_param_name) + + # sanity check + assert len(parameters_decay & parameters_no_decay) == 0 + assert len(parameters_decay) + len(parameters_no_decay) == len(list(self.parameters())) + + return parameters_decay, parameters_no_decay + +def ViT(num_classes, image_size): + return RelViT(num_classes, image_size, channels=256, head_channels=32, num_blocks=8, patch_size=2, + emb_p_drop=0., trans_p_drop=0., head_p_drop=0.3) \ No newline at end of file diff --git a/utils/tinyimagenet.py b/utils/tinyimagenet.py new file mode 100644 index 0000000..2986e83 --- /dev/null +++ b/utils/tinyimagenet.py @@ -0,0 +1,105 @@ +# Copyright (C) 2022 Leonardo Romor +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Simple Tiny ImageNet dataset utility class for pytorch.""" + +import os + +import shutil + +from torchvision.datasets import ImageFolder +from torchvision.datasets.utils import verify_str_arg +from torchvision.datasets.utils import download_and_extract_archive + + +def normalize_tin_val_folder_structure(path, + images_folder='images', + annotations_file='val_annotations.txt'): + # Check if files/annotations are still there to see + # if we already run reorganize the folder structure. + images_folder = os.path.join(path, images_folder) + annotations_file = os.path.join(path, annotations_file) + + # Exists + if not os.path.exists(images_folder) \ + and not os.path.exists(annotations_file): + if not os.listdir(path): + raise RuntimeError('Validation folder is empty.') + return + + # Parse the annotations + with open(annotations_file) as f: + for line in f: + values = line.split() + img = values[0] + label = values[1] + img_file = os.path.join(images_folder, values[0]) + label_folder = os.path.join(path, label) + os.makedirs(label_folder, exist_ok=True) + try: + shutil.move(img_file, os.path.join(label_folder, img)) + except FileNotFoundError: + continue + + os.sync() + assert not os.listdir(images_folder) + shutil.rmtree(images_folder) + os.remove(annotations_file) + os.sync() + + +class TinyImageNet(ImageFolder): + """Dataset for TinyImageNet-200""" + base_folder = 'tiny-imagenet-200' + zip_md5 = '90528d7ca1a48142e341f4ef8d21d0de' + splits = ('train', 'val') + filename = 'tiny-imagenet-200.zip' + url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip' + + def __init__(self, root, split='train', download=False, **kwargs): + self.data_root = os.path.expanduser(root) + self.split = verify_str_arg(split, "split", self.splits) + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError('Dataset not found.' + + ' You can use download=True to download it') + super().__init__(self.split_folder, **kwargs) + + @property + def dataset_folder(self): + return os.path.join(self.data_root, self.base_folder) + + @property + def split_folder(self): + return os.path.join(self.dataset_folder, self.split) + + def _check_exists(self): + return os.path.exists(self.split_folder) + + def extra_repr(self): + return "Split: {split}".format(**self.__dict__) + + def download(self): + if self._check_exists(): + return + download_and_extract_archive( + self.url, self.data_root, filename=self.filename, + remove_finished=True, md5=self.zip_md5) + assert 'val' in self.splits + normalize_tin_val_folder_structure( + os.path.join(self.dataset_folder, 'val')) \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..61cf98a --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,49 @@ +import argparse + +class BaseParser(dict): + def __init__(self, cmd_args, *file_args): + super(BaseParser, self).__init__() + args = [*file_args, cmd_args] # cmd_args has higher priority than file_args + + # default options for the parser, which can be customized for specific applications + self.choice_dict = {} + self.default_dict = {} + self.allowed_overwritten_list = {'seed': None} + + for i_d, d in enumerate(args): + # print(i_d) + if isinstance(d, argparse.Namespace): + d = vars(d) + + for k, v in d.items(): + assert k not in self.keys() or k in self.allowed_overwritten_list.keys(), f'duplicated arguments {k}, please check the configuration file.' + + if k in self.allowed_overwritten_list.keys() and v == self.allowed_overwritten_list[k]: + continue + # print(f'\t{k}: {v}') + self.add_item(k, v) + + # check whether the default options has been in args; otherswise, add it. + for k in self.default_dict.keys(): + if k not in self.keys(): + self[k] = self.default_dict[k] + + def add_item(self, k, v): + # 1. convert '-' to '_'; 2. replace string 'None' with NoneType + k = k.replace('-', '_') + + #check whether arguments match the limited choices + if k in self.choice_dict.keys() and v not in self.choice_dict[k]: + raise ValueError(f'Illegal argument \'{k}\' for choices {self.choice_dict[k]}') + + # convert string None to Nonetype, which is a side effect of using yaml + self[k] = None if v == 'None' else v + + def __getattr__(self, name): + try: + return self[name] + except KeyError: + raise AttributeError(f'{name}') + + def __setattr__(self, key, val): + self[key] = val \ No newline at end of file