Skip to content

Commit

Permalink
pr update
Browse files Browse the repository at this point in the history
  • Loading branch information
Lin-Dongzhao committed Mar 19, 2024
1 parent bcf3443 commit 3f1e7a8
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 18 deletions.
18 changes: 6 additions & 12 deletions rqalpha/data/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,19 +716,13 @@ def _get_array(self, instrument, start_date):
for field in self._fields:
dtype.append((field, record.dtype[field]))

dt_arr = np.array(df.index.tolist()).reshape((-1, 1))
dt_arr = np.apply_along_axis(self._get_trading_date, 1, dt_arr)
dt_arr = (dt_arr.astype('datetime64[Y]').astype(int) + 1970) * 10000 + (dt_arr.astype('datetime64[M]').astype(int) % 12 + 1) * 100 + (dt_arr.astype('datetime64[D]') - dt_arr.astype('datetime64[M]') + 1)
dt_arr.astype(int)
arr = np.ones((dt_arr.shape[0], ), dtype=dtype)
arr['trading_dt'] = dt_arr
dt = np.array(df.index.tolist())
trading_dt = self._env.data_proxy._data_source.get_trading_date_for_np(dt)
# trading_dt = trading_dt.year * 10000 + trading_dt.month * 100 + trading_dt.day
trading_dt = convert_date_to_date_int(trading_dt)
arr = np.ones((trading_dt.shape[0], ), dtype=dtype)
arr['trading_dt'] = trading_dt
for field in self._fields:
arr[field] = df[field].values
return arr
return None

def _get_trading_date(self, dt):
# type: (numpy.ndarray) -> Timestamp
dt = dt[0]
dt = self._env.data_proxy._data_source.get_future_trading_date(dt)
return dt
9 changes: 9 additions & 0 deletions rqalpha/data/trading_dates_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Dict, Optional, Union

import pandas as pd
import numpy as np

from rqalpha.utils.functools import lru_cache
from rqalpha.const import TRADING_CALENDAR_TYPE
Expand Down Expand Up @@ -114,3 +115,11 @@ def _get_future_trading_date(self, dt):
return trading_dates[pos + 1]

return td

def get_trading_date_for_np(self, dt_arr):
# 获取 numpy.array 中所有时间所在的交易日
# 认为晚八点后为第二个交易日,认为晚八点至次日凌晨四点为夜盘
dt = dt_arr - datetime.timedelta(hours=4)
trading_dates = self.get_trading_calendar(TRADING_CALENDAR_TYPE.EXCHANGE)
pos = trading_dates.searchsorted(dt.astype("datetime64[D]")) + np.where((dt.astype('datetime64[h]') - dt.astype('datetime64[D]')).astype(int) >= 16, 1, 0)
return trading_dates[pos]
4 changes: 2 additions & 2 deletions rqalpha/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
class Environment(object):
_env = None # type: Environment

def __init__(self, config):
def __init__(self, config, rqdatac_init):
Environment._env = self
self.config = config
self.data_proxy = None # type: Optional[rqalpha.data.data_proxy.DataProxy]
Expand All @@ -55,7 +55,7 @@ def __init__(self, config):
self._frontend_validators = {} # type: Dict[str, List]
self._default_frontend_validators = []
self._transaction_cost_decider_dict = {}
self.rqdatac_init = False
self.rqdatac_init = rqdatac_init # type: Boolean

# Environment.event_bus used in StrategyUniverse()
from rqalpha.core.strategy_universe import StrategyUniverse
Expand Down
8 changes: 4 additions & 4 deletions rqalpha/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_strategy_apis():
return {n: getattr(api, n) for n in api.__all__}


def init_rqdatac(rqdatac_uri, env):
def init_rqdatac(rqdatac_uri):
if rqdatac_uri in ["disabled", "DISABLED"]:
return

Expand All @@ -122,13 +122,14 @@ def init_rqdatac(rqdatac_uri, env):
init_rqdatac_env(rqdatac_uri)
try:
rqdatac.init()
env.set_rqdatac_init(result=True)
return True
except Exception as e:
system_log.warn(_('rqdatac init failed, some apis will not function properly: {}').format(str(e)))
return


def run(config, source_code=None, user_funcs=None):
env = Environment(config)
env = Environment(config, init_rqdatac(getattr(config.base, 'rqdatac_uri', None)))
persist_helper = None
init_succeed = False
mod_handler = ModHandler()
Expand All @@ -137,7 +138,6 @@ def run(config, source_code=None, user_funcs=None):
# avoid register handlers everytime
# when running in ipython
set_loggers(config)
init_rqdatac(getattr(config.base, 'rqdatac_uri', None), env)
system_log.debug("\n" + pformat(config.convert_to_dict()))

env.set_strategy_loader(init_strategy_loader(env, source_code, user_funcs, config))
Expand Down

0 comments on commit 3f1e7a8

Please sign in to comment.