Skip to content

Commit

Permalink
chagne default granurity to None
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu committed Feb 2, 2024
1 parent e75394c commit d306eb2
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions python/python/lance/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
with_row_id: bool = False,
rank: Optional[int] = None,
world_size: Optional[int] = None,
shard_granularity: Optional[Literal["fragment", "batch"]] = "fragment",
shard_granularity: Optional[Literal["fragment", "batch"]] = None,
batch_readehead: int = 16,
to_tensor_fn: Optional[
callable[[pa.RecordBatch], Union[dict[str, torch.Tensor], torch.Tensor]]
Expand Down Expand Up @@ -215,20 +215,21 @@ def __init__(
self.rank = rank
self.world_size = world_size
self.shard_granularity = shard_granularity
if not sampler:
if sampler is None:
if shard_granularity is None:
sampler = FullScanSampler()
if (rank is not None or world_size is not None):
warnings.warn(
"rank and world_size are deprecated,"
+ " use SharedFragmentSampler instead.",
DeprecationWarning,
)
sampler = ShardedFragmentSampler(rank=rank, world_size=world_size)
else:
sampler = FullScanSampler()
elif shard_granularity == "batch":
sampler = ShardedBatchSampler(rank, world_size)
elif shard_granularity == "fragment":
sampler = ShardedFragmentSampler(rank, world_size)
elif rank is not None and world_size is not None:
warnings.warn(
"rank and world_size are deprecated,"
+ " use SharedFragmentSampler instead.",
DeprecationWarning,
)
sampler = ShardedFragmentSampler(rank, world_size)
else:
raise ValueError("Invalid shard_granularity: {}")

Expand Down

0 comments on commit d306eb2

Please sign in to comment.