Skip to content

Commit

Permalink
remove input-output iters
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Dec 25, 2023
1 parent f972e10 commit efdad3d
Showing 1 changed file with 0 additions and 66 deletions.
66 changes: 0 additions & 66 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,72 +1361,6 @@ def si_iter(self, batch_size=1, shuffle=False):
batch_session_items = [[self.uir_tuple[1][i] for i in ids] for ids in batch_mapped_ids]
yield batch_session_indices, batch_mapped_ids, batch_session_items

def io_iter(self, batch_size=1, shuffle=False):
"""Paralellize mini-batch of input-output items. Create an iterator over data yielding batch of input item indices, batch of output item indices,
batch of start masking, batch of end masking, and batch of valid ids (relative positions of current sequences in the last batch).
Parameters
----------
batch_size: int, optional, default = 1
shuffle: bool, optional, default: False
If `True`, orders of triplets will be randomized. If `False`, default orders kept.
Returns
-------
iterator : batch of input item indices, batch of output item indices, batch of starting sequence mask, batch of ending sequence mask, batch of valid ids
"""
start_mask = np.zeros(batch_size, dtype="int")
end_mask = np.ones(batch_size, dtype="int")
input_iids = None
output_iids = None
l_pool = []
c_pool = [None for _ in range(batch_size)]
sizes = np.zeros(batch_size, dtype="int")
for _, batch_mapped_ids in self.s_iter(batch_size, shuffle):
l_pool += batch_mapped_ids
while len(l_pool) > 0:
if end_mask.sum() == 0:
input_iids = self.uir_tuple[1][[mapped_ids[-sizes[idx]] for idx, mapped_ids in enumerate(c_pool)]]
output_iids = self.uir_tuple[1][[mapped_ids[-sizes[idx] + 1] for idx, mapped_ids in enumerate(c_pool)]]
sizes -= 1
for idx, size in enumerate(sizes):
if size == 1:
end_mask[idx] = 1
yield input_iids, output_iids, start_mask, end_mask, np.arange(batch_size, dtype="int")
start_mask.fill(0) # reset start masking
while end_mask.sum() > 0 and len(l_pool) > 0:
next_seq = l_pool.pop()
if len(next_seq) > 1:
idx = np.nonzero(end_mask)[0][0]
end_mask[idx] = 0
start_mask[idx] = 1
c_pool[idx] = next_seq
sizes[idx] = len(c_pool[idx])

valid_id = np.ones(batch_size, dtype="int")
while True:
for idx, size in enumerate(sizes):
if size == 1:
end_mask[idx] = 1
valid_id[idx] = 0
input_iids = self.uir_tuple[1][[mapped_ids[-sizes[idx]] for idx, mapped_ids in enumerate(c_pool) if sizes[idx] > 1]]
output_iids = self.uir_tuple[1][[mapped_ids[-sizes[idx] + 1] for idx, mapped_ids in enumerate(c_pool) if sizes[idx] > 1]]
sizes -= 1
for idx, size in enumerate(sizes):
if size == 1:
end_mask[idx] = 1
start_mask = start_mask[np.nonzero(valid_id)[0]]
end_mask = end_mask[np.nonzero(valid_id)[0]]
sizes = sizes[np.nonzero(valid_id)[0]]
c_pool = [_ for _, valid in zip(c_pool, valid_id) if valid > 0]
yield input_iids, output_iids, start_mask, end_mask, np.nonzero(valid_id)[0]
valid_id = np.ones(len(input_iids), dtype="int")
if end_mask.sum() == len(input_iids):
break
start_mask.fill(0) # reset start masking

def usi_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of user indices, batch of session indices, batch of mapped ids, and batch of sessions' items
Expand Down

0 comments on commit efdad3d

Please sign in to comment.