Skip to content

Commit

Permalink
style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
maxymnaumchyk committed Oct 1, 2024
1 parent 303f78f commit 7f9bb74
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/awkward/operations/ak_to_raggedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def _impl(array):

# keep the same device
ak_device = ak.backend(array)
if ak_device not in ['cuda', 'cpu']:
if ak_device not in ["cuda", "cpu"]:
raise ValueError("""Only 'cpu' and 'cuda' backend conversions are allowed""")

if ak_device == 'cpu':
device = 'CPU:0'
if ak_device == "cpu":
device = "CPU:0"
else:
device = 'GPU:0'
device = "GPU:0"

with tf.device(device):
if isinstance(array, ak.contents.numpyarray.NumpyArray):
Expand All @@ -66,22 +66,27 @@ def _impl(array):
values = _cupy_to_tensor(values)

return tf.RaggedTensor.from_row_splits(
values=values, row_splits=[0, array.__len__()]
values=values, row_splits=[0, array.__len__()]
)

else:
flat_values, nested_row_splits = _recursive_call(array, ())

ragged_tensor = tf.RaggedTensor.from_nested_row_splits(flat_values, nested_row_splits)
print(ragged_tensor[0][0].device)
ragged_tensor = tf.RaggedTensor.from_nested_row_splits(
flat_values, nested_row_splits
)
# print(ragged_tensor[0][0].device)
return ragged_tensor


def _cupy_to_tensor(cupy):
# converts cupy directly to tensor,
# since `tf.RaggedTensor.from_nested_row_splits` can not work with Cupy arrays
import tensorflow as tf

return tf.experimental.dlpack.from_dlpack(cupy.toDlpack())


def _recursive_call(layout, offsets_arr):
try:
# change all the possible layout types to ListOffsetArray
Expand Down

0 comments on commit 7f9bb74

Please sign in to comment.