diff --git "a/C-\346\213\251\346\227\266\347\261\273/\346\212\200\346\234\257\345\210\206\346\236\220\347\256\227\346\263\225\346\241\206\346\236\266\344\270\216\345\256\236\346\210\230/py/technical_analysis_patterns.py" "b/C-\346\213\251\346\227\266\347\261\273/\346\212\200\346\234\257\345\210\206\346\236\220\347\256\227\346\263\225\346\241\206\346\236\266\344\270\216\345\256\236\346\210\230/py/technical_analysis_patterns.py" index 299eb30..b565e50 100644 --- "a/C-\346\213\251\346\227\266\347\261\273/\346\212\200\346\234\257\345\210\206\346\236\220\347\256\227\346\263\225\346\241\206\346\236\266\344\270\216\345\256\236\346\210\230/py/technical_analysis_patterns.py" +++ "b/C-\346\213\251\346\227\266\347\261\273/\346\212\200\346\234\257\345\210\206\346\236\220\347\256\227\346\263\225\346\241\206\346\236\266\344\270\216\345\256\236\346\210\230/py/technical_analysis_patterns.py" @@ -1,7 +1,7 @@ ''' Author: Hugo Date: 2022-02-18 21:17:27 -LastEditTime: 2022-02-21 20:48:09 +LastEditTime: 2022-02-23 21:43:17 LastEditors: Please set LastEditors ''' @@ -9,6 +9,7 @@ from collections import (defaultdict, namedtuple) from typing import (List, Tuple, Dict, Callable, Union) import itertools +import functools from tqdm.notebook import tqdm import warnings @@ -93,7 +94,7 @@ def rolling_windows(a: Union[np.ndarray, pd.Series, pd.DataFrame], window: int) def calc_smooth(prices: pd.Series, *, bw: Union[np.ndarray, str] = 'cv_ls', a: float = None, use_array: bool = True) -> Union[pd.Series, np.ndarray]: - """计算Nadaraya–Watson核估计后的价格数据 + """计算Nadaraya-Watson核估计后的价格数据 Args: prices (pd.Series): 价格数据 @@ -134,7 +135,9 @@ def calc_smooth(prices: pd.Series, *, bw: Union[np.ndarray, str] = 'cv_ls', a: f # TODO:考虑改为全部使用numpy以提升速度 # TODO:是否考虑娱乐sklearn接口? # 预留平滑函数 后续考虑使用机器学习算法 -def find_price_argrelextrema(prices: pd.Series, *, bw: Union[str, np.ndarray] = 'cv_ls', a: float = 0.3, offset: int = 1, smooth_fumc: Callable = calc_smooth) -> pd.Series: + + +def find_price_argrelextrema(prices: pd.Series, *, offset: int = 1, smooth_fumc: Callable = calc_smooth, **kw) -> pd.Series: """平滑数据并识别极大极小值 Args: @@ -142,7 +145,8 @@ def find_price_argrelextrema(prices: pd.Series, *, bw: Union[str, np.ndarray] = bw (Union[str,np.ndarray]): Either a user-specified bandwidth or the method for bandwidth selection. Defaults to cv_ls. a (float):论文中所说的比例数据. Defaults to 0.3. offset (int, optional): 避免陷入局部最大最小值. Defaults to 1. - smooth_fumc (Callable,optional): 平滑处理方法函数. Defaults to calc_smooth + smooth_fumc (Callable,optional): 平滑处理方法函数,返回值需要为ndarray. Defaults to calc_smooth + kw : 该参数传递给smooth_func Returns: pd.Series: 最大最小值的目标索引下标 index-dt value-price """ @@ -153,7 +157,8 @@ def find_price_argrelextrema(prices: pd.Series, *, bw: Union[str, np.ndarray] = raise ValueError('price数据长度过小') # 计算平滑价格 - smooth_arr: np.ndarray = smooth_fumc(prices, a=a, bw=bw, use_array=True) + # + smooth_arr: np.ndarray = smooth_fumc(prices, **kw) # 请多平滑后的高低点 local_max = argrelmax(smooth_arr)[0] @@ -186,6 +191,8 @@ def find_price_argrelextrema(prices: pd.Series, *, bw: Union[str, np.ndarray] = return prices.loc[idx] # TODO:算法待优化 是否能减少时间复杂度 + + def find_price_patterns(max_min: pd.Series, save_all: bool = True) -> Dict: """识别匹配常见形态,由于时间区间不同可能会有多个值 @@ -208,8 +215,8 @@ def find_price_patterns(max_min: pd.Series, save_all: bool = True) -> Dict: if size < 5: return {} - arrs: np.ndarray = rolling_windows(max_min, 5) # 平滑并确定好高低点的价格数据 - idxs: np.ndarray = rolling_windows(np.array(max_min.index), 5) # 索引 + arrs: np.ndarray = rolling_windows(max_min.values, 5) # 平滑并确定好高低点的价格数据 + idxs: np.ndarray = rolling_windows(max_min.index.values, 5) # 索引 for idx, arr in zip(idxs, arrs): @@ -262,8 +269,6 @@ def find_price_patterns(max_min: pd.Series, save_all: bool = True) -> Dict: return patterns - - """形态定义 论文中e_1 is a maximum/minimum 但是HS和IHS应该是e_1与e_2比较大小(我是看论文图例得出的结论,否则找出的形态有问题) @@ -432,6 +437,8 @@ def _pattern_RBOT(arr: np.ndarray) -> bool: # TODO:过于冗余 考虑使用多进程提升运行效率 # 初步完成多进程版本rolling_patterns2pool + + def rolling_patterns(price: pd.Series, *, bw: Union[str, np.ndarray] = 'cv_ls', a: float = 0.3, n: int = 35, offset: int = 1, reset_window: int = None) -> namedtuple: """滑动窗口识别 当窗口滑动时,历史上同一时间出现的形态可能会在多个连续窗口中被识别出来了不重复分析,我们只保留滑动期n内第一次识别到该形态的时点。 @@ -629,9 +636,11 @@ def _get_slice_price(tline: Union[Dict, np.array]) -> pd.DataFrame: type='candle', datetime_format='%Y-%m-%d', ax=ax) return ax + """使用Multiprocessing""" -def _roll_patterns_series(arrs: List)->Tuple[int,defaultdict]: + +def _roll_patterns_series(arrs: List, **kw) -> Tuple[defaultdict]: """获取窗口期内第一个匹配到的形态信息 Args: @@ -640,15 +649,18 @@ def _roll_patterns_series(arrs: List)->Tuple[int,defaultdict]: Returns: Tuple[int,defaultdict]: 0-切片下标信息 1-形态信息 """ - slice_arr, idx_arr,id_num = arrs + # slice_arr, idx_arr, id_num = arrs + slice_arr, idx_arr = arrs + close_ser = pd.Series(data=slice_arr, index=idx_arr) - max_min = find_price_argrelextrema(close_ser) + max_min = find_price_argrelextrema(close_ser, **kw) - return (id_num,find_price_patterns(max_min)) + # return (id_num, find_price_patterns(max_min)) + return find_price_patterns(max_min) -def rolling_patterns2pool(price: pd.Series,n:int,reset_window:int=None,n_workers: int = CPU_WORKER_NUM)->namedtuple: +def rolling_patterns2pool(price: pd.Series, n: int, reset_window: int = None, *, roll: bool = True, n_workers: int = CPU_WORKER_NUM, **kw) -> namedtuple: """使用多进程匹配 Args: @@ -664,6 +676,7 @@ def rolling_patterns2pool(price: pd.Series,n:int,reset_window:int=None,n_workers namedtuple: _description_ """ size = len(price) + if reset_window is None: reset_window = size - n + 1 # 表示不更新 @@ -671,26 +684,27 @@ def rolling_patterns2pool(price: pd.Series,n:int,reset_window:int=None,n_workers if reset_window >= size: raise ValueError('reset_window不能大于price长度') - - idxs: np.ndarray = rolling_windows(price.index.values, n) - arr: np.ndarray = rolling_windows(price.values, n) - id_num:np.ndarray = np.arange(len(idxs)) - chunk_size = calculate_best_chunk_size(len(idxs),n_workers) + # 用于储存结果 Record = namedtuple('Record', 'patterns,points') patterns = defaultdict(list) # 储存识别好的形态 points = defaultdict(list) # 出现形态时点 - - with Pool(processes=n_workers) as pool: - res_tuple:Tuple[Tuple[int,Dict]] = tuple(pool.imap_unordered(_roll_patterns_series, - zip(arr, idxs,id_num), chunksize=chunk_size)) + idxs: np.ndarray = rolling_windows(price.index.values, n) + arr: np.ndarray = rolling_windows(price.values, n) + + chunk_size = calculate_best_chunk_size(len(idxs), n_workers) + + roll_patterns_series = functools.partial(_roll_patterns_series, **kw) - res_tuple = sorted(res_tuple) + with Pool(processes=n_workers) as pool: - for sub_res in res_tuple: + res_tuple: Tuple[Dict] = tuple( + pool.imap(roll_patterns_series, zip(arr, idxs), chunksize=chunk_size)) - num,current_pattern = sub_res + for num, sub_res in enumerate(res_tuple): + + current_pattern = sub_res if current_pattern: @@ -734,12 +748,14 @@ def calculate_best_chunk_size(data_length: int, n_workers: int) -> int: # if __name__ == '__main__': # industry_data = pd.read_csv( -# r'中信二级行业测试\DATA.csv', index_col=[0, 1], parse_dates=[1]) +# 'DATA.csv', index_col=[0, 1], parse_dates=[1]) # industry_data.columns = industry_data.columns.str.lower() # industry_data.index.names = ['WIND_CODE','DATE'] # idx = pd.IndexSlice # close_ser = industry_data.loc[idx['CI005101.WI','2021-10-01':'2021-12-31'],'close'].reset_index(level=0,drop=True) -# res = rolling_patterns(close_ser) # rolling_patterns2pool(close_ser) - +# # 使用多进程 +# res = rolling_patterns2pool(close_ser,n=30,n_workers=4,a=None) # rolling_patterns2(close_ser) +# # 使用单进程 +# res = rolling_patterns(close_ser,n=30,n_workers=4,a=None) # print(res)