From f85c0539067d59757f30e4dd7bf62798f7b213d8 Mon Sep 17 00:00:00 2001 From: Tom Owers Date: Thu, 2 Jan 2025 18:58:19 +0100 Subject: [PATCH] fix(hogql): Allow simple usage of `ast.Call` in table joins (#27204) --- posthog/hogql/database/database.py | 18 ++++++++++++------ posthog/warehouse/api/test/test_view_link.py | 2 +- posthog/warehouse/api/view_link.py | 6 +++--- posthog/warehouse/models/join.py | 18 ++++++++++++------ 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py index dd8ffc8a377a9..b88fc89788e30 100644 --- a/posthog/hogql/database/database.py +++ b/posthog/hogql/database/database.py @@ -437,14 +437,20 @@ def define_mappings(warehouse: dict[str, Table], get_table: Callable): joining_table = database.get_table(join.joining_table_name) field = parse_expr(join.source_table_key) - if not isinstance(field, ast.Field): - raise ResolutionError("Data Warehouse Join HogQL expression should be a Field node") - from_field = field.chain + if isinstance(field, ast.Field): + from_field = field.chain + elif isinstance(field, ast.Call) and isinstance(field.args[0], ast.Field): + from_field = field.args[0].chain + else: + raise ResolutionError("Data Warehouse Join HogQL expression should be a Field or Call node") field = parse_expr(join.joining_table_key) - if not isinstance(field, ast.Field): - raise ResolutionError("Data Warehouse Join HogQL expression should be a Field node") - to_field = field.chain + if isinstance(field, ast.Field): + to_field = field.chain + elif isinstance(field, ast.Call) and isinstance(field.args[0], ast.Field): + to_field = field.args[0].chain + else: + raise ResolutionError("Data Warehouse Join HogQL expression should be a Field or Call node") source_table.fields[join.field_name] = LazyJoin( from_field=from_field, diff --git a/posthog/warehouse/api/test/test_view_link.py b/posthog/warehouse/api/test/test_view_link.py index 4bf4f697ef4a8..35f7beb5ac1d4 100644 --- a/posthog/warehouse/api/test/test_view_link.py +++ b/posthog/warehouse/api/test/test_view_link.py @@ -100,7 +100,7 @@ def test_create_saved_query_join_key_function(self): "field_name": "some_field", }, ) - self.assertEqual(response.status_code, 400, response.content) + self.assertEqual(response.status_code, 201, response.content) def test_update_with_configuration(self): join = DataWarehouseJoin.objects.create( diff --git a/posthog/warehouse/api/view_link.py b/posthog/warehouse/api/view_link.py index a249dbf9d3859..657fc8b378d4e 100644 --- a/posthog/warehouse/api/view_link.py +++ b/posthog/warehouse/api/view_link.py @@ -4,7 +4,7 @@ from posthog.api.routing import TeamAndOrgViewSetMixin from posthog.api.shared import UserBasicSerializer -from posthog.hogql.ast import Field +from posthog.hogql.ast import Field, Call from posthog.hogql.database.database import create_hogql_database from posthog.hogql.parser import parse_expr from posthog.warehouse.models import DataWarehouseJoin @@ -71,8 +71,8 @@ def _validate_join_key(self, join_key: Optional[str], table: Optional[str], team raise serializers.ValidationError(f"Invalid table: {table}") node = parse_expr(join_key) - if not isinstance(node, Field): - raise serializers.ValidationError(f"Join key {join_key} must be a table field - no function calls allowed") + if not isinstance(node, Field) and not (isinstance(node, Call) and isinstance(node.args[0], Field)): + raise serializers.ValidationError(f"Join key {join_key} must be a table field") return diff --git a/posthog/warehouse/models/join.py b/posthog/warehouse/models/join.py index 51108c58e578c..0bfef1317c830 100644 --- a/posthog/warehouse/models/join.py +++ b/posthog/warehouse/models/join.py @@ -65,14 +65,20 @@ def _join_function( raise ResolutionError(f"No fields requested from {join_to_add.to_table}") left = parse_expr(_source_table_key) - if not isinstance(left, ast.Field): - raise ResolutionError("Data Warehouse Join HogQL expression should be a Field node") - left.chain = [join_to_add.from_table, *left.chain] + if isinstance(left, ast.Field): + left.chain = [join_to_add.from_table, *left.chain] + elif isinstance(left, ast.Call) and isinstance(left.args[0], ast.Field): + left.args[0].chain = [join_to_add.from_table, *left.args[0].chain] + else: + raise ResolutionError("Data Warehouse Join HogQL expression should be a Field or Call node") right = parse_expr(_joining_table_key) - if not isinstance(right, ast.Field): - raise ResolutionError("Data Warehouse Join HogQL expression should be a Field node") - right.chain = [join_to_add.to_table, *right.chain] + if isinstance(right, ast.Field): + right.chain = [join_to_add.to_table, *right.chain] + elif isinstance(right, ast.Call) and isinstance(right.args[0], ast.Field): + right.args[0].chain = [join_to_add.to_table, *right.args[0].chain] + else: + raise ResolutionError("Data Warehouse Join HogQL expression should be a Field or Call node") join_expr = ast.JoinExpr( table=ast.SelectQuery(