Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
ruocheng.guo committed Jun 11, 2024
1 parent 4a59745 commit a2c87a1
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 14 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,8 @@ iDCF/README.md
iDCF/save_as_params.ipynb
iDCF/tune_script.py
iDCF/vae_exposure.py
debug_results/
job_def.yaml
job_params.yaml
rh2_entrypoint.py
dist_figs/
22 changes: 22 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Bottleneck==1.3.7
bytedrh2==1.18.7
densratio==0.3.0
jax==0.4.19
joblib==1.3.2
matplotlib==1.2.0
matplotlib==3.8.2
numpy==1.11.1
numpy==1.6.2
numpy==1.25.0
pandas==0.14.1
pyreadr==0.5.0
quantile_forest==1.2.3
ray==2.9.0
Requests==2.32.3
scikit_learn==0.12.1
scikit_learn==1.3.2
scipy==0.18.0
scipy==0.11.0
seaborn==0.13.2
torch==2.1.2
tqdm==4.66.1
15 changes: 2 additions & 13 deletions run_syn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def main(args):
np.random.seed(args.seed)

n_observation = args.n_obs
n_intervention_list = [100, 250, 500, 750, 1000]
n_intervention_list = [100, 250, 500, 750, 1000, 10, 20, 50]
# n_intervention_list = [10,20,50]

print(n_intervention_list)

Expand All @@ -78,18 +79,6 @@ def main(args):
d=args.x_dim,
err_scale=err_scale,
hidden_conf=args.HC)
elif args.dataset == 'ihdp':
# as ihdp is a small dataset w. 740+ samples
# we only allow the n_intervention to be no larger than 500
if n_intervention > 500:
print("n_intervention must be no larger than 500 for ihdp dataset")
return

df_o, df_i = IHDP_w_HC(n_intervention, args.seed, d=24,
hidden_confounding=True, beta_u=args.conf_strength,
root="/mnt/bn/confrank2/causal_TCP/data/IHDP")

n_observation = df_o.shape[0]

else:
raise ValueError('select a dataset from [synthetic]')
Expand Down
3 changes: 2 additions & 1 deletion run_syn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ test_frac=0.02
# methods=("inexact" "exact" "wcp" "naive")
# methods=("TCP")

dr_model="MLP"
dr_model="DR"

n_obs=10000

Expand All @@ -23,6 +23,7 @@ methods=("${@:4}")
seeds=($(seq 1234 1238))

# each run considers n_int = (100 500 1000 5000)

for seed in "${seeds[@]}"
do
for method in "${methods[@]}"
Expand Down

0 comments on commit a2c87a1

Please sign in to comment.