Skip to content

Commit

Permalink
update ood predict dataloader and some examples
Browse files Browse the repository at this point in the history
  • Loading branch information
SevenLJY committed Mar 19, 2024
1 parent 9a72cab commit a11c309
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 2 deletions.
69 changes: 68 additions & 1 deletion datamodules/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __getitem__(self, idx):

return data, cond


class IDPredDataset(BaseDataset):
'''In-distribution prediction dataset'''
def __init__(self, hparams, model_ids):
Expand Down Expand Up @@ -187,3 +186,71 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.model_ids)

class OODPredDataset(Dataset):
'''Out-of-distribution prediction dataset'''
def __init__(self, hparams, ref_file):
self.hparams = hparams
ref = json.load(open(ref_file, 'r'))
tree = ref['diffuse_tree']
self.cats, self.adjs, self.adjs_plot, self.hashes = [], [], [], []
self.num_nodes = []
for cat in tree:
for edges in tree[cat]:
adj, adj_plot = self.get_adj(edges)
h, n_nodes = self.get_hashcode(edges)
self.hashes.append(h)
self.adjs.append(adj)
self.adjs_plot.append(adj_plot)
self.cats.append(cat)
self.num_nodes.append(n_nodes)

def get_hashcode(self, edges):
G = nx.DiGraph()
G.add_edges_from(edges)
hashcode = nx.weisfeiler_lehman_graph_hash(G)
n_nodes = len(G.nodes)
return hashcode, n_nodes

def get_adj(self, edges):
K = self.hparams.K
adj = np.zeros((K, K))
adj_plot = np.zeros((K, K))
for edge in edges:
adj[edge[0], edge[1]] = 1
adj[edge[1], edge[0]] = 1
adj_plot[edge[1], edge[0]] = 1
adj[0][0] = 1
adj_plot[0][0] = 1
return adj.astype(np.float32), adj_plot.astype(np.float32)

def __getitem__(self, idx):
K = self.hparams.K
adj = self.adjs[idx]
cat = self.cats[idx]
adj_plot = self.adjs_plot[idx]
h = self.hashes[idx]
n_nodes = self.num_nodes[idx]
cond = {}
cond['obj_cat'] = cat
cond['cat'] = cat_ref[cat]
cond['adj'] = adj
cond['adj_plot'] = adj_plot
cond['n_nodes'] = n_nodes
cond['tree_hash'] = h
# key padding mask
pad_mask = np.zeros((K*5, K*5))
pad_mask[:, :n_nodes*5] = 1
cond['key_pad_mask'] = pad_mask.astype(np.float32)
# adj mask
adj_mask = cond['adj'].copy()
adj_mask = adj_mask.repeat(5, axis=0).repeat(5, axis=1)
cond['adj_mask'] = adj_mask.astype(np.float32)
# attr mask
attr_mask = np.eye(K, K)
attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
cond['attr_mask'] = attr_mask.astype(np.float32)
data = np.zeros((K * 5, 6)).astype(np.float32)
return data, cond

def __len__(self):
return len(self.cats
3 changes: 2 additions & 1 deletion datamodules/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def setup(self, stage=None):
elif stage == 'predict':
mode = self.hparams.pred_mode
if mode == 'ood':
self.pred_dataset = OODPredDataset(self.hparams, 'indexes/ood_graphs.json')
# load out-of-distribution graphs (manually constructed)
self.pred_dataset = OODPredDataset(self.hparams, 'datamodules/ood_graphs.json')
self.pred_size = 1
else:
self._prepare_predict()
Expand Down
56 changes: 56 additions & 0 deletions datamodules/ood_graphs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"diffuse_tree": {
"Dishwasher": [
[[0,1], [0,2]],
[[0,1], [0,2], [0,3], [0,4]],
[[0,1], [0,2], [0,3], [2,4], [3,5]],
[[0,1], [0,2], [0,3], [1,4], [2,5], [3,6]]
],
"Microwave": [
[[0,1], [1,2]],
[[0,1], [0,2], [1,3], [2,4]],
[[0,1], [0,2], [0,3], [0,4], [4,5]],
[[0,1], [0,2], [0,3], [0,4], [0,5], [0,6]]
],
"Oven": [
[[0,1], [0,2]],
[[0,1], [0,2], [0,3]],
[[0,1], [0,2], [2,3]],
[[0,1], [0,2], [0,3], [3,4]],
[[0,1], [0,2], [0,3], [0,4], [4,5]],
[[0,1], [0,2], [0,3], [0,4], [4,5], [3,6]]
],
"Refrigerator":[
[[0,1], [0,2], [0,3], [0,4]],
[[0,1], [0,2], [2,3]],
[[0,1], [0,2], [1,3], [1,4]],
[[0,1], [0,2], [1,3], [1,4], [2,5], [2,6]]
],
"Safe": [
[[0,1], [0,2], [0,3]],
[[0,1], [0,2]],
[[0,1], [0,2], [1,3], [2,4]],
[[0,1], [0,2], [0,3], [0,4]]
],
"StorageFurniture": [
[[0,1], [0,2], [1,3], [1,4], [2,5], [2,6]],
[[0,1], [0,2], [0,3], [0,4], [0,5]],
[[0,1], [0,2], [0,3], [0,4], [1,5], [2,6], [3,7]],
[[0,1], [0,2], [0,3], [3,4], [3,5]],
[[0,1], [0,2], [1,3], [1,4], [2,5]]
],
"Table": [
[[0,1], [0,2], [1,3], [2,4], [2,5]],
[[0,1], [0,2], [0,3], [2,4], [2,5], [3,6]],
[[0,1], [1,2]],
[[0,1], [0,2], [0,3]],
[[0,1], [0,2], [1,3], [1,4]]
],
"WashingMachine": [
[[0,1], [1,2]],
[[0,1], [1,2], [1,3]],
[[0,1], [0,2], [2,3]],
[[0,1], [0,2], [1,3], [2,4]]
]
}
}

0 comments on commit a11c309

Please sign in to comment.