Skip to content

Commit

Permalink
Distributed Compressed Sparse Column Matrix (#1377)
Browse files Browse the repository at this point in the history
* modified factory method to include compressed column type

* Created a base class for both DCSR_matrix and DCSC_matrix

* Updated docstrings and type annotations. The class, methods and dense/sparse conversion is complete.

* refactoring changes

* Arithmetric operations implemented

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* PyTorch CSC tensors do not support arithmetic ops yet

* tests for csc matrix - manipulations

* tests for DCSC_matrix class methods

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tests for sparse_csc factory method

* added name to CITATION.cff

* fix: fixed dtype conversion bug in astype method

* skip type conversion test for DCSC_matrix if torch < 2.0

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Fabian Hoppe <[email protected]>
Co-authored-by: Claudia Comito <[email protected]>
  • Loading branch information
4 people authored Jun 7, 2024
1 parent ee0d72a commit 0e40d14
Show file tree
Hide file tree
Showing 12 changed files with 1,460 additions and 276 deletions.
2 changes: 2 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ preferred-citation:
given-names: Achim
- family-names: Streit
given-names: Achim
- family-names: Vaithinathan Aravindan
given-names: Ashwath
year: 2020
collection-title: 2020 IEEE International Conference on Big Data (IEEE Big Data 2020)
collection-doi: 10.1109/BigData50022.2020.9378050
Expand Down
2 changes: 1 addition & 1 deletion heat/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""add sparse heat function to the ht.sparse namespace"""

from .arithmetics import *
from .dcsr_matrix import *
from .dcsx_matrix import *
from .factories import *
from ._operations import *
from .manipulations import *
116 changes: 78 additions & 38 deletions heat/sparse/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@
import torch
import numpy as np

from heat.sparse.dcsr_matrix import DCSR_matrix
from heat.sparse.dcsx_matrix import DCSC_matrix, DCSR_matrix, __DCSX_matrix

from . import factories
from ..core.communication import MPI
from ..core.dndarray import DNDarray
from ..core import types

from typing import Callable, Optional, Dict

__all__ = []


def __binary_op_csr(
def __binary_op_csx(
operation: Callable,
t1: DCSR_matrix,
t2: DCSR_matrix,
out: Optional[DCSR_matrix] = None,
t1: __DCSX_matrix,
t2: __DCSX_matrix,
out: Optional[__DCSX_matrix] = None,
orientation: str = "row",
fn_kwargs: Optional[Dict] = {},
) -> DCSR_matrix:
) -> __DCSX_matrix:
"""
Generic wrapper for element-wise binary operations of two operands.
Takes the operation function and the two operands involved in the operation as arguments.
Expand All @@ -31,37 +31,60 @@ def __binary_op_csr(
operation : PyTorch function
The operation to be performed. Function that performs operation elements-wise on the involved tensors,
e.g. add values from other to self
t1: DCSR_matrix
t1: __DCSX_matrix or scalar
The first operand involved in the operation.
t2: DCSR_matrix
t2: __DCSX_matrix or scalar
The second operand involved in the operation.
out: DCSR_matrix, optional
out: __DCSX_matrix, optional
Output buffer in which the result is placed. If not provided, a freshly allocated matrix is returned.
orientation: str, optional
The orientation of the operation. Options: 'row' or 'col'
Default: 'row'
fn_kwargs: Dict, optional
keyword arguments used for the given operation
Default: {} (empty dictionary)
Returns
-------
result: ht.sparse.DCSR_matrix
A DCSR_matrix containing the results of element-wise operation.
result: ht.sparse.__DCSX_matrix
A __DCSX_matrix containing the results of element-wise operation.
Raises
------
ValueError
If the orientation is invalid
ValueError
If the input types are not supported
ValueError
If the input shapes are not compatible
ValueError
If the output buffer shape is not compatible with the result
"""
if not np.isscalar(t1) and not isinstance(t1, DCSR_matrix):
if orientation not in ["row", "col"]:
raise ValueError(f"Invalid orientation: '{orientation}'. Options: 'row' or 'col'")

if not np.isscalar(t1) and not isinstance(t1, __DCSX_matrix):
raise TypeError(
f"Only Dcsr_matrices and numeric scalars are supported, but input was {type(t1)}"
)
if not np.isscalar(t2) and not isinstance(t2, DCSR_matrix):
if not np.isscalar(t2) and not isinstance(t2, __DCSX_matrix):
raise TypeError(
f"Only Dcsr_matrices and numeric scalars are supported, but input was {type(t2)}"
)

if not isinstance(t1, DCSR_matrix) and not isinstance(t2, DCSR_matrix):
if not isinstance(t1, __DCSX_matrix) and not isinstance(t2, __DCSX_matrix):
raise TypeError(
f"Operator only to be used with Dcsr_matrices, but input types were {type(t1)} and {type(t2)}"
)

promoted_type = types.result_type(t1, t2).torch_type()

torch_constructor = torch.sparse_csr_tensor if orientation == "row" else torch.sparse_csc_tensor
factory_method = (
factories.sparse_csr_matrix if orientation == "row" else factories.sparse_csc_matrix
)
split_axis = 0 if orientation == "row" else 1

# If one of the inputs is a scalar
# just perform the operation on the data tensor
# and create a new sparse matrix
Expand All @@ -74,15 +97,15 @@ def __binary_op_csr(
scalar = t1

res_values = operation(matrix.larray.values().to(promoted_type), scalar, **fn_kwargs)
res_torch_sparse_csr = torch.sparse_csr_tensor(
res_torch_sparse_csx = torch_constructor(
matrix.lindptr,
matrix.lindices,
res_values,
size=matrix.lshape,
device=matrix.device.torch_device,
)
return factories.sparse_csr_matrix(
res_torch_sparse_csr, is_split=matrix.split, comm=matrix.comm, device=matrix.device
return factory_method(
res_torch_sparse_csx, is_split=matrix.split, comm=matrix.comm, device=matrix.device
)

if t1.shape != t2.shape:
Expand All @@ -93,10 +116,10 @@ def __binary_op_csr(

if t1.split is not None or t2.split is not None:
if t1.split is None:
t1 = factories.sparse_csr_matrix(t1.larray, split=0)
t1 = factory_method(t1.larray, split=split_axis)

if t2.split is None:
t2 = factories.sparse_csr_matrix(t2.larray, split=0)
t2 = factory_method(t2.larray, split=split_axis)

output_split = t1.split
output_device = t1.device
Expand All @@ -113,10 +136,10 @@ def __binary_op_csr(

if out.split != output_split:
if out.split is None:
out = factories.sparse_csr_matrix(out.larray, split=0)
out = factory_method(out.larray, split=split_axis)
else:
out = factories.sparse_csr_matrix(
torch.sparse_csr_tensor(
out = factory_method(
torch_constructor(
torch.tensor(out.indptr, dtype=torch.int64),
torch.tensor(out.indices, dtype=torch.int64),
torch.tensor(out.data),
Expand Down Expand Up @@ -146,21 +169,38 @@ def __binary_op_csr(
output_type = types.canonical_heat_type(result.dtype)

if out is None:
return DCSR_matrix(
array=torch.sparse_csr_tensor(
result.crow_indices().to(torch.int64),
result.col_indices().to(torch.int64),
result.values(),
size=output_lshape,
),
gnnz=output_gnnz,
gshape=output_shape,
dtype=output_type,
split=output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)
if orientation == "row":
return DCSR_matrix(
array=torch_constructor(
result.crow_indices().to(torch.int64),
result.col_indices().to(torch.int64),
result.values(),
size=output_lshape,
),
gnnz=output_gnnz,
gshape=output_shape,
dtype=output_type,
split=output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)
else:
return DCSC_matrix(
array=torch_constructor(
result.ccol_indices().to(torch.int64),
result.row_indices().to(torch.int64),
result.values(),
size=output_lshape,
),
gnnz=output_gnnz,
gshape=output_shape,
dtype=output_type,
split=output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)

out.larray.copy_(result)
out.gnnz = output_gnnz
Expand Down
24 changes: 15 additions & 9 deletions heat/sparse/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from .dcsr_matrix import DCSR_matrix
from .dcsx_matrix import DCSC_matrix, DCSR_matrix

from . import _operations

Expand All @@ -14,7 +14,7 @@
]


def add(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
def add(t1: DCSR_matrix, t2: DCSR_matrix, orientation: str = "row") -> DCSR_matrix:
"""
Element-wise addition of values from two operands, commutative.
Takes the first and second operand (scalar or :class:`~heat.sparse.DCSR_matrix`) whose elements are to be added
Expand All @@ -26,6 +26,9 @@ def add(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
The first operand involved in the addition
t2: DCSR_matrix
The second operand involved in the addition
orientation: str, optional
The orientation of the operation. Options: 'row' or 'col'
Default: 'row'
Examples
--------
Expand All @@ -43,16 +46,16 @@ def add(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
DNDarray([[2., 0., 4.],
[0., 0., 6.]], dtype=ht.float32, device=cpu:0, split=0)
"""
return _operations.__binary_op_csr(torch.add, t1, t2)
return _operations.__binary_op_csx(torch.add, t1, t2, orientation=orientation)


DCSR_matrix.__add__ = lambda self, other: add(self, other)
DCSR_matrix.__add__ = lambda self, other: add(self, other, orientation="row")
DCSR_matrix.__add__.__doc__ = add.__doc__
DCSR_matrix.__radd__ = lambda self, other: add(self, other)
DCSR_matrix.__radd__ = lambda self, other: add(self, other, orientation="row")
DCSR_matrix.__radd__.__doc__ = add.__doc__


def mul(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
def mul(t1: DCSR_matrix, t2: DCSR_matrix, orientation: str = "row") -> DCSR_matrix:
"""
Element-wise multiplication (NOT matrix multiplication) of values from two operands, commutative.
Takes the first and second operand (scalar or :class:`~heat.sparse.DCSR_matrix`) whose elements are to be
Expand All @@ -64,6 +67,9 @@ def mul(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
The first operand involved in the multiplication
t2: DCSR_matrix
The second operand involved in the multiplication
orientation: str, optional
The orientation of the operation. Options: 'row' or 'col'
Default: 'row'
Examples
--------
Expand All @@ -81,10 +87,10 @@ def mul(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
DNDarray([[1., 0., 4.],
[0., 0., 9.]], dtype=ht.float32, device=cpu:0, split=0)
"""
return _operations.__binary_op_csr(torch.mul, t1, t2)
return _operations.__binary_op_csx(torch.mul, t1, t2, orientation=orientation)


DCSR_matrix.__mul__ = lambda self, other: mul(self, other)
DCSR_matrix.__mul__ = lambda self, other: mul(self, other, orientation="row")
DCSR_matrix.__mul__.__doc__ = mul.__doc__
DCSR_matrix.__rmul__ = lambda self, other: mul(self, other)
DCSR_matrix.__rmul__ = lambda self, other: mul(self, other, orientation="row")
DCSR_matrix.__rmul__.__doc__ = mul.__doc__
Loading

0 comments on commit 0e40d14

Please sign in to comment.