Skip to content

Commit

Permalink
fix(experiments): apply new count method and fix continuous (#27639)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
andehen and github-actions[bot] authored Jan 23, 2025
1 parent 377a1cc commit 8e3b930
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 153 deletions.
22 changes: 5 additions & 17 deletions frontend/src/scenes/experiments/experimentLogic.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -1427,8 +1427,8 @@ export const experimentLogic = kea<experimentLogicType>([
},
],
credibleIntervalForVariant: [
(s) => [s.experimentStatsVersion],
(experimentStatsVersion) =>
() => [],
() =>
(
metricResult: CachedExperimentTrendsQueryResponse | CachedExperimentFunnelsQueryResponse | null,
variantKey: string,
Expand Down Expand Up @@ -1460,26 +1460,14 @@ export const experimentLogic = kea<experimentLogicType>([
const controlVariant = (metricResult.variants as TrendExperimentVariant[]).find(
({ key }) => key === 'control'
) as TrendExperimentVariant
const variant = (metricResult.variants as TrendExperimentVariant[]).find(
({ key }) => key === variantKey
) as TrendExperimentVariant

const controlMean = controlVariant.count / controlVariant.absolute_exposure

const meanLowerBound =
experimentStatsVersion === 2
? credibleInterval[0] / variant.absolute_exposure
: credibleInterval[0]
const meanUpperBound =
experimentStatsVersion === 2
? credibleInterval[1] / variant.absolute_exposure
: credibleInterval[1]

// Calculate the percentage difference between the credible interval bounds of the variant and the control's mean.
// This represents the range in which the true percentage change relative to the control is likely to fall.
const lowerBound = ((meanLowerBound - controlMean) / controlMean) * 100
const upperBound = ((meanUpperBound - controlMean) / controlMean) * 100
return [lowerBound, upperBound]
const relativeLowerBound = ((credibleInterval[0] - controlMean) / controlMean) * 100
const relativeUpperBound = ((credibleInterval[1] - controlMean) / controlMean) * 100
return [relativeLowerBound, relativeUpperBound]
},
],
getIndexForVariant: [
Expand Down
53 changes: 34 additions & 19 deletions posthog/hogql_queries/experiments/experiment_trends_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from posthog.constants import ExperimentNoResultsErrorKeys
from posthog.hogql import ast
from posthog.hogql_queries.experiments import CONTROL_VARIANT_KEY
from posthog.hogql_queries.experiments.types import ExperimentMetricType
from posthog.hogql_queries.experiments.trends_statistics import (
are_results_significant,
calculate_credible_intervals,
Expand Down Expand Up @@ -69,6 +70,8 @@ def __init__(self, *args, **kwargs):

self.stats_version = self.experiment.get_stats_config("version") or 1

self._fix_math_aggregation()

self.prepared_count_query = self._prepare_count_query()
self.prepared_exposure_query = self._prepare_exposure_query()

Expand All @@ -86,6 +89,14 @@ def _uses_math_aggregation_by_user_or_property_value(self, query: TrendsQuery):
math_keys.remove("sum")
return any(entity.math in math_keys for entity in query.series)

def _fix_math_aggregation(self):
"""
Switch unsupported math aggregations to SUM
"""
uses_math_aggregation = self._uses_math_aggregation_by_user_or_property_value(self.query.count_query)
if uses_math_aggregation:
self.query.count_query.series[0].math = PropertyMathType.SUM

def _get_date_range(self) -> DateRange:
"""
Returns an DateRange object based on the experiment's start and end dates,
Expand Down Expand Up @@ -117,6 +128,14 @@ def _get_data_warehouse_breakdown_filter(self) -> BreakdownFilter:
breakdown_type="data_warehouse",
)

def _get_metric_type(self) -> ExperimentMetricType:
# Currently, we rely on the math type to determine the metric type
match self.query.count_query.series[0].math:
case PropertyMathType.SUM | "hogql":
return ExperimentMetricType.CONTINUOUS
case _:
return ExperimentMetricType.COUNT

def _prepare_count_query(self) -> TrendsQuery:
"""
This method takes the raw trend query and adapts it
Expand All @@ -129,13 +148,6 @@ def _prepare_count_query(self) -> TrendsQuery:
"""
prepared_count_query = TrendsQuery(**self.query.count_query.model_dump())

uses_math_aggregation = self._uses_math_aggregation_by_user_or_property_value(prepared_count_query)

# Only SUM is supported now, but some earlier experiments AVG. That does not
# make sense as input for experiment analysis, so we'll swithc that to SUM here
if uses_math_aggregation:
prepared_count_query.series[0].math = PropertyMathType.SUM

prepared_count_query.trendsFilter = TrendsFilter(display=ChartDisplayType.ACTIONS_LINE_GRAPH_CUMULATIVE)
prepared_count_query.dateRange = self._get_date_range()
if self._is_data_warehouse_query(prepared_count_query):
Expand Down Expand Up @@ -270,18 +282,21 @@ def run(query_runner: TrendsQueryRunner, result_key: str, is_parallel: bool):
# Statistical analysis
control_variant, test_variants = self._get_variants_with_base_stats(count_result, exposure_result)
if self.stats_version == 2:
if self.query.count_query.series[0].math:
probabilities = calculate_probabilities_v2_continuous(control_variant, test_variants)
significance_code, p_value = are_results_significant_v2_continuous(
control_variant, test_variants, probabilities
)
credible_intervals = calculate_credible_intervals_v2_continuous([control_variant, *test_variants])
else:
probabilities = calculate_probabilities_v2_count(control_variant, test_variants)
significance_code, p_value = are_results_significant_v2_count(
control_variant, test_variants, probabilities
)
credible_intervals = calculate_credible_intervals_v2_count([control_variant, *test_variants])
match self._get_metric_type():
case ExperimentMetricType.CONTINUOUS:
probabilities = calculate_probabilities_v2_continuous(control_variant, test_variants)
significance_code, p_value = are_results_significant_v2_continuous(
control_variant, test_variants, probabilities
)
credible_intervals = calculate_credible_intervals_v2_continuous([control_variant, *test_variants])
case ExperimentMetricType.COUNT:
probabilities = calculate_probabilities_v2_count(control_variant, test_variants)
significance_code, p_value = are_results_significant_v2_count(
control_variant, test_variants, probabilities
)
credible_intervals = calculate_credible_intervals_v2_count([control_variant, *test_variants])
case _:
raise ValueError(f"Unsupported metric type: {self._get_metric_type()}")
else:
probabilities = calculate_probabilities(control_variant, test_variants)
significance_code, p_value = are_results_significant(control_variant, test_variants, probabilities)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from django.test import override_settings
from ee.clickhouse.materialized_columns.columns import get_enabled_materialized_columns, materialize
from posthog.hogql_queries.experiments.experiment_trends_query_runner import ExperimentTrendsQueryRunner
from posthog.hogql_queries.experiments.types import ExperimentMetricType
from posthog.models.experiment import Experiment, ExperimentHoldout
from posthog.models.feature_flag.feature_flag import FeatureFlag
from posthog.schema import (
BaseMathType,
DataWarehouseNode,
EventsNode,
ExperimentSignificanceCode,
ExperimentTrendsQuery,
ExperimentTrendsQueryResponse,
PersonsOnEventsMode,
PropertyMathType,
TrendsQuery,
)
from posthog.settings import (
Expand All @@ -27,7 +30,7 @@
flush_persons_and_events,
)
from freezegun import freeze_time
from typing import cast
from typing import cast, Any
from django.utils import timezone
from datetime import datetime, timedelta
from posthog.test.test_journeys import journeys_for
Expand Down Expand Up @@ -2363,3 +2366,47 @@ def test_validate_event_variants_no_exposure(self):
}
)
self.assertEqual(cast(list, context.exception.detail)[0], expected_errors)

def test_get_metric_type(self):
feature_flag = self.create_feature_flag()
experiment = self.create_experiment(feature_flag=feature_flag)

# Test allowed count math types
allowed_count_math_types = [BaseMathType.TOTAL, BaseMathType.DAU, BaseMathType.UNIQUE_SESSION, None]
for math_type in allowed_count_math_types:
count_query = TrendsQuery(series=[EventsNode(event="$pageview", math=math_type)])
experiment_query = ExperimentTrendsQuery(
experiment_id=experiment.id,
kind="ExperimentTrendsQuery",
count_query=count_query,
)
query_runner = ExperimentTrendsQueryRunner(query=experiment_query, team=self.team)
self.assertEqual(query_runner._get_metric_type(), ExperimentMetricType.COUNT)

# Test allowed sum math types
allowed_sum_math_types: list[Any] = [PropertyMathType.SUM, "hogql"]
for math_type in allowed_sum_math_types:
count_query = TrendsQuery(
series=[EventsNode(event="checkout completed", math=math_type, math_property="revenue")]
)
experiment_query = ExperimentTrendsQuery(
experiment_id=experiment.id,
kind="ExperimentTrendsQuery",
count_query=count_query,
)
query_runner = ExperimentTrendsQueryRunner(query=experiment_query, team=self.team)
self.assertEqual(query_runner._get_metric_type(), ExperimentMetricType.CONTINUOUS)

# Test that AVG math gets converted to SUM and returns CONTINUOUS
count_query = TrendsQuery(
series=[EventsNode(event="checkout completed", math=PropertyMathType.AVG, math_property="revenue")]
)
experiment_query = ExperimentTrendsQuery(
experiment_id=experiment.id,
kind="ExperimentTrendsQuery",
count_query=count_query,
)
query_runner = ExperimentTrendsQueryRunner(query=experiment_query, team=self.team)
self.assertEqual(query_runner._get_metric_type(), ExperimentMetricType.CONTINUOUS)
# Verify the math type was converted to sum
self.assertEqual(query_runner.query.count_query.series[0].math, PropertyMathType.SUM)
Loading

0 comments on commit 8e3b930

Please sign in to comment.