Skip to content

Commit

Permalink
Refactor and optimize ADXR and DX indicators with improved Numba impl…
Browse files Browse the repository at this point in the history
…ementation

- Rewrite ADXR indicator calculation with more precise DX and ADX computation
- Optimize DX indicator by separating directional movement and true range calculations
- Improve computational efficiency and numerical stability
- Use RMA function for smoother indicator calculations
- Enhance Numba JIT compilation for better performance
  • Loading branch information
morteza-koohgard authored and saleh-mir committed Feb 20, 2025
1 parent b1776de commit 5b628f6
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 81 deletions.
37 changes: 19 additions & 18 deletions jesse/indicators/adxr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Union

import numpy as np
from numba import njit
from jesse.helpers import slice_candles
from numba import njit


@njit(cache=True)
Expand All @@ -12,7 +12,7 @@ def _adxr(high, low, close, period):
DMP = np.zeros(n)
DMM = np.zeros(n)

# First value initialization
# First value
TR[0] = high[0] - low[0]

# Calculate TR, DMP, DMM
Expand Down Expand Up @@ -54,28 +54,29 @@ def _adxr(high, low, close, period):
# Calculate DI+ and DI-
DI_plus = np.zeros(n)
DI_minus = np.zeros(n)
DX = np.zeros(n)

for i in range(n):
if STR[i] != 0:
DI_plus[i] = (S_DMP[i] / STR[i]) * 100
DI_minus[i] = (S_DMM[i] / STR[i]) * 100

denom = DI_plus[i] + DI_minus[i]
if denom != 0:
DX[i] = (abs(DI_plus[i] - DI_minus[i]) / denom) * 100

# Calculate DX
DX = np.zeros(n)
for i in range(n):
denom = DI_plus[i] + DI_minus[i]
if denom != 0:
DX[i] = (abs(DI_plus[i] - DI_minus[i]) / denom) * 100

# Calculate ADX
ADX = np.zeros(n)
ADX = np.full(n, np.nan)
if n >= period:
# First ADX value
ADX[period-1] = np.mean(DX[:period])
# Rest of ADX values
for i in range(period, n):
ADX[i] = ((ADX[i-1] * (period-1)) + DX[i]) / period
for i in range(period-1, n):
sum_dx = 0
for j in range(period):
sum_dx += DX[i-j]
ADX[i] = sum_dx / period

# Calculate ADXR
ADXR = np.zeros(n)
ADXR = np.full(n, np.nan)
if n > period:
for i in range(period, n):
ADXR[i] = (ADX[i] + ADX[i-period]) / 2
Expand All @@ -95,11 +96,11 @@ def adxr(candles: np.ndarray, period: int = 14, sequential: bool = False) -> Uni
:return: ADXR as float or np.ndarray
"""
candles = slice_candles(candles, sequential)

high = candles[:, 3]
low = candles[:, 4]
close = candles[:, 2]

res = _adxr(high, low, close, period)
ADXR = _adxr(high, low, close, period)

return res if sequential else res[-1]
return ADXR if sequential else ADXR[-1]
101 changes: 38 additions & 63 deletions jesse/indicators/dx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,73 +4,38 @@

import numpy as np
from jesse.helpers import slice_candles
from jesse.indicators.rma import rma
from numba import njit

DX = namedtuple('DX', ['adx', 'plusDI', 'minusDI'])

@njit(cache=True)
def _rma(src, length):
alpha = 1.0 / length
output = np.zeros_like(src)
output[0] = src[0]
for i in range(1, len(src)):
output[i] = alpha * src[i] + (1 - alpha) * output[i-1]
return output

@njit(cache=True)
def _dx(high, low, close, di_length, adx_smoothing):
def _fast_dm_tr(high: np.ndarray, low: np.ndarray, close: np.ndarray) -> tuple:
n = len(high)

# Pre-allocate arrays
up = np.zeros(n)
down = np.zeros(n)
plusDM = np.zeros(n)
minusDM = np.zeros(n)
tr = np.zeros(n)

# Calculate True Range and Directional Movement
for i in range(1, n):
high_diff = high[i] - high[i-1]
low_diff = low[i-1] - low[i]

# +DM
if high_diff > low_diff and high_diff > 0:
plusDM[i] = high_diff

# -DM
if low_diff > high_diff and low_diff > 0:
minusDM[i] = low_diff

# True Range
tr[i] = max(
high[i] - low[i],
abs(high[i] - close[i-1]),
abs(low[i] - close[i-1])
)

# Calculate smoothed values
tr_rma = _rma(tr, di_length)
plus_rma = _rma(plusDM, di_length)
minus_rma = _rma(minusDM, di_length)

# Calculate +DI and -DI
plusDI = np.zeros(n)
minusDI = np.zeros(n)
true_range = np.zeros(n)

for i in range(n):
if tr_rma[i] != 0:
plusDI[i] = 100 * plus_rma[i] / tr_rma[i]
minusDI[i] = 100 * minus_rma[i] / tr_rma[i]
if i == 0:
up[i] = 0
down[i] = 0
plusDM[i] = 0
minusDM[i] = 0
true_range[i] = high[i] - low[i]
else:
up[i] = high[i] - high[i - 1]
down[i] = low[i - 1] - low[i]
plusDM[i] = up[i] if (up[i] > down[i] and up[i] > 0) else 0
minusDM[i] = down[i] if (down[i] > up[i] and down[i] > 0) else 0
a = high[i] - low[i]
b = abs(high[i] - close[i - 1])
c = abs(low[i] - close[i - 1])
true_range[i] = max(a, b, c)

# Calculate DX
dx_values = np.zeros(n)
for i in range(n):
di_sum = plusDI[i] + minusDI[i]
if di_sum != 0:
dx_values[i] = 100 * abs(plusDI[i] - minusDI[i]) / di_sum

# Calculate ADX
adx = _rma(dx_values, adx_smoothing)

return adx, plusDI, minusDI
return plusDM, minusDM, true_range

def dx(candles: np.ndarray, di_length: int = 14, adx_smoothing: int = 14, sequential: bool = False) -> Union[float, np.ndarray]:
"""
Expand All @@ -84,14 +49,24 @@ def dx(candles: np.ndarray, di_length: int = 14, adx_smoothing: int = 14, sequen
:return: DX(adx, plusDI, minusDI)
"""
candles = slice_candles(candles, sequential)
high = candles[:, 3]
low = candles[:, 4]
close = candles[:, 2]

plusDM, minusDM, true_range = _fast_dm_tr(high, low, close)

tr_rma = rma(true_range, di_length, sequential=True)
plus_rma = rma(plusDM, di_length, sequential=True)
minus_rma = rma(minusDM, di_length, sequential=True)

# Compute +DI and -DI, avoiding division by zero
plusDI = np.where(tr_rma == 0, 0, 100 * plus_rma / tr_rma)
minusDI = np.where(tr_rma == 0, 0, 100 * minus_rma / tr_rma)

adx, plusDI, minusDI = _dx(
candles[:, 3], # high
candles[:, 4], # low
candles[:, 2], # close
di_length,
adx_smoothing
)
di_sum = plusDI + minusDI
di_diff = np.abs(plusDI - minusDI)
directional_index = di_diff / np.where(di_sum == 0, 1, di_sum)
adx = 100 * rma(directional_index, adx_smoothing, sequential=True)

if sequential:
return DX(adx, plusDI, minusDI)
Expand Down

0 comments on commit 5b628f6

Please sign in to comment.