Skip to content

Commit

Permalink
technical_analysis_patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
hugo2046 committed Feb 24, 2022
1 parent 28b25e6 commit 20f5fdb
Showing 1 changed file with 45 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
'''
Author: Hugo
Date: 2022-02-18 21:17:27
LastEditTime: 2022-02-21 20:48:09
LastEditTime: 2022-02-23 21:43:17
LastEditors: Please set LastEditors
'''

# 引入库
from collections import (defaultdict, namedtuple)
from typing import (List, Tuple, Dict, Callable, Union)
import itertools
import functools
from tqdm.notebook import tqdm
import warnings

Expand Down Expand Up @@ -93,7 +94,7 @@ def rolling_windows(a: Union[np.ndarray, pd.Series, pd.DataFrame], window: int)


def calc_smooth(prices: pd.Series, *, bw: Union[np.ndarray, str] = 'cv_ls', a: float = None, use_array: bool = True) -> Union[pd.Series, np.ndarray]:
"""计算NadarayaWatson核估计后的价格数据
"""计算Nadaraya-Watson核估计后的价格数据
Args:
prices (pd.Series): 价格数据
Expand Down Expand Up @@ -134,15 +135,18 @@ def calc_smooth(prices: pd.Series, *, bw: Union[np.ndarray, str] = 'cv_ls', a: f
# TODO:考虑改为全部使用numpy以提升速度
# TODO:是否考虑娱乐sklearn接口?
# 预留平滑函数 后续考虑使用机器学习算法
def find_price_argrelextrema(prices: pd.Series, *, bw: Union[str, np.ndarray] = 'cv_ls', a: float = 0.3, offset: int = 1, smooth_fumc: Callable = calc_smooth) -> pd.Series:


def find_price_argrelextrema(prices: pd.Series, *, offset: int = 1, smooth_fumc: Callable = calc_smooth, **kw) -> pd.Series:
"""平滑数据并识别极大极小值
Args:
smooth_prices (pd.Series): 价格序列
bw (Union[str,np.ndarray]): Either a user-specified bandwidth or the method for bandwidth selection. Defaults to cv_ls.
a (float):论文中所说的比例数据. Defaults to 0.3.
offset (int, optional): 避免陷入局部最大最小值. Defaults to 1.
smooth_fumc (Callable,optional): 平滑处理方法函数. Defaults to calc_smooth
smooth_fumc (Callable,optional): 平滑处理方法函数,返回值需要为ndarray. Defaults to calc_smooth
kw : 该参数传递给smooth_func
Returns:
pd.Series: 最大最小值的目标索引下标 index-dt value-price
"""
Expand All @@ -153,7 +157,8 @@ def find_price_argrelextrema(prices: pd.Series, *, bw: Union[str, np.ndarray] =
raise ValueError('price数据长度过小')

# 计算平滑价格
smooth_arr: np.ndarray = smooth_fumc(prices, a=a, bw=bw, use_array=True)
#
smooth_arr: np.ndarray = smooth_fumc(prices, **kw)

# 请多平滑后的高低点
local_max = argrelmax(smooth_arr)[0]
Expand Down Expand Up @@ -186,6 +191,8 @@ def find_price_argrelextrema(prices: pd.Series, *, bw: Union[str, np.ndarray] =
return prices.loc[idx]

# TODO:算法待优化 是否能减少时间复杂度


def find_price_patterns(max_min: pd.Series, save_all: bool = True) -> Dict:
"""识别匹配常见形态,由于时间区间不同可能会有多个值
Expand All @@ -208,8 +215,8 @@ def find_price_patterns(max_min: pd.Series, save_all: bool = True) -> Dict:
if size < 5:
return {}

arrs: np.ndarray = rolling_windows(max_min, 5) # 平滑并确定好高低点的价格数据
idxs: np.ndarray = rolling_windows(np.array(max_min.index), 5) # 索引
arrs: np.ndarray = rolling_windows(max_min.values, 5) # 平滑并确定好高低点的价格数据
idxs: np.ndarray = rolling_windows(max_min.index.values, 5) # 索引

for idx, arr in zip(idxs, arrs):

Expand Down Expand Up @@ -262,8 +269,6 @@ def find_price_patterns(max_min: pd.Series, save_all: bool = True) -> Dict:
return patterns




"""形态定义
论文中e_1 is a maximum/minimum 但是HS和IHS应该是e_1与e_2比较大小(我是看论文图例得出的结论,否则找出的形态有问题)
Expand Down Expand Up @@ -432,6 +437,8 @@ def _pattern_RBOT(arr: np.ndarray) -> bool:

# TODO:过于冗余 考虑使用多进程提升运行效率
# 初步完成多进程版本rolling_patterns2pool


def rolling_patterns(price: pd.Series, *, bw: Union[str, np.ndarray] = 'cv_ls', a: float = 0.3, n: int = 35, offset: int = 1, reset_window: int = None) -> namedtuple:
"""滑动窗口识别
当窗口滑动时,历史上同一时间出现的形态可能会在多个连续窗口中被识别出来了不重复分析,我们只保留滑动期n内第一次识别到该形态的时点。
Expand Down Expand Up @@ -629,9 +636,11 @@ def _get_slice_price(tline: Union[Dict, np.array]) -> pd.DataFrame:
type='candle', datetime_format='%Y-%m-%d', ax=ax)
return ax


"""使用Multiprocessing"""

def _roll_patterns_series(arrs: List)->Tuple[int,defaultdict]:

def _roll_patterns_series(arrs: List, **kw) -> Tuple[defaultdict]:
"""获取窗口期内第一个匹配到的形态信息
Args:
Expand All @@ -640,15 +649,18 @@ def _roll_patterns_series(arrs: List)->Tuple[int,defaultdict]:
Returns:
Tuple[int,defaultdict]: 0-切片下标信息 1-形态信息
"""
slice_arr, idx_arr,id_num = arrs
# slice_arr, idx_arr, id_num = arrs
slice_arr, idx_arr = arrs

close_ser = pd.Series(data=slice_arr, index=idx_arr)

max_min = find_price_argrelextrema(close_ser)
max_min = find_price_argrelextrema(close_ser, **kw)

return (id_num,find_price_patterns(max_min))
# return (id_num, find_price_patterns(max_min))
return find_price_patterns(max_min)


def rolling_patterns2pool(price: pd.Series,n:int,reset_window:int=None,n_workers: int = CPU_WORKER_NUM)->namedtuple:
def rolling_patterns2pool(price: pd.Series, n: int, reset_window: int = None, *, roll: bool = True, n_workers: int = CPU_WORKER_NUM, **kw) -> namedtuple:
"""使用多进程匹配
Args:
Expand All @@ -664,33 +676,35 @@ def rolling_patterns2pool(price: pd.Series,n:int,reset_window:int=None,n_workers
namedtuple: _description_
"""
size = len(price)

if reset_window is None:

reset_window = size - n + 1 # 表示不更新

if reset_window >= size:

raise ValueError('reset_window不能大于price长度')

idxs: np.ndarray = rolling_windows(price.index.values, n)
arr: np.ndarray = rolling_windows(price.values, n)
id_num:np.ndarray = np.arange(len(idxs))
chunk_size = calculate_best_chunk_size(len(idxs),n_workers)

# 用于储存结果
Record = namedtuple('Record', 'patterns,points')
patterns = defaultdict(list) # 储存识别好的形态
points = defaultdict(list) # 出现形态时点

with Pool(processes=n_workers) as pool:

res_tuple:Tuple[Tuple[int,Dict]] = tuple(pool.imap_unordered(_roll_patterns_series,
zip(arr, idxs,id_num), chunksize=chunk_size))
idxs: np.ndarray = rolling_windows(price.index.values, n)
arr: np.ndarray = rolling_windows(price.values, n)

chunk_size = calculate_best_chunk_size(len(idxs), n_workers)

roll_patterns_series = functools.partial(_roll_patterns_series, **kw)

res_tuple = sorted(res_tuple)
with Pool(processes=n_workers) as pool:

for sub_res in res_tuple:
res_tuple: Tuple[Dict] = tuple(
pool.imap(roll_patterns_series, zip(arr, idxs), chunksize=chunk_size))

num,current_pattern = sub_res
for num, sub_res in enumerate(res_tuple):

current_pattern = sub_res

if current_pattern:

Expand Down Expand Up @@ -734,12 +748,14 @@ def calculate_best_chunk_size(data_length: int, n_workers: int) -> int:
# if __name__ == '__main__':

# industry_data = pd.read_csv(
# r'中信二级行业测试\DATA.csv', index_col=[0, 1], parse_dates=[1])
# 'DATA.csv', index_col=[0, 1], parse_dates=[1])

# industry_data.columns = industry_data.columns.str.lower()
# industry_data.index.names = ['WIND_CODE','DATE']
# idx = pd.IndexSlice
# close_ser = industry_data.loc[idx['CI005101.WI','2021-10-01':'2021-12-31'],'close'].reset_index(level=0,drop=True)
# res = rolling_patterns(close_ser) # rolling_patterns2pool(close_ser)

# # 使用多进程
# res = rolling_patterns2pool(close_ser,n=30,n_workers=4,a=None) # rolling_patterns2(close_ser)
# # 使用单进程
# res = rolling_patterns(close_ser,n=30,n_workers=4,a=None)
# print(res)

0 comments on commit 20f5fdb

Please sign in to comment.