Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample fix #3

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,22 @@ Assuming you have [miniconda](https://docs.conda.io/en/latest/miniconda.html) in

### Training
The entrypoint `train` is the main driver for training and accepts parameters using Hydra syntax.
The available parameters for configuration can be found by running `train` --help or by looking in the `src/walkjump/hydra_config` directory
The available parameters for configuration can be found by running `train` --help (```walkjump_train --help```) or by looking in the `src/walkjump/hydra_config` directory

### Sampling
The entrypoint `sample` is the main driver for training and accepts parameters using Hydra syntax.
The available parameters for configuration can be found by running `sample` --help or by looking in the `src/walkjump/hydra_config` directory
The available parameters for configuration can be found by running `sample` --help (```walkjump_sample --help```) or by looking in the `src/walkjump/hydra_config` directory

### Example
```bash
conda activate wj
walkjump_train data.csv_data_path="data/poas.csv.gz"
```
then
```bash
walkjump_sample 'model.checkpoint_path="checkpoints/epoch=17-step=363937-val_loss=0.0040.ckpt"' designs.output_csv=my_samples.csv
```
(Extra quotation marks to handle "=" in file path)

## Contributing

Expand Down
4 changes: 2 additions & 2 deletions src/walkjump/cmdline/_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf

from walkjump.cmdline.utils import instantiate_redesign_mask, instantiate_seeds
from walkjump.cmdline.utils import instantiate_redesign_mask, instantiate_seeds, instantiate_model_for_sample_mode
from walkjump.sampling import walkjump


Expand All @@ -28,7 +28,7 @@ def sample(cfg: DictConfig) -> bool:
seeds = instantiate_seeds(cfg.designs)

if not cfg.dryrun:
model = hydra.utils.instantiate(cfg.model).to(device)
model = instantiate_model_for_sample_mode(cfg.model).to(device)
sample_df = walkjump(
seeds,
model,
Expand Down
17 changes: 9 additions & 8 deletions src/walkjump/hydra_config/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,25 @@ defaults:
- setup: default

model:
_target_: walkjump.cmdline.utils.instantiate_model_for_sample_mode
model_type: denoise
checkpoint_path: ???
denoise_path: null

langevin:
sigma: 1.0
delta: 0.5
lipschitz: 1.0
friction: 1.0
steps: 20
chunksize: 8
sigma: 1.0 # Noise level
delta: 0.5 # Step size
lipschitz: 1.0 # Lipschitz constant, related to mass: u = pow(lipschitz, -1)
friction: 1.0 # (Gamma) Dampening term
steps: 20 # (K) Number of steps in chain
chunksize: 8 # Used for chunking the batch to save memory. Providing chunksize = N will force the sampling to occur in N batches.

designs:
output_csv: samples.csv
redesign_regions: null
seeds: denovo
num_samples: 100
limit_seeds: 10
chunksize: 8

device: null
device: null
dryrun: false