diff --git a/jesse/indicators/adxr.py b/jesse/indicators/adxr.py index a1ba55826..5b5159717 100644 --- a/jesse/indicators/adxr.py +++ b/jesse/indicators/adxr.py @@ -1,8 +1,8 @@ from typing import Union import numpy as np -from numba import njit from jesse.helpers import slice_candles +from numba import njit @njit(cache=True) @@ -12,7 +12,7 @@ def _adxr(high, low, close, period): DMP = np.zeros(n) DMM = np.zeros(n) - # First value initialization + # First value TR[0] = high[0] - low[0] # Calculate TR, DMP, DMM @@ -54,28 +54,29 @@ def _adxr(high, low, close, period): # Calculate DI+ and DI- DI_plus = np.zeros(n) DI_minus = np.zeros(n) - DX = np.zeros(n) - for i in range(n): if STR[i] != 0: DI_plus[i] = (S_DMP[i] / STR[i]) * 100 DI_minus[i] = (S_DMM[i] / STR[i]) * 100 - - denom = DI_plus[i] + DI_minus[i] - if denom != 0: - DX[i] = (abs(DI_plus[i] - DI_minus[i]) / denom) * 100 + + # Calculate DX + DX = np.zeros(n) + for i in range(n): + denom = DI_plus[i] + DI_minus[i] + if denom != 0: + DX[i] = (abs(DI_plus[i] - DI_minus[i]) / denom) * 100 # Calculate ADX - ADX = np.zeros(n) + ADX = np.full(n, np.nan) if n >= period: - # First ADX value - ADX[period-1] = np.mean(DX[:period]) - # Rest of ADX values - for i in range(period, n): - ADX[i] = ((ADX[i-1] * (period-1)) + DX[i]) / period + for i in range(period-1, n): + sum_dx = 0 + for j in range(period): + sum_dx += DX[i-j] + ADX[i] = sum_dx / period # Calculate ADXR - ADXR = np.zeros(n) + ADXR = np.full(n, np.nan) if n > period: for i in range(period, n): ADXR[i] = (ADX[i] + ADX[i-period]) / 2 @@ -95,11 +96,11 @@ def adxr(candles: np.ndarray, period: int = 14, sequential: bool = False) -> Uni :return: ADXR as float or np.ndarray """ candles = slice_candles(candles, sequential) - + high = candles[:, 3] low = candles[:, 4] close = candles[:, 2] - res = _adxr(high, low, close, period) + ADXR = _adxr(high, low, close, period) - return res if sequential else res[-1] + return ADXR if sequential else ADXR[-1] diff --git a/jesse/indicators/dx.py b/jesse/indicators/dx.py index 43e855c06..e0d22f0f8 100644 --- a/jesse/indicators/dx.py +++ b/jesse/indicators/dx.py @@ -4,73 +4,38 @@ import numpy as np from jesse.helpers import slice_candles +from jesse.indicators.rma import rma from numba import njit DX = namedtuple('DX', ['adx', 'plusDI', 'minusDI']) @njit(cache=True) -def _rma(src, length): - alpha = 1.0 / length - output = np.zeros_like(src) - output[0] = src[0] - for i in range(1, len(src)): - output[i] = alpha * src[i] + (1 - alpha) * output[i-1] - return output - -@njit(cache=True) -def _dx(high, low, close, di_length, adx_smoothing): +def _fast_dm_tr(high: np.ndarray, low: np.ndarray, close: np.ndarray) -> tuple: n = len(high) - - # Pre-allocate arrays + up = np.zeros(n) + down = np.zeros(n) plusDM = np.zeros(n) minusDM = np.zeros(n) - tr = np.zeros(n) - - # Calculate True Range and Directional Movement - for i in range(1, n): - high_diff = high[i] - high[i-1] - low_diff = low[i-1] - low[i] - - # +DM - if high_diff > low_diff and high_diff > 0: - plusDM[i] = high_diff - - # -DM - if low_diff > high_diff and low_diff > 0: - minusDM[i] = low_diff - - # True Range - tr[i] = max( - high[i] - low[i], - abs(high[i] - close[i-1]), - abs(low[i] - close[i-1]) - ) - - # Calculate smoothed values - tr_rma = _rma(tr, di_length) - plus_rma = _rma(plusDM, di_length) - minus_rma = _rma(minusDM, di_length) - - # Calculate +DI and -DI - plusDI = np.zeros(n) - minusDI = np.zeros(n) + true_range = np.zeros(n) for i in range(n): - if tr_rma[i] != 0: - plusDI[i] = 100 * plus_rma[i] / tr_rma[i] - minusDI[i] = 100 * minus_rma[i] / tr_rma[i] + if i == 0: + up[i] = 0 + down[i] = 0 + plusDM[i] = 0 + minusDM[i] = 0 + true_range[i] = high[i] - low[i] + else: + up[i] = high[i] - high[i - 1] + down[i] = low[i - 1] - low[i] + plusDM[i] = up[i] if (up[i] > down[i] and up[i] > 0) else 0 + minusDM[i] = down[i] if (down[i] > up[i] and down[i] > 0) else 0 + a = high[i] - low[i] + b = abs(high[i] - close[i - 1]) + c = abs(low[i] - close[i - 1]) + true_range[i] = max(a, b, c) - # Calculate DX - dx_values = np.zeros(n) - for i in range(n): - di_sum = plusDI[i] + minusDI[i] - if di_sum != 0: - dx_values[i] = 100 * abs(plusDI[i] - minusDI[i]) / di_sum - - # Calculate ADX - adx = _rma(dx_values, adx_smoothing) - - return adx, plusDI, minusDI + return plusDM, minusDM, true_range def dx(candles: np.ndarray, di_length: int = 14, adx_smoothing: int = 14, sequential: bool = False) -> Union[float, np.ndarray]: """ @@ -84,14 +49,24 @@ def dx(candles: np.ndarray, di_length: int = 14, adx_smoothing: int = 14, sequen :return: DX(adx, plusDI, minusDI) """ candles = slice_candles(candles, sequential) + high = candles[:, 3] + low = candles[:, 4] + close = candles[:, 2] + + plusDM, minusDM, true_range = _fast_dm_tr(high, low, close) + + tr_rma = rma(true_range, di_length, sequential=True) + plus_rma = rma(plusDM, di_length, sequential=True) + minus_rma = rma(minusDM, di_length, sequential=True) + + # Compute +DI and -DI, avoiding division by zero + plusDI = np.where(tr_rma == 0, 0, 100 * plus_rma / tr_rma) + minusDI = np.where(tr_rma == 0, 0, 100 * minus_rma / tr_rma) - adx, plusDI, minusDI = _dx( - candles[:, 3], # high - candles[:, 4], # low - candles[:, 2], # close - di_length, - adx_smoothing - ) + di_sum = plusDI + minusDI + di_diff = np.abs(plusDI - minusDI) + directional_index = di_diff / np.where(di_sum == 0, 1, di_sum) + adx = 100 * rma(directional_index, adx_smoothing, sequential=True) if sequential: return DX(adx, plusDI, minusDI)