Skip to content

Commit

Permalink
tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
M-Chris committed Oct 14, 2024
1 parent b3902a3 commit 82fa04a
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 26 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,17 @@ Allora Model Maker is a comprehensive machine learning framework designed for ti

2. Create a conda environment:
```bash
conda create --name modelmaker python=3.9 && conda activate modelmaker
conda env create -f environment.yml
```

3. Preinstall setuptools, cython and numpy
If you want to manually do it:
```bash
conda create --name modelmaker python=3.9 && conda activate modelmaker
```
Preinstall setuptools, cython and numpy
```bash
pip install setuptools==72.1.0 Cython==3.0.11 numpy==1.24.3
```

4. Install dependencies:
Install dependencies:
```bash
pip install -r requirements.txt
```
Expand Down
6 changes: 4 additions & 2 deletions data/tiingo_data_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

# Load the .env.local file if it exists, otherwise load .env
if os.path.exists(".env.local"):
load_dotenv(dotenv_path=".env.local")
print("Loading .env.local file...")
load_dotenv(dotenv_path=".env.local", override=True)
else:
load_dotenv() # Defaults to loading .env
print("Loading .env file...")
load_dotenv(dotenv_path=".env") # Defaults to loading .env

# Retrieve the API keys from environment variables
TIINGO_API_KEY = os.getenv("TIINGO_API_KEY")
Expand Down
137 changes: 137 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
name: modelmaker
channels:
- conda-forge
- defaults
dependencies:
- appnope=0.1.4
- asttokens=2.4.1
- bzip2=1.0.8
- ca-certificates=2024.8.30
- comm=0.2.2
- debugpy=1.8.6
- decorator=5.1.1
- exceptiongroup=1.2.2
- executing=2.1.0
- importlib-metadata=8.5.0
- ipykernel=6.29.5
- ipython=8.18.1
- jedi=0.19.1
- jupyter_client=8.6.3
- jupyter_core=5.7.2
- krb5=1.21.3
- libcxx=19.1.0
- libedit=3.1.20191231
- libffi=3.4.2
- libsodium=1.0.20
- libsqlite=3.46.1
- libzlib=1.3.1
- matplotlib-inline=0.1.7
- ncurses=6.5
- nest-asyncio=1.6.0
- openssl=3.3.2
- packaging=24.1
- parso=0.8.4
- pexpect=4.9.0
- pickleshare=0.7.5
- pip=24.2
- platformdirs=4.3.6
- prompt-toolkit=3.0.48
- psutil=6.0.0
- ptyprocess=0.7.0
- pure_eval=0.2.3
- pygments=2.18.0
- python=3.9.20
- python_abi=3.9
- pyzmq=26.2.0
- readline=8.2
- six=1.16.0
- stack_data=0.6.2
- tk=8.6.13
- tornado=6.4.1
- traitlets=5.14.3
- typing_extensions=4.12.2
- wcwidth=0.2.13
- wheel=0.44.0
- xz=5.2.6
- zeromq=4.3.5
- zipp=3.20.2
- pip:
- aiohappyeyeballs==2.4.3
- aiohttp==3.10.8
- aiosignal==1.3.1
- annotated-types==0.7.0
- anyio==4.6.0
- astroid==3.2.4
- async-timeout==4.0.3
- attrs==24.2.0
- black==24.8.0
- certifi==2024.8.30
- cfgv==3.4.0
- charset-normalizer==3.3.2
- click==8.1.7
- cmdstanpy==1.2.4
- contourpy==1.3.0
- cycler==0.12.1
- cython==3.0.11
- dill==0.3.9
- distlib==0.3.8
- filelock==3.16.1
- fonttools==4.54.1
- frozenlist==1.4.1
- fsspec==2024.9.0
- holidays==0.57
- identify==2.6.1
- idna==3.10
- importlib-resources==6.4.5
- isort==5.13.2
- jinja2==3.1.4
- joblib==1.4.2
- kiwisolver==1.4.7
- markupsafe==2.1.5
- matplotlib==3.9.2
- mccabe==0.7.0
- mpmath==1.3.0
- multidict==6.1.0
- mypy-extensions==1.0.0
- networkx==3.2.1
- nodeenv==1.9.1
- numpy==1.24.3
- pandas==2.2.3
- pathspec==0.12.1
- patsy==0.5.6
- pillow==10.4.0
- pipdeptree==2.23.4
- plotly==5.24.0
- pre-commit==3.8.0
- prophet==1.1.5
- pydantic==2.9.2
- pydantic-core==2.23.4
- pylint==3.2.7
- pyparsing==3.1.4
- pystan==2.19.1.1
- python-dateutil==2.9.0.post0
- python-dotenv==1.0.1
- pytz==2024.2
- pyyaml==6.0.2
- requests==2.32.3
- schedule==1.2.2
- scikit-learn==1.5.1
- scipy==1.13.1
- seaborn==0.13.2
- setuptools==72.1.0
- sniffio==1.3.1
- stanio==0.5.1
- starlette==0.38.5
- statsmodels==0.14.2
- sympy==1.13.3
- tenacity==9.0.0
- threadpoolctl==3.5.0
- tomli==2.0.1
- tomlkit==0.13.2
- torch==2.4.1
- tqdm==4.66.5
- tzdata==2024.2
- urllib3==2.2.3
- virtualenv==20.26.6
- xgboost==2.1.1
- yarl==1.13.1
17 changes: 8 additions & 9 deletions models/arima/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,20 @@ def train(self, data: pd.DataFrame):
print("Data is not stationary, applying differencing...")
close_prices = differencing(close_prices)

# Perform grid search if enabled in the configuration
# Perform enhanced grid search if enabled in the configuration
if self.config.use_grid_search:
print("Performing grid search for ARIMA parameters...")
grid_search_arima(
self,
data=close_prices,
p_values=self.config.p_values,
d_values=self.config.d_values,
q_values=self.config.q_values,
)
else:
print("Using default ARIMA parameters...")

# Fit ARIMA model with the configured best parameters
# Fit ARIMA model with the configured or best-found parameters
self.model = ARIMA(close_prices, order=self.config.best_params)
self.model = self.model.fit(method_kwargs={"maxiter": self.config.max_iter})

Expand All @@ -73,23 +76,19 @@ def inference(self, input_data: pd.DataFrame) -> pd.DataFrame:
close_prices = resample_data(self, close_prices)

# Forecast the number of steps equal to the length of the input data
# pylint: disable=no-member
predictions = self.model.forecast(steps=len(close_prices))

# If differencing was applied, reverse the differencing to get price predictions
if self.config.best_params[1] > 0: # If d > 0, reverse differencing
predictions = reverse_differencing(close_prices, predictions)

# Optionally check for negative values and flag them instead of clipping
predictions[predictions < 0] = (
np.nan
) # Replace unreasonable negative prices with NaN
# Replace unreasonable negative values with NaN
predictions[predictions < 0] = np.nan

# convert the date to string
# Convert the date to string and return results
predictions = pd.Series(
predictions.values, index=close_prices.index.astype(str)
)

return pd.DataFrame(
{"date": predictions.index, "prediction": predictions.values.ravel()}
)
Expand Down
4 changes: 2 additions & 2 deletions models/lstm/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ class LstmConfig:
def __init__(self):
# Model architecture parameters
self.input_size = 1 # Input size (number of features)
self.hidden_size = 50 # Number of LSTM units per layer
self.hidden_size = 64 # Number of LSTM units per layer
self.output_size = 1 # Output size (prediction dimension)
self.num_layers = 2 # Number of stacked LSTM layers
self.dropout = 0.5 # Dropout probability for regularization

# Training parameters
self.learning_rate = 0.001 # Learning rate for the optimizer
self.learning_rate = 0.0001 # Learning rate for the optimizer
self.batch_size = 32 # Batch size for training
self.epochs = 100 # Number of training epochs
self.early_stopping_patience = 10 # Early stopping patience in epochs
Expand Down
13 changes: 8 additions & 5 deletions models/prophet/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ class ProphetConfig:

def __init__(self):
# Prophet model configuration
self.growth = "linear" # Options: 'linear', 'logistic'
self.changepoint_prior_scale = 0.05 # Regularization strength for changepoints
self.seasonality_mode = "additive" # Options: 'additive', 'multiplicative'
self.yearly_seasonality = True # Whether to include yearly seasonality
self.growth = "logistic" # Options: 'linear', 'logistic'
self.cap = None # Optional, for logistic growth. Can be dynamically calculated if None.
self.changepoint_prior_scale = 0.25 # Regularization strength for changepoints
self.seasonality_mode = (
"multiplicative" # Options: 'additive', 'multiplicative'
)
self.yearly_seasonality = False # Whether to include yearly seasonality
self.weekly_seasonality = True # Whether to include weekly seasonality
self.daily_seasonality = False # Whether to include daily seasonality
self.daily_seasonality = True # Whether to include daily seasonality

# Forecast parameters
self.periods = 365 # Default number of periods for future forecasts (trading days for stocks usually 252)
Expand Down
36 changes: 33 additions & 3 deletions models/prophet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,30 @@ def __init__(self, model_name="prophet", config=ProphetConfig(), debug=False):

def train(self, data: pd.DataFrame):
df = data[["date", "close"]].copy()
df["date"] = pd.to_datetime(df["date"], errors="coerce")

# Drop rows with NaN values in 'date' or 'close'
df = df.dropna(subset=["date", "close"])

if self.config.remove_timezone:
df["date"] = df["date"].dt.tz_localize(None) # Remove timezone if present
df["date"] = df["date"].dt.tz_localize(None)

df = df.rename(columns={"date": "ds", "close": "y"})
self.model.fit(df)

# Handle logistic growth: Set 'cap' and 'floor' values
if self.config.growth == "logistic":
max_y = df["y"].max()
df["cap"] = max_y * 1.1 # Set cap to 10% above max value
df["floor"] = 0 # Set a floor to prevent negative growth

if self.debug:
# Print data to check for NaNs or extreme values
print(df.isna().sum())
print(df.describe())

# Fit model
self.model.fit(df, iter=20000)

self.save()

def inference(self, input_data: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -47,12 +65,24 @@ def inference(self, input_data: pd.DataFrame) -> pd.DataFrame:
if self.config.remove_timezone:
future["ds"] = pd.to_datetime(future["ds"]).dt.tz_localize(None)

# Handle logistic growth: Ensure 'cap' and 'floor' columns are present
if self.config.growth == "logistic":
if "cap" not in input_data.columns:
# Dynamically set 'cap' if missing (assuming similar logic as training)
future["cap"] = (
input_data["close"].max() * 1.1
) # Example logic for setting 'cap'

if "floor" not in input_data.columns:
# Set the 'floor' if it's missing (e.g., default to 0)
future["floor"] = 0

if self.debug:
print("Future DataFrame for ProphetModel (after renaming):")
print(future)

# Perform the forecast using Prophet
forecast = self.model.predict(future)
forecast = self.model.predict(future, False)
if self.debug:
print("Forecast Output:")
print(forecast[["ds", "yhat"]])
Expand Down

0 comments on commit 82fa04a

Please sign in to comment.