diff --git a/python/datafusion/tests/test_expr.py b/python/datafusion/tests/test_expr.py index c9f0e98d5..1a41120a5 100644 --- a/python/datafusion/tests/test_expr.py +++ b/python/datafusion/tests/test_expr.py @@ -139,3 +139,31 @@ def test_relational_expr(test_ctx): assert df.filter(col("b") != "beta").count() == 2 assert df.filter(col("a") == "beta").count() == 0 + + +def test_expr_to_variant(): + # Taken from https://github.com/apache/datafusion-python/issues/781 + from datafusion import SessionContext + from datafusion.expr import Filter + + + def traverse_logical_plan(plan): + cur_node = plan.to_variant() + if isinstance(cur_node, Filter): + return cur_node.predicate().to_variant() + if hasattr(plan, 'inputs'): + for input_plan in plan.inputs(): + res = traverse_logical_plan(input_plan) + if res is not None: + return res + + ctx = SessionContext() + data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie']} + ctx.from_pydict(data, name='table1') + query = "SELECT * FROM table1 t1 WHERE t1.name IN ('dfa', 'ad', 'dfre', 'vsa')" + logical_plan = ctx.sql(query).optimized_logical_plan() + variant = traverse_logical_plan(logical_plan) + assert variant is not None + assert variant.expr().to_variant().qualified_name() == 'table1.name' + assert str(variant.list()) == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]' + assert not variant.negated()