Skip to content

Commit

Permalink
Add input-output items iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Dec 19, 2023
1 parent a4bccf3 commit 5c41405
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,3 +1360,35 @@ def si_iter(self, batch_size=1, shuffle=False):
for batch_session_indices, batch_mapped_ids in self.s_iter(batch_size, shuffle):
batch_session_items = [[self.uir_tuple[1][i] for i in ids] for ids in batch_mapped_ids]
yield batch_session_indices, batch_mapped_ids, batch_session_items

def io_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of input item indices, batch of output item indices.
A sequence `a b c d` produces [a, b, c] and [b, c, d] as input items and output items respectively.
Parameters
----------
batch_size: int, optional, default = 1
shuffle: bool, optional, default: False
If `True`, orders of triplets will be randomized. If `False`, default orders kept.
Returns
-------
iterator : batch of input item indices, batch of output item indices
"""
input_iids = np.asarray([], dtype="int")
output_iids = np.asarray([], dtype="int")
for _, [mapped_ids] in self.s_iter(1, shuffle):
if len(mapped_ids) < 2:
continue
input_iids = np.concatenate([input_iids, self.uir_tuple[1][mapped_ids[:-1]]])
output_iids = np.concatenate([output_iids, self.uir_tuple[1][mapped_ids[1:]]])
if len(input_iids) >= batch_size:
batch_input_iids = input_iids[:batch_size]
batch_output_iids = output_iids[:batch_size]
input_iids = input_iids[batch_size:]
output_iids = output_iids[batch_size:]
yield batch_input_iids, batch_output_iids
if len(input_iids) >= 0:
yield input_iids, output_iids

0 comments on commit 5c41405

Please sign in to comment.