-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,3 +95,4 @@ for f in mfiles | |
# compute and save the score | ||
save_sample_score(f, data, seed, ac) | ||
end | ||
@info "DONE" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
using DrWatson | ||
@quickactivate | ||
using ArgParse | ||
using GenerativeAD | ||
import StatsBase: fit!, predict | ||
using StatsBase | ||
using BSON, FileIO | ||
using Flux | ||
using GenerativeModels | ||
using DistributionsAD | ||
using ValueHistories | ||
|
||
s = ArgParseSettings() | ||
@add_arg_table! s begin | ||
"modelname" | ||
default = "vae" | ||
arg_type = String | ||
help = "model name" | ||
"datatype" | ||
default = "tabular" | ||
arg_type = String | ||
help = "tabular or image" | ||
"dataset" | ||
default = "iris" | ||
arg_type = String | ||
help = "dataset" | ||
"--seed" | ||
default = nothing | ||
help = "if specified, only results for a given seed will be recomputed" | ||
"--anomaly_class" | ||
default = nothing | ||
help = "if specified, only results for a given anomaly class will be recomputed" | ||
end | ||
parsed_args = parse_args(ARGS, s) | ||
@unpack dataset, datatype, modelname, seed, anomaly_class = parsed_args | ||
|
||
masterpath = datadir("experiments/$(datatype)/$(modelname)/$(dataset)") | ||
files = GenerativeAD.Evaluation.collect_files(masterpath) | ||
mfiles = filter(f->occursin("model", f), files) | ||
if seed != nothing | ||
filter!(x->occursin("/seed=$seed/", x), mfiles) | ||
end | ||
if anomaly_class != nothing | ||
filter!(x->occursin("/ac=$(anomaly_class)/", x), mfiles) | ||
end | ||
|
||
sample_score_batched(m,x,L,batchsize) = | ||
vcat(map(y-> Base.invokelatest(GenerativeAD.Models.latent_score, m, y, L), Flux.Data.DataLoader(x, batchsize=batchsize))...) | ||
sample_score_batched_gpu(m,x,L,batchsize) = | ||
vcat(map(y-> cpu(Base.invokelatest(GenerativeAD.Models.latent_score, m, gpu(Array(y)), L)), Flux.Data.DataLoader(x, batchsize=batchsize))...) | ||
|
||
function save_sample_score(f::String, data, seed::Int, ac=nothing) | ||
# get model | ||
savepath = dirname(f) | ||
mdata = load(f) | ||
model = mdata["model"] | ||
L = 100 | ||
|
||
# setup entries to be saved | ||
save_entries = ( | ||
modelname = modelname, | ||
fit_t = mdata["fit_t"], | ||
history = mdata["history"], | ||
dataset = dataset, | ||
npars = sum(map(p->length(p), Flux.params(model))), | ||
model = nothing, | ||
seed = seed | ||
) | ||
save_entries = (ac == nothing) ? save_entries : merge(save_entries, (ac=ac,)) | ||
if ac == nothing | ||
result = (x -> sample_score_batched(model, x, L, 512), | ||
merge(mdata["parameters"], (L = L, score = "latent-sampled"))) | ||
else | ||
result = (x -> sample_score_batched_gpu(gpu(model), x, L, 256), | ||
merge(mdata["parameters"], (L = L, score = "latent-sampled"))) | ||
end | ||
|
||
# if the file does not exist already, compute the scores | ||
savef = joinpath(savepath, savename(result[2], "bson", digits=5)) | ||
if !isfile(savef) | ||
@info "computing sample score for $f" | ||
GenerativeAD.experiment(result..., data, savepath; save_entries...) | ||
end | ||
end | ||
|
||
for f in mfiles | ||
# get data | ||
savepath = dirname(f) | ||
local seed = parse(Int, replace(basename(savepath), "seed=" => "")) | ||
ac = occursin("ac=", savepath) ? parse(Int, replace(basename(dirname(savepath)), "ac=" => "")) : nothing | ||
data = (ac == nothing) ? | ||
GenerativeAD.load_data(dataset, seed=seed) : | ||
GenerativeAD.load_data(dataset, seed=seed, anomaly_class_ind=ac) | ||
|
||
# compute and save the score | ||
save_sample_score(f, data, seed, ac) | ||
end | ||
@info "DONE" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/bin/bash | ||
#SBATCH --time=24:00:00 | ||
#SBATCH --nodes=1 --ntasks-per-node=2 --cpus-per-task=2 | ||
#SBATCH --gres=gpu:1 | ||
#SBATCH --partition=gpu | ||
#SBATCH --mem=80G | ||
|
||
MODEL=$1 | ||
DATATYPE=$2 | ||
DATASET=$3 | ||
SEED=$4 | ||
AC=$5 | ||
|
||
module load Julia/1.5.1-linux-x86_64 | ||
module load Python/3.8.2-GCCcore-9.3.0 | ||
|
||
julia ./sample_score_latent.jl $MODEL $DATATYPE $DATASET $SEED $AC |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/bin/bash | ||
#SBATCH --time=24:00:00 | ||
#SBATCH --nodes=1 --ntasks-per-node=2 --cpus-per-task=1 | ||
#SBATCH --mem=20G | ||
|
||
MODEL=$1 | ||
DATATYPE=$2 | ||
DATASET=$3 | ||
SEED=$4 | ||
AC=$5 | ||
|
||
module load Julia/1.5.1-linux-x86_64 | ||
module load Python/3.8.2-GCCcore-9.3.0 | ||
|
||
julia ./sample_score_latent.jl $MODEL $DATATYPE $DATASET $SEED $AC | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#!/bin/bash | ||
# This runs jacodeco for a certain model over a set of datasets. | ||
# USAGE EXAMPLE | ||
# ./jacodeco_run_parallel.sh vae tabular datasets_tabular.txt | ||
MODEL=$1 # which model to run | ||
DATATYPE=$2 # tabular/image | ||
DATASET_FILE=$3 # file with dataset list | ||
SEED=$4 | ||
AC=$5 | ||
|
||
LOG_DIR="${HOME}/logs/sample_score" | ||
|
||
if [ ! -d "$LOG_DIR" ]; then | ||
mkdir $LOG_DIR | ||
fi | ||
|
||
|
||
while read d; do | ||
if [ $d = "MNIST" ] || [ $d = "FashionMNIST" ] || [ $d = "CIFAR10" ] || [ $d = "SVHN2" ] | ||
then | ||
RUNSCRIPT="./sample_score_latent_gpu_run.sh" | ||
else | ||
RUNSCRIPT="./sample_score_latent_run.sh" | ||
fi | ||
|
||
# submit to slurm | ||
sbatch \ | ||
--output="${LOG_DIR}/${DATATYPE}_${MODEL}_${d}-%A.out" \ | ||
$RUNSCRIPT $MODEL $DATATYPE $d $SEED $AC | ||
done < ${DATASET_FILE} |