Skip to content

Commit

Permalink
use a trading calendar to get the start date that is length bars earl…
Browse files Browse the repository at this point in the history
…ier then end_date
  • Loading branch information
brettelliot committed Nov 8, 2024
1 parent ce08f8c commit 0cc30db
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
10 changes: 9 additions & 1 deletion lumibot/data_sources/tradier_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytz

from lumibot.entities import Asset, Bars
from lumibot.tools.helpers import create_options_symbol, parse_timestep_qty_and_unit
from lumibot.tools.helpers import create_options_symbol, parse_timestep_qty_and_unit, get_trading_days
from lumiwealth_tradier import Tradier

from .data_source import DataSource
Expand Down Expand Up @@ -202,6 +202,14 @@ def get_historical_prices(
td, _ = self.convert_timestep_str_to_timedelta(timestep)
start_date = end_date - (td * length)

if timestep == 'day' and timeshift is None:
# What we really want is the last n bars, not the bars from the last n days.
# get twice as many days as we need to ensure we get enough bars
tcal_start_date = end_date - (td * length * 2)
trading_days = get_trading_days(market='NYSE', start_date=tcal_start_date, end_date=end_date)
# Now, start_date is the length bars before the last trading day
start_date = trading_days.index[-length]

# Check what timestep we are using, different endpoints are required for different timesteps
try:
if parsed_timestep_unit == "minute":
Expand Down
8 changes: 4 additions & 4 deletions tests/test_bars.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def setup_class(cls):
df['expected_return'] = df['Adj Close'].pct_change()
cls.expected_df = df

@pytest.mark.skip()
# @pytest.mark.skip()
@pytest.mark.skipif(not ALPACA_CONFIG['API_KEY'], reason="This test requires an alpaca API key")
@pytest.mark.skipif(ALPACA_CONFIG['API_KEY'] == '<your key here>', reason="This test requires an alpaca API key")
def test_alpaca_data_source_daily_bars(self):
Expand All @@ -78,7 +78,7 @@ def test_alpaca_data_source_daily_bars(self):
# check that there is no dividend column... This test will fail when dividends are added. We hope that's soon.
assert "dividend" not in prices.df.columns

@pytest.mark.skip()
# @pytest.mark.skip()
def test_yahoo_data_source_daily_bars(self):
"""
This tests that the yahoo data_source calculates adjusted returns for bars and that they
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_yahoo_data_source_daily_bars(self):
rtol=0
)

@pytest.mark.skip()
# @pytest.mark.skip()
def test_pandas_data_source_daily_bars(self, pandas_data_fixture):
"""
This tests that the pandas data_source calculates adjusted returns for bars and that they
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_pandas_data_source_daily_bars(self, pandas_data_fixture):
rtol=0
)

@pytest.mark.skip()
# @pytest.mark.skip()
@pytest.mark.skipif(POLYGON_API_KEY == '<your key here>', reason="This test requires a Polygon.io API key")
def test_polygon_data_source_daily_bars(self):
"""
Expand Down

0 comments on commit 0cc30db

Please sign in to comment.