Skip to content

Commit

Permalink
Add missing exports for wrapper modules (apache#782)
Browse files Browse the repository at this point in the history
* Add imports in base file to match those in internal

* Correct class capitalization and exports for substrait

* Add exports for common to match internal

* Add exports for Expr to match internal

* Add __all__ to functions

* Add exports for object store to match internal

* Add pytest to ensure all pyo3 exposed objects are also exposed in our wrappers so we don't miss any functions or classes

* Add license
  • Loading branch information
timsaucer authored Aug 1, 2024
1 parent 951d6b9 commit 9a6805e
Show file tree
Hide file tree
Showing 8 changed files with 423 additions and 14 deletions.
22 changes: 21 additions & 1 deletion python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@
SQLOptions,
)

from .catalog import Catalog, Database, Table

# The following imports are okay to remain as opaque to the user.
from ._internal import Config
from ._internal import Config, LogicalPlan, ExecutionPlan, runtime

from .record_batch import RecordBatchStream, RecordBatch

from .udf import ScalarUDF, AggregateUDF, Accumulator

Expand All @@ -49,6 +53,8 @@
WindowFrame,
)

from . import functions, object_store, substrait

__version__ = importlib_metadata.version(__name__)

__all__ = [
Expand All @@ -65,6 +71,20 @@
"column",
"literal",
"DFSchema",
"runtime",
"Catalog",
"Database",
"Table",
"AggregateUDF",
"LogicalPlan",
"ExecutionPlan",
"RecordBatch",
"RecordBatchStream",
"common",
"expr",
"functions",
"object_store",
"substrait",
]


Expand Down
32 changes: 29 additions & 3 deletions python/datafusion/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,34 @@
# under the License.
"""Common data types used throughout the DataFusion project."""

from ._internal import common
from ._internal import common as common_internal

# TODO these should all have proper wrapper classes

def __getattr__(name):
return getattr(common, name)
DFSchema = common_internal.DFSchema
DataType = common_internal.DataType
DataTypeMap = common_internal.DataTypeMap
NullTreatment = common_internal.NullTreatment
PythonType = common_internal.PythonType
RexType = common_internal.RexType
SqlFunction = common_internal.SqlFunction
SqlSchema = common_internal.SqlSchema
SqlStatistics = common_internal.SqlStatistics
SqlTable = common_internal.SqlTable
SqlType = common_internal.SqlType
SqlView = common_internal.SqlView

__all__ = [
"DFSchema",
"DataType",
"DataTypeMap",
"RexType",
"PythonType",
"SqlType",
"NullTreatment",
"SqlTable",
"SqlSchema",
"SqlView",
"SqlStatistics",
"SqlFunction",
]
74 changes: 74 additions & 0 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
CrossJoin = expr_internal.CrossJoin
Distinct = expr_internal.Distinct
DropTable = expr_internal.DropTable
EmptyRelation = expr_internal.EmptyRelation
Exists = expr_internal.Exists
Explain = expr_internal.Explain
Extension = expr_internal.Extension
Expand All @@ -58,6 +59,7 @@
InSubquery = expr_internal.InSubquery
IsFalse = expr_internal.IsFalse
IsNotTrue = expr_internal.IsNotTrue
IsNull = expr_internal.IsNull
IsTrue = expr_internal.IsTrue
IsUnknown = expr_internal.IsUnknown
IsNotFalse = expr_internal.IsNotFalse
Expand All @@ -83,6 +85,70 @@
TableScan = expr_internal.TableScan
TryCast = expr_internal.TryCast
Union = expr_internal.Union
Unnest = expr_internal.Unnest
Window = expr_internal.Window

__all__ = [
"Expr",
"Column",
"Literal",
"BinaryExpr",
"Literal",
"AggregateFunction",
"Not",
"IsNotNull",
"IsNull",
"IsTrue",
"IsFalse",
"IsUnknown",
"IsNotTrue",
"IsNotFalse",
"IsNotUnknown",
"Negative",
"Like",
"ILike",
"SimilarTo",
"ScalarVariable",
"Alias",
"InList",
"Exists",
"Subquery",
"InSubquery",
"ScalarSubquery",
"Placeholder",
"GroupingSet",
"Case",
"CaseBuilder",
"Cast",
"TryCast",
"Between",
"Explain",
"Limit",
"Aggregate",
"Sort",
"Analyze",
"EmptyRelation",
"Join",
"JoinType",
"JoinConstraint",
"CrossJoin",
"Union",
"Unnest",
"Extension",
"Filter",
"Projection",
"TableScan",
"CreateMemoryTable",
"CreateView",
"Distinct",
"SubqueryAlias",
"DropTable",
"Partitioning",
"Repartition",
"Window",
"WindowFrame",
"WindowFrameBound",
]


class Expr:
Expand Down Expand Up @@ -246,6 +312,14 @@ def __lt__(self, rhs: Any) -> Expr:
rhs = Expr.literal(rhs)
return Expr(self.expr.__lt__(rhs.expr))

__radd__ = __add__
__rand__ = __and__
__rmod__ = __mod__
__rmul__ = __mul__
__ror__ = __or__
__rsub__ = __sub__
__rtruediv__ = __truediv__

@staticmethod
def literal(value: Any) -> Expr:
"""Creates a new expression representing a scalar value.
Expand Down
Loading

0 comments on commit 9a6805e

Please sign in to comment.