Skip to content

Commit

Permalink
add cond_label mode
Browse files Browse the repository at this point in the history
  • Loading branch information
SevenLJY committed Apr 2, 2024
1 parent 76ff534 commit 2dd510c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 3 additions & 3 deletions datamodules/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ def _prepare_predict(self):
self.pred_size=len(model_ids)
elif self.hparams.pred_mode == 'cond_graph':
self.pred_size = 1
hb_path = self.hparams.hash_table
hb_path = self.hparams.hash_table # hashtable records models corresponding to each graph topology under each category
hb = json.load(open(hb_path, 'r'))
model_ids = []
for cat in hb.keys():
for h in hb[cat]:
model_ids.append(hb[cat][h][0])
model_ids.append(hb[cat][h][0]) # take one model for each graph topology as the condition
self.pred_dataset = IDPredDataset(self.hparams, model_ids)
else: # condition on node attributes
self.pred_size = 1
hb_path = self.hparams.hash_table
hb = json.load(open(hb_path, 'r'))
model_ids = []
n_examples = 8
n_examples = 5 # load the first 8 models for each graph topology as the condition
for cat in hb.keys():
for h in hb[cat]:
if len(hb[cat][h]) < n_examples:
Expand Down
2 changes: 2 additions & 0 deletions systems/cage.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def pred_cond_other(self, batch, batch_idx):
indices = torch.arange(1, 32*5, step=5) # 1 is type
elif mode == 'cond_axis':
indices = torch.arange(2, 32*5, step=5) # 2 is axis
elif mode == 'cond_label':
indices = torch.arange(4, 32*5, step=5) # 4 is semantic label
else:
raise NotImplementedError
# init the noisy input
Expand Down

0 comments on commit 2dd510c

Please sign in to comment.