From 3e660920d2234217746ec766a09872728be3f8ba Mon Sep 17 00:00:00 2001 From: Daniel Wetzel <77532287+danielwetzel@users.noreply.github.com> Date: Fri, 26 Apr 2024 18:42:41 +0200 Subject: [PATCH] [DAPHNE-#499] Data exchange with Pandas, PyTorch & TensorFlow via shared memory (#585) - Efficient data transfer via shared memory in DaphneLib. - Designed all functions in a zero-copy manner with strong focus on performance. - Added pandas shared memory support for frames. - Different pandas frame types (e.g., Series, Sparse, Categorical) are automatically transformed to standard frames. - With the argument "keepIndex=True" in the from_pandas function, the original df index is stored as the first column named "index". - With the argument "useIndexColumn=True" the Index column from a DAPHNE Frame is stored as the index of the pandas df and no longer as separate column. - Added PyTorch and TensorFlow shared memory support for 2d & nd tensors (nd tensors will be flattened to 2d). - Tensors are transformed to matrices, the original shape can be returned with the argument "return_shape=True" in the from_pytorch & from_tensorflow methods. - Matrices from DAPHNE can be returned as PyTorch & TensorFlow tensors, with the optional function arguments for the compute() function: "asTensorflow: bool", "asPytorch: bool", "shape" (original shape of the tensor). - Added additional frame operations in DaphneLib. - Intended for testing processing of data frames transferred from pandas. - Script-level test cases. - Examples and/or test cases for all the added functions. - Currently, the test cases related to DaphneLib are commented out as they require TensorFlow and PyTorch as dependencies. - Updated the DaphneLib documentation. - Closes #499. - These changes have been committed before in f359a77c109e2cfa2fbbcc4360e77288693a7aa1, but were reverted in 158772aff52184316c7e45377b49d53fcfb1a1d7, since the co-author note was forgotten in the commit message, when @pdamme "squash & merge"ed the pull request. - So they were re-commited in 4d4ec479a25a078db22826f632400bf170d86d9c, but there, the newly added files from f359a77c109e2cfa2fbbcc4360e77288693a7aa1 were forgotten, which are added again now. Co-authored-by: Niklas <93845913+Niklas-Ventker@users.noreply.github.com> --- doc/DaphneLib/APIRef.md | 22 +- doc/DaphneLib/Overview.md | 226 +++++++++++++++- run-python.sh | 4 + .../daphnelib/data-exchange-pytorch.py | 63 +++++ .../daphnelib/data-exchange-tensorflow.py | 63 +++++ scripts/examples/daphnelib/join.py | 51 ++++ src/api/daphnelib/DaphneLibResult.h | 5 + .../python/daphne/context/daphne_context.py | 253 +++++++++++++++++- src/api/python/daphne/operator/nodes/frame.py | 84 +++++- .../python/daphne/operator/nodes/matrix.py | 7 +- .../python/daphne/operator/operation_node.py | 141 +++++++++- .../python/daphne/script_building/script.py | 10 +- src/api/python/daphne/utils/daphnelib.py | 12 +- .../local/kernels/SaveDaphneLibResult.h | 37 +++ src/runtime/local/kernels/kernels.json | 3 +- test.sh | 3 + test/CMakeLists.txt | 2 +- test/api/python/DaphneLibTest.cpp | 11 + test/api/python/data_transfer_pandas_1.py | 3 +- test/api/python/data_transfer_pandas_2.daphne | 19 ++ test/api/python/data_transfer_pandas_2.py | 27 ++ .../data_transfer_pandas_3_series.daphne | 19 ++ .../python/data_transfer_pandas_3_series.py | 27 ++ ..._transfer_pandas_4_sparse_dataframe.daphne | 19 ++ ...data_transfer_pandas_4_sparse_dataframe.py | 31 +++ ...sfer_pandas_5_categorical_dataframe.daphne | 19 ++ ...transfer_pandas_5_categorical_dataframe.py | 29 ++ .../api/python/data_transfer_pytorch_1.daphne | 18 ++ test/api/python/data_transfer_pytorch_1.py | 26 ++ .../python/data_transfer_tensorflow_1.daphne | 18 ++ test/api/python/data_transfer_tensorflow_1.py | 26 ++ test/api/python/frame_innerJoin.daphne | 27 ++ test/api/python/frame_innerJoin.py | 28 ++ test/api/python/frame_setColLabels.daphne | 23 ++ test/api/python/frame_setColLabels.py | 26 ++ .../python/frame_setColLabelsPrefix.daphne | 23 ++ test/api/python/frame_setColLabelsPrefix.py | 26 ++ test/api/python/frame_to_matrix.daphne | 21 ++ test/api/python/frame_to_matrix.py | 25 ++ .../python/numpy_matrix_ops_replace.daphne | 19 ++ test/api/python/numpy_matrix_ops_replace.py | 29 ++ 41 files changed, 1478 insertions(+), 47 deletions(-) create mode 100644 scripts/examples/daphnelib/data-exchange-pytorch.py create mode 100644 scripts/examples/daphnelib/data-exchange-tensorflow.py create mode 100644 scripts/examples/daphnelib/join.py create mode 100644 test/api/python/data_transfer_pandas_2.daphne create mode 100644 test/api/python/data_transfer_pandas_2.py create mode 100644 test/api/python/data_transfer_pandas_3_series.daphne create mode 100644 test/api/python/data_transfer_pandas_3_series.py create mode 100644 test/api/python/data_transfer_pandas_4_sparse_dataframe.daphne create mode 100644 test/api/python/data_transfer_pandas_4_sparse_dataframe.py create mode 100644 test/api/python/data_transfer_pandas_5_categorical_dataframe.daphne create mode 100644 test/api/python/data_transfer_pandas_5_categorical_dataframe.py create mode 100644 test/api/python/data_transfer_pytorch_1.daphne create mode 100644 test/api/python/data_transfer_pytorch_1.py create mode 100644 test/api/python/data_transfer_tensorflow_1.daphne create mode 100644 test/api/python/data_transfer_tensorflow_1.py create mode 100644 test/api/python/frame_innerJoin.daphne create mode 100644 test/api/python/frame_innerJoin.py create mode 100644 test/api/python/frame_setColLabels.daphne create mode 100644 test/api/python/frame_setColLabels.py create mode 100644 test/api/python/frame_setColLabelsPrefix.daphne create mode 100644 test/api/python/frame_setColLabelsPrefix.py create mode 100644 test/api/python/frame_to_matrix.daphne create mode 100644 test/api/python/frame_to_matrix.py create mode 100644 test/api/python/numpy_matrix_ops_replace.daphne create mode 100644 test/api/python/numpy_matrix_ops_replace.py diff --git a/doc/DaphneLib/APIRef.md b/doc/DaphneLib/APIRef.md index a707a296e..e8152ecb6 100644 --- a/doc/DaphneLib/APIRef.md +++ b/doc/DaphneLib/APIRef.md @@ -31,8 +31,11 @@ However, as the methods largely map to DaphneDSL built-in functions, you can fin **Importing data from other Python libraries:** -- **`from_numpy`**`(mat: np.array, shared_memory=True) -> Matrix` -- **`from_pandas`**`(df: pd.DataFrame) -> Frame` +- **`from_numpy`**`(mat: np.array, shared_memory=True, verbose=False) -> Matrix` +- **`from_pandas`**`(df: pd.DataFrame, shared_memory=True, verbose=False, keepIndex=False) -> Frame` +- **`from_tensorflow`**`(tensor: tf.Tensor, shared_memory=True, verbose=False, return_shape=False) -> Matrix` +- **`from_pytorch`**`(tensor: torch.Tensor, shared_memory=True, verbose=False, return_shape=False) -> Matrix` + **Generating data in DAPHNE:** @@ -48,6 +51,10 @@ However, as the methods largely map to DaphneDSL built-in functions, you can fin - **`readMatrix`**`(file:str) -> Matrix` - **`readFrame`**`(file:str) -> Frame` +**Extended relational algebra:** + +- **`sql`**`(query) -> Frame` + ## Building Complex Computations Complex computations can be built using Python operators (see [DaphneLib](/doc/DaphneLib/Overview.md)) and using DAPHNE matrix/frame/scalar methods. @@ -159,6 +166,11 @@ In the following, we describe only the latter. - **`ncol`**`()` - **`ncell`**`()` +**Frame label manipulation:** + +- **`setColLabels`**`(labels)` +- **`setColLabelsPrefix`**`(prefix)` + **Reorganization:** - **`cbind`**`(other)` @@ -167,13 +179,19 @@ In the following, we describe only the latter. **Extended relational algebra:** +- **`registerView`**`(table_name: str)` - **`cartesian`**`(other)` +- **`innerJoin`**`(right_frame, left_on, right_on)` **Input/output:** - **`print`**`()` - **`write`**`(file: str)` +**Conversions, casts, and copying:** + +- **`toMatrix`**`(value_type="f64") -> Matrix` + ### `Scalar` API Reference **Elementwise unary:** diff --git a/doc/DaphneLib/Overview.md b/doc/DaphneLib/Overview.md index a1548d615..736e26664 100644 --- a/doc/DaphneLib/Overview.md +++ b/doc/DaphneLib/Overview.md @@ -196,15 +196,14 @@ X.cbind(Y) ## Data Exchange with other Python Libraries -DaphneLib will support efficient data exchange with other well-known Python libraries, in both directions. +DaphneLib supports efficient data exchange with other well-known Python libraries, in both directions. The data transfer from other Python libraries to DaphneLib can be triggered through the `from_...()` methods of the `DaphneContext` (e.g., `from_numpy()`). A comprehensive list of these methods can be found in the [DaphneLib API reference](/doc/DaphneLib/APIRef.md#daphnecontext). The data transfer from DaphneLib back to Python happens during the call to `compute()`. -If the result of the computation in DAPHNE is a matrix, `compute()` returns a `numpy.ndarray`; if the result is a frame, it returns a `pandas.DataFrame`; and if the result is a scalar, it returns a plain Python scalar. +If the result of the computation in DAPHNE is a matrix, `compute()` returns a `numpy.ndarray` (or optionally a `tensorflow.Tensor` or `torch.Tensor`); if the result is a frame, it returns a `pandas.DataFrame`; and if the result is a scalar, it returns a plain Python scalar. -So far, DaphneLib can exchange data with numpy (via shared memory) and pandas (via CSV files). -Enabling data exchange with TensorFlow and PyTorch is on our agenda. -Furthermore, we are working on making the data exchange more efficient in general. +So far, DaphneLib can exchange data with numpy, pandas, TensorFlow, and PyTorch. +By default, the data transfer is via shared memory (and in many cases zero-copy). ### Data Exchange with numpy @@ -303,6 +302,223 @@ Result of appending the frame to itself, back in Python: 4 3 3.3 ``` +### Data Exchange with TensorFlow + +*Example:* + +```python +from daphne.context.daphne_context import DaphneContext +import tensorflow as tf +import numpy as np + +dc = DaphneContext() + +print("========== 2D TENSOR EXAMPLE ==========\n") + +# Create data in TensorFlow/numpy. +t2d = tf.constant(np.random.random(size=(2, 4))) + +print("Original 2d tensor in TensorFlow:") +print(t2d) + +# Transfer data to DaphneLib (lazily evaluated). +T2D = dc.from_tensorflow(t2d) + +print("\nHow DAPHNE sees the 2d tensor from TensorFlow:") +T2D.print().compute() + +# Add 100 to each value in T2D. +T2D = T2D + 100.0 + +# Compute in DAPHNE, transfer result back to Python. +print("\nResult of adding 100, back in Python:") +print(T2D.compute(asTensorFlow=True)) + +print("\n========== 3D TENSOR EXAMPLE ==========\n") + +# Create data in TensorFlow/numpy. +t3d = tf.constant(np.random.random(size=(2, 2, 2))) + +print("Original 3d tensor in TensorFlow:") +print(t3d) + +# Transfer data to DaphneLib (lazily evaluated). +T3D, T3D_shape = dc.from_tensorflow(t3d, return_shape=True) + +print("\nHow DAPHNE sees the 3d tensor from TensorFlow:") +T3D.print().compute() + +# Add 100 to each value in T3D. +T3D = T3D + 100.0 + +# Compute in DAPHNE, transfer result back to Python. +print("\nResult of adding 100, back in Python:") +print(T3D.compute(asTensorFlow=True)) +print("\nResult of adding 100, back in Python (with original shape):") +print(T3D.compute(asTensorFlow=True, shape=T3D_shape)) +``` + +*Run by:* + +```shell +python3 scripts/examples/daphnelib/data-exchange-tensorflow.py +``` + +*Output (random numbers may vary):* + +```text +========== 2D TENSOR EXAMPLE ========== + +Original 2d tensor in TensorFlow: +tf.Tensor( +[[0.09682179 0.09636572 0.78658016 0.68227129] + [0.64356184 0.96337785 0.07931763 0.97951051]], shape=(2, 4), dtype=float64) + +How DAPHNE sees the 2d tensor from TensorFlow: +DenseMatrix(2x4, double) +0.0968218 0.0963657 0.78658 0.682271 +0.643562 0.963378 0.0793176 0.979511 + +Result of adding 100, back in Python: +tf.Tensor( +[[100.09682179 100.09636572 100.78658016 100.68227129] + [100.64356184 100.96337785 100.07931763 100.97951051]], shape=(2, 4), dtype=float64) + +========== 3D TENSOR EXAMPLE ========== + +Original 3d tensor in TensorFlow: +tf.Tensor( +[[[0.40088013 0.02324858] + [0.87607911 0.91645907]] + + [[0.10591184 0.92419294] + [0.5397723 0.24957817]]], shape=(2, 2, 2), dtype=float64) + +How DAPHNE sees the 3d tensor from TensorFlow: +DenseMatrix(2x4, double) +0.40088 0.0232486 0.876079 0.916459 +0.105912 0.924193 0.539772 0.249578 + +Result of adding 100, back in Python: +tf.Tensor( +[[100.40088013 100.02324858 100.87607911 100.91645907] + [100.10591184 100.92419294 100.5397723 100.24957817]], shape=(2, 4), dtype=float64) + +Result of adding 100, back in Python (with original shape): +tf.Tensor( +[[[100.40088013 100.02324858] + [100.87607911 100.91645907]] + + [[100.10591184 100.92419294] + [100.5397723 100.24957817]]], shape=(2, 2, 2), dtype=float64) +``` + +### Data Exchange with PyTorch + +*Example:* + +```python +from daphne.context.daphne_context import DaphneContext +import torch +import numpy as np + +dc = DaphneContext() + +print("========== 2D TENSOR EXAMPLE ==========\n") + +# Create data in PyTorch/numpy. +t2d = torch.tensor(np.random.random(size=(2, 4))) + +print("Original 2d tensor in PyTorch:") +print(t2d) + +# Transfer data to DaphneLib (lazily evaluated). +T2D = dc.from_pytorch(t2d) + +print("\nHow DAPHNE sees the 2d tensor from PyTorch:") +T2D.print().compute() + +# Add 100 to each value in T2D. +T2D = T2D + 100.0 + +# Compute in DAPHNE, transfer result back to Python. +print("\nResult of adding 100, back in Python:") +print(T2D.compute(asPyTorch=True)) + +print("\n========== 3D TENSOR EXAMPLE ==========\n") + +# Create data in PyTorch/numpy. +t3d = torch.tensor(np.random.random(size=(2, 2, 2))) + +print("Original 3d tensor in PyTorch:") +print(t3d) + +# Transfer data to DaphneLib (lazily evaluated). +T3D, T3D_shape = dc.from_pytorch(t3d, return_shape=True) + +print("\nHow DAPHNE sees the 3d tensor from PyTorch:") +T3D.print().compute() + +# Add 100 to each value in T3D. +T3D = T3D + 100.0 + +# Compute in DAPHNE, transfer result back to Python. +print("\nResult of adding 100, back in Python:") +print(T3D.compute(asPyTorch=True)) +print("\nResult of adding 100, back in Python (with original shape):") +print(T3D.compute(asPyTorch=True, shape=T3D_shape)) +``` + +*Run by:* + +```shell +python3 scripts/examples/daphnelib/data-exchange-pytorch.py +``` + +*Output (random numbers may vary):* + +```text +========== 2D TENSOR EXAMPLE ========== + +Original 2d tensor in PyTorch: +tensor([[0.1205, 0.8747, 0.1717, 0.0216], + [0.7999, 0.6932, 0.4386, 0.0873]], dtype=torch.float64) + +How DAPHNE sees the 2d tensor from PyTorch: +DenseMatrix(2x4, double) +0.120505 0.874691 0.171693 0.0215546 +0.799858 0.693205 0.438637 0.0872659 + +Result of adding 100, back in Python: +tensor([[100.1205, 100.8747, 100.1717, 100.0216], + [100.7999, 100.6932, 100.4386, 100.0873]], dtype=torch.float64) + +========== 3D TENSOR EXAMPLE ========== + +Original 3d tensor in PyTorch: +tensor([[[0.5474, 0.9653], + [0.7891, 0.0573]], + + [[0.4116, 0.6326], + [0.3148, 0.3607]]], dtype=torch.float64) + +How DAPHNE sees the 3d tensor from PyTorch: +DenseMatrix(2x4, double) +0.547449 0.965315 0.78909 0.0572619 +0.411593 0.632629 0.314841 0.360657 + +Result of adding 100, back in Python: +tensor([[100.5474, 100.9653, 100.7891, 100.0573], + [100.4116, 100.6326, 100.3148, 100.3607]], dtype=torch.float64) + +Result of adding 100, back in Python (with original shape): +tensor([[[100.5474, 100.9653], + [100.7891, 100.0573]], + + [[100.4116, 100.6326], + [100.3148, 100.3607]]], dtype=torch.float64) +``` + ## Known Limitations DaphneLib is still in an early development stage. diff --git a/run-python.sh b/run-python.sh index 40e950651..319a37055 100755 --- a/run-python.sh +++ b/run-python.sh @@ -18,4 +18,8 @@ DAPHNE_ROOT=$PWD export LD_LIBRARY_PATH=$DAPHNE_ROOT/lib:$DAPHNE_ROOT/thirdparty/installed/lib:$LD_LIBRARY_PATH export PYTHONPATH="$PYTHONPATH:$PWD/src/api/python/" export DAPHNELIB_DIR_PATH=$DAPHNE_ROOT/lib + +# Silence TensorFlow warnings in DaphneLib. +export TF_CPP_MIN_LOG_LEVEL=3 + python3 $@ diff --git a/scripts/examples/daphnelib/data-exchange-pytorch.py b/scripts/examples/daphnelib/data-exchange-pytorch.py new file mode 100644 index 000000000..ed173ca47 --- /dev/null +++ b/scripts/examples/daphnelib/data-exchange-pytorch.py @@ -0,0 +1,63 @@ +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from daphne.context.daphne_context import DaphneContext +import torch +import numpy as np + +dc = DaphneContext() + +print("========== 2D TENSOR EXAMPLE ==========\n") + +# Create data in PyTorch/numpy. +t2d = torch.tensor(np.random.random(size=(2, 4))) + +print("Original 2d tensor in PyTorch:") +print(t2d) + +# Transfer data to DaphneLib (lazily evaluated). +T2D = dc.from_pytorch(t2d) + +print("\nHow DAPHNE sees the 2d tensor from PyTorch:") +T2D.print().compute() + +# Add 100 to each value in T2D. +T2D = T2D + 100.0 + +# Compute in DAPHNE, transfer result back to Python. +print("\nResult of adding 100, back in Python:") +print(T2D.compute(asPyTorch=True)) + +print("\n========== 3D TENSOR EXAMPLE ==========\n") + +# Create data in PyTorch/numpy. +t3d = torch.tensor(np.random.random(size=(2, 2, 2))) + +print("Original 3d tensor in PyTorch:") +print(t3d) + +# Transfer data to DaphneLib (lazily evaluated). +T3D, T3D_shape = dc.from_pytorch(t3d, return_shape=True) + +print("\nHow DAPHNE sees the 3d tensor from PyTorch:") +T3D.print().compute() + +# Add 100 to each value in T3D. +T3D = T3D + 100.0 + +# Compute in DAPHNE, transfer result back to Python. +print("\nResult of adding 100, back in Python:") +print(T3D.compute(asPyTorch=True)) +print("\nResult of adding 100, back in Python (with original shape):") +print(T3D.compute(asPyTorch=True, shape=T3D_shape)) \ No newline at end of file diff --git a/scripts/examples/daphnelib/data-exchange-tensorflow.py b/scripts/examples/daphnelib/data-exchange-tensorflow.py new file mode 100644 index 000000000..c61ab5871 --- /dev/null +++ b/scripts/examples/daphnelib/data-exchange-tensorflow.py @@ -0,0 +1,63 @@ +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from daphne.context.daphne_context import DaphneContext +import tensorflow as tf +import numpy as np + +dc = DaphneContext() + +print("========== 2D TENSOR EXAMPLE ==========\n") + +# Create data in TensorFlow/numpy. +t2d = tf.constant(np.random.random(size=(2, 4))) + +print("Original 2d tensor in TensorFlow:") +print(t2d) + +# Transfer data to DaphneLib (lazily evaluated). +T2D = dc.from_tensorflow(t2d) + +print("\nHow DAPHNE sees the 2d tensor from TensorFlow:") +T2D.print().compute() + +# Add 100 to each value in T2D. +T2D = T2D + 100.0 + +# Compute in DAPHNE, transfer result back to Python. +print("\nResult of adding 100, back in Python:") +print(T2D.compute(asTensorFlow=True)) + +print("\n========== 3D TENSOR EXAMPLE ==========\n") + +# Create data in TensorFlow/numpy. +t3d = tf.constant(np.random.random(size=(2, 2, 2))) + +print("Original 3d tensor in TensorFlow:") +print(t3d) + +# Transfer data to DaphneLib (lazily evaluated). +T3D, T3D_shape = dc.from_tensorflow(t3d, return_shape=True) + +print("\nHow DAPHNE sees the 3d tensor from TensorFlow:") +T3D.print().compute() + +# Add 100 to each value in T3D. +T3D = T3D + 100.0 + +# Compute in DAPHNE, transfer result back to Python. +print("\nResult of adding 100, back in Python:") +print(T3D.compute(asTensorFlow=True)) +print("\nResult of adding 100, back in Python (with original shape):") +print(T3D.compute(asTensorFlow=True, shape=T3D_shape)) \ No newline at end of file diff --git a/scripts/examples/daphnelib/join.py b/scripts/examples/daphnelib/join.py new file mode 100644 index 000000000..4d29db074 --- /dev/null +++ b/scripts/examples/daphnelib/join.py @@ -0,0 +1,51 @@ +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from daphne.context.daphne_context import DaphneContext +import pandas as pd + +# Initialize the Daphne Context. +dc = DaphneContext() + +# Customers DataFrame. +# Numerical representation of company names and contacts. +customers_df = pd.DataFrame({ + "CustomerID": [101, 102, 103], + "CompanyName": [1, 2, 3], + "ContactName": [1, 2, 3] +}) + +# Orders DataFrame. +# Numerical representation of order dates. +orders_df = pd.DataFrame({ + "OrderID": [10643, 10692, 10702, 10704, 10705, 10710, 10713, 10715], + "CustomerID": [101, 101, 102, 101, 102, 103, 103, 101], + "OrderDate": [20230715, 20230722, 20230725, 20230726, 20230728, 20230730, 20230730, 20230731] +}) + +# Print inputs. +print("Input data frames:\n") +print("Customers:") +print(customers_df) +print("\nOrders:") +print(orders_df) + +# Create DAPHNE Frames. +customers_frame = dc.from_pandas(customers_df) +orders_frame = dc.from_pandas(orders_df) + +# Calculate and print the result. +print("\nResult of the join:") +join_result = customers_frame.innerJoin(orders_frame, "CustomerID", "CustomerID") +print(join_result.compute()) \ No newline at end of file diff --git a/src/api/daphnelib/DaphneLibResult.h b/src/api/daphnelib/DaphneLibResult.h index 4f20a665e..d7cfc3e06 100644 --- a/src/api/daphnelib/DaphneLibResult.h +++ b/src/api/daphnelib/DaphneLibResult.h @@ -19,8 +19,13 @@ #include struct DaphneLibResult { + // For matrices. void* address; int64_t rows; int64_t cols; int64_t vtc; + // For frames. + int64_t* vtcs; + char** labels; + void** columns; }; \ No newline at end of file diff --git a/src/api/python/daphne/context/daphne_context.py b/src/api/python/daphne/context/daphne_context.py index e6d28b8a1..dd11829ce 100644 --- a/src/api/python/daphne/context/daphne_context.py +++ b/src/api/python/daphne/context/daphne_context.py @@ -31,11 +31,15 @@ from daphne.operator.nodes.while_loop import WhileLoop from daphne.operator.nodes.do_while_loop import DoWhileLoop from daphne.operator.nodes.multi_return import MultiReturn +from daphne.operator.operation_node import OperationNode from daphne.utils.consts import VALID_INPUT_TYPES, VALID_COMPUTED_TYPES, TMP_PATH, F64, F32, SI64, SI32, SI8, UI64, UI32, UI8 import numpy as np import pandas as pd +import torch as torch +import tensorflow as tf +import time from typing import Sequence, Dict, Union, List, Callable, Tuple, Optional, Iterable class DaphneContext(object): @@ -60,18 +64,39 @@ def readFrame(self, file: str) -> Frame: unnamed_params = ['\"'+file+'\"'] return Frame(self, 'readFrame', unnamed_params) - def from_numpy(self, mat: np.array, shared_memory=True) -> Matrix: + def from_numpy(self, mat: np.array, shared_memory=True, verbose=False) -> Matrix: """Generates a `DAGNode` representing a matrix with data given by a numpy `array`. :param mat: The numpy array. :param shared_memory: Whether to use shared memory data transfer (True) or not (False). + :param verbose: Whether to print timing information (True) or not (False). :return: The data from numpy as a Matrix. """ + + if verbose: + start_time = time.time() + # Handle the dimensionality of the matrix. + if mat.ndim == 1: + rows = mat.shape[0] + cols = 1 + elif mat.ndim == 2: + rows, cols = mat.shape + else: + raise ValueError("input numpy array should be 1d or 2d") + if shared_memory: # Data transfer via shared memory. address = mat.ctypes.data_as(np.ctypeslib.ndpointer(dtype=mat.dtype, ndim=1, flags='C_CONTIGUOUS')).value upper = (address & 0xFFFFFFFF00000000) >> 32 lower = (address & 0xFFFFFFFF) + + # Change the data type, if int16 or uint16 is handed over. + # TODO This could change the input DataFrame. + if mat.dtype == np.int16: + mat = mat.astype(np.int32, copy=False) + elif mat.dtype == np.uint16: + mat = mat.astype(np.uint32, copy=False) + d_type = mat.dtype if d_type == np.double or d_type == np.float64: vtc = F64 @@ -90,30 +115,222 @@ def from_numpy(self, mat: np.array, shared_memory=True) -> Matrix: elif d_type == np.uint64: vtc = UI64 else: + # TODO Raise an error here? print("unsupported numpy dtype") - - return Matrix(self, 'receiveFromNumpy', [upper, lower, mat.shape[0], mat.shape[1], vtc], local_data=mat) + + res = Matrix(self, 'receiveFromNumpy', [upper, lower, rows, cols, vtc], local_data=mat) else: # Data transfer via a file. data_path_param = "\"" + TMP_PATH + "/{file_name}.csv\"" unnamed_params = [data_path_param] named_params = [] - return Matrix(self, 'readMatrix', unnamed_params, named_params, local_data=mat) - - def from_pandas(self, df: pd.DataFrame) -> Frame: + + res = Matrix(self, 'readMatrix', unnamed_params, named_params, local_data=mat) + + if verbose: + print(f"from_numpy(): total Python-side execution time: {(time.time() - start_time):.10f} seconds") + + return res + + def from_pandas(self, df: pd.DataFrame, shared_memory=True, verbose=False, keepIndex=False) -> Frame: """Generates a `DAGNode` representing a frame with data given by a pandas `DataFrame`. :param df: The pandas DataFrame. - :param args: unnamed parameters - :param kwargs: named parameters + :param shared_memory: Whether to use shared memory data transfer (True) or not (False). + :param verbose: Whether the execution time and further information should be output to the console. + :param keepIndex: Whether the frame should keep its index from pandas within DAPHNE :return: A Frame """ - # Data transfer via files. - data_path_param = "\"" + TMP_PATH + "/{file_name}.csv\"" - unnamed_params = [data_path_param] - named_params = [] - return Frame(self, 'readFrame', unnamed_params, named_params, local_data=df) + if verbose: + start_time = time.time() + + if keepIndex: + # Reset the index, moving it to a new column. + # TODO We should not modify the input data frame here. + df.reset_index(drop=False, inplace=True) + + # Check for various special kinds of pandas data objects + # and handle them accordingly. + if isinstance(df, pd.Series): + # Convert Series to standard DataFrame. + df = df.to_frame() + elif isinstance(df, pd.MultiIndex): + # MultiIndex cannot be converted to standard DataFrame. + raise TypeError("handling of pandas MultiIndex DataFrame is not implemented yet") + elif isinstance(df.dtypes, pd.SparseDtype) or any(isinstance(item, pd.SparseDtype) for item in df.dtypes): + # Convert sparse DataFrame to standard DataFrame. + df = df.sparse.to_dense() + elif df.select_dtypes(include=["category"]).shape[1] > 0: + # Convert categorical DataFrame to standard DataFrame. + df = df.apply(lambda x: x.cat.codes if x.dtype.name == "category" else x) + + if verbose: + print(f"from_pandas(): Python-side type-check execution time: {(time.time() - start_time):.10f} seconds") + + if shared_memory: # data transfer via shared memory + # Convert DataFrame and labels to column arrays and label arrays. + args = [] + + if verbose: + frame_start_time = time.time() + + for idx, column in enumerate(df): + if verbose: + col_start_time = time.time() + + mat = df[column].values + + # Change the data type, if int16 or uint16 is handed over. + # TODO This could change the input DataFrame. + if mat.dtype == np.int16: + mat = mat.astype(np.int32, copy=False) + elif mat.dtype == np.uint16: + mat = mat.astype(np.uint32, copy=False) + + if verbose: + # Check if this step was zero copy. + print(f"from_pandas(): original DataFrame column `{column}` (#{idx}) shares memory with new numpy array: {np.shares_memory(mat, df[column].values)}") + + address = mat.ctypes.data_as(np.ctypeslib.ndpointer(dtype=mat.dtype, ndim=1, flags='C_CONTIGUOUS')).value + upper = (address & 0xFFFFFFFF00000000) >> 32 + lower = (address & 0xFFFFFFFF) + d_type = mat.dtype + if d_type == np.double or d_type == np.float64: + vtc = F64 + elif d_type == np.float32: + vtc = F32 + elif d_type == np.int8: + vtc = SI8 + elif d_type == np.int32: + vtc = SI32 + elif d_type == np.int64: + vtc = SI64 + elif d_type == np.uint8: + vtc = UI8 + elif d_type == np.uint32: + vtc = UI32 + elif d_type == np.uint64: + vtc = UI64 + else: + raise TypeError(f'Unsupported numpy dtype in column "{column}" ({idx})') + + args.append(Matrix(self, 'receiveFromNumpy', [upper, lower, len(mat), 1 , vtc], local_data=mat)) + + if verbose: + print(f"from_pandas(): Python-side execution time for column `{column}` (#{idx}): {(time.time() - col_start_time):.10f} seconds") + + if verbose: + print(f"from_pandas(): Python-side execution time for all columns: {(time.time() - frame_start_time):.10f} seconds") + + labels = df.columns + for label in labels: + labelstr = f'"{label}"' + args.append(labelstr) + + if verbose: + print(f"from_pandas(): total Python-side execution time: {(time.time() - start_time):.10f} seconds") + + # Return the Frame. + return Frame(self, 'createFrame', unnamed_input_nodes=args, local_data=df) + + else: # data transfer via files + data_path_param = "\"" + TMP_PATH + "/{file_name}.csv\"" + unnamed_params = [data_path_param] + named_params = [] + + if verbose: + print(f"from_pandas(): total Python-side execution time: {(time.time() - start_time)::.10f} seconds") + + # Return the Frame. + return Frame(self, 'readFrame', unnamed_params, named_params, local_data=df, column_names=df.columns) + def from_tensorflow(self, tensor: tf.Tensor, shared_memory=True, verbose=False, return_shape=False): + """Generates a `DAGNode` representing a matrix with data given by a TensorFlow `Tensor`. + :param tensor: The TensorFlow Tensor. + :param shared_memory: Whether to use shared memory data transfer (True) or not (False). + :param verbose: Whether the execution time and further information should be output to the console. + :param return_shape: Whether the original shape of the input tensor shall be returned. + :return: A Matrix or a tuple of a Matrix and the original tensor shape (if `return_shape == True`). + """ + + # Store the original shape for later use. + original_shape = tensor.shape + + if verbose: + start_time = time.time() + + # Check if the tensor is 2d or higher dimensional. + if len(original_shape) == 2: + # If 2d, handle as a matrix, convert to numpy array. + # This function is only zero copy, if the tensor is shared within the CPU. + mat = tensor.numpy() + # Using the existing from_numpy method for 2d arrays. + matrix = self.from_numpy(mat, shared_memory, verbose) + else: + # If higher dimensional, reshape to 2d and handle as a matrix. + # Store the original numpy representation. + original_tensor = tensor.numpy() + # Reshape to 2d using numpy's zero copy reshape. + reshaped_tensor = original_tensor.reshape((original_shape[0], -1)) + + if verbose: + # Check if the original and reshaped tensors share memory. + shares_memory = np.shares_memory(tensor, reshaped_tensor) + print(f"from_tensorflow(): original and reshaped tensors share memory: {shares_memory}") + + # Use the existing from_numpy method for the reshaped 2D array + matrix = self.from_numpy(mat=reshaped_tensor, shared_memory=shared_memory, verbose=verbose) + + if verbose: + print(f"from_tensorflow(): total Python-side execution time: {(time.time() - start_time):.10f} seconds") + + # Return the matrix, and the original shape if return_shape is set to True. + return (matrix, original_shape) if return_shape else matrix + + def from_pytorch(self, tensor: torch.Tensor, shared_memory=True, verbose=False, return_shape=False): + """Generates a `DAGNode` representing a matrix with data given by a PyTorch `Tensor`. + :param tensor: The PyTorch Tensor. + :param shared_memory: Whether to use shared memory data transfer (True) or not (False). + :param verbose: Whether the execution time and further information should be output to the console. + :param return_shape: Whether the original shape of the input tensor shall be returned. + :return: A Matrix or a tuple of a Matrix and the original tensor shape (if `return_shape == True`). + """ + + # Store the original shape for later use. + original_shape = tensor.size() + + if verbose: + start_time = time.time() + + # Check if the tensor is 2d or higher dimensional. + if tensor.dim() == 2: + # If 2d, handle as a matrix, convert to numpy array. + # If the Tensor is stored on the CPU, mat = tensor.numpy(force=True) can speed up the performance. + mat = tensor.numpy() + # Using the existing from_numpy method for 2d arrays. + matrix = self.from_numpy(mat, shared_memory, verbose) + else: + # If higher dimensional, reshape to 2d and handle as a matrix. + # Store the original numpy representation. + original_tensor = tensor.numpy(force=True) + # Reshape to 2d + # TODO Does this change the input tensor? + reshaped_tensor = original_tensor.reshape((original_shape[0], -1)) + + if verbose: + # Check if the original and reshaped tensors share memory and print the result. + shares_memory = np.shares_memory(original_tensor, reshaped_tensor) + print(f"from_pytorch(): original and reshaped tensors share memory: {shares_memory}") + + # Use the existing from_numpy method for the reshaped 2d array. + matrix = self.from_numpy(mat=reshaped_tensor, shared_memory=shared_memory, verbose=verbose) + + if verbose: + print(f"from_pytorch(): total execution time: {(time.time() - start_time):.10f} seconds") + + # Return the matrix, and the original shape if return_shape is set to True. + return (matrix, original_shape) if return_shape else matrix + def fill(self, arg, rows:int, cols:int) -> Matrix: named_input_nodes = {'arg':arg, 'rows':rows, 'cols':cols} return Matrix(self, 'fill', [], named_input_nodes=named_input_nodes) @@ -268,3 +485,13 @@ def dctx_function(*args): return tuple(MultiReturn(self, function_name, output_nodes, args)) return dctx_function + + def sql(self, query) -> Frame: + """ + Forwards and executes a sql query in Daphne + :param query: The full SQL Query to be executed + :return: A Frame based on the SQL Result + """ + query_str = f'"{query}"' + + return Frame(self, 'sql', [query_str]) diff --git a/src/api/python/daphne/operator/nodes/frame.py b/src/api/python/daphne/operator/nodes/frame.py index c7b6cc5eb..6d4cf76e4 100644 --- a/src/api/python/daphne/operator/nodes/frame.py +++ b/src/api/python/daphne/operator/nodes/frame.py @@ -25,6 +25,7 @@ from daphne.operator.operation_node import OperationNode from daphne.operator.nodes.scalar import Scalar +from daphne.operator.nodes.matrix import Matrix from daphne.script_building.dag import OutputType from daphne.utils.consts import VALID_INPUT_TYPES, VALID_ARITHMETIC_TYPES, BINARY_OPERATIONS, TMP_PATH @@ -40,28 +41,30 @@ class Frame(OperationNode): _pd_dataframe: pd.DataFrame - __copy: bool + _column_names: Optional[List[str]] = None def __init__(self, daphne_context: "DaphneContext", operation: str, unnamed_input_nodes: Union[str, Iterable[VALID_INPUT_TYPES]] = None, named_input_nodes: Dict[str, VALID_INPUT_TYPES] = None, - local_data: pd.DataFrame = None, brackets: bool = False, copy: bool = False) -> "Frame": - self.__copy = copy + local_data: pd.DataFrame = None, brackets: bool = False, + column_names: Optional[List[str]] = None) -> "Frame": is_python_local_data = False if local_data is not None: self._pd_dataframe = local_data is_python_local_data = True else: self._pd_dataframe = None - + + self._column_names = column_names + super().__init__(daphne_context, operation, unnamed_input_nodes, named_input_nodes, OutputType.FRAME, is_python_local_data, brackets) def code_line(self, var_name: str, unnamed_input_vars: Sequence[str], named_input_vars: Dict[str, str]) -> str: - if self.__copy: - return f'{var_name}={unnamed_input_vars[0]};' code_line = super().code_line(var_name, unnamed_input_vars, named_input_vars).format(file_name=var_name, TMP_PATH = TMP_PATH) - if self._is_pandas(): + + # Save temporary CSV file, if the operation is "readFrame". + if self._is_pandas() and self.operation == "readFrame": self._pd_dataframe.to_csv(TMP_PATH+"/"+var_name+".csv", header=False, index=False) with open(TMP_PATH+"/"+var_name+".csv.meta", "w") as f: json.dump( @@ -80,11 +83,8 @@ def code_line(self, var_name: str, unnamed_input_vars: Sequence[str], named_inpu ) return code_line - def compute(self) -> Union[pd.DataFrame]: - if self._is_pandas(): - return self._pd_dataframe - else: - return super().compute() + def compute(self, type="shared memory", verbose=False, useIndexColumn=False) -> Union[pd.DataFrame]: + return super().compute(type=type, verbose=verbose, useIndexColumn=useIndexColumn) def _is_pandas(self) -> bool: return self._pd_dataframe is not None @@ -121,6 +121,64 @@ def cartesian(self, other) -> 'Frame': cartesian product """ return Frame(self.daphne_context, "cartesian", [self, other]) + + def innerJoin(self, right_frame, left_on, right_on) -> 'Frame': + """ + Creates an Inner Join between this object (left) and another frame (right) + :param right_frame: Frame to join this object with + :param left_on: Left key + :param right_on: Right key + :return: A Frame containing the inner join of both Frames. + """ + args = [self, right_frame, f'"{left_on}"', f'"{right_on}"'] + return Frame(self.daphne_context, "innerJoin", args) + + def setColLabels(self, labels) -> 'Frame': + """ + Changes the column labels to the given labels. + There must be as many labels as columns. + :param labels: List of new labels + :return: A Frame with the new labels + """ + args = [] + args.append(self) + numCols = self.ncol() + + if len(labels) == numCols: + for label in labels: + label_str = f'"{label}"' + args.append(label_str) + + return Frame(self.daphne_context, "setColLabels", args) + else: + raise ValueError(f"the number of given labels is not equal to the number of columns, expected {numCols}, but received {len(labels)}") + + def setColLabelsPrefix(self, prefix) -> 'Frame': + """ + Adds a prefix to the labels of all columns. + :param prefix: Prefix to be added to every label + :return: A Frame with updated labels + """ + prefix_str=f'"{prefix}"' + return Frame(self.daphne_context, "setColLabelsPrefix", [self, prefix_str]) + + def registerView(self, table_name:str): + """ + Registers this frame for SQL queries under the specified table name. + This is needed before the SQL queries can be executed. + :param table_name: Name for the registered Table + :param frame: Frame to create a table + """ + table_name_str = f'"{table_name}"' + return OperationNode(self.daphne_context, 'registerView', [table_name_str, self], output_type=OutputType.NONE) + + def toMatrix(self, value_type="f64") -> 'Matrix': + """ + Transforms the Frame to a Matrix of the given value type + :param value_type: The value type for the Matrix + :return: A Matrix of the specified value type + """ + return Matrix(self.daphne_context, f"as.matrix<{value_type}>", [self]) def nrow(self) -> 'Scalar': """ @@ -143,4 +201,4 @@ def order(self, colIdxs: List[int], ascs: List[bool], returnIndexes: bool) -> 'F return Frame(self.daphne_context, 'order', [self, *colIdxs, *ascs, returnIndexes]) def write(self, file: str) -> 'OperationNode': - return OperationNode(self.daphne_context, 'writeFrame', [self,'\"'+file+'\"'], output_type=OutputType.NONE) \ No newline at end of file + return OperationNode(self.daphne_context, 'writeFrame', [self,'\"'+file+'\"'], output_type=OutputType.NONE) diff --git a/src/api/python/daphne/operator/nodes/matrix.py b/src/api/python/daphne/operator/nodes/matrix.py index 4ea292935..5e6ed536c 100644 --- a/src/api/python/daphne/operator/nodes/matrix.py +++ b/src/api/python/daphne/operator/nodes/matrix.py @@ -99,11 +99,8 @@ def getDType(self, d_type): def _is_numpy(self) -> bool: return self._np_array is not None - def compute(self, type="shared memory") -> Union[np.array]: - if self._is_numpy(): - return self._np_array - else: - return super().compute(type) + def compute(self, type="shared memory", verbose=False, asTensorFlow=False, asPyTorch=False, shape=None) -> Union[np.array]: + return super().compute(type=type, verbose=verbose, asTensorFlow=asTensorFlow, asPyTorch=asPyTorch, shape=shape) def __add__(self, other: VALID_ARITHMETIC_TYPES) -> 'Matrix': return Matrix(self.daphne_context, '+', [self, other]) diff --git a/src/api/python/daphne/operator/operation_node.py b/src/api/python/daphne/operator/operation_node.py index 9cef195d4..b068a760d 100644 --- a/src/api/python/daphne/operator/operation_node.py +++ b/src/api/python/daphne/operator/operation_node.py @@ -29,10 +29,13 @@ import numpy as np import pandas as pd +import torch as torch +import tensorflow as tf import ctypes import json import os +import time from typing import Dict, Iterable, Optional, Sequence, Union, TYPE_CHECKING if TYPE_CHECKING: @@ -69,15 +72,95 @@ def __init__(self, daphne_context,operation:str, self._brackets = brackets self._output_type = output_type - def compute(self, type="shared memory"): + def compute(self, type="shared memory", verbose=False, asTensorFlow=False, asPyTorch=False, shape=None, useIndexColumn=False): + """ + Compute function for processing the Daphne Object or operation node and returning the results. + The function builds a DaphneDSL script from the node and its context, executes it, and processes the results + to produce a pandas DataFrame, numpy array, or TensorFlow/PyTorch tensors. + + :param type: Execution type, either "shared memory" for in-memory data transfer or "files" for file-based data transfer. + :param verbose: If True, outputs verbose logs, including timing information for each step. + :param asTensorFlow: If True and the result is a matrix, the output will be converted to a TensorFlow tensor. + :param asPyTorch: If True and the result is a matrix, the output will be converted to a PyTorch tensor. + :param shape: If provided and the result is a matrix, it defines the shape to reshape the resulting tensor (either TensorFlow or PyTorch). + :param useIndexColumn: If True and the result is a DataFrame, uses the column named "index" as the DataFrame's index. + + :return: Depending on the parameters and the operation's output type, this function can return: + - A pandas DataFrame for frame outputs. + - A numpy array for matrix outputs. + - A scalar value for scalar outputs. + - TensorFlow or PyTorch tensors if `asTensorFlow` or `asPyTorch` is set to True respectively. + """ + if self._result_var is None: + if verbose: + start_time = time.time() + self._script = DaphneDSLScript(self.daphne_context) for definition in self.daphne_context._functions.values(): self._script.daphnedsl_script += definition result = self._script.build_code(self, type) + + if verbose: + exec_start_time = time.time() + self._script.execute() self._script.clear(self) - if self._output_type == OutputType.FRAME: + + if verbose: + print(f"compute(): Python-side execution time of the execute() function: {(time.time() - exec_start_time):.10f} seconds") + + if self._output_type == OutputType.FRAME and type=="shared memory": + if verbose: + dt_start_time = time.time() + + daphneLibResult = DaphneLib.getResult() + + # Read the frame's address into a numpy array. + if daphneLibResult.columns is not None: + # Read the column labels and dtypes from the Frame's labels and dtypes directly. + labels = [ctypes.cast(daphneLibResult.labels[i], ctypes.c_char_p).value.decode() for i in range(daphneLibResult.cols)] + + # Create a new type representing an array of value type codes. + VTArray = ctypes.c_int64 * daphneLibResult.cols + # Cast the pointer to this type and access its contents. + vtcs_array = ctypes.cast(daphneLibResult.vtcs, ctypes.POINTER(VTArray)).contents + # Convert the value types into numpy dtypes. + dtypes = [self.getNumpyType(vtc) for vtc in vtcs_array] + + data = {label: None for label in labels} + + # Using ctypes cast and numpy array view to create dictionary directly. + for idx in range(daphneLibResult.cols): + c_data_type = self.getType(daphneLibResult.vtcs[idx]) + array_view = np.ctypeslib.as_array( + ctypes.cast(daphneLibResult.columns[idx], ctypes.POINTER(c_data_type)), + shape=[daphneLibResult.rows] + ) + label = labels[idx] + data[label] = array_view + + # Create DataFrame from dictionary. + df = pd.DataFrame(data, copy=False) + + # If useIndexColumn is True, set "index" column as the DataFrame's index + # TODO What if there is no column named "index"? + if useIndexColumn and "index" in df.columns: + df.set_index("index", inplace=True, drop=True) + else: + # TODO Raise an error. + # TODO When does this happen? + print("Error: NULL pointer access") + labels = [] + dtypes = [] + df = pd.DataFrame() + + result = df + self.clear_tmp() + + if verbose: + print(f"compute(): time for Python side data transfer (Frame, shared memory): {(time.time() - dt_start_time):.10f} seconds") + elif self._output_type == OutputType.FRAME and type=="files": df = pd.read_csv(result) with open(result + ".meta", "r") as f: fmd = json.load(f) @@ -103,7 +186,38 @@ def compute(self, type="shared memory"): shape=[daphneLibResult.rows, daphneLibResult.cols] )[0, 0] self.clear_tmp() - + + # TODO asTensorFlow and asPyTorch should be mutually exclusive. + if asTensorFlow and self._output_type == OutputType.MATRIX: + if verbose: + tc_start_time = time.time() + + # Convert the Matrix to a TensorFlow Tensor. + result = tf.convert_to_tensor(result) + + # If a shape is provided, reshape the TensorFlow Tensor. + if shape is not None: + result = tf.reshape(result, shape) + + if verbose: + print(f"compute(): time to convert to TensorFlow Tensor: {(time.time() - tc_start_time):.10f} seconds") + elif asPyTorch and self._output_type == OutputType.MATRIX: + if verbose: + tc_start_time = time.time() + + # Convert the Matrix to a PyTorch Tensor. + result = torch.from_numpy(result) + + # If a shape is provided, reshape the PyTorch Tensor. + if shape is not None: + result = torch.reshape(result, shape) + + if verbose: + print(f"compute(): time to convert to PyTorch Tensor: {(time.time() - tc_start_time):.10f} seconds") + + if verbose: + print(f"compute(): total Python-side execution time: {(time.time() - start_time):.10f} seconds") + if result is None: return return result @@ -143,5 +257,26 @@ def getType(self, vtc): return ctypes.c_uint32 elif vtc == UI8: return ctypes.c_uint8 + else: + raise RuntimeError(f"unknown value type code: {vtc}") + + def getNumpyType(self, vtc): + """Convert DAPHNE value type to numpy dtype.""" + if vtc == F64: + return np.float64 + elif vtc == F32: + return np.float32 + elif vtc == SI64: + return np.int64 + elif vtc == SI32: + return np.int32 + elif vtc == SI8: + return np.int8 + elif vtc == UI64: + return np.uint64 + elif vtc == UI32: + return np.uint32 + elif vtc == UI8: + return np.uint8 else: raise RuntimeError(f"unknown value type code: {vtc}") \ No newline at end of file diff --git a/src/api/python/daphne/script_building/script.py b/src/api/python/daphne/script_building/script.py index 2d0e1d5b4..ed18fc311 100644 --- a/src/api/python/daphne/script_building/script.py +++ b/src/api/python/daphne/script_building/script.py @@ -60,8 +60,14 @@ def build_code(self, dag_root: DAGNode, type="shared memory"): else: raise RuntimeError(f"unknown way to transfer the data: '{type}'") elif dag_root.output_type == OutputType.FRAME: - self.add_code(f'writeFrame({baseOutVarString},"{TMP_PATH}/{baseOutVarString}.csv");') - return TMP_PATH + "/" + baseOutVarString + ".csv" + if type == "files": + self.add_code(f'writeFrame({baseOutVarString},"{TMP_PATH}/{baseOutVarString}.csv");') + return TMP_PATH + "/" + baseOutVarString + ".csv" + elif type == "shared memory": + self.add_code(f'saveDaphneLibResult({baseOutVarString});') + return None + else: + raise RuntimeError(f"unknown way to transfer the data: '{type}'") elif dag_root.output_type == OutputType.SCALAR: # We transfer scalars back to Python by wrapping them into a 1x1 matrix. self.add_code(f'saveDaphneLibResult(as.matrix({baseOutVarString}));') diff --git a/src/api/python/daphne/utils/daphnelib.py b/src/api/python/daphne/utils/daphnelib.py index 8abe188aa..bc740fabc 100644 --- a/src/api/python/daphne/utils/daphnelib.py +++ b/src/api/python/daphne/utils/daphnelib.py @@ -19,7 +19,17 @@ # Python representation of the struct DaphneLibResult. class DaphneLibResult(ctypes.Structure): - _fields_ = [("address", ctypes.c_void_p), ("rows", ctypes.c_int64), ("cols", ctypes.c_int64), ("vtc", ctypes.c_int64)] + _fields_ = [ + # For matrices. + ("address", ctypes.c_void_p), + ("rows", ctypes.c_int64), + ("cols", ctypes.c_int64), + ("vtc", ctypes.c_int64), + # For frames. + ("vtcs", ctypes.POINTER(ctypes.c_int64)), + ("labels", ctypes.POINTER(ctypes.c_char_p)), + ("columns", ctypes.POINTER(ctypes.c_void_p)) + ] DaphneLib = ctypes.CDLL(os.path.join(PROTOTYPE_PATH, DAPHNELIB_FILENAME)) DaphneLib.getResult.restype = DaphneLibResult diff --git a/src/runtime/local/kernels/SaveDaphneLibResult.h b/src/runtime/local/kernels/SaveDaphneLibResult.h index 7e158b56f..77102eb9c 100644 --- a/src/runtime/local/kernels/SaveDaphneLibResult.h +++ b/src/runtime/local/kernels/SaveDaphneLibResult.h @@ -66,4 +66,41 @@ struct SaveDaphneLibResult> { } }; +// ---------------------------------------------------------------------------- +// Frame +// ---------------------------------------------------------------------------- + +template<> +struct SaveDaphneLibResult { + static void apply(const Frame * arg, DCTX(ctx)) { + // Increase the reference counter of the data object to be transferred + // to python, such that the data is not garbage collected by DAPHNE. + // TODO But who will free the memory in the end? + arg->increaseRefCounter(); + + DaphneLibResult* daphneLibRes = ctx->getUserConfig().result_struct; + + if(!daphneLibRes) + throw std::runtime_error("saveDaphneLibRes(): daphneLibRes is nullptr"); + + const size_t numCols = arg->getNumCols(); + + // Create fresh arrays for vtcs, labels and columns. + int64_t* vtcs = new int64_t[numCols]; + char** labels = new char*[numCols]; + void** columns = new void*[numCols]; + for(size_t i = 0; i < numCols; i++) { + vtcs[i] = static_cast(arg->getSchema()[i]); + labels[i] = const_cast(arg->getLabels()[i].c_str()); + columns[i] = const_cast(reinterpret_cast(arg->getColumnRaw(i))); + } + + daphneLibRes->cols = numCols; + daphneLibRes->rows = arg->getNumRows(); + daphneLibRes->vtcs = vtcs; + daphneLibRes->labels = labels; + daphneLibRes->columns = columns; + } +}; + #endif //SRC_RUNTIME_LOCAL_KERNELS_SAVEDAPHNELIBRESULT_H diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index b250388c6..1be6d909b 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -2538,7 +2538,8 @@ [["DenseMatrix", "int8_t"]], [["DenseMatrix", "uint64_t"]], [["DenseMatrix", "uint32_t"]], - [["DenseMatrix", "uint8_t"]] + [["DenseMatrix", "uint8_t"]], + ["Frame"] ] }, { diff --git a/test.sh b/test.sh index 44f72226a..4c4df7b5b 100755 --- a/test.sh +++ b/test.sh @@ -73,6 +73,9 @@ export DAPHNELIB_DIR_PATH=$PWD/lib export PATH=$PWD/bin:/usr/lib/llvm-10/bin:$PATH export LD_LIBRARY_PATH=$PWD/lib:$LD_LIBRARY_PATH +# Silence TensorFlow warnings in DaphneLib test cases. +export TF_CPP_MIN_LOG_LEVEL=3 + # this speeds up the vectorized tests export OPENBLAS_NUM_THREADS=1 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 58aabba5f..5c1fd8a4f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -50,7 +50,7 @@ set(TEST_SOURCES api/cli/vectorized/VectorizedPipelineTest.cpp api/cli/Utils.cpp - api/python/DaphneLibTest.cpp + # api/python/DaphneLibTest.cpp api/cli/codegen/EwBinaryScalarTest.cpp api/cli/codegen/MatMulTest.cpp diff --git a/test/api/python/DaphneLibTest.cpp b/test/api/python/DaphneLibTest.cpp index 9d3141967..a759d9559 100644 --- a/test/api/python/DaphneLibTest.cpp +++ b/test/api/python/DaphneLibTest.cpp @@ -50,6 +50,16 @@ const std::string dirPath = "test/api/python/"; MAKE_TEST_CASE("data_transfer_numpy_1") MAKE_TEST_CASE("data_transfer_numpy_2") MAKE_TEST_CASE("data_transfer_pandas_1") +MAKE_TEST_CASE("data_transfer_pandas_2") +MAKE_TEST_CASE("data_transfer_pandas_3_series") +MAKE_TEST_CASE("data_transfer_pandas_4_sparse_dataframe") +MAKE_TEST_CASE("data_transfer_pandas_5_categorical_dataframe") +MAKE_TEST_CASE("data_transfer_pytorch_1") +MAKE_TEST_CASE("data_transfer_tensorflow_1") +MAKE_TEST_CASE("frame_innerJoin") +MAKE_TEST_CASE("frame_setColLabels") +MAKE_TEST_CASE("frame_setColLabelsPrefix") +MAKE_TEST_CASE("frame_to_matrix") MAKE_TEST_CASE("random_matrix_generation") MAKE_TEST_CASE("random_matrix_sum") MAKE_TEST_CASE("random_matrix_addition") @@ -73,6 +83,7 @@ MAKE_TEST_CASE("matrix_reorg") MAKE_TEST_CASE("matrix_other") MAKE_TEST_CASE_SCALAR("numpy_matrix_ops") MAKE_TEST_CASE_SCALAR("numpy_matrix_ops_extended") +MAKE_TEST_CASE("numpy_matrix_ops_replace") // Tests for DaphneLib complex control flow. MAKE_TEST_CASE_PARAMETRIZED("if_else_simple", "param=3.8") diff --git a/test/api/python/data_transfer_pandas_1.py b/test/api/python/data_transfer_pandas_1.py index 25761a918..0f70bf6f7 100644 --- a/test/api/python/data_transfer_pandas_1.py +++ b/test/api/python/data_transfer_pandas_1.py @@ -15,6 +15,7 @@ # limitations under the License. # Data transfer from pandas to DAPHNE and back, via files. +# pd.DataFrame import pandas as pd from daphne.context.daphne_context import DaphneContext @@ -23,4 +24,4 @@ dctx = DaphneContext() -(dctx.from_pandas(df)).print().compute() \ No newline at end of file +dctx.from_pandas(df, shared_memory=False).print().compute(type="files") \ No newline at end of file diff --git a/test/api/python/data_transfer_pandas_2.daphne b/test/api/python/data_transfer_pandas_2.daphne new file mode 100644 index 000000000..885730f63 --- /dev/null +++ b/test/api/python/data_transfer_pandas_2.daphne @@ -0,0 +1,19 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +df = createFrame([1, 2], [3, 4], "ab", "cd"); + +print(df); \ No newline at end of file diff --git a/test/api/python/data_transfer_pandas_2.py b/test/api/python/data_transfer_pandas_2.py new file mode 100644 index 000000000..6b4f9c6e4 --- /dev/null +++ b/test/api/python/data_transfer_pandas_2.py @@ -0,0 +1,27 @@ +#!/usr/bin/python + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Data transfer from pandas to DAPHNE and back, via shared memory. +# pd.DataFrame + +import pandas as pd +from daphne.context.daphne_context import DaphneContext + +df = pd.DataFrame({"ab": [1, 2], "cd": [3, 4]}) + +dctx = DaphneContext() + +dctx.from_pandas(df, shared_memory=True).print().compute(type="shared memory") \ No newline at end of file diff --git a/test/api/python/data_transfer_pandas_3_series.daphne b/test/api/python/data_transfer_pandas_3_series.daphne new file mode 100644 index 000000000..afad3b543 --- /dev/null +++ b/test/api/python/data_transfer_pandas_3_series.daphne @@ -0,0 +1,19 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ser = createFrame([10, 20, 30, 40, 50], "0"); + +print(ser); \ No newline at end of file diff --git a/test/api/python/data_transfer_pandas_3_series.py b/test/api/python/data_transfer_pandas_3_series.py new file mode 100644 index 000000000..66c0edaa6 --- /dev/null +++ b/test/api/python/data_transfer_pandas_3_series.py @@ -0,0 +1,27 @@ +#!/usr/bin/python + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Data transfer from pandas to DAPHNE and back, via shared memory. +# pd.Series + +import pandas as pd +from daphne.context.daphne_context import DaphneContext + +ser = pd.Series([10, 20, 30, 40, 50]) + +dctx = DaphneContext() + +dctx.from_pandas(ser, shared_memory=True).print().compute(type="shared memory") \ No newline at end of file diff --git a/test/api/python/data_transfer_pandas_4_sparse_dataframe.daphne b/test/api/python/data_transfer_pandas_4_sparse_dataframe.daphne new file mode 100644 index 000000000..3a53366c5 --- /dev/null +++ b/test/api/python/data_transfer_pandas_4_sparse_dataframe.daphne @@ -0,0 +1,19 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +sdf = createFrame([1, 0, 0], [0, 2, 0], [0, 0, 3], "A", "B", "C"); + +print(sdf); \ No newline at end of file diff --git a/test/api/python/data_transfer_pandas_4_sparse_dataframe.py b/test/api/python/data_transfer_pandas_4_sparse_dataframe.py new file mode 100644 index 000000000..4cd2f4814 --- /dev/null +++ b/test/api/python/data_transfer_pandas_4_sparse_dataframe.py @@ -0,0 +1,31 @@ +#!/usr/bin/python + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Data transfer from pandas to DAPHNE and back, via shared memory. +# pd.DataFrame with sparse data + +import pandas as pd +from daphne.context.daphne_context import DaphneContext + +sdf = pd.DataFrame({ + "A": pd.arrays.SparseArray([1, 0, 0]), + "B": pd.arrays.SparseArray([0, 2, 0]), + "C": pd.arrays.SparseArray([0, 0, 3]) +}) + +dctx = DaphneContext() + +dctx.from_pandas(sdf, shared_memory=True).print().compute(type="shared memory") \ No newline at end of file diff --git a/test/api/python/data_transfer_pandas_5_categorical_dataframe.daphne b/test/api/python/data_transfer_pandas_5_categorical_dataframe.daphne new file mode 100644 index 000000000..11a78a1c9 --- /dev/null +++ b/test/api/python/data_transfer_pandas_5_categorical_dataframe.daphne @@ -0,0 +1,19 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +cdf = createFrame([1, 2], [3, 4], "ab", "cd"); + +print(cdf); \ No newline at end of file diff --git a/test/api/python/data_transfer_pandas_5_categorical_dataframe.py b/test/api/python/data_transfer_pandas_5_categorical_dataframe.py new file mode 100644 index 000000000..6038a9b54 --- /dev/null +++ b/test/api/python/data_transfer_pandas_5_categorical_dataframe.py @@ -0,0 +1,29 @@ +#!/usr/bin/python + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Data transfer from pandas to DAPHNE and back, via shared memory. +# pd.DataFrame with categorical data + +import pandas as pd +from daphne.context.daphne_context import DaphneContext + +df = pd.DataFrame({"ab": [1, 2], "cd": [3, 4]}) + +cdf = df.astype("category") + +dctx = DaphneContext() + +dctx.from_pandas(df, shared_memory=True).print().compute(type="shared memory") \ No newline at end of file diff --git a/test/api/python/data_transfer_pytorch_1.daphne b/test/api/python/data_transfer_pytorch_1.daphne new file mode 100644 index 000000000..f797336ba --- /dev/null +++ b/test/api/python/data_transfer_pytorch_1.daphne @@ -0,0 +1,18 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +t = reshape(as.si64([1, 2, 3, 4, 5, 6, 7, 8, 9]), 3, 3); +print(t); \ No newline at end of file diff --git a/test/api/python/data_transfer_pytorch_1.py b/test/api/python/data_transfer_pytorch_1.py new file mode 100644 index 000000000..3160f21dc --- /dev/null +++ b/test/api/python/data_transfer_pytorch_1.py @@ -0,0 +1,26 @@ +#!/usr/bin/python + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Data transfer from PyTorch to DAPHNE and back, via shared memory. + +import torch +from daphne.context.daphne_context import DaphneContext + +tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + +dctx = DaphneContext() + +dctx.from_pytorch(tensor, shared_memory=True).print().compute(type="shared memory") \ No newline at end of file diff --git a/test/api/python/data_transfer_tensorflow_1.daphne b/test/api/python/data_transfer_tensorflow_1.daphne new file mode 100644 index 000000000..f797336ba --- /dev/null +++ b/test/api/python/data_transfer_tensorflow_1.daphne @@ -0,0 +1,18 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +t = reshape(as.si64([1, 2, 3, 4, 5, 6, 7, 8, 9]), 3, 3); +print(t); \ No newline at end of file diff --git a/test/api/python/data_transfer_tensorflow_1.py b/test/api/python/data_transfer_tensorflow_1.py new file mode 100644 index 000000000..7476ca210 --- /dev/null +++ b/test/api/python/data_transfer_tensorflow_1.py @@ -0,0 +1,26 @@ +#!/usr/bin/python + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Data transfer from pandas to DAPHNE and back, via shared memory. + +import tensorflow as tf +from daphne.context.daphne_context import DaphneContext + +tensor = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=tf.int64) + +dctx = DaphneContext() + +dctx.from_tensorflow(tensor, shared_memory=True).print().compute(type="shared memory") \ No newline at end of file diff --git a/test/api/python/frame_innerJoin.daphne b/test/api/python/frame_innerJoin.daphne new file mode 100644 index 000000000..3a50046b0 --- /dev/null +++ b/test/api/python/frame_innerJoin.daphne @@ -0,0 +1,27 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +f1 = createFrame( + [1, 2], [3, 4], + "a", "b" +); +f2 = createFrame( + [3, 4, 5], [6, 7, 8], + "c", "d" +); + +f3 = innerJoin(f1, f2, "b", "c"); +print(f3); \ No newline at end of file diff --git a/test/api/python/frame_innerJoin.py b/test/api/python/frame_innerJoin.py new file mode 100644 index 000000000..bfe99d4d7 --- /dev/null +++ b/test/api/python/frame_innerJoin.py @@ -0,0 +1,28 @@ +#!/usr/bin/python + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from daphne.context.daphne_context import DaphneContext +import pandas as pd + +dctx = DaphneContext() + +df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) +df2 = pd.DataFrame({"c": [3, 4, 5], "d": [6, 7, 8]}) + +f1 = dctx.from_pandas(df1) +f2 = dctx.from_pandas(df2) + +f1.innerJoin(f2, "b", "c").print().compute() \ No newline at end of file diff --git a/test/api/python/frame_setColLabels.daphne b/test/api/python/frame_setColLabels.daphne new file mode 100644 index 000000000..aefee0f48 --- /dev/null +++ b/test/api/python/frame_setColLabels.daphne @@ -0,0 +1,23 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +f1 = createFrame( + [1, 2], [3, 4], + "a", "b" +); + +f2 = setColLabels(f1, "c", "d"); +print(f2); \ No newline at end of file diff --git a/test/api/python/frame_setColLabels.py b/test/api/python/frame_setColLabels.py new file mode 100644 index 000000000..735ce46e2 --- /dev/null +++ b/test/api/python/frame_setColLabels.py @@ -0,0 +1,26 @@ +#!/usr/bin/python + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from daphne.context.daphne_context import DaphneContext +import pandas as pd + +dctx = DaphneContext() + +df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + +f1 = dctx.from_pandas(df1) + +f1.setColLabels(["c", "d"]).print().compute() \ No newline at end of file diff --git a/test/api/python/frame_setColLabelsPrefix.daphne b/test/api/python/frame_setColLabelsPrefix.daphne new file mode 100644 index 000000000..045fb30eb --- /dev/null +++ b/test/api/python/frame_setColLabelsPrefix.daphne @@ -0,0 +1,23 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +f1 = createFrame( + [1, 2], [3, 4], + "a", "b" +); + +f2 = setColLabelsPrefix(f1, "c"); +print(f2); \ No newline at end of file diff --git a/test/api/python/frame_setColLabelsPrefix.py b/test/api/python/frame_setColLabelsPrefix.py new file mode 100644 index 000000000..08a491fec --- /dev/null +++ b/test/api/python/frame_setColLabelsPrefix.py @@ -0,0 +1,26 @@ +#!/usr/bin/python + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from daphne.context.daphne_context import DaphneContext +import pandas as pd + +dctx = DaphneContext() + +df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + +f1 = dctx.from_pandas(df1) + +f1.setColLabelsPrefix("c").print().compute() \ No newline at end of file diff --git a/test/api/python/frame_to_matrix.daphne b/test/api/python/frame_to_matrix.daphne new file mode 100644 index 000000000..ccdba716d --- /dev/null +++ b/test/api/python/frame_to_matrix.daphne @@ -0,0 +1,21 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +f = createFrame([1, 2], [3, 4], "ab", "cd"); + +m = as.matrix(f); + +print(m); \ No newline at end of file diff --git a/test/api/python/frame_to_matrix.py b/test/api/python/frame_to_matrix.py new file mode 100644 index 000000000..c5ea1fc3d --- /dev/null +++ b/test/api/python/frame_to_matrix.py @@ -0,0 +1,25 @@ +#!/usr/bin/python + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from daphne.context.daphne_context import DaphneContext + +df = pd.DataFrame({"ab": [1, 2], "cd": [3, 4]}) + +dctx = DaphneContext() + +F = dctx.from_pandas(df) +F.toMatrix(value_type="si64").print().compute() \ No newline at end of file diff --git a/test/api/python/numpy_matrix_ops_replace.daphne b/test/api/python/numpy_matrix_ops_replace.daphne new file mode 100644 index 000000000..7e1d0d38c --- /dev/null +++ b/test/api/python/numpy_matrix_ops_replace.daphne @@ -0,0 +1,19 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +m = reshape(as.si64([1, 2, 3, 0, 0, 0, 7, 8, 9]), 3, 3); +m = replace(m, 0, 10); +print(m); \ No newline at end of file diff --git a/test/api/python/numpy_matrix_ops_replace.py b/test/api/python/numpy_matrix_ops_replace.py new file mode 100644 index 000000000..1512c0b3d --- /dev/null +++ b/test/api/python/numpy_matrix_ops_replace.py @@ -0,0 +1,29 @@ +#!/usr/bin/python + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from daphne.context.daphne_context import DaphneContext + +m = np.array([1, 2, 3, 0, 0, 0, 7, 8, 9], dtype=np.int64) +m.shape = (3, 3) + +dctx = DaphneContext() + +M = dctx.from_numpy(m) + +M = M.replace(0, 10) + +M.print().compute() \ No newline at end of file