diff --git a/.pylintrc b/.pylintrc index 8f6ec9e..f8c5e34 100644 --- a/.pylintrc +++ b/.pylintrc @@ -452,8 +452,7 @@ timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests. # List of note tags to take in consideration, separated by a comma. notes=FIXME, - XXX, - TODO + XXX # Regular expression of note tags to take in consideration. notes-rgx= diff --git a/Makefile b/Makefile index 3fd1614..3664592 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,7 @@ test-dependencies: pip install .'[test]' test: test-dependencies - pytest + uv run pytest # Build protocol buffers definitions. build_proto: diff --git a/nada_dsl/ast_util.py b/nada_dsl/ast_util.py index a25a2fa..93915f2 100644 --- a/nada_dsl/ast_util.py +++ b/nada_dsl/ast_util.py @@ -5,7 +5,7 @@ import hashlib from typing import Dict, List from sortedcontainers import SortedDict -from nada_dsl.nada_types import NadaTypeRepr, Party +from nada_dsl.nada_types import DslTypeRepr, Party from nada_dsl.source_ref import SourceRef OPERATION_ID_COUNTER = 0 @@ -41,7 +41,7 @@ class ASTOperation(ABC): id: int source_ref: SourceRef - ty: NadaTypeRepr + ty: DslTypeRepr def child_operations(self) -> List[int]: """Returns the list of identifiers of all the child operations of this operation.""" @@ -377,6 +377,28 @@ def to_mir(self): } +@dataclass +class TupleAccessorASTOperation(ASTOperation): + """AST representation of a tuple accessor operation.""" + + index: int + source: int + + def child_operations(self): + return [self.source] + + def to_mir(self): + return { + "TupleAccessor": { + "id": self.id, + "index": self.index, + "source": self.source, + "type": self.ty, + "source_ref_index": self.source_ref.to_index(), + } + } + + @dataclass class NTupleAccessorASTOperation(ASTOperation): """AST representation of a n tuple accessor operation.""" diff --git a/nada_dsl/compiler_frontend.py b/nada_dsl/compiler_frontend.py index de41075..f8c5b41 100644 --- a/nada_dsl/compiler_frontend.py +++ b/nada_dsl/compiler_frontend.py @@ -20,6 +20,7 @@ InputASTOperation, LiteralASTOperation, MapASTOperation, + TupleAccessorASTOperation, NTupleAccessorASTOperation, NadaFunctionASTOperation, NadaFunctionArgASTOperation, @@ -298,6 +299,7 @@ def process_operation( NewASTOperation, RandomASTOperation, NadaFunctionArgASTOperation, + TupleAccessorASTOperation, NTupleAccessorASTOperation, ObjectAccessorASTOperation, ), diff --git a/nada_dsl/nada_types/__init__.py b/nada_dsl/nada_types/__init__.py index f65a47f..1c30396 100644 --- a/nada_dsl/nada_types/__init__.py +++ b/nada_dsl/nada_types/__init__.py @@ -1,5 +1,6 @@ """Nada type definitions.""" +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from typing import Dict, TypeAlias, Union, Type @@ -84,7 +85,7 @@ def __init__(self, name): ] "" -NadaTypeRepr: TypeAlias = str | Dict[str, Dict] +DslTypeRepr: TypeAlias = str | Dict[str, Dict] """Type alias for the NadaType representation. This representation can be either a string ("SecretInteger") @@ -112,8 +113,9 @@ def is_numeric(self) -> bool: return self in (BaseType.INTEGER, BaseType.UNSIGNED_INTEGER) +# TODO: abstract? @dataclass -class NadaType: +class DslType: """Nada type class. This is the parent class of all nada types. @@ -144,20 +146,7 @@ def __init__(self, child: OperationType): """ self.child = child if self.child is not None: - self.child.store_in_ast(self.to_mir()) - - def to_mir(self): - """Default implementation for the Conversion of a type into MIR representation.""" - return self.__class__.class_to_mir() - - @classmethod - def class_to_mir(cls) -> str: - """Converts a class into a MIR Nada type.""" - name = cls.__name__ - # Rename public variables so they are considered as the same as literals. - if name.startswith("Public"): - name = name[len("Public") :].lstrip() - return name + self.child.store_in_ast(self.type().to_mir()) def __bool__(self): raise NotImplementedError @@ -171,3 +160,7 @@ def is_scalar(cls) -> bool: def is_literal(cls) -> bool: """Returns True if the type is a literal.""" return False + + @abstractmethod + def type(self): + """Returns a meta type for this NadaType.""" diff --git a/nada_dsl/nada_types/collections.py b/nada_dsl/nada_types/collections.py index dbba76a..e8d4d4b 100644 --- a/nada_dsl/nada_types/collections.py +++ b/nada_dsl/nada_types/collections.py @@ -1,23 +1,21 @@ """Nada Collection type definitions.""" -import copy from dataclasses import dataclass -import inspect from typing import Any, Dict, Generic, List import typing -from typing import TypeVar from nada_dsl.ast_util import ( AST_OPERATIONS, BinaryASTOperation, MapASTOperation, + TupleAccessorASTOperation, NTupleAccessorASTOperation, NewASTOperation, ObjectAccessorASTOperation, ReduceASTOperation, UnaryASTOperation, ) -from nada_dsl.nada_types import NadaType +from nada_dsl.nada_types import DslType # Wildcard import due to non-zero types from nada_dsl.nada_types.scalar_types import * # pylint: disable=W0614:wildcard-import @@ -27,9 +25,9 @@ InvalidTypeError, NotAllowedException, ) -from nada_dsl.nada_types.function import NadaFunction, nada_fn +from nada_dsl.nada_types.function import NadaFunction, create_nada_fn from nada_dsl.nada_types.generics import U, T, R -from . import AllTypes, AllTypesType, NadaTypeRepr, OperationType +from . import AllTypes, AllTypesType, DslTypeRepr, OperationType def is_primitive_integer(nada_type_str: str): @@ -47,82 +45,7 @@ def is_primitive_integer(nada_type_str: str): ) -class Collection(NadaType): - """Superclass of collection types""" - - left_type: AllTypesType - right_type: AllTypesType - contained_type: AllTypesType - - def to_mir(self): - """Convert operation wrapper to a dictionary representing its type.""" - if isinstance(self, (Array, ArrayType)): - size = {"size": self.size} if self.size else {} - contained_type = self.retrieve_inner_type() - return {"Array": {"inner_type": contained_type, **size}} - if isinstance(self, (Tuple, TupleType)): - return { - "Tuple": { - "left_type": ( - self.left_type.to_mir() - if isinstance(self.left_type, (NadaType, ArrayType, TupleType)) - else self.left_type.class_to_mir() - ), - "right_type": ( - self.right_type.to_mir() - if isinstance( - self.right_type, - (NadaType, ArrayType, TupleType), - ) - else self.right_type.class_to_mir() - ), - } - } - if isinstance(self, NTuple): - return { - "NTuple": { - "types": [ - ( - ty.to_mir() - if isinstance(ty, (NadaType, ArrayType, TupleType)) - else ty.class_to_mir() - ) - for ty in [ - type(value) - for value in self.values # pylint: disable=E1101 - ] - ] - } - } - if isinstance(self, Object): - return { - "Object": { - "types": { - name: ( - ty.to_mir() - if isinstance(ty, (NadaType, ArrayType, TupleType)) - else ty.class_to_mir() - ) - for name, ty in [ - (name, type(value)) - for name, value in self.values.items() # pylint: disable=E1101 - ] - } - } - } - raise InvalidTypeError( - f"{self.__class__.__name__} is not a valid Nada Collection" - ) - - def retrieve_inner_type(self): - """Retrieves the child type of this collection""" - if isinstance(self.contained_type, TypeVar): - return "T" - if inspect.isclass(self.contained_type): - return self.contained_type.class_to_mir() - return self.contained_type.to_mir() - - +@dataclass class Map(Generic[T, R]): """The Map operation""" @@ -186,12 +109,17 @@ def store_in_ast(self, ty): ) -@dataclass -class TupleType: +class TupleType(NadaType): """Marker type for Tuples.""" - left_type: NadaType - right_type: NadaType + is_compound = True + + def __init__(self, left_type: DslType, right_type: DslType): + self.left_type = left_type + self.right_type = right_type + + def instantiate(self, child_or_value): + return Tuple(child_or_value, self.left_type, self.right_type) def to_mir(self): """Convert a tuple object into a Nada type.""" @@ -203,7 +131,14 @@ def to_mir(self): } -class Tuple(Generic[T, U], Collection): +def _generate_accessor(ty: Any, accessor: Any) -> DslType: + if hasattr(ty, "ty") and ty.ty.is_literal(): # TODO: fix + raise TypeError("Literals are not supported in accessors") + return ty.instantiate(accessor) + + +@dataclass +class Tuple(Generic[T, U], DslType): """The Tuple type""" left_type: T @@ -216,13 +151,13 @@ def __init__(self, child, left_type: T, right_type: U): super().__init__(self.child) @classmethod - def new(cls, left_type: T, right_type: U) -> "Tuple[T, U]": + def new(cls, left_value: DslType, right_value: DslType) -> "Tuple[T, U]": """Constructs a new Tuple.""" return Tuple( - left_type=left_type, - right_type=right_type, + left_type=left_value.type(), + right_type=right_value.type(), child=TupleNew( - child=(left_type, right_type), + child=(left_value, right_value), source_ref=SourceRef.back_frame(), ), ) @@ -232,56 +167,108 @@ def generic_type(cls, left_type: U, right_type: T) -> TupleType: """Returns the generic type for this Tuple""" return TupleType(left_type=left_type, right_type=right_type) + @property + def left(self) -> DslType: + """The left element of the Tuple.""" + accessor = TupleAccessor( + index=0, + child=self, + source_ref=SourceRef.back_frame(), + ) -def _generate_accessor(value: Any, accessor: Any) -> NadaType: - ty = type(value) + return _generate_accessor(self.left_type, accessor) - if ty.is_scalar(): - if ty.is_literal(): - return value - return ty(child=accessor) - if ty == Array: - return Array( - child=accessor, - contained_type=value.contained_type, - size=value.size, - ) - if ty == NTuple: - return NTuple( - child=accessor, - values=value.values, + @property + def right(self) -> DslType: + """The right element of the Tuple.""" + accessor = TupleAccessor( + index=1, + child=self, + source_ref=SourceRef.back_frame(), ) - if ty == Object: - return Object( - child=accessor, - values=value.values, + + return _generate_accessor(self.right_type, accessor) + + def type(self): + """Metatype for Tuple""" + return TupleType(self.left_type, self.right_type) + + +@dataclass +class TupleAccessor: + """Accessor for Tuple""" + + child: Tuple + index: int + source_ref: SourceRef + + def __init__( + self, + child: Tuple, + index: int, + source_ref: SourceRef, + ): + self.id = next_operation_id() + self.child = child + self.index = index + self.source_ref = source_ref + + def store_in_ast(self, ty: object): + """Store this accessor in the AST.""" + AST_OPERATIONS[self.id] = TupleAccessorASTOperation( + id=self.id, + source=self.child.child.id, + index=self.index, + source_ref=self.source_ref, + ty=ty, ) - raise TypeError(f"Unsupported type for accessor: {ty}") -class NTuple(Collection): +class NTupleType(NadaType): + """Marker type for NTuples.""" + + is_compound = True + + def __init__(self, types: List[DslType]): + self.types = types + + def instantiate(self, child_or_value): + return NTuple(child_or_value, self.types) + + def to_mir(self): + """Convert a tuple object into a Nada type.""" + return { + "NTuple": { + "types": [ty.to_mir() for ty in self.types], + } + } + + +@dataclass +class NTuple(DslType): """The NTuple type""" - values: List[NadaType] + types: List[Any] - def __init__(self, child, values: List[NadaType]): - self.values = values + def __init__(self, child, types: List[Any]): + self.types = types self.child = child super().__init__(self.child) @classmethod - def new(cls, values: List[NadaType]) -> "NTuple": + def new(cls, values: List[Any]) -> "NTuple": """Constructs a new NTuple.""" + types = [value.type() for value in values] return NTuple( - values=values, + types=types, child=NTupleNew( child=values, source_ref=SourceRef.back_frame(), ), ) - def __getitem__(self, index: int) -> NadaType: - if index >= len(self.values): + def __getitem__(self, index: int) -> DslType: + if index >= len(self.types): raise IndexError(f"Invalid index {index} for NTuple.") accessor = NTupleAccessor( @@ -290,7 +277,11 @@ def __getitem__(self, index: int) -> NadaType: source_ref=SourceRef.back_frame(), ) - return _generate_accessor(self.values[index], accessor) + return _generate_accessor(self.types[index], accessor) + + def type(self): + """Metatype for NTuple""" + return NTupleType(self.types) @dataclass @@ -323,29 +314,49 @@ def store_in_ast(self, ty: object): ) -class Object(Collection): +class ObjectType(NadaType): + """Marker type for Objects.""" + + is_compound = True + + def __init__(self, types: Dict[str, DslType]): + self.types = types + + def to_mir(self): + """Convert an object into a Nada type.""" + return { + "Object": {"types": {name: ty.to_mir() for name, ty in self.types.items()}} + } + + def instantiate(self, child_or_value): + return Object(child_or_value, self.types) + + +@dataclass +class Object(DslType): """The Object type""" - values: Dict[str, NadaType] + types: Dict[str, Any] - def __init__(self, child, values: Dict[str, NadaType]): - self.values = values + def __init__(self, child, types: Dict[str, Any]): + self.types = types self.child = child super().__init__(self.child) @classmethod - def new(cls, values: Dict[str, NadaType]) -> "Object": + def new(cls, values: Dict[str, Any]) -> "Object": """Constructs a new Object.""" + types = {key: value.type() for key, value in values.items()} return Object( - values=values, + types=types, child=ObjectNew( child=values, source_ref=SourceRef.back_frame(), ), ) - def __getattr__(self, attr: str) -> NadaType: - if attr not in self.values: + def __getattr__(self, attr: str) -> DslType: + if attr not in self.types: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{attr}'" ) @@ -356,7 +367,11 @@ def __getattr__(self, attr: str) -> NadaType: source_ref=SourceRef.back_frame(), ) - return _generate_accessor(self.values[attr], accessor) + return _generate_accessor(self.types[attr], accessor) + + def type(self): + """Metatype for Object""" + return ObjectType(types=self.types) @dataclass @@ -389,15 +404,6 @@ def store_in_ast(self, ty: object): ) -# pylint: disable=W0511 -# TODO: remove this -def get_inner_type(inner_type): - """Utility that returns the inner type for a composite type.""" - inner_type = copy.copy(inner_type) - setattr(inner_type, "inner", None) - return inner_type - - class Zip: """The Zip operation.""" @@ -407,7 +413,7 @@ def __init__(self, left: AllTypes, right: AllTypes, source_ref: SourceRef): self.right = right self.source_ref = source_ref - def store_in_ast(self, ty: NadaTypeRepr): + def store_in_ast(self, ty: DslTypeRepr): """Store a Zip object in the AST.""" AST_OPERATIONS[self.id] = BinaryASTOperation( id=self.id, @@ -427,7 +433,7 @@ def __init__(self, child: AllTypes, source_ref: SourceRef): self.child = child self.source_ref = source_ref - def store_in_ast(self, ty: NadaTypeRepr): + def store_in_ast(self, ty: DslTypeRepr): """Store an Unzip object in the AST.""" AST_OPERATIONS[self.id] = UnaryASTOperation( id=self.id, @@ -447,7 +453,7 @@ def __init__(self, left: AllTypes, right: AllTypes, source_ref: SourceRef): self.right = right self.source_ref = source_ref - def store_in_ast(self, ty: NadaTypeRepr): + def store_in_ast(self, ty: DslTypeRepr): """Store the InnerProduct object in the AST.""" AST_OPERATIONS[self.id] = BinaryASTOperation( id=self.id, @@ -459,15 +465,23 @@ def store_in_ast(self, ty: NadaTypeRepr): ) -@dataclass -class ArrayType: +class ArrayType(NadaType): """Marker type for arrays.""" - contained_type: AllTypesType - size: int + is_compound = True + + def __init__(self, contained_type: AllTypesType, size: int): + self.contained_type = contained_type + self.size = size def to_mir(self): """Convert this generic type into a MIR Nada type.""" + # TODO size is None when array used in function argument and used @nada_fn + # So you know the type but not the size, we should stop using @nada_fn decorator + # and apply the same logic when the function gets passed to .map() or .reduce() + # so we now the size of the array + if self.size is None: + raise NotImplementedError("ArrayType.to_mir") return { "Array": { "inner_type": self.contained_type.to_mir(), @@ -475,8 +489,12 @@ def to_mir(self): } } + def instantiate(self, child_or_value): + return Array(child_or_value, self.size, self.contained_type) -class Array(Generic[T], Collection): + +@dataclass +class Array(Generic[T], DslType): """Nada Array type. This is the representation of arrays in Nada MIR. @@ -497,27 +515,32 @@ class Array(Generic[T], Collection): def __init__(self, child, size: int, contained_type: T = None): self.contained_type = ( - contained_type - if (child is None or contained_type is not None) - else get_inner_type(child) + contained_type if contained_type is not None else child.type() ) + self.size = size self.child = ( child if contained_type is not None else getattr(child, "child", None) ) if self.child is not None: - self.child.store_in_ast(self.to_mir()) + self.child.store_in_ast(self.type().to_mir()) def __iter__(self): raise NotAllowedException( "Cannot loop over a Nada Array, use functional style Array operations (map, reduce, zip)." ) + def check_not_constant(self, ty): + """Checks that a type is not a constant""" + if ty.is_constant: + raise NotAllowedException( + "functors (map and reduce) can't be called with constant args" + ) + def map(self: "Array[T]", function) -> "Array": """The map operation for Arrays.""" - nada_function = function - if not isinstance(function, NadaFunction): - nada_function = nada_fn(function) + self.check_not_constant(self.contained_type) + nada_function = create_nada_fn(function, args_ty=[self.contained_type]) return Array( size=self.size, contained_type=nada_function.return_type, @@ -526,9 +549,12 @@ def map(self: "Array[T]", function) -> "Array": def reduce(self: "Array[T]", function, initial: R) -> R: """The Reduce operation for arrays.""" - if not isinstance(function, NadaFunction): - function = nada_fn(function) - return function.return_type( + self.check_not_constant(self.contained_type) + self.check_not_constant(initial.type()) + function = create_nada_fn( + function, args_ty=[initial.type(), self.contained_type] + ) + return function.return_type.instantiate( Reduce( child=self, fn=function, @@ -543,10 +569,9 @@ def zip(self: "Array[T]", other: "Array[U]") -> "Array[Tuple[T, U]]": raise IncompatibleTypesError("Cannot zip arrays of different size") return Array( size=self.size, - contained_type=Tuple( + contained_type=TupleType( left_type=self.contained_type, right_type=other.contained_type, - child=None, ), child=Zip(left=self, right=other, source_ref=SourceRef.back_frame()), ) @@ -558,18 +583,11 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T: "Cannot do child product of arrays of different size" ) - if is_primitive_integer(self.retrieve_inner_type()) and is_primitive_integer( - other.retrieve_inner_type() + if is_primitive_integer(self.contained_type) and is_primitive_integer( + other.contained_type ): - contained_type = ( - self.contained_type - if inspect.isclass(self.contained_type) - else self.contained_type.__class__ - ) - return contained_type( - child=InnerProduct( - left=self, right=other, source_ref=SourceRef.back_frame() - ) + return self.contained_type.instantiate( + InnerProduct(left=self, right=other, source_ref=SourceRef.back_frame()) ) # type: ignore raise InvalidTypeError( @@ -587,7 +605,7 @@ def new(cls, *args) -> "Array[T]": raise TypeError("All arguments must be of the same type") return Array( - contained_type=first_arg, + contained_type=first_arg.type(), size=len(args), child=ArrayNew( child=args, @@ -595,12 +613,12 @@ def new(cls, *args) -> "Array[T]": ), ) - @classmethod - def init_as_template_type(cls, contained_type) -> "Array[T]": - """Construct an empty template array with the given child type.""" - return Array(child=None, contained_type=contained_type, size=None) + def type(self): + """Metatype for Array""" + return ArrayType(self.contained_type, self.size) +@dataclass class TupleNew(Generic[T, U]): """MIR Tuple new operation. @@ -626,16 +644,17 @@ def store_in_ast(self, ty: object): ) +@dataclass class NTupleNew: """MIR NTuple new operation. Represents the creation of a new Tuple. """ - child: List[NadaType] + child: List[DslType] source_ref: SourceRef - def __init__(self, child: List[NadaType], source_ref: SourceRef): + def __init__(self, child: List[DslType], source_ref: SourceRef): self.id = next_operation_id() self.child = child self.source_ref = source_ref @@ -651,16 +670,17 @@ def store_in_ast(self, ty: object): ) +@dataclass class ObjectNew: """MIR Object new operation. Represents the creation of a new Object. """ - child: Dict[str, NadaType] + child: Dict[str, DslType] source_ref: SourceRef - def __init__(self, child: Dict[str, NadaType], source_ref: SourceRef): + def __init__(self, child: Dict[str, DslType], source_ref: SourceRef): self.id = next_operation_id() self.child = child self.source_ref = source_ref @@ -692,6 +712,7 @@ def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]: ) +@dataclass class ArrayNew(Generic[T]): """MIR Array new operation""" @@ -703,7 +724,7 @@ def __init__(self, child: List[T], source_ref: SourceRef): self.child = child self.source_ref = source_ref - def store_in_ast(self, ty: NadaType): + def store_in_ast(self, ty: DslType): """Store this ArrayNew object in the AST.""" AST_OPERATIONS[self.id] = NewASTOperation( id=self.id, diff --git a/nada_dsl/nada_types/function.py b/nada_dsl/nada_types/function.py index 4209d8d..f9454a4 100644 --- a/nada_dsl/nada_types/function.py +++ b/nada_dsl/nada_types/function.py @@ -6,7 +6,6 @@ import inspect from dataclasses import dataclass from typing import Generic, List, Callable -from copy import copy from nada_dsl import SourceRef from nada_dsl.ast_util import ( AST_OPERATIONS, @@ -16,9 +15,7 @@ next_operation_id, ) from nada_dsl.nada_types.generics import T, R -from nada_dsl.nada_types import Mode, NadaType -from nada_dsl.nada_types.scalar_types import ScalarType -from nada_dsl.errors import NotAllowedException +from nada_dsl.nada_types import DslType class NadaFunctionArg(Generic[T]): @@ -53,8 +50,6 @@ class NadaFunction(Generic[T, R]): Represents a Nada Function. Nada functions are special types of functions that are used in map / reduce operations. - - They are decorated using the `@nada_fn` decorator. """ id: int @@ -70,22 +65,8 @@ def __init__( function: Callable[[T], R], return_type: R, source_ref: SourceRef, - child: NadaType, + child: DslType, ): - if issubclass(return_type, ScalarType) and return_type.mode == Mode.CONSTANT: - raise NotAllowedException( - "Nada functions with literal return types are not allowed" - ) - # Nada functions with literal argument types are not supported. - # This is because the compiler consolidates operations between literals. - if all( - issubclass(arg.type.__class__, ScalarType) - and arg.type.mode == Mode.CONSTANT - for arg in args - ): - raise NotAllowedException( - "Nada functions with literal argument types are not allowed" - ) self.child = child self.id = function_id self.args = args @@ -101,7 +82,7 @@ def store_in_ast(self): name=self.function.__name__, args=[arg.id for arg in self.args], id=self.id, - ty=self.return_type.class_to_mir(), + ty=self.return_type.to_mir(), source_ref=self.source_ref, child=self.child.child.id, ) @@ -117,7 +98,7 @@ class NadaFunctionCall(Generic[R]): """Represents a call to a Nada Function.""" fn: NadaFunction - args: List[NadaType] + args: List[DslType] source_ref: SourceRef def __init__(self, nada_function, args, source_ref): @@ -125,7 +106,7 @@ def __init__(self, nada_function, args, source_ref): self.args = args self.fn = nada_function self.source_ref = source_ref - self.store_in_ast(nada_function.return_type.class_to_mir()) + self.store_in_ast(nada_function.return_type.type().to_mir()) def store_in_ast(self, ty): """Store this function call in the AST.""" @@ -138,20 +119,7 @@ def store_in_ast(self, ty): ) -def contained_types(ty): - """Utility function that calculates the child type for a function argument.""" - - origin_ty = getattr(ty, "__origin__", ty) - if not issubclass(origin_ty, ScalarType): - inner_ty = getattr(ty, "__args__", None) - inner_ty = contained_types(inner_ty[0]) if inner_ty else T - return origin_ty.init_as_template_type(inner_ty) - if origin_ty.mode == Mode.CONSTANT: - return origin_ty(value=0) - return origin_ty(child=None) - - -def nada_fn(fn, args_ty=None, return_ty=None) -> NadaFunction[T, R]: +def create_nada_fn(fn, args_ty) -> NadaFunction[T, R]: """ Can be used also for lambdas ```python @@ -165,28 +133,21 @@ def nada_fn(fn, args_ty=None, return_ty=None) -> NadaFunction[T, R]: args = inspect.getfullargspec(fn) nada_args = [] function_id = next_operation_id() - for arg in args.args: - arg_type = args_ty[arg] if args_ty else args.annotations[arg] - arg_type = contained_types(arg_type) + nada_args_type_wrapped = [] + for arg, arg_ty in zip(args.args, args_ty): # We'll get the function source ref for now nada_arg = NadaFunctionArg( function_id, name=arg, - arg_type=arg_type, + arg_type=arg_ty, source_ref=SourceRef.back_frame(), ) nada_args.append(nada_arg) - - nada_args_type_wrapped = [] - - for arg in nada_args: - arg_type = copy(arg.type) - arg_type.child = arg - nada_args_type_wrapped.append(arg_type) + nada_args_type_wrapped.append(arg_ty.instantiate(nada_arg)) child = fn(*nada_args_type_wrapped) - return_type = return_ty if return_ty else args.annotations["return"] + return_type = child.type() return NadaFunction( function_id, function=fn, diff --git a/nada_dsl/nada_types/generics.py b/nada_dsl/nada_types/generics.py index 6923c60..ab53633 100644 --- a/nada_dsl/nada_types/generics.py +++ b/nada_dsl/nada_types/generics.py @@ -2,8 +2,8 @@ from typing import TypeVar -from nada_dsl.nada_types import NadaType +from nada_dsl.nada_types import DslType -R = TypeVar("R", bound=NadaType) -T = TypeVar("T", bound=NadaType) -U = TypeVar("U", bound=NadaType) +R = TypeVar("R", bound=DslType) +T = TypeVar("T", bound=DslType) +U = TypeVar("U", bound=DslType) diff --git a/nada_dsl/nada_types/scalar_types.py b/nada_dsl/nada_types/scalar_types.py index 3d40991..7980150 100644 --- a/nada_dsl/nada_types/scalar_types.py +++ b/nada_dsl/nada_types/scalar_types.py @@ -1,13 +1,14 @@ # pylint:disable=W0401,W0614 """The Nada Scalar type definitions.""" +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Union, TypeVar from typing_extensions import Self from nada_dsl.operations import * from nada_dsl.program_io import Literal from nada_dsl import SourceRef -from . import NadaType, Mode, BaseType, OperationType +from . import DslType, Mode, BaseType, OperationType # Constant dictionary that stores all the Nada types and is use to # convert from the (mode, base_type) representation to the concrete Nada type @@ -46,7 +47,7 @@ AnyBoolean = Union["Boolean", "PublicBoolean", "SecretBoolean"] -class ScalarType(NadaType): +class ScalarDslType(DslType): """The Nada Scalar type. This is the super class for all scalar types in Nada. These are: @@ -82,7 +83,7 @@ def is_scalar(cls) -> bool: def equals_operation( - operation, operator, left: ScalarType, right: ScalarType, f + operation, operator, left: ScalarDslType, right: ScalarDslType, f ) -> AnyBoolean: """This function is an abstraction for the equality operations""" base_type = left.base_type @@ -122,7 +123,7 @@ def new_scalar_type(mode: Mode, base_type: BaseType) -> type[AnyScalarType]: return SCALAR_TYPES[(mode, base_type)] -class NumericType(ScalarType): +class NumericDslType(ScalarDslType): """The superclass of all the numeric types in Nada: - Integer, PublicInteger, SecretInteger - UnsignedInteger, PublicUnsignedInteger, SecretUnsignedInteger @@ -213,8 +214,8 @@ def __radd__(self, other): def binary_arithmetic_operation( - operation, operator, left: ScalarType, right: ScalarType, f -) -> ScalarType: + operation, operator, left: ScalarDslType, right: ScalarDslType, f +) -> ScalarDslType: """This function is an abstraction for the binary arithmetic operations. Arithmetic operations apply to Numeric types only in Nada.""" @@ -233,8 +234,8 @@ def binary_arithmetic_operation( def shift_operation( - operation, operator, left: ScalarType, right: ScalarType, f -) -> ScalarType: + operation, operator, left: ScalarDslType, right: ScalarDslType, f +) -> ScalarDslType: """This function is an abstraction for the shift operations""" base_type = left.base_type right_base_type = right.base_type @@ -255,7 +256,7 @@ def shift_operation( def binary_relational_operation( - operation, operator, left: ScalarType, right: ScalarType, f + operation, operator, left: ScalarDslType, right: ScalarDslType, f ) -> AnyBoolean: """This function is an abstraction for the binary relational operations""" base_type = left.base_type @@ -272,7 +273,9 @@ def binary_relational_operation( return new_scalar_type(mode, BaseType.BOOLEAN)(child) # type: ignore -def public_equals_operation(left: ScalarType, right: ScalarType) -> "PublicBoolean": +def public_equals_operation( + left: ScalarDslType, right: ScalarDslType +) -> "PublicBoolean": """This function is an abstraction for the public_equals operation for all types.""" base_type = left.base_type if base_type != right.base_type: @@ -287,7 +290,7 @@ def public_equals_operation(left: ScalarType, right: ScalarType) -> "PublicBoole ) -class BooleanType(ScalarType): +class BooleanDslType(ScalarDslType): """This abstraction represents all boolean types: - Boolean, PublicBoolean, SecretBoolean It provides common operation implementations for all the boolean types, defined above. @@ -327,8 +330,8 @@ def if_else(self, arg_0: _AnyScalarType, arg_1: _AnyScalarType) -> _AnyScalarTyp def binary_logical_operation( - operation, operator, left: ScalarType, right: ScalarType, f -) -> ScalarType: + operation, operator, left: ScalarDslType, right: ScalarDslType, f +) -> ScalarDslType: """This function is an abstraction for the logical operations.""" base_type = left.base_type if base_type != right.base_type or not base_type == BaseType.BOOLEAN: @@ -348,8 +351,42 @@ def binary_logical_operation( return SecretBoolean(child=operation) +class NadaType(ABC): + """Abstract meta type""" + + is_constant = False + is_scalar = False + is_compound = False + + @abstractmethod + def instantiate(self, child_or_value): + """Creates a value corresponding to this meta type""" + + @abstractmethod + def to_mir(self): + """Returns a MIR representation of this meta type""" + + +class TypePassthroughMixin(NadaType): + """Mixin for meta types""" + + def instantiate(self, child_or_value): + """Creates a value corresponding to this meta type""" + return self.ty(child_or_value) + + def to_mir(self): + name = self.ty.__name__ + # Rename public variables so they are considered as the same as literals. + if name.startswith("Public"): + name = name[len("Public") :].lstrip() + + if name.endswith("Type"): + name = name[: -len("Type")].rstrip() + return name + + @register_scalar_type(Mode.CONSTANT, BaseType.INTEGER) -class Integer(NumericType): +class Integer(NumericDslType): """The Nada Integer type. Represents a constant (literal) integer.""" @@ -364,16 +401,27 @@ def __init__(self, value): self.value = value def __eq__(self, other) -> AnyBoolean: - return ScalarType.__eq__(self, other) + return ScalarDslType.__eq__(self, other) @classmethod def is_literal(cls) -> bool: return True + def type(self): + return IntegerType() + + +class IntegerType(TypePassthroughMixin): + """Meta type for integers""" + + ty = Integer + is_constant = True + is_scalar = True + @dataclass @register_scalar_type(Mode.CONSTANT, BaseType.UNSIGNED_INTEGER) -class UnsignedInteger(NumericType): +class UnsignedInteger(NumericDslType): """The Nada Unsigned Integer type. Represents a constant (literal) unsigned integer.""" @@ -390,15 +438,26 @@ def __init__(self, value): self.value = value def __eq__(self, other) -> AnyBoolean: - return ScalarType.__eq__(self, other) + return ScalarDslType.__eq__(self, other) @classmethod def is_literal(cls) -> bool: return True + def type(self): + return UnsignedIntegerType() + + +class UnsignedIntegerType(TypePassthroughMixin): + """Meta type for unsigned integers""" + + ty = UnsignedInteger + is_constant = True + is_scalar = True + @register_scalar_type(Mode.CONSTANT, BaseType.BOOLEAN) -class Boolean(BooleanType): +class Boolean(BooleanDslType): """The Nada Boolean type. Represents a constant (literal) boolean.""" @@ -418,7 +477,7 @@ def __bool__(self) -> bool: return self.value def __eq__(self, other) -> AnyBoolean: - return ScalarType.__eq__(self, other) + return ScalarDslType.__eq__(self, other) def __invert__(self: "Boolean") -> "Boolean": return Boolean(value=bool(not self.value)) @@ -427,19 +486,30 @@ def __invert__(self: "Boolean") -> "Boolean": def is_literal(cls) -> bool: return True + def type(self): + return BooleanType() + + +class BooleanType(TypePassthroughMixin): + """Meta type for booleans""" + + ty = Boolean + is_constant = True + is_scalar = True + @register_scalar_type(Mode.PUBLIC, BaseType.INTEGER) -class PublicInteger(NumericType): +class PublicInteger(NumericDslType): """The Nada Public Unsigned Integer type. Represents a public unsigned integer in a program. This is a public variable evaluated at runtime.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.INTEGER, Mode.PUBLIC) def __eq__(self, other) -> AnyBoolean: - return ScalarType.__eq__(self, other) + return ScalarDslType.__eq__(self, other) def public_equals( self, other: Union["PublicInteger", "SecretInteger"] @@ -447,19 +517,29 @@ def public_equals( """Implementation of public equality for Public integer types.""" return public_equals_operation(self, other) + def type(self): + return PublicIntegerType() + + +class PublicIntegerType(TypePassthroughMixin): + """Meta type for public integers""" + + ty = PublicInteger + is_scalar = True + @register_scalar_type(Mode.PUBLIC, BaseType.UNSIGNED_INTEGER) -class PublicUnsignedInteger(NumericType): +class PublicUnsignedInteger(NumericDslType): """The Nada Public Integer type. Represents a public integer in a program. This is a public variable evaluated at runtime.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.UNSIGNED_INTEGER, Mode.PUBLIC) def __eq__(self, other) -> AnyBoolean: - return ScalarType.__eq__(self, other) + return ScalarDslType.__eq__(self, other) def public_equals( self, other: Union["PublicUnsignedInteger", "SecretUnsignedInteger"] @@ -467,20 +547,30 @@ def public_equals( """Implementation of public equality for Public unsigned integer types.""" return public_equals_operation(self, other) + def type(self): + return PublicUnsignedIntegerType() + + +class PublicUnsignedIntegerType(TypePassthroughMixin): + """Meta type for public unsigned integers""" + + ty = PublicUnsignedInteger + is_scalar = True + @dataclass @register_scalar_type(Mode.PUBLIC, BaseType.BOOLEAN) -class PublicBoolean(BooleanType): +class PublicBoolean(BooleanDslType): """The Nada Public Boolean type. Represents a public boolean in a program. This is a public variable evaluated at runtime.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.BOOLEAN, Mode.PUBLIC) def __eq__(self, other) -> AnyBoolean: - return ScalarType.__eq__(self, other) + return ScalarDslType.__eq__(self, other) def __invert__(self: "PublicBoolean") -> "PublicBoolean": operation = Not(this=self, source_ref=SourceRef.back_frame()) @@ -492,17 +582,27 @@ def public_equals( """Implementation of public equality for Public boolean types.""" return public_equals_operation(self, other) + def type(self): + return PublicBooleanType() + + +class PublicBooleanType(TypePassthroughMixin): + """Meta type for public booleans""" + + ty = PublicBoolean + is_scalar = True + @dataclass @register_scalar_type(Mode.SECRET, BaseType.INTEGER) -class SecretInteger(NumericType): +class SecretInteger(NumericDslType): """The Nada secret integer type.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.INTEGER, Mode.SECRET) def __eq__(self, other) -> AnyBoolean: - return ScalarType.__eq__(self, other) + return ScalarDslType.__eq__(self, other) def public_equals( self, other: Union["PublicInteger", "SecretInteger"] @@ -537,17 +637,27 @@ def to_public(self: "SecretInteger") -> "PublicInteger": operation = Reveal(this=self, source_ref=SourceRef.back_frame()) return PublicInteger(child=operation) + def type(self): + return SecretIntegerType() + + +class SecretIntegerType(TypePassthroughMixin): + """Meta type for secret integers""" + + ty = SecretInteger + is_scalar = True + @dataclass @register_scalar_type(Mode.SECRET, BaseType.UNSIGNED_INTEGER) -class SecretUnsignedInteger(NumericType): +class SecretUnsignedInteger(NumericDslType): """The Nada Secret Unsigned integer type.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.UNSIGNED_INTEGER, Mode.SECRET) def __eq__(self, other) -> AnyBoolean: - return ScalarType.__eq__(self, other) + return ScalarDslType.__eq__(self, other) def public_equals( self, other: Union["PublicUnsignedInteger", "SecretUnsignedInteger"] @@ -584,17 +694,27 @@ def to_public( operation = Reveal(this=self, source_ref=SourceRef.back_frame()) return PublicUnsignedInteger(child=operation) + def type(self): + return SecretUnsignedIntegerType() + + +class SecretUnsignedIntegerType(TypePassthroughMixin): + """Meta type for secret unsigned integers""" + + ty = SecretUnsignedInteger + is_scalar = True + @dataclass @register_scalar_type(Mode.SECRET, BaseType.BOOLEAN) -class SecretBoolean(BooleanType): +class SecretBoolean(BooleanDslType): """The SecretBoolean Nada MIR type.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.BOOLEAN, Mode.SECRET) def __eq__(self, other) -> AnyBoolean: - return ScalarType.__eq__(self, other) + return ScalarDslType.__eq__(self, other) def __invert__(self: "SecretBoolean") -> "SecretBoolean": operation = Not(this=self, source_ref=SourceRef.back_frame()) @@ -610,25 +730,53 @@ def random(cls) -> "SecretBoolean": """Generate a random secret boolean.""" return SecretBoolean(child=Random(source_ref=SourceRef.back_frame())) + def type(self): + return SecretBooleanType() + + +class SecretBooleanType(TypePassthroughMixin): + """Meta type for secret booleans""" + + ty = SecretBoolean + is_scalar = True + @dataclass -class EcdsaSignature(NadaType): +class EcdsaSignature(DslType): """The EcdsaSignature Nada MIR type.""" def __init__(self, child: OperationType): super().__init__(child=child) + def type(self): + return EcdsaSignatureType() + + +class EcdsaSignatureType(TypePassthroughMixin): + """Meta type for EcdsaSignatures""" + + ty = EcdsaSignature + @dataclass -class EcdsaDigestMessage(NadaType): +class EcdsaDigestMessage(DslType): """The EcdsaDigestMessage Nada MIR type.""" def __init__(self, child: OperationType): super().__init__(child=child) + def type(self): + return EcdsaDigestMessageType() + + +class EcdsaDigestMessageType(TypePassthroughMixin): + """Meta type for EcdsaDigestMessages""" + + ty = EcdsaDigestMessage + @dataclass -class EcdsaPrivateKey(NadaType): +class EcdsaPrivateKey(DslType): """The EcdsaPrivateKey Nada MIR type.""" def __init__(self, child: OperationType): @@ -639,3 +787,12 @@ def ecdsa_sign(self, digest: "EcdsaDigestMessage") -> "EcdsaSignature": return EcdsaSignature( child=EcdsaSign(left=self, right=digest, source_ref=SourceRef.back_frame()) ) + + def type(self): + return EcdsaPrivateKeyType() + + +class EcdsaPrivateKeyType(TypePassthroughMixin): + """Meta type for EcdsaPrivateKeys""" + + ty = EcdsaPrivateKey diff --git a/nada_dsl/program_io.py b/nada_dsl/program_io.py index cd4de48..9848ab6 100644 --- a/nada_dsl/program_io.py +++ b/nada_dsl/program_io.py @@ -15,11 +15,11 @@ ) from nada_dsl.errors import InvalidTypeError from nada_dsl.nada_types import AllTypes, Party -from nada_dsl.nada_types import NadaType +from nada_dsl.nada_types import DslType from nada_dsl.source_ref import SourceRef -class Input(NadaType): +class Input(DslType): """ Represents an input to the computation. @@ -55,7 +55,8 @@ def store_in_ast(self, ty: object): ) -class Literal(NadaType): +@dataclass +class Literal(DslType): """ Represents a literal value. @@ -102,7 +103,7 @@ class Output: def __init__(self, child, name, party): self.source_ref = SourceRef.back_frame() - if not issubclass(type(child), NadaType): + if not issubclass(type(child), DslType): raise InvalidTypeError( f"{self.source_ref.file}:{self.source_ref.lineno}: Output value " f"{child} of type {type(child)} is not " diff --git a/test-programs/map_simple.py b/test-programs/map_simple.py index bc72c11..2b7b33c 100644 --- a/test-programs/map_simple.py +++ b/test-programs/map_simple.py @@ -6,7 +6,6 @@ def nada_main(): my_array_1 = Array(SecretInteger(Input(name="my_array_1", party=party1)), size=3) my_int = SecretInteger(Input(name="my_int", party=party1)) - @nada_fn def inc(a: SecretInteger) -> SecretInteger: return a + my_int diff --git a/test-programs/nada_fn_literal.py b/test-programs/nada_fn_literal.py deleted file mode 100644 index f7258ad..0000000 --- a/test-programs/nada_fn_literal.py +++ /dev/null @@ -1,12 +0,0 @@ -from nada_dsl import * - - -def nada_main(): - party1 = Party(name="Party1") - - @nada_fn - def add(a: Integer, b: Integer) -> Integer: - return a + b - - new_int = add(Integer(2), Integer(-5)) - return [Output(new_int, "my_output", party1)] diff --git a/test-programs/nada_fn_simple.py b/test-programs/nada_fn_simple.py deleted file mode 100644 index cf53bbf..0000000 --- a/test-programs/nada_fn_simple.py +++ /dev/null @@ -1,14 +0,0 @@ -from nada_dsl import * - - -def nada_main(): - party1 = Party(name="Party1") - my_int1 = SecretInteger(Input(name="my_int1", party=party1)) - my_int2 = SecretInteger(Input(name="my_int2", party=party1)) - - @nada_fn - def add(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b - - new_int = add(my_int1, my_int2) - return [Output(new_int, "my_output", party1)] diff --git a/test-programs/ntuple_accessor.py b/test-programs/ntuple_accessor.py index 6edef15..08640b2 100644 --- a/test-programs/ntuple_accessor.py +++ b/test-programs/ntuple_accessor.py @@ -9,16 +9,23 @@ def nada_main(): array = Array.new(my_int1, my_int1) # Store a scalar, a compound type and a literal. - tuple = NTuple.new([my_int1, array, Integer(42)]) + tup = NTuple.new([my_int1, array, my_int2]) - scalar = tuple[0] - array = tuple[1] - literal = tuple[2] + scalar = tup[0] + array = tup[1] + scalar2 = tup[2] - @nada_fn - def add(a: PublicInteger) -> PublicInteger: - return a + my_int2 + def add(acc: PublicInteger, a: PublicInteger) -> PublicInteger: + return a + acc - sum = array.reduce(add, Integer(0)) + result = array.reduce(add, my_int1) - return [Output(scalar + literal + sum, "my_output", party1)] + scalar_sum = scalar + scalar2 + + final = result + scalar_sum + + return [Output(final, "my_output", party1)] + + +if __name__ == "__main__": + nada_main() diff --git a/test-programs/object_accessor.py b/test-programs/object_accessor.py index 0258b8e..8c006b5 100644 --- a/test-programs/object_accessor.py +++ b/test-programs/object_accessor.py @@ -9,16 +9,15 @@ def nada_main(): array = Array.new(my_int1, my_int1) # Store a scalar, a compound type and a literal. - object = Object.new({"a": my_int1, "b": array, "c": Integer(42)}) + object = Object.new({"a": my_int1, "b": array, "c": my_int2}) scalar = object.a array = object.b - literal = object.c + scalar_2 = object.c - @nada_fn - def add(a: PublicInteger) -> PublicInteger: - return a + my_int2 + def add(acc: PublicInteger, a: PublicInteger) -> PublicInteger: + return acc + a - sum = array.reduce(add, Integer(0)) + sum = array.reduce(add, my_int2) - return [Output(scalar + literal + sum, "my_output", party1)] + return [Output(scalar + scalar_2 + sum, "my_output", party1)] diff --git a/tests/compile_test.py b/tests/compile_test.py index 225e94f..3045d29 100644 --- a/tests/compile_test.py +++ b/tests/compile_test.py @@ -29,14 +29,6 @@ def get_test_programs_folder(): return this_directory + "../test-programs/" -def test_compile_nada_fn_simple(): - mir_str = compile_script(f"{get_test_programs_folder()}/nada_fn_simple.py").mir - assert mir_str != "" - mir = json.loads(mir_str) - mir_functions = mir["functions"] - assert len(mir_functions) == 1 - - def test_compile_sum_integers(): mir_str = compile_script(f"{get_test_programs_folder()}/sum_integers.py").mir assert mir_str != "" @@ -88,11 +80,9 @@ def nada_main(): my_int1 = SecretInteger(Input(name="my_int1", party=party1)) my_int2 = SecretInteger(Input(name="my_int2", party=party1)) - @nada_fn def add(a: SecretInteger, b: SecretInteger) -> SecretInteger: return a + b - @nada_fn def add_times(a: SecretInteger, b: SecretInteger) -> SecretInteger: return a * add(a, b) @@ -104,33 +94,6 @@ def add_times(a: SecretInteger, b: SecretInteger) -> SecretInteger: compile_string(encoded_program_str) -# TODO recursive programs fail with `NameError` for now. This is incorrect. -def test_compile_program_with_recursion(): - program_str = """from nada_dsl import * - -def nada_main(): - party1 = Party(name="Party1") - my_int1 = SecretInteger(Input(name="my_int1", party=party1)) - my_int2 = SecretInteger(Input(name="my_int2", party=party1)) - - @nada_fn - def add_times(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a * add_times(a, b) - - new_int = add_times(my_int1, my_int2) - return [Output(new_int, "my_output", party1)] -""" - encoded_program_str = base64.b64encode(bytes(program_str, "utf-8")).decode("utf_8") - - with pytest.raises(NameError): - compile_string(encoded_program_str) - - -def test_compile_nada_fn_literals(): - with pytest.raises(NotAllowedException): - mir_str = compile_script(f"{get_test_programs_folder()}/nada_fn_literal.py").mir - - def test_compile_map_simple(): mir_str = compile_script(f"{get_test_programs_folder()}/map_simple.py").mir assert mir_str != "" diff --git a/tests/compiler_frontend_test.py b/tests/compiler_frontend_test.py index e4a2199..5b664d0 100644 --- a/tests/compiler_frontend_test.py +++ b/tests/compiler_frontend_test.py @@ -31,7 +31,11 @@ ) from nada_dsl.nada_types import AllTypes, Party from nada_dsl.nada_types.collections import Array, Tuple, NTuple, Object, unzip -from nada_dsl.nada_types.function import NadaFunctionArg, NadaFunctionCall, nada_fn +from nada_dsl.nada_types.function import ( + NadaFunctionArg, + NadaFunctionCall, + create_nada_fn, +) @pytest.fixture(autouse=True) @@ -125,7 +129,7 @@ def test_duplicated_inputs_checks(): def test_array_type_conversion(input_type, type_name, size): inner_input = create_input(SecretInteger, "name", "party", **{}) collection = create_collection(input_type, inner_input, size, **{}) - converted_input = collection.to_mir() + converted_input = collection.type().to_mir() assert list(converted_input.keys()) == [type_name] @@ -197,7 +201,6 @@ def test_unzip(input_type: type[Array]): ], ) def test_map(input_type, input_name): - @nada_fn def nada_function(a: SecretInteger) -> SecretInteger: return a + a @@ -227,7 +230,6 @@ def nada_function(a: SecretInteger) -> SecretInteger: def test_reduce(input_type: type[Array]): c = create_input(SecretInteger, "c", "party", **{}) - @nada_fn def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: return a + b @@ -262,191 +264,6 @@ def check_nada_function_arg_ref(arg_ref, function_id, name, ty): assert arg_ref["NadaFunctionArgRef"]["type"] == ty -def nada_function_to_mir(function_name: str): - nada_function: NadaFunctionASTOperation = find_function_in_ast(function_name) - assert isinstance(nada_function, NadaFunctionASTOperation) - fn_ops = {} - traverse_and_process_operations(nada_function.child, fn_ops, {}) - return nada_function.to_mir(fn_ops) - - -def test_nada_function_simple(): - @nada_fn - def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b - - nada_function = nada_function_to_mir("nada_function") - assert nada_function["function"] == "nada_function" - args = nada_function["args"] - assert len(args) == 2 - check_arg(args[0], "a", "SecretInteger") - check_arg(args[1], "b", "SecretInteger") - assert nada_function["return_type"] == "SecretInteger" - - operations = nada_function["operations"] - return_op = operations[nada_function["return_operation_id"]] - assert list(return_op.keys()) == ["Addition"] - addition = return_op["Addition"] - - check_nada_function_arg_ref( - operations[addition["left"]], nada_function["id"], "a", "SecretInteger" - ) - check_nada_function_arg_ref( - operations[addition["right"]], nada_function["id"], "b", "SecretInteger" - ) - - -def test_nada_function_using_inputs(): - c = create_input(SecretInteger, "c", "party", **{}) - - @nada_fn - def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b + c - - nada_function = nada_function_to_mir("nada_function") - assert nada_function["function"] == "nada_function" - args = nada_function["args"] - assert len(args) == 2 - check_arg(args[0], "a", "SecretInteger") - check_arg(args[1], "b", "SecretInteger") - assert nada_function["return_type"] == "SecretInteger" - - operation = nada_function["operations"] - return_op = operation[nada_function["return_operation_id"]] - - assert list(return_op.keys()) == ["Addition"] - addition = return_op["Addition"] - addition_right = operation[addition["right"]] - assert input_reference(addition_right) == "c" - addition_left = operation[addition["left"]] - assert list(addition_left.keys()) == ["Addition"] - - addition = addition_left["Addition"] - - check_nada_function_arg_ref( - operation[addition["left"]], nada_function["id"], "a", "SecretInteger" - ) - check_nada_function_arg_ref( - operation[addition["right"]], nada_function["id"], "b", "SecretInteger" - ) - - -def test_nada_function_call(): - c = create_input(SecretInteger, "c", "party", **{}) - d = create_input(SecretInteger, "c", "party", **{}) - - @nada_fn - def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b - - nada_fn_call_return = nada_function(c, d) - nada_fn_type = nada_function_to_mir("nada_function") - - nada_function_call = nada_fn_call_return.child - assert isinstance(nada_function_call, NadaFunctionCall) - assert nada_function_call.fn.id == nada_fn_type["id"] - - -def test_nada_function_using_operations(): - c = create_input(SecretInteger, "c", "party", **{}) - d = create_input(SecretInteger, "d", "party", **{}) - - @nada_fn - def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b + c + d - - nada_function_ast = nada_function_to_mir("nada_function") - assert nada_function_ast["function"] == "nada_function" - args = nada_function_ast["args"] - assert len(args) == 2 - check_arg(args[0], "a", "SecretInteger") - check_arg(args[1], "b", "SecretInteger") - assert nada_function_ast["return_type"] == "SecretInteger" - - operation = nada_function_ast["operations"] - return_op = operation[nada_function_ast["return_operation_id"]] - - assert list(return_op.keys()) == ["Addition"] - addition = return_op["Addition"] - - assert input_reference(operation[addition["right"]]) == "d" - addition_left = operation[addition["left"]] - assert list(addition_left.keys()) == ["Addition"] - addition = addition_left["Addition"] - assert input_reference(operation[addition["right"]]) == "c" - - addition_left = operation[addition["left"]] - assert list(addition_left.keys()) == ["Addition"] - addition = addition_left["Addition"] - - check_nada_function_arg_ref( - operation[addition["left"]], nada_function_ast["id"], "a", "SecretInteger" - ) - check_nada_function_arg_ref( - operation[addition["right"]], nada_function_ast["id"], "b", "SecretInteger" - ) - - -def find_function_in_ast(fn_name: str): - for op in AST_OPERATIONS.values(): - if isinstance(op, NadaFunctionASTOperation) and op.name == fn_name: - return op - return None - - -@pytest.mark.parametrize( - ("input_type", "input_name"), - [ - (Array, "Array"), - ], -) -def test_nada_function_using_matrix(input_type, input_name): - c = create_input(SecretInteger, "c", "party", **{}) - - @nada_fn - def add(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b - - @nada_fn - def matrix_addition( - a: input_type[SecretInteger], b: input_type[SecretInteger] - ) -> SecretInteger: - return a.zip(b).map(add).reduce(add, c) - - add_fn = nada_function_to_mir("add") - matrix_addition_fn = nada_function_to_mir("matrix_addition") - assert matrix_addition_fn["function"] == "matrix_addition" - args = matrix_addition_fn["args"] - assert len(args) == 2 - a_arg_type = {input_name: {"inner_type": "SecretInteger"}} - check_arg(args[0], "a", a_arg_type) - b_arg_type = {input_name: {"inner_type": "SecretInteger"}} - check_arg(args[1], "b", b_arg_type) - assert matrix_addition_fn["return_type"] == "SecretInteger" - - operations = matrix_addition_fn["operations"] - return_op = operations[matrix_addition_fn["return_operation_id"]] - assert list(return_op.keys()) == ["Reduce"] - reduce_op = return_op["Reduce"] - reduce_op["function_id"] = add_fn["id"] - reduce_op["type"] = "SecretInteger" - - reduce_op_inner = operations[reduce_op["inner"]] - assert list(reduce_op_inner.keys()) == ["Map"] - map_op = reduce_op_inner["Map"] - map_op["function_id"] = add_fn["id"] - map_op["type"] = {input_name: {"inner_type": "SecretInteger", "size": None}} - - map_op_inner = operations[map_op["inner"]] - assert list(map_op_inner.keys()) == ["Zip"] - zip_op = map_op_inner["Zip"] - - zip_op_left = operations[zip_op["left"]] - zip_op_right = operations[zip_op["right"]] - check_nada_function_arg_ref(zip_op_left, matrix_addition_fn["id"], "a", a_arg_type) - check_nada_function_arg_ref(zip_op_right, matrix_addition_fn["id"], "b", b_arg_type) - - def test_array_new(): first_input = create_input(SecretInteger, "first", "party", **{}) second_input = create_input(SecretInteger, "second", "party", **{}) @@ -511,7 +328,7 @@ def test_tuple_new_empty(): Tuple.new() assert ( str(e.value) - == "Tuple.new() missing 2 required positional arguments: 'left_type' and 'right_type'" + == "Tuple.new() missing 2 required positional arguments: 'left_value' and 'right_value'" ) diff --git a/tests/nada_type_test.py b/tests/nada_type_test.py deleted file mode 100644 index ad20ea5..0000000 --- a/tests/nada_type_test.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Tests for NadaType.""" - -import pytest -from nada_dsl.nada_types import NadaType -from nada_dsl.nada_types.scalar_types import Integer, PublicBoolean, SecretInteger - - -@pytest.mark.parametrize( - ("cls", "expected"), - [ - (SecretInteger, "SecretInteger"), - (Integer, "Integer"), - (PublicBoolean, "Boolean"), - ], -) -def test_class_to_mir(cls: NadaType, expected: str): - """Tests `NadaType.class_to_mir()""" - assert cls.class_to_mir() == expected diff --git a/tests/scalar_type_test.py b/tests/scalar_type_test.py index 94b9aa3..0a27c47 100644 --- a/tests/scalar_type_test.py +++ b/tests/scalar_type_test.py @@ -7,6 +7,7 @@ from nada_dsl import Input, Party from nada_dsl.nada_types import BaseType, Mode from nada_dsl.nada_types.scalar_types import ( + BooleanDslType, Integer, PublicInteger, SecretInteger, @@ -16,7 +17,7 @@ UnsignedInteger, PublicUnsignedInteger, SecretUnsignedInteger, - ScalarType, + ScalarDslType, BooleanType, ) @@ -119,7 +120,9 @@ def combine_lists(list1, list2): @pytest.mark.parametrize("left, right, operation", binary_arithmetic_operations) -def test_binary_arithmetic_operations(left: ScalarType, right: ScalarType, operation): +def test_binary_arithmetic_operations( + left: ScalarDslType, right: ScalarDslType, operation +): result = operation(left, right) assert result.base_type, left.base_type assert result.base_type, right.base_type @@ -136,7 +139,7 @@ def test_binary_arithmetic_operations(left: ScalarType, right: ScalarType, opera @pytest.mark.parametrize("left, right", allowed_pow_operands) -def test_pow(left: ScalarType, right: ScalarType): +def test_pow(left: ScalarDslType, right: ScalarDslType): result = left**right assert result.base_type, left.base_type assert result.base_type, right.base_type @@ -161,7 +164,7 @@ def test_pow(left: ScalarType, right: ScalarType): @pytest.mark.parametrize("left, right, operation", allowed_shift_operands) -def test_shift(left: ScalarType, right: ScalarType, operation): +def test_shift(left: ScalarDslType, right: ScalarDslType, operation): result = operation(left, right) assert result.base_type, left.base_type assert result.mode, left.mode @@ -188,7 +191,9 @@ def test_shift(left: ScalarType, right: ScalarType, operation): @pytest.mark.parametrize("left, right, operation", binary_relational_operations) -def test_binary_relational_operations(left: ScalarType, right: ScalarType, operation): +def test_binary_relational_operations( + left: ScalarDslType, right: ScalarDslType, operation +): result = operation(left, right) assert result.base_type, BaseType.BOOLEAN assert result.mode.value, max([left.mode.value, right.mode.value]) @@ -206,7 +211,7 @@ def test_binary_relational_operations(left: ScalarType, right: ScalarType, opera @pytest.mark.parametrize("left, right, operation", equals_operations) -def test_equals_operations(left: ScalarType, right: ScalarType, operation): +def test_equals_operations(left: ScalarDslType, right: ScalarDslType, operation): result = operation(left, right) assert result.base_type, BaseType.BOOLEAN assert result.mode.value, max([left.mode.value, right.mode.value]) @@ -253,7 +258,7 @@ def test_public_equals( @pytest.mark.parametrize("left, right, operation", binary_logic_operations) -def test_logic_operations(left: BooleanType, right: BooleanType, operation): +def test_logic_operations(left: BooleanDslType, right: BooleanDslType, operation): result = operation(left, right) assert result.base_type, BaseType.BOOLEAN assert result.mode.value, max([left.mode.value, right.mode.value]) @@ -317,7 +322,7 @@ def test_to_public(operand): @pytest.mark.parametrize("condition, left, right", if_else_operands) -def test_if_else(condition: BooleanType, left: ScalarType, right: ScalarType): +def test_if_else(condition: BooleanType, left: ScalarDslType, right: ScalarDslType): result = condition.if_else(left, right) assert left.base_type == right.base_type assert result.base_type == left.base_type @@ -525,7 +530,7 @@ def test_not_allowed_random(operand): @pytest.mark.parametrize("condition, left, right", not_allowed_if_else_operands) -def test_if_else(condition: BooleanType, left: ScalarType, right: ScalarType): +def test_if_else(condition: BooleanType, left: ScalarDslType, right: ScalarDslType): with pytest.raises(Exception) as invalid_operation: condition.if_else(left, right) assert invalid_operation.type == TypeError