diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..f9244ed
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,9 @@
+*__pycache__*
+*txt
+*cache*
+*.DS_Store*
+tokenized_datasets/*
+exps/*
+probs/*
+.ipynb_checkpoints/
+tiny-imagenet-200/
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..01ecbb6
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Hui-Po Wang
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..a9ea6cb
--- /dev/null
+++ b/README.md
@@ -0,0 +1,136 @@
+
+
+
+
+
+
+
+
+
+
+The official implementation of "Language Models as Zero-shot Lossless Gradient Compressors: Towards General Neural Parameter Prior Models" publised at NeurIPS 2024.
+
+
+[Preprint]
+
+
+
+## Overview
+
+data:image/s3,"s3://crabby-images/d91b4/d91b4adf2bfac40b51e2e916e182b6f61e6b70a6" alt="Product Screenshot"
+
+This project provides the source code of LM-GC, the first LLM-powered gradient compressor.
+
+Here are take-aways:
+
+- We demonstrate that large language models (LLMs) hold significant potential as prior models for gradients, a concept that has been widely applied in other modalities but gradients.
+- We introduce an novel serialization method that converts IEEE 754 floating points into hexadecimal format, enabling LLMs to comprehend and achieve state-of-the-art lossless gradient compression.
+- Our LLM-based prior model could unlock new applications for gradients, similar to those in other modalities, such as super-resolution, denoising, generation, and more.
+
+
+
+*If you find the project interesting, don't forget to star and cite our work:*
+
+```bibtex
+@article{wang2024language,
+ title={Language Models as Zero-shot Lossless Gradient Compressors: Towards General Neural Parameter Prior Models},
+ author={Wang, Hui-Po and Fritz, Mario},
+ journal={Advances in Neural Information Processing Systems},
+ year={2024}
+}
+```
+## Getting Started
+### Prerequisites
+
+- torch ≥ 2.12.0
+- transformers ≥ 4.40.1
+- [torchac](https://github.com/fab-jul/torchac)
+- [flash attention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2) ≥ 2.5.8 via ```pip install flash-attn --no-build-isolation``` for NVIDIA GPUs
+
+or
+
+- install via ```pip```
+ ```sh
+ pip install -r requirements.txt
+ ```
+**After set up the huggingface access token, ideally, the codebase will download language models automatically via HuggingFace except for LLAMA2. See [More LLMs](#more-llms) for more information.**
+
+### Quickstart
+We provide a quick demo here. Please refer to [Usage](#usage) for the detailed usage.
+```bash
+cd scripts
+# compress gradients of a ConvNet trained on TinyImageNet using TinyLLAMA
+bash pipeline.sh
+```
+## Usage
+It takes three steps to reproduce the experiments in the paper, including (1) train neural networks to collect gradients; (2) serialize and tokenize raw gradients; (3) run LLMs and arithmetic (LM-GC).
+
+### 1. Gradient collection
+This step trains a network (e.g. a ConvNet on TinyImageNet in the following example) and collect gradients for compression later. See ```scripts/run_collect.sh``` for more details.
+```bash
+DATASET='tinyimagenet' # cifar10 # mnist
+ARCH="convnet" # vgg16 # resnet18 # vit
+for i in 0 1 2
+do
+ python -u train_and_collect_grad.py -cfg settings/gradient_collection/$DATASET-$ARCH.yaml --tag $i --grad-interval 400 --download
+done
+```
+### 2. Serialization and tokenization
+For convenience, we process the data before conducting arithmetic encoding. The data is serialized and tokenized here. We create three preprocessed datasets here. See ```scripts/serialization.sh``` for more details.
+```bash
+NUM_SUBSAMPLE=10
+DATASET='tinyimagenet' # cifar10 # mnist
+ARCH="convnet" # vgg16 # resnet18 # vit
+TYPE="grad"
+COMPRESSOR="tinyllama" # llama2-7b # openllama3b
+SEP="hex-none" # hex-space # hex-comma+space # iso # hex-semicolon
+BPG=4 # 8
+for i in 1 2 3
+do
+ python -u tokenize_dataset.py --cfg settings/compression/cifar10-$SEP.yaml \
+ --data-path exps/$DATASET-$ARCH/0/grads/ --bytes-per-group $BPG \
+ --compressor $COMPRESSOR --exhaustive-listing --num-subsample $NUM_SUBSAMPLE \
+ --output-name $ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i
+done
+```
+### 3. Run compression
+The processed data from the previous step is now divided into several disjoint windows. By default, LLMs see a set of 2048 (including 1 BOS token) tokens every time. The experimented are repeated three times. See ```scripts/compress.sh``` for more details.
+```bash
+HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
+NUM_SUBSAMPLE=10
+DATASET='tinyimagenet' # cifar10 # mnist
+ARCH="convnet" # vgg16 # resnet18 # vit
+TYPE="grad"
+COMPRESSOR="tinyllama" # llama2-7b # openllama3b
+SEP="hex-none" # hex-space # hex-comma+space # iso # hex-semicolon
+BATCHSIZE=4 # depending on your GPUs
+BPG=4 # 8
+for i in 1 2 3
+do
+ python -u compress.py -cfg settings/compression/cifar10-$SEP.yaml --compressor $COMPRESSOR --dataset tokenized_dataset \
+ --data-path ./tokenized_datasets/$ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i.pkl --batch-size $BATCHSIZE
+done
+```
+
+## Options
+
+### More LLMs
+
+### More models to compress
+
+### Ablation study
+- Bytes per group
+- Context window size
+
+## TO-DO
+- [x] prepare `pipeline.sh`
+- [x] sanity check
+- [ ] how to add more LLMs
+- [ ] provide a runnable encode/decode example
+- [ ] Baseline codec
+## License
+
+Distributed under the MIT License. See [MIT License](https://opensource.org/licenses/MIT) for more information.
+
+## Acknowledgments
+This project is partially built up on [Deepmind's work](), and the readme file template comes from [makeread.me](https://github.com/ShaanCoding/ReadME-Generator).
diff --git a/arithmetic_coder.py b/arithmetic_coder.py
new file mode 100644
index 0000000..b1eccd8
--- /dev/null
+++ b/arithmetic_coder.py
@@ -0,0 +1,324 @@
+# Copyright 2024 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Implements an Arithmetic Encoder and Decoder."""
+
+from typing import Any, Callable
+
+import chex
+import numpy as np
+
+InputFn = Callable[[], int]
+OutputFn = Callable[[int], None]
+IOFn = InputFn | OutputFn
+
+
+def _log_power_of_b(n: int, base: int) -> int:
+ """Returns k assuming n = base ** k.
+
+ We manually implement this function to be faster than a np.log or math.log,
+ which doesn't assume n is an integer.
+
+ Args:
+ n: The integer of which we want the log.
+ base: The base of the log.
+ """
+ log_n = 0
+ while n > 1:
+ n //= base
+ log_n += 1
+ return log_n
+
+
+def _raise_post_terminate_exception(*args: Any, **kwargs: Any) -> None:
+ """Dummy function that raises an error to ensure AC termination."""
+ del args, kwargs
+ raise ValueError(
+ "Arithmetic encoder was terminated. "
+ "Create a new instance for encoding more data. "
+ "Do NOT use an output function that writes to the same data sink "
+ "used by the output function of this instance. "
+ "This will corrupt the arithmetic code as decoding relies on detecting "
+ "when the compressed data stream is exhausted."
+ )
+
+
+class _CoderBase:
+ """Arithmetic coder (AC) base class."""
+
+ def __init__(self, base: int, precision: int, io_fn: IOFn):
+ """Does initialization shared by AC encoder and decoder.
+
+ Args:
+ base: The arithmetic coder will output digits in {0, 1, ..., base - 1}.
+ precision: Precision for internal state; on the average this will waste
+ code space worth at most 1/log(base) * base ** -(precision - 2) digits
+ of output per coding step.
+ io_fn: Function to write digits to compressed stream/read digits from
+ compressed stream.
+ """
+ chex.assert_scalar_in(base, 2, np.inf)
+ chex.assert_scalar_in(precision, 2, np.inf)
+
+ self._base: int = base
+ self._base_to_pm1: int = int(base ** (precision - 1))
+ self._base_to_pm2: int = int(base ** (precision - 2))
+ self._io_fn = io_fn
+
+ # NOTE: We represent the AC interval [0, 1) as rational numbers:
+ # [0, 1)
+ # ~ [self._low / base ** precision, (self._high + 1) / base ** precision)
+ # = [self._low / base ** precision, self._high / base ** precision],
+ # where the we represent the upper bound *INCLUSIVE*. This is a subtle
+ # detail required to make the integer arithmetic work correctly given that
+ # all involved integers have `precision` digits in base `base`.
+ self._low: int = 0
+ self._high: int = int(base**precision) - 1
+ self._num_carry_digits: int = 0
+ self._code: int = 0
+
+ def __str__(self) -> str:
+ """Returns string describing internal state."""
+ if self._base > 16:
+ raise ValueError("`__str__` with `base` exceeding 16 not implmeneted.")
+
+ p = 1 + _log_power_of_b(self._base_to_pm1, base=self._base)
+
+ def _to_str(x: int) -> str:
+ """Returns representation of `n` in base `self._base`."""
+ digits = [(x // self._base**i) % self._base for i in range(p)]
+ return f"{digits[-1]:x}" + "".join(
+ f"{d:x}" for d in digits[-2::-1]
+ )
+
+ return (
+ f"[{_to_str(self._low)}, {_to_str(self._high)}) {_to_str(self._code)}"
+ )
+
+ def _get_intervals(self, pdf: np.ndarray) -> np.ndarray:
+ """Partition the current AC interval according to the distribution `pdf`."""
+ if (pdf < 0).any():
+ raise ValueError(
+ "Some probabilities are negative. Please make sure that pdf[x] > 0."
+ )
+ # Compute CPDF s.t. cpdf[x] = sum_y= max(base ** -(precision - 2), np.dtype(x).eps) for any"
+ " symbol by either preprocessing `pdf` or by increasing `precision`."
+ )
+ if qcpdf[-1] > width:
+ raise ValueError(
+ "Cumulative sum of probabilities exceeds 1 after quantisation. "
+ "Please make sure that sum(pdf) <= 1.0 - eps, for a small eps > 0."
+ )
+ return self._low + qcpdf
+
+ def _remove_matching_digits(self, low_pre_split: int, encoding: bool) -> None:
+ """Remove matching most significant digits from AC state [low, high).
+
+ This is the *FIRST* normalization step after encoding a symbol into the AC
+ state.
+
+ When encoding we write the most significant matching digits of the
+ integer representation of [low, high) to the output, widen the integer
+ representation of [low, high) including a (potential) queue of carry digits;
+ when decoding we drop the matching most significant digits of the integer
+ representation of [low, high), widen this interval and keep the current
+ slice of the arithmetic code word `self._code` in sync.
+
+ Args:
+ low_pre_split: Value of `self._low` before encoding a new symbol into the
+ AC state when `encoding` is True; abitrary, otherwise.
+ encoding: Are we encoding (i.e. normalise by writing data) or decoding
+ (i.e. normalise by reading data)?
+ """
+
+ def _shift_left(x: int) -> int:
+ """Shift `x` one digit left."""
+ return (x % self._base_to_pm1) * self._base
+
+ while self._low // self._base_to_pm1 == self._high // self._base_to_pm1:
+ if encoding:
+ low_msd = self._low // self._base_to_pm1
+ self._io_fn(low_msd)
+ # Note that carry digits will only be written in the first round of this
+ # loop.
+ carry_digit = (
+ self._base - 1 + low_msd - low_pre_split // self._base_to_pm1
+ ) % self._base
+ assert carry_digit in {0, self._base - 1} or self._num_carry_digits == 0
+ while self._num_carry_digits > 0:
+ self._io_fn(carry_digit)
+ self._num_carry_digits -= 1
+ else:
+ self._code = _shift_left(self._code) + self._io_fn()
+ self._low = _shift_left(self._low)
+ self._high = _shift_left(self._high) + self._base - 1
+
+ def _remove_carry_digits(self, encoding: bool) -> None:
+ """Remove and record 2nd most significant digits from AC state [low, high).
+
+ This is the *SECOND* normalization step after encoding a symbol into the AC
+ state [low, high).
+
+ If the AC state takes the form
+ low = x B-1 B-1 ... B-1 u ...
+ high = x+1 0 0 0 v ...
+ ^__ prefix __^
+ where x, u and v are base-B digits then low and high can get arbitrarily (
+ well, by means of infinite precision arithmetics) without matching. Since we
+ work with finite precision arithmetics, we must make sure that this doesn't
+ occour and we guarantee sufficient of coding range (`high - low`). To end
+ this we detect the above situation and cut off the highlighted prefix above
+ to widen the integer representation of [low, high) and record the number of
+ prefix digits removed. When decoding we must similarly process the current
+ slice of the arithmetic code word `self._code` to keep it in sync.
+
+ Args:
+ encoding: Are we encoding (i.e. normalise by writing data) or decoding
+ (i.e. normalise by reading data)?
+ """
+
+ def _shift_left_keeping_msd(x: int) -> int:
+ """Shift `x` except MSD, which remains in place, one digit left."""
+ return x - (x % self._base_to_pm1) + (x % self._base_to_pm2) * self._base
+
+ while self._low // self._base_to_pm2 + 1 == self._high // self._base_to_pm2:
+ if encoding:
+ self._num_carry_digits += 1
+ else:
+ self._code = _shift_left_keeping_msd(self._code) + self._io_fn()
+ self._low = _shift_left_keeping_msd(self._low)
+ self._high = _shift_left_keeping_msd(self._high) + self._base - 1
+
+ def _process(self, pdf: np.ndarray, symbol: int | None) -> int:
+ """Perform an AC encoding or decoding step and modify AC state in-place.
+
+ Args:
+ pdf: Probability distribution over input alphabet.
+ symbol: Letter to encode from {0, 1, ..., pdf.size - 1} when encoding or
+ `None` when decoding.
+
+ Returns:
+ y: `symbol` from above when encoding or decoded letter from {0, 1, ...,
+ pdf.size - 1}.
+ """
+
+ encoding = symbol is not None
+ intervals = self._get_intervals(pdf)
+ if not encoding:
+ symbol = np.searchsorted(intervals, self._code, side="right") - 1
+ assert 0 <= symbol < pdf.size
+ low_pre_split = self._low
+ self._low, self._high = intervals[[symbol, symbol + 1]]
+ # Due to integer arithmetics the integer representation of [low, high) has
+ # an inclusive upper bound, so decrease high.
+ self._high -= 1
+ assert 0 <= self._low <= self._high < self._base_to_pm1 * self._base
+
+ # Normalize the AC state.
+ self._remove_matching_digits(low_pre_split=low_pre_split, encoding=encoding)
+ assert 0 <= self._low <= self._high < self._base_to_pm1 * self._base
+ assert encoding or self._low <= self._code <= self._high
+ assert self._low // self._base_to_pm1 != self._high // self._base_to_pm1
+
+ self._remove_carry_digits(encoding=encoding)
+ assert 0 <= self._low <= self._high < self._base_to_pm1 * self._base
+ assert encoding or self._low <= self._code <= self._high
+ assert self._high - self._low > self._base_to_pm2
+
+ return symbol
+
+ @classmethod
+ def p_min(cls, base: int, precision: int) -> float:
+ """Get minimum probability supported by AC config."""
+ # The leading factor 2 is supposed to account for rounding errors and
+ # wouldn't be necessary given infinite float precision.
+ return 2.0 * base ** -(precision - 2)
+
+
+class Encoder(_CoderBase):
+ """Arithmetic encoder."""
+
+ def __init__(self, base: int, precision: int, output_fn: OutputFn):
+ """Constructs arithmetic encoder.
+
+ Args:
+ base: The arithmetic coder will output digits in {0, 1, ..., base - 1}.
+ precision: Precision for internal state; on the average this will waste
+ code space worth at most 1/log(base) * base ** -(precision - 2) digits
+ of output per coding step.
+ output_fn: Function that writes a digit from {0, 1, ..., base - 1} to the
+ compressed output.
+ """
+ super().__init__(base, precision, output_fn)
+
+ def encode(self, pdf: np.ndarray, symbol: int) -> None:
+ """Encodes symbol `symbol` assuming coding distribution `pdf`."""
+ self._process(pdf, symbol)
+
+ def terminate(self) -> None:
+ """Finalizes arithmetic code."""
+ # Write outstanding part of the arithmetic code plus one digit to uniquely
+ # determine a code within the interval of the final symbol coded.
+ self._io_fn(self._low // self._base_to_pm1)
+ for _ in range(self._num_carry_digits):
+ self._io_fn(self._base - 1)
+ self.encode = _raise_post_terminate_exception
+ self.terminate = _raise_post_terminate_exception
+
+
+class Decoder(_CoderBase):
+ """Arithmetic decoder."""
+
+ def __init__(self, base: int, precision: int, input_fn: InputFn):
+ """Constructs arithmetic decoder.
+
+ Args:
+ base: The arithmetic coder will output digits in {0, 1, ..., base - 1}.
+ precision: Precision for internal state; on the average this will waste
+ code space worth at most 1/log(base) * base ** -(precision - 2) digits
+ of output per coding step.
+ input_fn: Function that reads a digit from {0, 1, ..., base - 1} from the
+ compressed input or returns `None` when the input is exhausted.
+ """
+ # Add padding to ensure the AC state is well-defined when decoding the last
+ # symbol. Note that what exactly we do here depends on how encoder
+ # termination is implemented (see `Encoder.terminate`).
+ trailing_digits = (base - 1 for _ in range(precision - 1))
+
+ def _padded_input_fn() -> int:
+ """Reads digit from input padding the arithmetic code."""
+ digit = input_fn()
+ if digit is None:
+ digit = next(trailing_digits)
+ chex.assert_scalar_in(int(digit), 0, base - 1)
+ return digit
+
+ super().__init__(base, precision, _padded_input_fn)
+ for _ in range(precision):
+ self._code = self._code * base + _padded_input_fn()
+
+ def decode(self, pdf: np.ndarray) -> int:
+ return self._process(pdf, None)
\ No newline at end of file
diff --git a/compress.py b/compress.py
new file mode 100644
index 0000000..1a343c9
--- /dev/null
+++ b/compress.py
@@ -0,0 +1,270 @@
+# Copyright 2024 CISPA Helmholtz Center for Information Security Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# import sys
+# sys.path.append('../gradient-compressors')
+import argparse, os
+os.environ['HF_HOME'] = './cache'
+
+import functools
+import time
+import logging, yaml
+
+from collections.abc import Generator
+from typing import Callable
+
+import numpy as np
+
+import torch
+
+import tqdm
+import constants
+
+from utils import data_loaders
+from utils import deepmind_utils
+from utils.utils import BaseParser
+from compressors import compressor
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('-cfg', '--cfg', default=None, type=str, required=True)
+
+ parser.add_argument('-data-path', '--data-path', default='./', required=True, type=str)
+ parser.add_argument('-exhaustive-listing', '--exhaustive-listing', action='store_true', help='If read all of the checkpoints in the data path.')
+ parser.add_argument('-num-subsample', '--num-subsample', type=int, default=1)
+
+ parser.add_argument('-download', '--download', action='store_true')
+
+ parser.add_argument('-use_mask', '--use_mask', action='store_true', help='Applying mask functions, particularly for decoding images into ASCII.')
+ parser.add_argument('-use_slow_compression', '--use_slow_compression', action='store_true')
+
+ parser.add_argument('--dataset', '-dataset', default='tokenized_dataset', type=str, help='Indicatge what kind of data to compress.')
+ parser.add_argument('--compressor', '-compressor', default='gpt2', type=str, help='What kind of compressor to use.')
+
+ parser.add_argument('--bytes-per-group', '-bytes-per-group', default=None, type=int, help='Specify after how many bytes a separator will be added.')
+ parser.add_argument('-batch-size', '--batch-size', type=int, default=32)
+ parser.add_argument('-max-tokens', '--max-tokens', type=int, default=2048)
+ parser.add_argument('-preprocess', '--preprocess', type=str, default=None)
+
+ args = parser.parse_args()
+ with open(args.cfg, 'r') as stream:
+ settings = yaml.safe_load(stream)
+
+ args = BaseParser(args, settings)
+
+ # handle bytes_per_group
+ if args.bytes_per_group is None:
+ args.bytes_per_group = 4 if args.dataset == 'tokenized_data' or args.dataset == 'gradient' else 1
+
+ if args.exhaustive_listing:
+ paths = np.array([os.path.join(args.data_path, f) for f in os.listdir(args.data_path) if os.path.isfile(os.path.join(args.data_path, f))])
+ args.sample_index = np.random.choice(len(paths), args.num_subsample, replace=False)
+ print(args.sample_index)
+ args.data_path = paths[args.sample_index]
+
+ print(args)
+ print('chunck size: ', constants.CHUNK_SIZE_BYTES)
+ print('num_chuncks: ',constants.NUM_CHUNKS)
+ return args
+
+def evaluate_compressor_chunked(
+ args: BaseParser,
+ compress_fn: compressor.Compressor,
+ get_data_generator_fn: Callable[[], Generator[bytes, None, None]],
+ num_chunks: int,
+ count_header_only_once: bool = True,
+ mask_fn: Callable[[bytes], tuple[bytes, int]] | None = None,
+ use_tqdm: bool = True,
+) -> tuple[float, float]:
+ """Evaluates the compressor on the chunked dataset.
+
+ Args:
+ compress_fn: The function that evaluates data.
+ get_data_generator_fn: The function that creates a data generator.
+ num_chunks: The number of chunks to consider
+ count_header_only_once: Whether to count the header as part of the
+ compressed output only once for the whole dataset or for every chunk
+ individually.
+ mask_fn: The function that masks the data in case the compressor cannot
+ handle all possible byte values (e.g., language models can only process
+ ASCII-decodable data).
+ use_tqdm: Whether to use a progress bar or not.
+
+ Returns:
+ The compression rate and the total running time.
+ """
+ num_missed_bits = running_time = raw_length = compressed_length = num_samples = 0
+
+ raw_length, data_generator = get_data_generator_fn()
+ print(f'Data to compress has size {raw_length} bytes.')
+
+ if args.dataset == 'tokenized_dataset':
+ for num_bytes, data in tqdm.tqdm(data_generator):
+ num_samples += len(data)
+ if mask_fn is not None:
+ d, missed_bits = mask_fn(data)
+ num_missed_bits += missed_bits
+
+ t0 = time.perf_counter()
+ compressed_data = compress_fn(data, use_slow_lossless_compression=args.use_slow_compression)
+ t1 = time.perf_counter()
+
+ running_time += t1 - t0
+ compressed_length += len(compressed_data)
+ else:
+ for num_bytes, data in tqdm.tqdm(data_generator):
+ num_samples += len(data)
+ for d_size, d in zip(num_bytes, data):
+ if mask_fn is not None:
+ d, missed_bits = mask_fn(d)
+ num_missed_bits += missed_bits
+
+ if isinstance(d, torch.Tensor):
+ d = d.numpy()
+
+ t0 = time.perf_counter()
+ compressed_data = compress_fn(d)
+ t1 = time.perf_counter()
+
+ running_time += t1 - t0
+ compressed_length += len(compressed_data)
+
+ # raw_length += constants.CHUNK_SIZE_BYTES * num_chunks
+
+ # Since language models are trained on ASCII strings, they cannot handle all
+ # byte values. Thus, we mask the data to be ASCII-decodable by zeroing
+ # `num_missed_bits` of the most significant bits. However, this means that we
+ # are effectively only compressing `num_bits - num_missed_bits` bits, so we
+ # rescale the `compressed_length` to account for this.
+ if mask_fn is not None:
+ num_bits = 8 * num_samples * constants.CHUNK_SIZE_BYTES
+ compressed_length *= num_bits / (num_bits - num_missed_bits)
+
+ # We only count the header once for classical compressors.
+ # if count_header_only_once:
+ # header_length = len(compress_fn((0).to_bytes(1, 'little')))
+ # compressed_length -= header_length * (num_samples - 1)
+
+ return compressed_length / raw_length, running_time
+
+
+def evaluate_compressor_unchunked(
+ compress_fn: compressor.Compressor,
+ get_data_generator_fn: Callable[[], Generator[bytes, None, None]],
+ num_chunks: int,
+) -> tuple[float, float]:
+ """Evaluates the compressor on the unchunked dataset.
+
+ Args:
+ compress_fn: The function that compresses data.
+ get_data_generator_fn: The function that creates a data generator.
+ num_chunks: The number of chunks to consider.
+
+ Returns:
+ The compression rate and the total running time.
+ """
+ all_data = None
+
+ raw_length, data_generator = get_data_generator_fn()
+ print(f'Data to compress has size {raw_length} bytes.')
+
+ with tqdm.tqdm(total=constants.NUM_CHUNKS) as pbar:
+ for num_bytes, data in data_generator:
+ for d in data:
+ if all_data is None:
+ all_data = bytearray() if isinstance(d, bytes) else []
+
+ if isinstance(d, bytes):
+ all_data += d
+ else:
+ d = d.numpy()
+ all_data.append(d)
+
+ pbar.update(1)
+
+ all_data = bytes(all_data) if isinstance(all_data[0], bytes) else np.concatenate(all_data)
+
+ t0 = time.perf_counter()
+ compressed_data = compress_fn(all_data)
+ t1 = time.perf_counter()
+
+ return len(compressed_data) / raw_length, t1 - t0
+
+
+def main(args) -> None:
+ print('start')
+ compress_fn = compressor.COMPRESS_FN_DICT[args.compressor]
+
+ if args.compressor in compressor.COMPRESSOR_TYPES['classical']:
+ get_data_generator_fn = functools.partial(
+ data_loaders.GET_DATA_GENERATOR_FN_DICT[args.dataset],
+ args=args,
+ serialization=False,
+ preprocess=args.preprocess,
+ return_fp=False if args.compressor != 'fpzip' else True
+ )
+ unchunked_rate, unchunked_time = evaluate_compressor_unchunked(
+ compress_fn=compress_fn,
+ get_data_generator_fn=get_data_generator_fn,
+ num_chunks=constants.NUM_CHUNKS,
+ )
+ chunked_rate, chunked_time = evaluate_compressor_chunked(
+ args=args,
+ compress_fn=compress_fn,
+ get_data_generator_fn=get_data_generator_fn,
+ num_chunks=constants.NUM_CHUNKS,
+ count_header_only_once=True,
+ mask_fn=None,
+ )
+ print(
+ f'Unchunked: {100 * unchunked_rate:.2f} [{unchunked_time:.1f}s]'
+ )
+ print(f'Chunked: {100 * chunked_rate:.2f} [{chunked_time:.1f}s]')
+
+ elif args.compressor in compressor.COMPRESSOR_TYPES['arithmetic_coding']:
+ get_data_generator_fn = functools.partial(
+ data_loaders.GET_DATA_GENERATOR_FN_DICT[args.dataset],
+ args=args,
+ serialization=True,
+ max_tokens=args.max_tokens,
+ preprocess=args.preprocess,
+ )
+ model = compress_fn(args)
+ if args.use_mask:
+ # To compress bytes data, we convert it first to ASCII.
+ if args.dataset == 'enwik9':
+ # For Enwik9, some characters are UTF-8 but not ASCII, so we still need
+ # to do the conversion.
+ mask_fn = deepmind_utils.zero_most_significant_bit_if_not_ascii_decodable
+ else:
+ mask_fn = deepmind_utils.right_shift_bytes_by_one
+ else:
+ mask_fn = None
+
+ chunked_rate, chunked_time = evaluate_compressor_chunked(
+ args = args,
+ compress_fn=model.compress,
+ get_data_generator_fn=get_data_generator_fn,
+ num_chunks=constants.NUM_CHUNKS,
+ count_header_only_once=False,
+ mask_fn=mask_fn,
+ )
+ print(f'Chunked: {100 * chunked_rate:.2f} [{chunked_time:.1f}s]')
+ else:
+ raise NotImplementedError(f'Unknown codec {args.compressor}.')
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
\ No newline at end of file
diff --git a/compressors/compressor.py b/compressors/compressor.py
new file mode 100644
index 0000000..f899c0d
--- /dev/null
+++ b/compressors/compressor.py
@@ -0,0 +1,54 @@
+# Copyright 2024 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Defines the compressor interface."""
+import os
+
+import functools
+import gzip
+import lzma
+import fpzip
+from typing import Mapping, Protocol
+
+from compressors import flac
+from compressors import language_model
+from compressors import png
+
+
+class Compressor(Protocol):
+
+ def __call__(self, data: bytes, *args, **kwargs) -> bytes | tuple[bytes, int]:
+ """Returns the compressed version of `data`, with optional padded bits."""
+
+
+COMPRESSOR_TYPES = {
+ 'classical': ['flac', 'gzip', 'lzma', 'png', 'fpzip'],
+ 'arithmetic_coding': ['gpt2', 'openllama3b', 'tinyllama3b', 'tinyllama', 'mistral7b', 'llama2-7b', 'llama3-8b'],
+}
+
+COMPRESS_FN_DICT: Mapping[str, Compressor] = {
+ 'flac': flac.compress,
+ 'gzip': functools.partial(gzip.compress, compresslevel=9),
+ 'gpt2': language_model.LanguageModelCompressor,
+ 'openllama3b': language_model.LanguageModelCompressor,
+ 'tinyllama3b': language_model.LanguageModelCompressor,
+ 'tinyllama': language_model.LanguageModelCompressor,
+ 'mistral7b': language_model.LanguageModelCompressor,
+ 'llama2-7b': language_model.LanguageModelCompressor,
+ 'llama3-8b': language_model.LanguageModelCompressor,
+ 'lzma': lzma.compress,
+ 'png': png.compress,
+ 'fpzip': functools.partial(fpzip.compress, precision=0, order='C'),
+}
\ No newline at end of file
diff --git a/compressors/flac.py b/compressors/flac.py
new file mode 100644
index 0000000..e79d2dd
--- /dev/null
+++ b/compressors/flac.py
@@ -0,0 +1,56 @@
+# Copyright 2024 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Implements a lossless compressor with FLAC."""
+
+import audioop
+import io
+
+import pydub
+
+def compress(data: bytes) -> bytes:
+ """Returns data compressed with the FLAC codec.
+
+ Args:
+ data: Assumes 1 byte per sample (`sample_width`), meaning 256 possible
+ values, and 1 channel and a `frame_rate` of 16kHz.
+ """
+ sample = pydub.AudioSegment(
+ data=data,
+ channels=1,
+ sample_width=1,
+ frame_rate=16000,
+ )
+ return sample.export(
+ format='flac',
+ parameters=['-compression_level', '12'],
+ ).read()
+
+
+def decompress(data: bytes) -> bytes:
+ """Decompresses `data` losslessly using the FLAC codec.
+
+ Args:
+ data: The data to be decompressed. Assumes 2 bytes per sample (16 bit).
+
+ Returns:
+ The decompressed data. Assumes 1 byte per sample (8 bit).
+ """
+ sample = pydub.AudioSegment.from_file(io.BytesIO(data), format='flac')
+ # FLAC assumes that data is 16 bit. However, since our original data is 8 bit,
+ # we need to convert the samples from 16 bit to 8 bit (i.e., changing from two
+ # channels to one channel with `lin2lin`) and add 128 since 16 bit is signed
+ # (i.e., adding 128 using `bias`).
+ return audioop.bias(audioop.lin2lin(sample.raw_data, 2, 1), 1, 128)
\ No newline at end of file
diff --git a/compressors/language_model.py b/compressors/language_model.py
new file mode 100644
index 0000000..b3db48c
--- /dev/null
+++ b/compressors/language_model.py
@@ -0,0 +1,213 @@
+# Copyright 2024 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Implements a lossless compressor with language models (arithmetic coding)."""
+
+from collections.abc import Iterator
+import functools, time
+from typing import Callable
+
+import torch
+torch.backends.cuda.matmul.allow_tf32 = True
+import torchac
+from transformers import AutoTokenizer, AutoModelForCausalLM
+import numpy as np
+
+import arithmetic_coder
+import constants
+from utils.utils import BaseParser
+from utils import deepmind_utils
+
+from multiprocessing import Pool
+
+MODEL_NAME_DICT = {
+ "gpt2": "gpt2",
+ "openllama3b": "openlm-research/open_llama_3b_v2",
+ "tinyllama3b": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
+ "tinyllama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
+ "mistral7b": "mistralai/Mistral-7B-v0.1",
+ "llama2-7b": "meta-llama/Llama-2-7b-hf",
+ "llama3-8b": "meta-llama/Meta-Llama-3-8B",
+}
+SKIP_LIST = ["openllama3b", "tinyllama3b", "tinyllama", "llama2-7b", "llama3-8b"]
+
+def work(args):
+ pid, cum_prob, input_ids = args
+ byte_stream = torchac.encode_float_cdf(cum_prob, input_ids, check_input_bounds=False)
+ return len(byte_stream)
+
+class LanguageModelCompressor():
+ def __init__(self, args: BaseParser):
+ self.args = args
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ print(self.device)
+
+ access_token = None
+ assert access_token, 'Please enter the huggingface access token here if you want to use models like LLAMA.'
+
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_DICT[args.compressor], token=access_token)
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ self.model = AutoModelForCausalLM.from_pretrained(
+ MODEL_NAME_DICT[args.compressor],
+ torch_dtype=torch.bfloat16,
+ # device_map='auto',
+ attn_implementation="flash_attention_2",
+ token=access_token,
+ )
+ self.model.config.pad_token_id = self.model.config.eos_token_id
+ self.model.to(self.device)
+ self.model.eval()
+
+ @torch.no_grad()
+ def infer_prob(
+ self,
+ data: bytes,
+ return_num_padded_bits: bool = False,
+ use_slow_lossless_compression: bool = False,
+ ) -> torch.Tensor:
+ if self.args.dataset == 'tokenized_dataset':
+ input_ids = data.to(self.device)#.view(1, -1)
+ tokenized_data = {
+ 'input_ids': input_ids,
+ 'attention_mask': torch.ones(data.size()),
+ }
+ else:
+ texified_data = data
+ if not self.args.compressor in SKIP_LIST:
+ texified_data = self.tokenizer.bos_token + texified_data
+
+ print('length of texified data after adding eos:', len(texified_data))
+
+ tokenized_data = self.tokenizer(texified_data, padding=False, return_tensors="pt").to(self.device)
+ print(tokenized_data.keys())
+ input_ids = tokenized_data.input_ids
+
+ print('length of tokens', input_ids.shape)
+
+ if use_slow_lossless_compression:
+ # Compress the token at the position idx+1 using tokens_{t <= idx}
+ """Remark!
+ Due to the stochastic operations in huggingface and precision requirements by arithmetic coding,
+ one may use this option to compute the probability; however, in theory, they should not affect the result.
+
+ TO-DO: provide a runnable example here.
+ """
+ for idx in range(len(tokenized_data.input_ids[0])-1):
+ input_to_model = {
+ 'input_ids': tokenized_data.input_ids[:, :idx+1].view(1, -1),
+ 'attention_mask': tokenized_data.attention_mask[:, :idx+1].view(1, -1)
+ }
+ outputs = self.model(**input_to_model)
+
+ logits = outputs.logits[:, idx, :] # next-word prob. of size # num_sentences x
+ pdf = torch.softmax(logits, dim=-1).squeeze().detach().cpu().numpy()
+ symbol = tokenized_data.input_ids[:, idx+1].item()
+ else:
+ # pay attention to the output length
+ outputs = self.model(**tokenized_data)
+
+ # collect the probability of the generated token -- probability at index 0 corresponds to the token at index 1
+ probs = outputs.logits[:, :-1, :] # num_sentences x next-word prob. of size (excl. the last word prediction)
+ input_ids = input_ids[:, 1:] # number_sentences (1?) x max_legnth_chars-1 (excl. bos_token)
+
+ probs = torch.log_softmax(probs, dim=-1).detach()
+ probs = probs.exp()
+
+ assert (input_ids < 65536).all()
+ input_ids = input_ids.squeeze().detach().cpu().short()
+ probs = probs.squeeze().detach().cpu()
+
+ return input_ids, probs
+
+ @torch.no_grad()
+ def compress(
+ self,
+ data: bytes,
+ return_num_padded_bits: bool = False,
+ use_slow_lossless_compression: bool = False,
+ ) -> bytes | tuple[bytes, int]:
+ input_ids, probs = self.infer_prob(data, return_num_padded_bits=return_num_padded_bits, use_slow_lossless_compression=use_slow_lossless_compression)
+
+ cum_prob = probs / torch.cumsum(probs, dim=-1)
+ cum_prob = torch.cat([torch.zeros(cum_prob.size()[:-1]).unsqueeze(-1), cum_prob], dim=-1)
+
+ byte_stream = torchac.encode_float_cdf(cum_prob, input_ids, check_input_bounds=False, needs_normalization=False)
+
+ return byte_stream
+
+ @torch.no_grad()
+ def decompress(
+ self,
+ data: bytes,
+ num_padded_bits: int = 0,
+ uncompressed_length: int = constants.CHUNK_SIZE_BYTES,
+ ) -> bytes:
+ """Decompresses the `data` using arithmetic coding and a pretrained model.
+
+ See https://en.wikipedia.org/wiki/Arithmetic_coding for details.
+
+ Args:
+ data: The data to be decompressed.
+ num_padded_bits: The number of zeros added to the encoded bitstream in order
+ to make it byte-decodeable (i.e., divisble by 8).
+ uncompressed_length: The length of the original data stream (in bytes).
+
+ Returns:
+ The decompressed data.
+ """
+ data_iter = iter(deepmind_utils.bytes_to_bits(data, num_padded_bits=num_padded_bits))
+ # The decoder requires a function that reads digits from {0, 1, ..., base - 1}
+ # from the compressed input and returns `None` when the input is exhausted.
+ def _input_fn(bit_sequence: Iterator[str] = data_iter) -> int | None:
+ try:
+ return int(next(bit_sequence))
+ except StopIteration:
+ return None
+
+ decoder = arithmetic_coder.Decoder(
+ base=constants.ARITHMETIC_CODER_BASE,
+ precision=constants.ARITHMETIC_CODER_PRECISION,
+ input_fn=_input_fn,
+ )
+ # We need a dummy token because the language model right-shifts the sequence
+ # by one when computing the conditional probabilities. Concretely, at every
+ # step, we need the `pdf` of the next token given all currently decompressed
+ # tokens, but without a dummy token, the last `pdf` would be that of the last
+ # already decompressed token. The value of the dummy token is irrelevant.
+ input_ids = torch.empty((1, uncompressed_length), dtype=torch.long, device=self.device)
+ input_ids[:, :] = self.tokenizer.pad_token_id
+ attention_mask = torch.ones(input_ids.shape).to(self.device)
+
+ # In our current implementation, we always begin with a token.
+ input_ids[0, 0] = self.tokenizer.bos_token_id
+
+ for idx in range(0, uncompressed_length-1):
+ input_to_model = {
+ 'input_ids': input_ids[:, :idx+1].view(1, -1),
+ 'attention_mask': attention_mask[:, :idx+1].view(1, -1)
+ }
+ outputs = self.model(**input_to_model)
+
+ pdf = outputs.logits[:, idx, :] # next-word prob. of size # num_sentences x
+ pdf = torch.softmax(pdf, dim=-1).squeeze().detach().cpu().numpy()
+ token = decoder.decode(
+ deepmind_utils.normalize_pdf_for_arithmetic_coding(pdf)
+ )
+ input_ids[0, idx+1] = token
+
+ input_ids = input_ids.cpu().detach().numpy()
+ # Remove the dummy token and convert to bytes.
+ return input_ids[0, :uncompressed_length]
\ No newline at end of file
diff --git a/compressors/png.py b/compressors/png.py
new file mode 100644
index 0000000..666b303
--- /dev/null
+++ b/compressors/png.py
@@ -0,0 +1,72 @@
+# Copyright 2024 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Implements a lossless compressor with PNG."""
+
+import io
+import math
+
+from PIL import Image
+
+
+def _get_the_two_closest_factors(n: int) -> tuple[int, int]:
+ """Returns the 2 closest factors (square root if `n` is a perfect square)."""
+ a = round(math.sqrt(n))
+ while n % a > 0:
+ a -= 1
+ return a, n // a
+
+
+def compress(data: bytes) -> bytes:
+ """Compresses `data` losslessly using the PNG format.
+
+ The data, which is a sequence of bytes, is reshaped into a
+ as-close-to-square-as-possible image before compression with 8-bit pixels
+ (grayscale).
+
+ Args:
+ data: The data to be compressed.
+
+ Returns:
+ The compressed data.
+ """
+ # Compute the height and width of the image.
+ size = _get_the_two_closest_factors(len(data))
+
+ # Load the image using 8-bit grayscale pixels.
+ image = Image.frombytes(mode='L', size=size, data=data)
+
+ with io.BytesIO() as buffer:
+ image.save(
+ buffer,
+ format='PNG',
+ optimize=True,
+ )
+ return buffer.getvalue()
+
+
+def decompress(data: bytes) -> bytes:
+ """Decompresses `data` losslessly using the PNG format.
+
+ To apply the PNG format, the `data` is treated as the compressed sequence of
+ bytes from an image consisting of 8-bit pixels (grayscale).
+
+ Args:
+ data: The data to be decompressed.
+
+ Returns:
+ The decompressed data.
+ """
+ return Image.open(io.BytesIO(data)).tobytes()
\ No newline at end of file
diff --git a/constants.py b/constants.py
new file mode 100644
index 0000000..d6a4f47
--- /dev/null
+++ b/constants.py
@@ -0,0 +1,9 @@
+NUM_CHUNKS = 10000000
+CHUNK_SIZE_BYTES = 512 # always read this amount of data, regardless of data type
+CHUNK_SHAPE_2D = (16, 32)
+UINT8_LOADING = True # only useful for images
+
+# Base 2 means that the coder writes bits.
+ARITHMETIC_CODER_BASE = 2
+# Precision 32 implies 32 bit arithmetic.
+ARITHMETIC_CODER_PRECISION = 32
\ No newline at end of file
diff --git a/images/lm-gc.png b/images/lm-gc.png
new file mode 100644
index 0000000..0f1f911
Binary files /dev/null and b/images/lm-gc.png differ
diff --git a/images/teaser.png b/images/teaser.png
new file mode 100644
index 0000000..2e26614
Binary files /dev/null and b/images/teaser.png differ
diff --git a/scripts/collect_gradients.sh b/scripts/collect_gradients.sh
new file mode 100644
index 0000000..6c8d759
--- /dev/null
+++ b/scripts/collect_gradients.sh
@@ -0,0 +1,8 @@
+cd ..
+
+DATASET='tinyimagenet' # cifar10 # mnist
+ARCH="convnet" # vgg16 # resnet18 # vit
+for i in 0 1 2
+do
+ python -u train_and_collect_grad.py -cfg settings/gradient_collection/$DATASET-$ARCH.yaml --tag $i --grad-interval 400 --download
+done
\ No newline at end of file
diff --git a/scripts/compress.sh b/scripts/compress.sh
new file mode 100644
index 0000000..3a1263c
--- /dev/null
+++ b/scripts/compress.sh
@@ -0,0 +1,16 @@
+cd ..
+
+HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
+NUM_SUBSAMPLE=10
+DATASET='tinyimagenet' # cifar10 # mnist
+ARCH="convnet" # vgg16 # resnet18 # vit
+TYPE="grad"
+COMPRESSOR="tinyllama" # llama2-7b # openllama3b
+SEP="hex-none" # hex-space # hex-comma+space # iso # hex-semicolon
+BATCHSIZE=4 # depending on your GPUs
+BPG=4 # 8
+for i in 1 2 3
+do
+ python -u compress.py -cfg settings/compression/cifar10-$SEP.yaml --compressor $COMPRESSOR --dataset tokenized_dataset \
+ --data-path ./tokenized_datasets/$ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i.pkl --batch-size $BATCHSIZE
+done
\ No newline at end of file
diff --git a/scripts/pipeline.sh b/scripts/pipeline.sh
new file mode 100644
index 0000000..5c718c3
--- /dev/null
+++ b/scripts/pipeline.sh
@@ -0,0 +1,3 @@
+bash collect_gradients.sh
+bash serialization.sh
+bash compress.sh
\ No newline at end of file
diff --git a/scripts/serialization.sh b/scripts/serialization.sh
new file mode 100644
index 0000000..8e0762f
--- /dev/null
+++ b/scripts/serialization.sh
@@ -0,0 +1,14 @@
+cd ..
+
+NUM_SUBSAMPLE=10
+DATASET='tinyimagenet' # cifar10 # mnist
+ARCH="convnet" # vgg16 # resnet18 # vit
+TYPE="grad"
+COMPRESSOR="tinyllama" # llama2-7b # openllama3b
+SEP="hex-none" # hex-space # hex-comma+space # iso # hex-semicolon
+BPG=4 # 8
+for i in 1 2 3
+do
+ python -u tokenize_dataset.py --cfg settings/compression/cifar10-$SEP.yaml --data-path exps/$DATASET-$ARCH/0/grads/ --bytes-per-group $BPG \
+ --compressor $COMPRESSOR --exhaustive-listing --num-subsample $NUM_SUBSAMPLE --output-name $ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i
+done
\ No newline at end of file
diff --git a/scripts/server_collect_gradients.sh b/scripts/server_collect_gradients.sh
new file mode 100644
index 0000000..4a89612
--- /dev/null
+++ b/scripts/server_collect_gradients.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+#SBATCH --nodes=1 # How many nodes?
+#SBATCH -A hai_lmgc # Who pays for it?
+#SBATCH --partition develbooster
+#SBATCH --ntasks=1
+#SBATCH --gres=gpu:4
+#SBATCH --time=2:00:00
+#SBATCH -o output-%x.txt
+#SBATCH -e error-%x.txt
+# Where does the code run?
+# Required for legacy reasons
+# How long?
+
+source /p/home/jusers/wang34/juwels/hai_fedak/huipo/general/activate.sh # path to the environment
+cd /p/home/jusers/wang34/juwels/hai_fedak/huipo/lmgc-to-release # path to the upper folder
+LOG_DIR="scripts/test_collect"
+
+if [ ! -d "$LOG_DIR" ]; then
+ echo "$LOG_DIR does not exist."
+ mkdir "$LOG_DIR"
+else
+ echo "$LOG_DIR does exist."
+fi
+
+DATASET='tinyimagenet'
+ARCH="convnet"
+INDEX=1
+for i in 0 1 2
+do
+ srun --exclusive --gres=gpu:1 --cpus-per-task=6 --ntasks=1 -o "$LOG_DIR/out-$INDEX-$i.txt" -e "$LOG_DIR/err-$INDEX-$i.txt" \
+ python -u train_and_collect_grad.py -cfg settings/gradient_collection/$DATASET-$ARCH.yaml --tag $i --grad-interval 400 --download &
+done
+
+wait
\ No newline at end of file
diff --git a/scripts/server_compress.sh b/scripts/server_compress.sh
new file mode 100644
index 0000000..794af76
--- /dev/null
+++ b/scripts/server_compress.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+#SBATCH --nodes=1 # How many nodes?
+#SBATCH -A hai_lmgc # Who pays for it?
+#SBATCH --partition booster
+#SBATCH --ntasks=1
+#SBATCH --gres=gpu:4
+#SBATCH --time=4:00:00
+#SBATCH -o output-%x.txt
+#SBATCH -e error-%x.txt
+# Where does the code run?
+# Required for legacy reasons
+# How long?
+
+source /p/home/jusers/wang34/juwels/hai_fedak/huipo/general/activate.sh # path to the environment
+cd /p/home/jusers/wang34/juwels/hai_fedak/huipo/lmgc-to-release # path to the upper folder
+LOG_DIR="scripts/test_compress"
+
+if [ ! -d "$LOG_DIR" ]; then
+ echo "$LOG_DIR does not exist."
+ mkdir "$LOG_DIR"
+else
+ echo "$LOG_DIR does exist."
+fi
+
+
+HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
+NUM_SUBSAMPLE=10
+DATASET='tinyimagenet' # cifar10 # mnist
+ARCH="convnet" # vgg16 # resnet18 # vit
+TYPE="grad"
+COMPRESSOR="tinyllama" # llama2-7b # openllama3b
+SEP="hex-space" # hex-space # hex-comma+space # iso # hex-semicolon
+BATCHSIZE=4 # number of windows processed everytime -- depending on your GPUs
+BPG=4 # 8
+INDEX=2
+for i in 1 2 3
+do
+ srun --exclusive --gres=gpu:1 --cpus-per-task=6 --ntasks=1 -o "$LOG_DIR/out-$INDEX-$i.txt" -e "$LOG_DIR/err-$INDEX-$i.txt" \
+ python -u compress.py -cfg settings/compression/cifar10-$SEP.yaml --compressor $COMPRESSOR --dataset tokenized_dataset \
+ --data-path /p/home/jusers/wang34/juwels/hai_fedak/huipo/gradient_compressors/tokenized_datasets/$ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i.pkl --batch-size $BATCHSIZE &
+done
+
+wait
\ No newline at end of file
diff --git a/scripts/server_serialization.sh b/scripts/server_serialization.sh
new file mode 100644
index 0000000..a43e2bd
--- /dev/null
+++ b/scripts/server_serialization.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+#SBATCH --nodes=1 # How many nodes?
+#SBATCH -A hai_lmgc # Who pays for it?
+#SBATCH --partition develbooster
+#SBATCH --ntasks=1
+#SBATCH --gres=gpu:4
+#SBATCH --time=2:00:00
+#SBATCH -o output-%x.txt
+#SBATCH -e error-%x.txt
+# Where does the code run?
+# Required for legacy reasons
+# How long?
+
+source /p/home/jusers/wang34/juwels/hai_fedak/huipo/general/activate.sh # path to the environment
+cd /p/home/jusers/wang34/juwels/hai_fedak/huipo/lmgc-to-release # path to the upper folder
+LOG_DIR="scripts/test_serialize"
+
+if [ ! -d "$LOG_DIR" ]; then
+ echo "$LOG_DIR does not exist."
+ mkdir "$LOG_DIR"
+else
+ echo "$LOG_DIR does exist."
+fi
+
+NUM_SUBSAMPLE=10
+DATASET='tinyimagenet' # cifar10 # mnist
+ARCH="convnet" # vgg16 # resnet18 # vit
+TYPE="grad"
+COMPRESSOR="tinyllama" # llama2-7b # openllama3b
+SEP="hex-space" # hex-space # hex-comma+space # iso # hex-semicolon
+BPG=4 # 8
+INDEX=2
+for i in 1 2 3
+do
+ srun --exclusive --gres=gpu:1 --cpus-per-task=6 --ntasks=1 -o "$LOG_DIR/out-$INDEX-$i.txt" -e "$LOG_DIR/err-$INDEX-$i.txt" \
+ python -u tokenize_dataset.py --cfg settings/compression/cifar10-$SEP.yaml --data-path exps/$DATASET-$ARCH/0/grads/ --bytes-per-group $BPG \
+ --compressor $COMPRESSOR --exhaustive-listing --num-subsample $NUM_SUBSAMPLE --output-name $ARCH-$DATASET-$COMPRESSOR-$SEP-$NUM_SUBSAMPLE-$TYPE-$BPG-$i &
+done
+
+wait
\ No newline at end of file
diff --git a/settings/compression/cifar10-hex-0x.yaml b/settings/compression/cifar10-hex-0x.yaml
new file mode 100644
index 0000000..7bee1c8
--- /dev/null
+++ b/settings/compression/cifar10-hex-0x.yaml
@@ -0,0 +1,2 @@
+dummy: True
+decoding: hex-0x
\ No newline at end of file
diff --git a/settings/compression/cifar10-hex-comma+space.yaml b/settings/compression/cifar10-hex-comma+space.yaml
new file mode 100644
index 0000000..b247421
--- /dev/null
+++ b/settings/compression/cifar10-hex-comma+space.yaml
@@ -0,0 +1,2 @@
+dummy: True
+decoding: hex-comma+space
\ No newline at end of file
diff --git a/settings/compression/cifar10-hex-comma.yaml b/settings/compression/cifar10-hex-comma.yaml
new file mode 100644
index 0000000..5236349
--- /dev/null
+++ b/settings/compression/cifar10-hex-comma.yaml
@@ -0,0 +1,2 @@
+dummy: True
+decoding: hex-comma
\ No newline at end of file
diff --git a/settings/compression/cifar10-hex-none.yaml b/settings/compression/cifar10-hex-none.yaml
new file mode 100644
index 0000000..170b3e2
--- /dev/null
+++ b/settings/compression/cifar10-hex-none.yaml
@@ -0,0 +1,2 @@
+dummy: True
+decoding: hex-none
\ No newline at end of file
diff --git a/settings/compression/cifar10-hex-semicolon.yaml b/settings/compression/cifar10-hex-semicolon.yaml
new file mode 100644
index 0000000..d309b6b
--- /dev/null
+++ b/settings/compression/cifar10-hex-semicolon.yaml
@@ -0,0 +1,2 @@
+dummy: True
+decoding: hex-semicolon
\ No newline at end of file
diff --git a/settings/compression/cifar10-hex-space.yaml b/settings/compression/cifar10-hex-space.yaml
new file mode 100644
index 0000000..5d563cb
--- /dev/null
+++ b/settings/compression/cifar10-hex-space.yaml
@@ -0,0 +1,2 @@
+dummy: True
+decoding: hex-space
\ No newline at end of file
diff --git a/settings/compression/cifar10-iso-comma+space.yaml b/settings/compression/cifar10-iso-comma+space.yaml
new file mode 100644
index 0000000..58e04ba
--- /dev/null
+++ b/settings/compression/cifar10-iso-comma+space.yaml
@@ -0,0 +1,2 @@
+dummy: True
+decoding: iso-comma+space
\ No newline at end of file
diff --git a/settings/compression/cifar10-iso-space.yaml b/settings/compression/cifar10-iso-space.yaml
new file mode 100644
index 0000000..9a64274
--- /dev/null
+++ b/settings/compression/cifar10-iso-space.yaml
@@ -0,0 +1,2 @@
+dummy: True
+decoding: iso-space
\ No newline at end of file
diff --git a/settings/compression/cifar10-iso.yaml b/settings/compression/cifar10-iso.yaml
new file mode 100644
index 0000000..185a07b
--- /dev/null
+++ b/settings/compression/cifar10-iso.yaml
@@ -0,0 +1,2 @@
+dummy: True
+decoding: iso
\ No newline at end of file
diff --git a/settings/gradient_collection/cifar10-VGG16.yaml b/settings/gradient_collection/cifar10-VGG16.yaml
new file mode 100644
index 0000000..7d616bf
--- /dev/null
+++ b/settings/gradient_collection/cifar10-VGG16.yaml
@@ -0,0 +1,2 @@
+dataset: cifar10
+arch: VGG16
\ No newline at end of file
diff --git a/settings/gradient_collection/cifar10-convnet.yaml b/settings/gradient_collection/cifar10-convnet.yaml
new file mode 100644
index 0000000..b1dfc44
--- /dev/null
+++ b/settings/gradient_collection/cifar10-convnet.yaml
@@ -0,0 +1,2 @@
+dataset: cifar10
+arch: ConvNet
\ No newline at end of file
diff --git a/settings/gradient_collection/cifar10-resnet18.yaml b/settings/gradient_collection/cifar10-resnet18.yaml
new file mode 100644
index 0000000..4d1c720
--- /dev/null
+++ b/settings/gradient_collection/cifar10-resnet18.yaml
@@ -0,0 +1,2 @@
+dataset: cifar10
+arch: resnet18
\ No newline at end of file
diff --git a/settings/gradient_collection/cifar10-vit.yaml b/settings/gradient_collection/cifar10-vit.yaml
new file mode 100644
index 0000000..ac16c33
--- /dev/null
+++ b/settings/gradient_collection/cifar10-vit.yaml
@@ -0,0 +1,2 @@
+dataset: cifar10
+arch: ViT
\ No newline at end of file
diff --git a/settings/gradient_collection/mnist-convnet.yaml b/settings/gradient_collection/mnist-convnet.yaml
new file mode 100644
index 0000000..8eea9a6
--- /dev/null
+++ b/settings/gradient_collection/mnist-convnet.yaml
@@ -0,0 +1,2 @@
+dataset: mnist
+arch: ConvNet
\ No newline at end of file
diff --git a/settings/gradient_collection/tinyimagenet-convnet.yaml b/settings/gradient_collection/tinyimagenet-convnet.yaml
new file mode 100644
index 0000000..6d3b0b5
--- /dev/null
+++ b/settings/gradient_collection/tinyimagenet-convnet.yaml
@@ -0,0 +1,2 @@
+dataset: tinyimagenet
+arch: ConvNet
\ No newline at end of file
diff --git a/tokenize_dataset.py b/tokenize_dataset.py
new file mode 100644
index 0000000..1b01ec4
--- /dev/null
+++ b/tokenize_dataset.py
@@ -0,0 +1,126 @@
+import argparse, yaml
+import functools, pickle
+import os
+os.environ['HF_HOME'] = './cache/'
+
+import constants
+import numpy as np
+
+from transformers import AutoTokenizer
+
+from tqdm import tqdm
+from utils import data_loaders
+from utils.utils import BaseParser
+from compressors import language_model
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('-cfg', '--cfg', default=None, type=str, required=True)
+
+ parser.add_argument('-data-path', '--data-path', default='./', required=True, type=str)
+ parser.add_argument('-exhaustive-listing', '--exhaustive-listing', action='store_true', help='If read all of the checkpoints in the data path.')
+ parser.add_argument('-num-subsample', '--num-subsample', required=True, type=int)
+
+ parser.add_argument('-download', '--download', action='store_true')
+
+ parser.add_argument('-use_mask', '--use_mask', action='store_true', help='Applying mask functions, particularly for decoding images into ASCII.')
+
+ parser.add_argument('--dataset', '-dataset', default='gradient', type=str, help='Indicatge what kind of data to compress.')
+
+ parser.add_argument('--output-name', '-output-name', required=True, help='Name of the output pre-tokenized dataset.')
+
+ parser.add_argument('--compressor', '-compressor', default='tinyllama3b', type=str, help='What kind of compressor to use.')
+ parser.add_argument('--verbose', '-verbose', action='store_true', help='Print first few tokens for debugging.')
+
+ parser.add_argument('--bytes-per-group', '-bytes-per-group', default=None, type=int, help='Specify after how many bytes a separator will be added.')
+ parser.add_argument('-batch-size', '--batch-size', type=int, default=32)
+ parser.add_argument('-preprocess', '--preprocess', type=str, default=None)
+
+ parser.add_argument('-noise-level', '--noise-level', default=None, type=float, help='Standard deviation of gaussian noise to add. Useful only when preprocess "gaussian" is enabled.')
+ parser.add_argument('-clipping-bound', '--clipping-bound', default=None, type=float, help='Clipping bound. Useful only when preprocess "clipping" is enabled.')
+
+ args = parser.parse_args()
+ with open(args.cfg, 'r') as stream:
+ settings = yaml.safe_load(stream)
+
+ args = BaseParser(args, settings)
+
+ args.output_dir = './tokenized_datasets'
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+ # arch-cifar-lm-serializeation-step-[model|grad]-tag.pkl
+ args.output_name = os.path.join(args.output_dir, args.output_name+'.pkl')
+
+ # handle bytes_per_group
+ if args.bytes_per_group is None:
+ args.bytes_per_group = 4 if args.dataset == 'tokenized_data' or args.dataset == 'gradient' else 1
+
+ print(args)
+ print('chunck size: ', constants.CHUNK_SIZE_BYTES)
+ print('num_chuncks: ',constants.NUM_CHUNKS)
+
+ if args.exhaustive_listing:
+ paths = np.array([os.path.join(args.data_path, f) for f in os.listdir(args.data_path) if os.path.isfile(os.path.join(args.data_path, f))])
+ args.sample_index = np.random.choice(len(paths), args.num_subsample, replace=False)
+ args.data_path = paths[args.sample_index]
+
+ # print(args.data_path)
+
+ return args
+
+def main(args):
+ _, data_generator = data_loaders.GET_DATA_GENERATOR_FN_DICT[args.dataset](args=args, serialization=True, preprocess=args.preprocess)
+ tokenizer = AutoTokenizer.from_pretrained(language_model.MODEL_NAME_DICT[args.compressor])
+ tokenizer.pad_token = tokenizer.eos_token
+
+ skip_tokens = {
+ 'tinyllama3b': [1, 29871],
+ 'tinyllama': [1, 29871],
+ 'openllama3b': [1, 29871],
+ 'gpt2': [],
+ 'llama3-8b': [128000],
+ 'llama2-7b': [1, 29871],
+ }
+ # tokenize -> remove strat token if using LLAMA; sometimes tokenizer also add a '' token with id 29871 to the data
+ concat_tokens = []
+
+ assert args.compressor in skip_tokens.keys()
+
+ print(f'Ready to tokenize {len(data_generator)} * 32 samples.')
+ print(f'Each sample consists of {constants.CHUNK_SIZE_BYTES} bytes.')
+ print(f'The preprocess will be hanlded by {args.compressor}\'s tokenizer.')
+
+ num_sample = 0
+ for num_bytes, data in tqdm(data_generator):
+ num_sample += len(data)
+ for d_size, d in zip(num_bytes, data):
+ tokenized_data = tokenizer(d, padding=False, return_tensors="pt")#.view(-1) # size: num_tokens
+ tokenized_data = tokenized_data.input_ids.view(-1)
+ if args.verbose:
+ print(tokenized_data[:10])
+ print(f'\t{d[:20]}')
+ print(f'\t{[tokenizer.decode(t) for t in tokenized_data[:10]]}')
+
+ # while loop pop bos and additional space tokens
+ started_index = 0
+ while tokenized_data[started_index] in skip_tokens[args.compressor]:
+ started_index += 1
+
+ tokenized_data = tokenized_data[started_index:]
+ concat_tokens.append(tokenized_data.cpu().detach().numpy())
+
+ if args.verbose:
+ print(tokenized_data[:10])
+ print(f'\t{[tokenizer.decode(t) for t in tokenized_data[:10]]}')
+
+ concat_tokens = np.concatenate(concat_tokens)
+ print(f'Processed dataset contains {num_sample} samples with {len(concat_tokens)} tokens, containing {num_sample * constants.CHUNK_SIZE_BYTES} byte'
+ f' ~= {num_sample * constants.CHUNK_SIZE_BYTES/1024/1024:.2f} MB.')
+
+ with open(args.output_name, 'wb') as handle:
+ pickle.dump({'tokens': concat_tokens, 'sample_index': args.sample_index, 'total_bytes': num_sample * constants.CHUNK_SIZE_BYTES}, handle, protocol=pickle.HIGHEST_PROTOCOL)
+
+if __name__ == '__main__':
+ args = parse_args()
+ main(args)
\ No newline at end of file
diff --git a/train_and_collect_grad.py b/train_and_collect_grad.py
new file mode 100644
index 0000000..3d2e6ca
--- /dev/null
+++ b/train_and_collect_grad.py
@@ -0,0 +1,160 @@
+import argparse, yaml, os
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from tqdm import tqdm
+from collections import OrderedDict
+
+from utils.misc import get_network, prepare_dataset
+from utils.utils import BaseParser
+
+def parse_args():
+ # specify the architecture, the frequency of gradient collection
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('-cfg', '--cfg', default=None, type=str, required=True)
+ parser.add_argument('-tag', '--tag', default=0, type=int)
+ parser.add_argument('-seed', '--seed', default=None)
+ parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
+
+ parser.add_argument('-data-path', '--data-path', default='./', type=str)
+ parser.add_argument('-download', '--download', action='store_true')
+
+ parser.add_argument('-save_path', '--save_path', default='./saves', type=str)
+ parser.add_argument('-grad-interval', '--grad-interval', default=200, type=int, help="The frequency of saving gradients. Counted by #batch steps.")
+ parser.add_argument('-model-interval', '--model-interval', default=5, type=int, help="The frequency of saving models. Counted by #epoch.")
+
+ parser.add_argument('-start-epoch', '--start-epoch', default=1, type=int, help="Load checkpoints, if sepcified.")
+
+ args = parser.parse_args()
+ with open(args.cfg, 'r') as stream:
+ settings = yaml.safe_load(stream)
+
+ args = BaseParser(args, settings)
+
+ args.name = os.path.basename(args.cfg).split('.')[0]
+ tag = args.tag
+ args.exp_dir = os.path.join('exps/', args.name, f'{tag}')
+
+ if not os.path.exists(args.exp_dir):
+ os.makedirs(args.exp_dir)
+
+ # save checkpoints
+ args.model_dir = os.path.join(args.exp_dir, 'ckpts')
+ if not os.path.exists(args.model_dir):
+ os.makedirs(args.model_dir)
+
+ # save gradients for compression experiments
+ args.grad_dir = os.path.join(args.exp_dir, 'grads')
+ if not os.path.exists(args.grad_dir):
+ os.makedirs(args.grad_dir)
+
+ # logs
+ args.log_dir = os.path.join(args.exp_dir, 'logs')
+ if not os.path.exists(args.log_dir):
+ os.makedirs(args.log_dir)
+
+ return args
+
+def train(args, net, opt, crit, train_loader):
+ net.train()
+ train_loss = 0
+ correct = 0
+ total = 0
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
+ args.batch_step += 1
+
+ inputs, targets = inputs.to(args.device), targets.to(args.device)
+
+ opt.zero_grad()
+
+ outputs = net(inputs)
+ loss = crit(outputs, targets)
+ loss.backward()
+
+ if args.batch_step % args.grad_interval == 0:
+ grad_dict = {k: v.grad.detach().clone().cpu() for k, v in net.named_parameters()}
+ state = {
+ 'state_dict': grad_dict,
+ 'batch_step': args.batch_step,
+ }
+ fname = os.path.join(args.grad_dir, f'grad_{args.batch_step:08d}.ckpt')
+ torch.save(state, fname)
+
+ opt.step()
+
+ train_loss += loss.item()
+ _, predicted = outputs.max(1)
+ total += targets.size(0)
+ correct += predicted.eq(targets).sum().item()
+
+ acc = 100.*correct/total
+ return train_loss/len(train_loader), acc, correct, total
+
+def test(args, net, crit, test_loader):
+ net.eval()
+ test_loss = 0
+ correct = 0
+ total = 0
+ with torch.no_grad():
+ for batch_idx, (inputs, targets) in enumerate(test_loader):
+ inputs, targets = inputs.to(args.device), targets.to(args.device)
+ outputs = net(inputs)
+ loss = crit(outputs, targets)
+
+ test_loss += loss.item()
+ _, predicted = outputs.max(1)
+ total += targets.size(0)
+ correct += predicted.eq(targets).sum().item()
+
+ acc = 100.*correct/total
+ return test_loss/len(test_loader), acc, correct, total
+
+
+def main(args):
+ args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ train_loader, test_loader, data_info = prepare_dataset(args)
+ net = get_network(args.arch, data_info).to(args.device)
+
+ crit = nn.CrossEntropyLoss()
+ if 'vit' in args.arch:
+ param_dict = {pn: p for pn, p in net.named_parameters()}
+ parameters_decay, parameters_no_decay = net.separate_parameters()
+
+ optim_groups = [
+ {"params": [param_dict[pn] for pn in parameters_decay], "weight_decay": 1e-1},
+ {"params": [param_dict[pn] for pn in parameters_no_decay], "weight_decay": 0.0},
+ ]
+ opt = optim.AdamW(optim_groups, lr=1e-3)
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-3,
+ steps_per_epoch=len(train_loader), epochs=200)
+ else:
+ opt = optim.SGD(net.parameters(), lr=args.lr,
+ momentum=0.9, weight_decay=5e-4)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=200)
+
+ args.batch_step = 0 # just to avoid call-by-assign
+ for epoch in tqdm(range(args.start_epoch, 200+1)):
+ train_loss, train_acc, correct, total = train(args, net, opt, crit, train_loader)
+ test_loss, test_acc, correct, total = test(args, net, crit, test_loader)
+ print(f'At epoch {epoch}: Training accuracy {train_acc:.2f} and testing accuracy {test_acc:.2f}.')
+ scheduler.step()
+
+ if epoch % args.model_interval == 0:
+ print(f'Saving models at epoch {epoch} ...')
+ state = {
+ 'state_dict': net.state_dict(),
+ 'train_loss': train_loss,
+ 'train_acc': train_acc,
+ 'test_loss': test_loss,
+ 'test_acc': test_acc,
+ 'epoch': epoch,
+ }
+ fname = os.path.join(args.model_dir, f'model_{epoch:03d}.ckpt')
+ torch.save(state, fname)
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/utils/data_loaders.py b/utils/data_loaders.py
new file mode 100644
index 0000000..9f55aa9
--- /dev/null
+++ b/utils/data_loaders.py
@@ -0,0 +1,397 @@
+"""Implements data loaders."""
+from collections.abc import Iterator
+import sys
+sys.path.append('../')
+import itertools, pickle, functools
+import os, os.path
+import random
+import torch
+import torch.nn as nn
+import torchvision
+
+import numpy as np
+import torchvision.transforms as transforms
+
+from transformers import AutoTokenizer
+
+import constants
+from utils.utils import BaseParser
+from utils.misc import linear_quantization
+from compressors import language_model, compressor
+from typing import List
+
+# number of bits represented by one symbol in a specific decoing scheme
+SPLIT_VALUES = {
+ 'quater': 2,
+ 'oct': 3,
+ 'hex': 4,
+}
+
+SEP_VALUE = {
+ None: '',
+ 'none': '',
+ 'comma': ',',
+ 'space': ' ',
+ 'comma+space': ', ',
+ 'semicolon': ';',
+ '0x': '0x',
+}
+
+def worker_init_fn(worker_id):
+ os.sched_setaffinity(0, range(os.cpu_count()))
+
+def _extract_image_patches(image: np.ndarray) -> Iterator[bytes]:
+ h, w = constants.CHUNK_SHAPE_2D
+ height, width = image.shape
+
+ for row, col in itertools.product(range(height // h), range(width // w)):
+ yield image[row * h : (row + 1) * h, col * w : (col + 1) * w].tobytes()
+
+def _extract_parameter_patches(sample: bytes) -> Iterator[bytes]:
+ patches = np.array_split(
+ np.frombuffer(sample, dtype=np.uint8),
+ range(
+ constants.CHUNK_SIZE_BYTES,
+ len(sample),
+ constants.CHUNK_SIZE_BYTES,
+ ),
+ )
+ if len(patches[-1]) != constants.CHUNK_SIZE_BYTES:
+ # pad the array to have the same size
+ current_size = len(patches[-1])
+ padding_size = constants.CHUNK_SIZE_BYTES - current_size
+ patches[-1] = np.pad(patches[-1], pad_width=(0, padding_size), mode='constant', constant_values=0)
+
+ return map(lambda x: x.tobytes(), patches)
+
+def _convert_bin_fp(data: bytes, precision=64) -> List[float]:
+ converted_fp = np.frombuffer(data, dtype=np.float64 if precision == 64 else np.float32)
+ return converted_fp
+
+def _convert_fp_param(model: nn.Module, fp: float) -> None:
+ # pointer: iterate through the entire dp and parse them to param
+ pt = 0
+
+ for _, p in model.named_parameters():
+ len_p = p.numel()
+
+ param_from_fp = torch.tensor(fp[pt:pt+len_p]).view(p.size()).to(model.device)
+ p = param_from_fp
+
+ pt = pt + len_p
+
+def _serialize(data: bytes, decoding: str, bytes_per_group: int = 1) -> str:
+ '''A function that convert bytes into a hex string list and then the formats that LLMs can understand.
+ '''
+ codec = decoding.split('-')[0]
+ sep = decoding.split('-')[1] if '-' in decoding else None
+ assert sep in SEP_VALUE.keys(), f'Unknown separator {sep}. This is typically controlled by the decoding argument.'
+ sep = SEP_VALUE[sep]
+
+ split = SPLIT_VALUES[codec] if codec in SPLIT_VALUES.keys() else 8
+
+ # dump every byte as two hexi-decimal numbers and remove '0x' prefix
+ hex_str = [hex(n)[2:].zfill(2) for n in data]
+
+ # pre-group bytes according to the hyperparameter "bytes_per_group"
+ hex_str = [''.join(hex_str[i:i+bytes_per_group]) for i in range(0, len(hex_str), bytes_per_group)]
+ out = []
+
+ # read one group, consisting of "bytes_per_group bytes" and represented by hex numbers
+ for group in hex_str:
+ concat_str = []
+ num = int(group, 16)
+ if codec == 'quater':
+ pass
+ elif codec == 'hex':
+ num = hex(num)[2:].zfill(len(group))
+ elif codec == 'iso':
+ num = num.to_bytes(bytes_per_group, 'big') # orders should be taken care when loading the data
+ num = num.decode('iso-8859-1')
+ else:
+ raise NotImplementedError(f'Unknown serialization method: {decoding}.')
+ concat_str.append(num)
+ num = ''.join(concat_str)
+
+ out.append(num)
+
+ out = sep.join(out)
+ return out
+
+def _deserialize(data: str, decoding: str, bytes_per_group: int = 1) -> bytes:
+ '''A function that convert hex string list back to bytes.
+ '''
+ codec = decoding.split('-')[0]
+ sep = decoding.split('-')[1] if '-' in decoding else None
+ assert sep in SEP_VALUE.keys(), f'Unknown separator {sep}. This is typically controlled by the decoding argument.'
+ sep = SEP_VALUE[sep]
+
+ split = SPLIT_VALUES[codec] if codec in SPLIT_VALUES.keys() else 8
+
+ def hex_char_to_bin(hex_char):
+ return bin(int(hex_char, 16))[2:].zfill(4)
+
+ # group data separated by the separators.
+ grouped_data = data.split(sep)
+
+ binary_string = ""
+ for g in grouped_data:
+ for char in g:
+ # print(f'current string: {binary_string}, precoessing {char}')
+ binary_string += hex_char_to_bin(char)
+
+ byte_array = int(binary_string, 2).to_bytes((len(binary_string) + 7) // 8, byteorder='big')
+ return byte_array
+
+
+def get_image_iterator(
+ args: BaseParser = None,
+ serialization: bool = False,
+ max_tokens: int = None,
+ preprocess: str = None,
+):
+ trainset = IterImagePatchDataset(args, serialization)
+ loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, num_workers=1)
+ return loader
+
+def get_gradient_iterator(
+ args: BaseParser = None,
+ serialization: bool = False,
+ max_tokens: int = None,
+ preprocess: str = None,
+ return_fp: bool=False
+):
+ trainset = IterGradientDataset(args, serialization, preprocess=preprocess, return_fp=return_fp)
+ loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, num_workers=1)
+ return trainset.total_bytes, loader
+
+def get_td_iterator(
+ args: BaseParser = None,
+ serialization: bool = False,
+ max_tokens: int = None,
+ preprocess: str = None,
+):
+ trainset = PreTokenizedDataset(args, max_tokens)
+ loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, num_workers=6, worker_init_fn=worker_init_fn)
+ return trainset.total_bytes, loader
+
+class IterImagePatchDataset(torch.utils.data.IterableDataset):
+ def __init__(self, args, serialization):
+ self.args = args
+ self.serialization = serialization
+
+ transform_train = transforms.Compose([
+ transforms.Grayscale(),
+ # transforms.ToTensor(),
+ ])
+
+ transform_test = transforms.Compose([
+ transforms.Grayscale(),
+ # transforms.ToTensor(),
+ ])
+
+ if args.dataset == 'cifar10':
+ self.trainset = torchvision.datasets.CIFAR10(
+ root=args.data_path, train=True, download=args.download, transform=transform_train)
+ self.testset = torchvision.datasets.CIFAR10(
+ root=args.data_path, train=False, download=args.download, transform=transform_test)
+ self.bytes_per_sample = 32 * 32 # only consider gray-scale images here
+ elif args.dataset == 'imagenet':
+ self.trainset = torchvision.datasets.ImageNet(
+ root=args.data_path, train=True, download=args.download, transform=transform_train)
+ self.testset = torchvision.datasets.ImageNet(
+ root=args.data_path, train=False, download=args.download, transform=transform_test)
+ self.bytes_per_sample = 256 * 256 # only consider gray-scale images here
+ else:
+ raise NotImplementedError(f'Unknown dataset: {args.dataset}')
+
+ def __iter__(self):
+ idx = 0
+
+ for data in self.trainset:
+ image, label = data
+ if constants.UINT8_LOADING:
+ image = np.array(image)
+ else:
+ image = image.squeeze().numpy()
+ # print(type(image), image.shape)
+ for patch in _extract_image_patches(image):
+ num_bytes = len(patch)
+ if idx == constants.NUM_CHUNKS:
+ return
+ if self.serialization:
+ yield num_bytes, _serialize(patch, self.args.decoding, bytes_per_group=self.args.bytes_per_group)
+ else:
+ yield num_bytes, patch
+ idx += 1
+
+class IterGradientDataset(torch.utils.data.IterableDataset):
+ def __init__(self, args, serialization, preprocess=None, return_fp=False):
+ self.args = args
+ self.serialization = serialization
+ self.return_fp = return_fp
+ self.ckpt_list = args.data_path
+ if not isinstance(self.ckpt_list, list) and not isinstance(self.ckpt_list, np.ndarray):
+ self.ckpt_list = [self.ckpt_list]
+
+ preprocess_fn = []
+ if preprocess is not None:
+ preprocess = preprocess.split('+')
+ for p in preprocess:
+ print(f'Runing preprocessing {p} ...')
+ if 'sparsification' in p:
+ strength = int(p.split('sparsification')[1]) / 100. # convert to percentage
+ preprocess_fn.append(functools.partial(self._sparsify, strength=strength))
+ elif 'quantization' in p:
+ n_bits = int(p.split('quantization')[1])
+ preprocess_fn.append(functools.partial(self._quantize, n_bits=n_bits))
+ else:
+ raise ValueError(f'Unknown preprocess instruction: {preprocess} specified in the dataloader.')
+
+ self.length, self.total_bytes = 0, 0
+ self.dset = []
+ for ckpt in self.ckpt_list:
+ state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['state_dict']
+
+ sample = []
+ for k, v in state_dict.items():
+ sample.append(v.numpy().flatten())
+ sample = np.concatenate(sample)
+
+ if len(preprocess_fn) > 0:
+ for fn in preprocess_fn:
+ sample = fn(sample)
+
+ print('===============> total bytes: ', len(sample.tobytes()))
+
+ sample = sample.tobytes()
+ self.length += int(np.ceil(len(sample)/constants.CHUNK_SIZE_BYTES))
+ self.total_bytes += len(sample) if self.length < constants.NUM_CHUNKS else constants.NUM_CHUNKS * constants.CHUNK_SIZE_BYTES
+
+ self.dset.append(sample)
+ # sample: List[bytes]; len(sample) = #bytes contained in this ckpt
+
+
+ def _add_gaussian_noise(self, data: np.ndarray, std: float=1):
+ noise = np.random.normal(0, std, size=data.shape)
+ return data + noise
+
+ def _gradient_clipping(self, data: np.ndarray, strength: float=1):
+ norm = np.linalg.norm(data)
+ return data * (strength / norm) if norm > strength else data
+
+ def _sparsify(self, data: np.ndarray, strength: float=0.25):
+ # two way to implement this: set non-masked bits to zero (standard); only transmit masked bits (used by some existing works)
+ # we use the latter here
+ assert strength > 0 and strength <= 1
+ mask = (np.random.uniform(0, 1, size=data.shape) < strength)
+ return data[mask]
+
+ def _is_power_of_two(self, n):
+ return (n & (n - 1)) == 0 and n != 0
+
+ def _quantize(self, data: np.ndarray, n_bits: int=8):
+ assert n_bits in [1, 8, 16]
+ num_levels = 2 ** n_bits
+
+ length = data.max() - data.min()
+ if n_bits != 1:
+ intervals = np.linspace(0, length, num_levels + 1)
+ inds = np.digitize(data - data.min(), intervals)
+ else:
+ inds = data
+ inds[inds>0] = 1
+ inds[inds<=0] = 0
+
+ if n_bits == 1:
+ inds = np.packbits(inds.astype(bool))
+ elif n_bits == 8:
+ inds = inds.astype(np.uint8)
+ elif n_bits == 16:
+ inds = inds.astype(np.uint16)
+ else:
+ raise ValueError(f'Unsupported n_bits: {n_bits} for quantization.')
+
+ return inds
+
+ def __iter__(self):
+ idx = 0
+ for bid, s in enumerate(self.dset):
+ if self.return_fp:
+ step_size = constants.CHUNK_SIZE_BYTES // 4
+ patches = np.array_split(
+ s,
+ range(
+ step_size,
+ len(s),
+ step_size,
+ ),
+ )
+
+ if len(patches[-1]) != step_size:
+ # pad the array to have the same size
+ current_size = len(patches[-1])
+ padding_size = step_size - current_size
+ patches[-1] = np.pad(patches[-1], pad_width=(0, padding_size), mode='constant', constant_values=0)
+
+ for patch in patches:
+ yield sys.getsizeof(patch), patch
+ else:
+ for patch in _extract_parameter_patches(s):
+ num_bytes = len(patch)
+ if idx == constants.NUM_CHUNKS:
+ return
+ if self.serialization:
+ yield num_bytes, _serialize(patch, self.args.decoding, bytes_per_group=self.args.bytes_per_group)
+ else:
+ yield num_bytes, patch
+ idx += 1
+
+ def __len__(self):
+ return self.length
+
+class PreTokenizedDataset(torch.utils.data.Dataset):
+ def __init__(self, args, max_tokens=None):
+ self.args = args
+ with open(args.data_path, 'rb') as handle:
+ loaded_data = pickle.load(handle)
+ self.dset = loaded_data['tokens'].reshape(-1)
+ self.total_bytes = loaded_data['total_bytes']
+
+ if args.compressor in compressor.COMPRESSOR_TYPES['arithmetic_coding']:
+ self.tokenizer = AutoTokenizer.from_pretrained(language_model.MODEL_NAME_DICT[args.compressor])
+
+ # we save one token quota for the bos_token which will be added later
+ self.max_tokens = max_tokens - 1 if max_tokens is not None else 2048 - 1
+ print(self.max_tokens)
+
+ self.num_samples = int(np.ceil(len(self.dset)/(self.max_tokens + 1)))
+ self.eos_token_id = self.tokenizer.eos_token_id
+ else:
+ self.max_tokens = constants.CHUNK_SIZE_BYTES
+ self.num_samples = int(np.ceil(len(self.dset)/self.max_tokens))
+
+ def _add_start_token(self, data: np.ndarray):
+ # for some models like LLAMA, they append more than one tokens at the begining such as a space token
+ # However, we ingore it here and only prepend a bos_token
+ bos_token_id = self.tokenizer.bos_token_id
+ return np.insert(data, 0, bos_token_id)
+
+ def __getitem__(self, index):
+ sample = self.dset[index*self.max_tokens:(index+1)*self.max_tokens]
+ if self.args.compressor in compressor.COMPRESSOR_TYPES['arithmetic_coding']:
+ if len(sample) < self.max_tokens:
+ padding = np.full((self.max_tokens, ), self.tokenizer.eos_token_id)
+ padding[:len(sample)] = sample
+ sample = padding
+ sample = self._add_start_token(sample)
+ return len(sample), sample
+
+ def __len__(self):
+ return self.num_samples
+
+GET_DATA_GENERATOR_FN_DICT = {
+ 'cifar10': get_image_iterator,
+ 'gradient': get_gradient_iterator,
+ 'tokenized_dataset': get_td_iterator,
+}
\ No newline at end of file
diff --git a/utils/deepmind_utils.py b/utils/deepmind_utils.py
new file mode 100644
index 0000000..69a48a0
--- /dev/null
+++ b/utils/deepmind_utils.py
@@ -0,0 +1,103 @@
+# Copyright 2024 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utility functions."""
+
+import numpy as np
+
+
+def bits_to_bytes(bits: str) -> tuple[bytes, int]:
+ """Returns the bytes representation of bitstream and number of padded bits."""
+ # Pad the string with zeros if the length is not a multiple of 8.
+ padded_bits = bits.zfill((len(bits) + 7) // 8 * 8)
+ num_padded_bits = len(padded_bits) - len(bits)
+
+ # Split the string into 8-bit chunks.
+ chunks = [padded_bits[i : i + 8] for i in range(0, len(padded_bits), 8)]
+
+ # Convert each chunk to an integer and then to a byte.
+ bytes_data = bytes([int(chunk, base=2) for chunk in chunks])
+
+ return bytes_data, num_padded_bits
+
+
+def bytes_to_bits(data: bytes, num_padded_bits: int = 0) -> str:
+ """Returns the bitstream of bytes data accounting for padded bits."""
+ return ''.join([bin(byte)[2:].zfill(8) for byte in data])[num_padded_bits:]
+
+
+def right_shift_bytes_by_one(data: bytes) -> tuple[bytes, int]:
+ """Returns right-shifted bytes, i.e., divided by 2, and the number of bytes.
+
+ Our language models were trained on ASCII data. However, not all bytes can be
+ decoded to ASCII, so we set the most significant bit (MSB) to 0, to ensure
+ that we can decode the data to ASCII.
+
+ However, for certain data types (e.g., images), masking the MSB and leaving
+ the rest of the byte unchanged will destroy the structure of the data. Thus,
+ we instead divide the number by two (i.e., we shift the bits to the right by
+ one).
+
+ Args:
+ data: The bytes to be shifted.
+ """
+ return bytes([byte >> 1 for byte in data]), len(data)
+
+
+def zero_most_significant_bit_if_not_ascii_decodable(
+ data: bytes,
+) -> tuple[bytes, int]:
+ """Returns ascii-decodable data & the number of zeroed most significant bits.
+
+ Our language models were trained on ASCII data. However, not all bytes can be
+ decoded to ASCII, so we set the most significant bit (MSB) to 0, to ensure
+ that we can decode the data to ASCII.
+
+ Args:
+ data: The bytes to be shifted.
+ """
+ masked_bits = 0
+ masked_data = list()
+
+ for byte in data:
+ if chr(byte).isascii():
+ masked_data.append(byte)
+ else:
+ masked_bits += 1
+ masked_data.append(byte & 0x7F)
+
+ return bytes(masked_data), masked_bits
+
+
+def normalize_pdf_for_arithmetic_coding(pdf: np.array) -> np.array:
+ """Normalizes the probabilities for arithmetic coding.
+
+ Arithmetic coding converts the floating-point pdf to integers to avoid
+ numerical issues. To that end, all pdf values need to be larger than the
+ machine epsilon (to yield different integer values) and the sum of the pdf
+ cannot exceed 1 (minus some precision tolerance).
+
+ Args:
+ pdf: The probabilities to be normalized.
+
+ Returns:
+ The normalized probabilities.
+ """
+ machine_epsilon = np.finfo(np.float32).eps
+ # Normalize the probabilities to avoid floating-point errors.
+ pdf = pdf / np.cumsum(pdf)[-1]
+ # Ensure all probabilities are sufficiently large to yield distinct cdfs.
+ pdf = (1 - 2 * pdf.shape[0] * machine_epsilon) * pdf + machine_epsilon
+ return pdf
\ No newline at end of file
diff --git a/utils/misc.py b/utils/misc.py
new file mode 100644
index 0000000..52f9057
--- /dev/null
+++ b/utils/misc.py
@@ -0,0 +1,168 @@
+import pickle
+
+import torch
+import torchvision
+
+import torchvision.transforms as transforms
+
+import numpy as np
+
+from utils.models.vgg import VGG
+from utils.models.convnet import ConvNet
+from utils.models.vit import ViT
+from utils.models.resnet import resnet
+from utils.tinyimagenet import TinyImageNet
+
+def get_network(arch, data_info=None):
+ if 'VGG' in arch:
+ return VGG(arch)
+ elif 'ConvNet' in arch:
+ # default settings
+ net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling'
+ return ConvNet(channel=data_info['channel'], num_classes=data_info['num_classes'], net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=data_info['im_size'])
+ elif 'ViT' in arch:
+ return ViT(data_info['num_classes'], data_info['im_size'][0])
+ elif 'resnet' in arch:
+ return resnet(arch, data_info['channel'], data_info['num_classes'])
+ else:
+ raise NotImplementedError(f'Unknown model architecture {arch}.')
+
+def prepare_dataset(args):
+ data_info = {}
+ if args.dataset == 'cifar10':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ trainset = torchvision.datasets.CIFAR10(
+ root=args.data_path, train=True, download=args.download, transform=transform_train)
+ trainloader = torch.utils.data.DataLoader(
+ trainset, batch_size=128, shuffle=True, num_workers=2)
+
+ testset = torchvision.datasets.CIFAR10(
+ root=args.data_path, train=False, download=args.download, transform=transform_test)
+ testloader = torch.utils.data.DataLoader(
+ testset, batch_size=100, shuffle=False, num_workers=2)
+
+ data_info['classes'] = ('plane', 'car', 'bird', 'cat', 'deer',
+ 'dog', 'frog', 'horse', 'ship', 'truck')
+ data_info['num_classes'] = 10
+ data_info['channel'] = 3
+ data_info['im_size'] = (32, 32)
+
+ elif args.dataset == 'mnist':
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
+
+ trainset = torchvision.datasets.MNIST(
+ root=args.data_path, train=True, download=args.download, transform=transform)
+ trainloader = torch.utils.data.DataLoader(
+ trainset, batch_size=128, shuffle=True, num_workers=2)
+
+ testset = torchvision.datasets.MNIST(
+ root=args.data_path, train=False, download=args.download, transform=transform)
+ testloader = torch.utils.data.DataLoader(
+ testset, batch_size=100, shuffle=False, num_workers=2)
+
+ data_info['classes'] = ('0', '1', '2', '3', '4', '5',
+ '6', '7', '8', '9')
+ data_info['num_classes'] = 10
+ data_info['channel'] = 1
+ data_info['im_size'] = (28, 28)
+
+ elif args.dataset == 'tinyimagenet':
+ transform_train = transforms.Compose([
+ transforms.Resize(32),
+ transforms.RandomRotation(20),
+ transforms.RandomHorizontalFlip(0.5),
+ transforms.ToTensor(),
+ transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
+ ])
+
+ transform_test = transforms.Compose([
+ transforms.Resize(32),
+ transforms.ToTensor(),
+ transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
+ ])
+
+ trainset = TinyImageNet(
+ root=args.data_path, split='train', download=args.download, transform=transform_train)
+ trainloader = torch.utils.data.DataLoader(
+ trainset, batch_size=128, shuffle=True, num_workers=6)
+
+ testset = TinyImageNet(
+ root=args.data_path, split='val', download=args.download, transform=transform_test)
+ testloader = torch.utils.data.DataLoader(
+ testset, batch_size=128, shuffle=False, num_workers=6)
+
+ # data_info['classes'] =
+ data_info['num_classes'] = 200
+ data_info['channel'] = 3
+ data_info['im_size'] = (32, 32)
+ else:
+ raise NotImplementedError(f'Unknown dataset type: {args.dataset}')
+
+ return trainloader, testloader, data_info
+
+# n_bits = 2, 4, 8
+# unbiased: apply probabilistic unbiased quantization or not
+# hadamard: apply random hadamard rotation or not
+def linear_quantization(input, n_bits, unbiased=True, hadamard=True):
+ quanti_level = 2 ** n_bits
+ rand_diag = []
+
+ if hadamard:
+ input , rand_diag = hadamard_rotation(input)
+
+ v_max = input.max()
+ v_min = input.min()
+ output = input
+ output = (output - v_min) / (v_max - v_min) * (quanti_level - 1)
+
+ if unbiased:
+ output = prob_quantization(output)
+ else:
+ output = output.round()
+
+ #output = output.reshape(sz)
+
+ return output, v_min, v_max, rand_diag, quanti_level
+
+def hadamard_rotation(input):
+ sz = input.shape
+ sz1 = sz[0]
+ sz2 = int(input.size / sz1)
+ dim = 2 ** np.ceil(np.log2(sz1))
+ hadamard_mat = hadamard(dim)
+ if hadamard_mat.shape[0] != sz1:
+ hadamard_mat = hadamard_mat[:sz1, :sz1]
+
+ x = input.reshape(sz1, sz2)
+ diag = np.random.uniform(0, 1, size=x.shape) < 0.5
+ diag = diag * 2 - 1
+ x = np.matmul(hadamard_mat, x) * diag
+ x = x.reshape(sz)
+ return x, diag
+
+def prob_quantization(input):
+ x = np.ceil(input)
+ p = np.random.uniform(0, 1, size=x.shape)
+ x = x - (p < x - input)
+ return x
+
+def hist_tokenized_dataset(data_path: str) -> np.ndarray:
+ with open(data_path, 'rb') as handle:
+ loaded_data = pickle.load(handle)
+ dset = loaded_data['tokens'].reshape(-1)
+ total_bytes = loaded_data['total_bytes']
+
+ unique_elements, counts = np.unique(dset, return_counts=True)
+
+ return unique_elements, counts
\ No newline at end of file
diff --git a/utils/models/convnet.py b/utils/models/convnet.py
new file mode 100644
index 0000000..7592f3d
--- /dev/null
+++ b/utils/models/convnet.py
@@ -0,0 +1,79 @@
+import torch
+import torch.nn as nn
+
+class ConvNet(nn.Module):
+ def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size=(32, 32)):
+ super(ConvNet, self).__init__()
+ self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling,
+ im_size)
+ num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2]
+ self.classifier = nn.Linear(num_feat, num_classes)
+
+ def forward(self, x):
+ out = self.features(x)
+ out = out.view(out.size(0), -1)
+ out = self.classifier(out)
+ return out
+
+ def embed(self, x):
+ feat = self.features(x)
+ feat = feat.view(feat.size(0), -1)
+ out = self.classifier(feat)
+ return feat, out
+
+ def _get_activation(self, net_act):
+ if net_act == 'sigmoid':
+ return nn.Sigmoid()
+ elif net_act == 'relu':
+ return nn.ReLU(inplace=True)
+ elif net_act == 'leakyrelu':
+ return nn.LeakyReLU(negative_slope=0.01)
+ elif net_act == 'swish':
+ return Swish()
+ else:
+ exit('unknown activation function: %s' % net_act)
+
+ def _get_pooling(self, net_pooling):
+ if net_pooling == 'maxpooling':
+ return nn.MaxPool2d(kernel_size=2, stride=2)
+ elif net_pooling == 'avgpooling':
+ return nn.AvgPool2d(kernel_size=2, stride=2)
+ elif net_pooling == 'none':
+ return None
+ else:
+ exit('unknown net_pooling: %s' % net_pooling)
+
+ def _get_normlayer(self, net_norm, shape_feat):
+ # shape_feat = (c*h*w)
+ if net_norm == 'batchnorm':
+ return nn.BatchNorm2d(shape_feat[0], affine=True)
+ elif net_norm == 'layernorm':
+ return nn.LayerNorm(shape_feat, elementwise_affine=True)
+ elif net_norm == 'instancenorm':
+ return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
+ elif net_norm == 'groupnorm':
+ return nn.GroupNorm(4, shape_feat[0], affine=True)
+ elif net_norm == 'none':
+ return None
+ else:
+ exit('unknown net_norm: %s' % net_norm)
+
+ def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
+ layers = []
+ in_channels = channel
+ if im_size[0] == 28:
+ im_size = (32, 32)
+ shape_feat = [in_channels, im_size[0], im_size[1]]
+ for d in range(net_depth):
+ layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
+ shape_feat[0] = net_width
+ if net_norm != 'none':
+ layers += [self._get_normlayer(net_norm, shape_feat)]
+ layers += [self._get_activation(net_act)]
+ in_channels = net_width
+ if net_pooling != 'none':
+ layers += [self._get_pooling(net_pooling)]
+ shape_feat[1] //= 2
+ shape_feat[2] //= 2
+
+ return nn.Sequential(*layers), shape_feat
\ No newline at end of file
diff --git a/utils/models/resnet.py b/utils/models/resnet.py
new file mode 100644
index 0000000..5b0344a
--- /dev/null
+++ b/utils/models/resnet.py
@@ -0,0 +1,144 @@
+# The code is adapted from https://github.com/VICO-UoE/DatasetCondensation for CIFAR-10
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
+ super(BasicBlock, self).__init__()
+ self.norm = norm
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
+ nn.GroupNorm(self.expansion * planes, self.expansion * planes,
+ affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
+ super(Bottleneck, self).__init__()
+ self.norm = norm
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
+ self.bn3 = nn.GroupNorm(self.expansion * planes, self.expansion * planes,
+ affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
+ nn.GroupNorm(self.expansion * planes, self.expansion * planes,
+ affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = F.relu(self.bn2(self.conv2(out)))
+ out = self.bn3(self.conv3(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
+ super(ResNet, self).__init__()
+ self.in_planes = 64
+ self.norm = norm
+
+ self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64)
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+ self.classifier = nn.Linear(512 * block.expansion, num_classes)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride, self.norm))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+ out = F.avg_pool2d(out, 4)
+ out = out.view(out.size(0), -1)
+ out = self.classifier(out)
+ return out
+
+ def embed(self, x):
+ x = F.relu(self.bn1(self.conv1(x)))
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = F.avg_pool2d(x, kernel_size=1, stride=1) # modification
+ x = x.view(x.size(0), -1)
+
+ out = self.classifier(x)
+ return x, out
+
+
+def ResNet18BN(channel, num_classes):
+ return ResNet(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes, norm='batchnorm')
+
+
+def ResNet18(channel, num_classes):
+ return ResNet(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes)
+
+
+def ResNet34(channel, num_classes):
+ return ResNet(BasicBlock, [3, 4, 6, 3], channel=channel, num_classes=num_classes)
+
+
+def ResNet50(channel, num_classes):
+ return ResNet(Bottleneck, [3, 4, 6, 3], channel=channel, num_classes=num_classes)
+
+
+def ResNet101(channel, num_classes):
+ return ResNet(Bottleneck, [3, 4, 23, 3], channel=channel, num_classes=num_classes)
+
+
+def ResNet152(channel, num_classes):
+ return ResNet(Bottleneck, [3, 8, 36, 3], channel=channel, num_classes=num_classes)
+
+def resnet(arch, channel, num_classes):
+ arch_choices = {
+ 'resnet18bn': ResNet18BN,
+ 'resnet18': ResNet18,
+ 'resnet34': ResNet34,
+ 'resnet50': ResNet50,
+ 'resnet101': ResNet101,
+ 'resnet152': ResNet152,
+ }
+
+ return arch_choices[arch](channel, num_classes)
\ No newline at end of file
diff --git a/utils/models/vgg.py b/utils/models/vgg.py
new file mode 100644
index 0000000..c38f31e
--- /dev/null
+++ b/utils/models/vgg.py
@@ -0,0 +1,38 @@
+'''VGG11/13/16/19 in Pytorch.'''
+import torch
+import torch.nn as nn
+
+
+cfg = {
+ 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+ 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+ 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
+ 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+}
+
+
+class VGG(nn.Module):
+ def __init__(self, vgg_name):
+ super(VGG, self).__init__()
+ self.features = self._make_layers(cfg[vgg_name])
+ self.classifier = nn.Linear(512, 10)
+
+ def forward(self, x):
+ out = self.features(x)
+ out = out.view(out.size(0), -1)
+ out = self.classifier(out)
+ return out
+
+ def _make_layers(self, cfg):
+ layers = []
+ in_channels = 3
+ for x in cfg:
+ if x == 'M':
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+ else:
+ layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
+ nn.BatchNorm2d(x),
+ nn.ReLU(inplace=True)]
+ in_channels = x
+ layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
+ return nn.Sequential(*layers)
\ No newline at end of file
diff --git a/utils/models/vit.py b/utils/models/vit.py
new file mode 100644
index 0000000..62a9884
--- /dev/null
+++ b/utils/models/vit.py
@@ -0,0 +1,196 @@
+# CIFAR-10 ViT model is adapted from https://juliusruseckas.github.io/ml/cifar10-vit.html
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class Residual(nn.Module):
+ def __init__(self, *layers):
+ super().__init__()
+ self.residual = nn.Sequential(*layers)
+ self.gamma = nn.Parameter(torch.zeros(1))
+
+ def forward(self, x):
+ return x + self.gamma * self.residual(x)
+
+class LayerNormChannels(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.norm = nn.LayerNorm(channels)
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = self.norm(x)
+ x = x.transpose(-1, 1)
+ return x
+
+class SelfAttention2d(nn.Module):
+ def __init__(self, in_channels, out_channels, head_channels, shape):
+ super().__init__()
+ self.heads = out_channels // head_channels
+ self.head_channels = head_channels
+ self.scale = head_channels**-0.5
+
+ self.to_keys = nn.Conv2d(in_channels, out_channels, 1)
+ self.to_queries = nn.Conv2d(in_channels, out_channels, 1)
+ self.to_values = nn.Conv2d(in_channels, out_channels, 1)
+ self.unifyheads = nn.Conv2d(out_channels, out_channels, 1)
+
+ height, width = shape
+ self.pos_enc = nn.Parameter(torch.Tensor(self.heads, (2 * height - 1) * (2 * width - 1)))
+ self.register_buffer("relative_indices", self.get_indices(height, width))
+
+ def forward(self, x):
+ b, _, h, w = x.shape
+
+ keys = self.to_keys(x).view(b, self.heads, self.head_channels, -1)
+ values = self.to_values(x).view(b, self.heads, self.head_channels, -1)
+ queries = self.to_queries(x).view(b, self.heads, self.head_channels, -1)
+
+ att = keys.transpose(-2, -1) @ queries
+
+ indices = self.relative_indices.expand(self.heads, -1)
+ rel_pos_enc = self.pos_enc.gather(-1, indices)
+ rel_pos_enc = rel_pos_enc.unflatten(-1, (h * w, h * w))
+
+ att = att * self.scale + rel_pos_enc
+ att = F.softmax(att, dim=-2)
+
+ out = values @ att
+ out = out.view(b, -1, h, w)
+ out = self.unifyheads(out)
+ return out
+
+ @staticmethod
+ def get_indices(h, w):
+ y = torch.arange(h, dtype=torch.long)
+ x = torch.arange(w, dtype=torch.long)
+
+ y1, x1, y2, x2 = torch.meshgrid(y, x, y, x, indexing='ij')
+ indices = (y1 - y2 + h - 1) * (2 * w - 1) + x1 - x2 + w - 1
+ indices = indices.flatten()
+
+ return indices
+
+class FeedForward(nn.Sequential):
+ def __init__(self, in_channels, out_channels, mult=4):
+ hidden_channels = in_channels * mult
+ super().__init__(
+ nn.Conv2d(in_channels, hidden_channels, 1),
+ nn.GELU(),
+ nn.Conv2d(hidden_channels, out_channels, 1)
+ )
+
+class TransformerBlock(nn.Sequential):
+ def __init__(self, channels, head_channels, shape, p_drop=0.):
+ super().__init__(
+ Residual(
+ LayerNormChannels(channels),
+ SelfAttention2d(channels, channels, head_channels, shape),
+ nn.Dropout(p_drop)
+ ),
+ Residual(
+ LayerNormChannels(channels),
+ FeedForward(channels, channels),
+ nn.Dropout(p_drop)
+ )
+ )
+
+class TransformerStack(nn.Sequential):
+ def __init__(self, num_blocks, channels, head_channels, shape, p_drop=0.):
+ layers = [TransformerBlock(channels, head_channels, shape, p_drop) for _ in range(num_blocks)]
+ super().__init__(*layers)
+
+class ToPatches(nn.Sequential):
+ def __init__(self, in_channels, channels, patch_size, hidden_channels=32):
+ super().__init__(
+ nn.Conv2d(in_channels, hidden_channels, 3, padding=1),
+ nn.GELU(),
+ nn.Conv2d(hidden_channels, channels, patch_size, stride=patch_size)
+ )
+
+class AddPositionEmbedding(nn.Module):
+ def __init__(self, channels, shape):
+ super().__init__()
+ self.pos_embedding = nn.Parameter(torch.Tensor(channels, *shape))
+
+ def forward(self, x):
+ return x + self.pos_embedding
+
+class ToEmbedding(nn.Sequential):
+ def __init__(self, in_channels, channels, patch_size, shape, p_drop=0.):
+ super().__init__(
+ ToPatches(in_channels, channels, patch_size),
+ AddPositionEmbedding(channels, shape),
+ nn.Dropout(p_drop)
+ )
+
+# main model
+class Head(nn.Sequential):
+ def __init__(self, in_channels, classes, p_drop=0.):
+ super().__init__(
+ LayerNormChannels(in_channels),
+ nn.GELU(),
+ nn.AdaptiveAvgPool2d(1),
+ nn.Flatten(),
+ nn.Dropout(p_drop),
+ nn.Linear(in_channels, classes)
+ )
+class RelViT(nn.Sequential):
+ def __init__(self, classes, image_size, channels, head_channels, num_blocks, patch_size,
+ in_channels=3, emb_p_drop=0., trans_p_drop=0., head_p_drop=0.):
+ reduced_size = image_size // patch_size
+ shape = (reduced_size, reduced_size)
+ super().__init__(
+ ToEmbedding(in_channels, channels, patch_size, shape, emb_p_drop),
+ TransformerStack(num_blocks, channels, head_channels, shape, trans_p_drop),
+ Head(channels, classes, head_p_drop)
+ )
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ for m in self.modules():
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ nn.init.kaiming_normal_(m.weight)
+ if m.bias is not None: nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.weight, 1.)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, AddPositionEmbedding):
+ nn.init.normal_(m.pos_embedding, mean=0.0, std=0.02)
+ elif isinstance(m, SelfAttention2d):
+ nn.init.normal_(m.pos_enc, mean=0.0, std=0.02)
+ elif isinstance(m, Residual):
+ nn.init.zeros_(m.gamma)
+
+ def separate_parameters(self):
+ parameters_decay = set()
+ parameters_no_decay = set()
+ modules_weight_decay = (nn.Linear, nn.Conv2d)
+ modules_no_weight_decay = (nn.LayerNorm,)
+
+ for m_name, m in self.named_modules():
+ for param_name, param in m.named_parameters():
+ full_param_name = f"{m_name}.{param_name}" if m_name else param_name
+
+ if isinstance(m, modules_no_weight_decay):
+ parameters_no_decay.add(full_param_name)
+ elif param_name.endswith("bias"):
+ parameters_no_decay.add(full_param_name)
+ elif isinstance(m, Residual) and param_name.endswith("gamma"):
+ parameters_no_decay.add(full_param_name)
+ elif isinstance(m, AddPositionEmbedding) and param_name.endswith("pos_embedding"):
+ parameters_no_decay.add(full_param_name)
+ elif isinstance(m, SelfAttention2d) and param_name.endswith("pos_enc"):
+ parameters_no_decay.add(full_param_name)
+ elif isinstance(m, modules_weight_decay):
+ parameters_decay.add(full_param_name)
+
+ # sanity check
+ assert len(parameters_decay & parameters_no_decay) == 0
+ assert len(parameters_decay) + len(parameters_no_decay) == len(list(self.parameters()))
+
+ return parameters_decay, parameters_no_decay
+
+def ViT(num_classes, image_size):
+ return RelViT(num_classes, image_size, channels=256, head_channels=32, num_blocks=8, patch_size=2,
+ emb_p_drop=0., trans_p_drop=0., head_p_drop=0.3)
\ No newline at end of file
diff --git a/utils/tinyimagenet.py b/utils/tinyimagenet.py
new file mode 100644
index 0000000..2986e83
--- /dev/null
+++ b/utils/tinyimagenet.py
@@ -0,0 +1,105 @@
+# Copyright (C) 2022 Leonardo Romor
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Simple Tiny ImageNet dataset utility class for pytorch."""
+
+import os
+
+import shutil
+
+from torchvision.datasets import ImageFolder
+from torchvision.datasets.utils import verify_str_arg
+from torchvision.datasets.utils import download_and_extract_archive
+
+
+def normalize_tin_val_folder_structure(path,
+ images_folder='images',
+ annotations_file='val_annotations.txt'):
+ # Check if files/annotations are still there to see
+ # if we already run reorganize the folder structure.
+ images_folder = os.path.join(path, images_folder)
+ annotations_file = os.path.join(path, annotations_file)
+
+ # Exists
+ if not os.path.exists(images_folder) \
+ and not os.path.exists(annotations_file):
+ if not os.listdir(path):
+ raise RuntimeError('Validation folder is empty.')
+ return
+
+ # Parse the annotations
+ with open(annotations_file) as f:
+ for line in f:
+ values = line.split()
+ img = values[0]
+ label = values[1]
+ img_file = os.path.join(images_folder, values[0])
+ label_folder = os.path.join(path, label)
+ os.makedirs(label_folder, exist_ok=True)
+ try:
+ shutil.move(img_file, os.path.join(label_folder, img))
+ except FileNotFoundError:
+ continue
+
+ os.sync()
+ assert not os.listdir(images_folder)
+ shutil.rmtree(images_folder)
+ os.remove(annotations_file)
+ os.sync()
+
+
+class TinyImageNet(ImageFolder):
+ """Dataset for TinyImageNet-200"""
+ base_folder = 'tiny-imagenet-200'
+ zip_md5 = '90528d7ca1a48142e341f4ef8d21d0de'
+ splits = ('train', 'val')
+ filename = 'tiny-imagenet-200.zip'
+ url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
+
+ def __init__(self, root, split='train', download=False, **kwargs):
+ self.data_root = os.path.expanduser(root)
+ self.split = verify_str_arg(split, "split", self.splits)
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError('Dataset not found.' +
+ ' You can use download=True to download it')
+ super().__init__(self.split_folder, **kwargs)
+
+ @property
+ def dataset_folder(self):
+ return os.path.join(self.data_root, self.base_folder)
+
+ @property
+ def split_folder(self):
+ return os.path.join(self.dataset_folder, self.split)
+
+ def _check_exists(self):
+ return os.path.exists(self.split_folder)
+
+ def extra_repr(self):
+ return "Split: {split}".format(**self.__dict__)
+
+ def download(self):
+ if self._check_exists():
+ return
+ download_and_extract_archive(
+ self.url, self.data_root, filename=self.filename,
+ remove_finished=True, md5=self.zip_md5)
+ assert 'val' in self.splits
+ normalize_tin_val_folder_structure(
+ os.path.join(self.dataset_folder, 'val'))
\ No newline at end of file
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000..61cf98a
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,49 @@
+import argparse
+
+class BaseParser(dict):
+ def __init__(self, cmd_args, *file_args):
+ super(BaseParser, self).__init__()
+ args = [*file_args, cmd_args] # cmd_args has higher priority than file_args
+
+ # default options for the parser, which can be customized for specific applications
+ self.choice_dict = {}
+ self.default_dict = {}
+ self.allowed_overwritten_list = {'seed': None}
+
+ for i_d, d in enumerate(args):
+ # print(i_d)
+ if isinstance(d, argparse.Namespace):
+ d = vars(d)
+
+ for k, v in d.items():
+ assert k not in self.keys() or k in self.allowed_overwritten_list.keys(), f'duplicated arguments {k}, please check the configuration file.'
+
+ if k in self.allowed_overwritten_list.keys() and v == self.allowed_overwritten_list[k]:
+ continue
+ # print(f'\t{k}: {v}')
+ self.add_item(k, v)
+
+ # check whether the default options has been in args; otherswise, add it.
+ for k in self.default_dict.keys():
+ if k not in self.keys():
+ self[k] = self.default_dict[k]
+
+ def add_item(self, k, v):
+ # 1. convert '-' to '_'; 2. replace string 'None' with NoneType
+ k = k.replace('-', '_')
+
+ #check whether arguments match the limited choices
+ if k in self.choice_dict.keys() and v not in self.choice_dict[k]:
+ raise ValueError(f'Illegal argument \'{k}\' for choices {self.choice_dict[k]}')
+
+ # convert string None to Nonetype, which is a side effect of using yaml
+ self[k] = None if v == 'None' else v
+
+ def __getattr__(self, name):
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(f'{name}')
+
+ def __setattr__(self, key, val):
+ self[key] = val
\ No newline at end of file