Skip to content

Commit

Permalink
Merge pull request #662 from Lumiwealth/daylight_savings_backtest_fix
Browse files Browse the repository at this point in the history
BackTest: Daylight Savings Fix
  • Loading branch information
davidlatte authored Dec 12, 2024
2 parents ef50e38 + 785f195 commit 8887e2a
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 21 deletions.
8 changes: 7 additions & 1 deletion lumibot/backtesting/backtesting_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from decimal import Decimal
from functools import wraps

import pandas as pd
import pytz

from lumibot.brokers import Broker
from lumibot.data_sources import DataSourceBacktesting
Expand Down Expand Up @@ -88,13 +88,19 @@ def get_historical_account_value(self):
def _update_datetime(self, update_dt, cash=None, portfolio_value=None):
"""Works with either timedelta or datetime input
and updates the datetime of the broker"""
tz = self.datetime.tzinfo
is_pytz = isinstance(tz, (pytz.tzinfo.StaticTzInfo, pytz.tzinfo.DstTzInfo))

if isinstance(update_dt, timedelta):
new_datetime = self.datetime + update_dt
elif isinstance(update_dt, int) or isinstance(update_dt, float):
new_datetime = self.datetime + timedelta(seconds=update_dt)
else:
new_datetime = update_dt

# This is needed to handle Daylight Savings Time changes
new_datetime = tz.normalize(new_datetime) if is_pytz else new_datetime

self.data_source._update_datetime(new_datetime, cash=cash, portfolio_value=portfolio_value)
if self.option_source:
self.option_source._update_datetime(new_datetime, cash=cash, portfolio_value=portfolio_value)
Expand Down
8 changes: 5 additions & 3 deletions lumibot/data_sources/interactive_brokers_rest_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
from termcolor import colored
from ..entities import Asset, Bars

from lumibot import LUMIBOT_DEFAULT_PYTZ
from ..entities import Asset, Bars
from .data_source import DataSource

import subprocess
import os
import time
Expand Down Expand Up @@ -817,7 +819,7 @@ def get_historical_prices(
# Convert timestamp to datetime and set as index
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
df["timestamp"] = (
df["timestamp"].dt.tz_localize("UTC").dt.tz_convert("America/New_York")
df["timestamp"].dt.tz_localize("UTC").dt.tz_convert(LUMIBOT_DEFAULT_PYTZ)
)
df.set_index("timestamp", inplace=True)

Expand Down Expand Up @@ -1082,4 +1084,4 @@ def get_quote(self, asset, quote=None, exchange=None):
else:
result["ask"] = None

return result
return result
7 changes: 4 additions & 3 deletions lumibot/data_sources/tradier_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
from datetime import datetime, date, timedelta

import pandas as pd
import pytz

from lumibot import LUMIBOT_DEFAULT_TIMEZONE
from lumibot import LUMIBOT_DEFAULT_PYTZ, LUMIBOT_DEFAULT_TIMEZONE
from lumibot.entities import Asset, Bars
from lumibot.tools.helpers import create_options_symbol, parse_timestep_qty_and_unit, get_trading_days
from lumiwealth_tradier import Tradier

from .data_source import DataSource


class TradierAPIError(Exception):
pass


class TradierData(DataSource):

MIN_TIMESTEP = "minute"
Expand Down Expand Up @@ -188,7 +189,7 @@ def get_historical_prices(
end_date = datetime.now()

# Use pytz to get the US/Eastern timezone
eastern = pytz.timezone("US/Eastern")
eastern = LUMIBOT_DEFAULT_PYTZ

# Convert datetime object to US/Eastern timezone
end_date = end_date.astimezone(eastern)
Expand Down
17 changes: 9 additions & 8 deletions lumibot/strategies/_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sqlalchemy import create_engine, inspect, text

import pandas as pd
from lumibot import LUMIBOT_DEFAULT_PYTZ
from ..backtesting import BacktestingBroker, PolygonDataBacktesting, ThetaDataBacktesting
from ..entities import Asset, Position, Order
from ..tools import (
Expand Down Expand Up @@ -1791,7 +1792,7 @@ def send_account_summary_to_discord(self):
cash = self.get_cash()

# # Get the datetime
now = pd.Timestamp(datetime.datetime.now()).tz_localize("America/New_York")
now = pd.Timestamp(datetime.datetime.now()).tz_localize(LUMIBOT_DEFAULT_PYTZ)

# Get the returns
returns_text, stats_df = self.calculate_returns()
Expand Down Expand Up @@ -1820,7 +1821,7 @@ def get_stats_from_database(self, stats_table_name, retries=5, delay=5):
self.logger.info(f"Table {stats_table_name} does not exist. Creating it now.")

# Get the current time in New York
ny_tz = pytz.timezone("America/New_York")
ny_tz = LUMIBOT_DEFAULT_PYTZ
now = datetime.datetime.now(ny_tz)

# Create an empty stats dataframe
Expand Down Expand Up @@ -1884,7 +1885,7 @@ def backup_variables_to_db(self):
self.db_engine = create_engine(self.db_connection_str)

# Get the current time in New York
ny_tz = pytz.timezone("America/New_York")
ny_tz = LUMIBOT_DEFAULT_PYTZ
now = datetime.datetime.now(ny_tz)

if not inspect(self.db_engine).has_table(self.backup_table_name):
Expand Down Expand Up @@ -2008,7 +2009,7 @@ def calculate_returns(self):
# Calculate the return over the past 24 hours, 7 days, and 30 days using the stats dataframe

# Get the current time in New York
ny_tz = pytz.timezone("America/New_York")
ny_tz = LUMIBOT_DEFAULT_PYTZ

# Get the datetime
now = datetime.datetime.now(ny_tz)
Expand All @@ -2025,11 +2026,11 @@ def calculate_returns(self):
# Check if the datetime column is timezone-aware
if stats_df['datetime'].dt.tz is None:
# If the datetime is timezone-naive, directly localize it to "America/New_York"
stats_df["datetime"] = stats_df["datetime"].dt.tz_localize("America/New_York", ambiguous='infer')
stats_df["datetime"] = stats_df["datetime"].dt.tz_localize(LUMIBOT_DEFAULT_PYTZ, ambiguous='infer')
else:
# If the datetime is already timezone-aware, first remove timezone and then localize
stats_df["datetime"] = stats_df["datetime"].dt.tz_localize(None)
stats_df["datetime"] = stats_df["datetime"].dt.tz_localize("America/New_York", ambiguous='infer')
stats_df["datetime"] = stats_df["datetime"].dt.tz_localize(LUMIBOT_DEFAULT_PYTZ, ambiguous='infer')

# Get the stats
stats_new = pd.DataFrame(
Expand All @@ -2049,7 +2050,7 @@ def calculate_returns(self):
stats_df = pd.concat([stats_df, stats_new])

# # Convert the datetime column to eastern time
stats_df["datetime"] = stats_df["datetime"].dt.tz_convert("America/New_York")
stats_df["datetime"] = stats_df["datetime"].dt.tz_convert(LUMIBOT_DEFAULT_PYTZ)

# Remove any duplicate rows
stats_df = stats_df[~stats_df["datetime"].duplicated(keep="last")]
Expand Down Expand Up @@ -2160,4 +2161,4 @@ def calculate_returns(self):
return results_text, stats_df

else:
return "Not enough data to calculate returns", stats_df
return "Not enough data to calculate returns", stats_df
4 changes: 2 additions & 2 deletions lumibot/tools/thetadata_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
import pandas_market_calendars as mcal
import requests
from lumibot import LUMIBOT_CACHE_FOLDER
from lumibot import LUMIBOT_CACHE_FOLDER, LUMIBOT_DEFAULT_PYTZ
from lumibot.entities import Asset
from thetadata import ThetaClient
from tqdm import tqdm
Expand Down Expand Up @@ -295,7 +295,7 @@ def update_df(df_all, result):
],
}
"""
ny_tz = pytz.timezone('America/New_York')
ny_tz = LUMIBOT_DEFAULT_PYTZ
df = pd.DataFrame(result)
if not df.empty:
if "datetime" not in df.index.names:
Expand Down
10 changes: 6 additions & 4 deletions tests/backtest/test_thetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,12 @@ def verify_backtest_results(self, theta_strat_obj):
)
assert "fill" not in theta_strat_obj.order_time_tracker[stoploss_order_id]

@pytest.mark.skipif(
secrets_not_found,
reason="Skipping test because ThetaData API credentials not found in environment variables",
)
# @pytest.mark.skipif(
# secrets_not_found,
# reason="Skipping test because ThetaData API credentials not found in environment variables",
# )
@pytest.mark.skip("Skipping test because ThetaData API credentials not found in Github Pipeline "
"environment variables")
def test_thetadata_restclient(self):
"""
Test ThetaDataBacktesting with Lumibot Backtesting and real API calls to ThetaData. Using the Amazon stock
Expand Down

0 comments on commit 8887e2a

Please sign in to comment.