Skip to content

Commit

Permalink
fixing bug in copula rvs
Browse files Browse the repository at this point in the history
  • Loading branch information
tfm000 committed Dec 3, 2023
1 parent 386f0b4 commit 52ed7ed
Showing 1 changed file with 27 additions and 11 deletions.
38 changes: 27 additions & 11 deletions sklarpy/copulas/_prefit_dists.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union, Iterable, Callable, Dict, List
import numpy as np
import pandas as pd
from collections import deque

from sklarpy.copulas import MarginalFitter
from sklarpy.utils._input_handlers import check_multivariate_data, get_mask
Expand Down Expand Up @@ -799,17 +800,32 @@ def copula_rvs(self, size: int, copula_params: Union[Params, tuple],
distribution. These correspond to randomly sampled cdf /
pseudo-observation values of the univariate marginals.
"""
# generating random variables from multivariate distribution
raw_mv_rvs: np.ndarray = self._mv_object.rvs(size, copula_params)

# bounding these above and below
# eps: float = 10 ** -5
# rvs_df: pd.DataFrame = pd.DataFrame(raw_mv_rvs)
# rvs_df[rvs_df < 0] = eps
# rvs_df[rvs_df > 1] = 1 - eps
# mv_rvs: np.ndarray = rvs_df.to_numpy()

return self._g_to_u(raw_mv_rvs, copula_params)
max_num_loops: int = 100
num_loops: int = 0
d: int = self._mv_object._get_dim(copula_params)
valid_copula_rvs: deque = deque()
while size > 0:
# generating random variables from multivariate distribution
mv_rvs: np.ndarray = self._mv_object.rvs(size, copula_params)

# converting to copula rvs
raw_copula_rvs: np.ndarray = self._g_to_u(mv_rvs, copula_params)

# filtering out invalid copula rvs (not in [0, 1]^d)
mask: np.ndarray = ((raw_copula_rvs > 0) & (raw_copula_rvs < 1)
).sum(axis=1) == d
copula_rvs = raw_copula_rvs[mask]
valid_copula_rvs.append(copula_rvs)

# repeating until sample size reached
size -= copula_rvs.shape[0]
num_loops += 1

if num_loops > max_num_loops:
raise ArithmeticError(f"Unable to generate valid copula rvs. "
f"number of retries reached "
f"{max_num_loops}")
return np.concatenate(valid_copula_rvs, axis=0)

def _get_components_summary(self,
fitted_mv_object: FittedContinuousMultivariate,
Expand Down

0 comments on commit 52ed7ed

Please sign in to comment.