Skip to content

Commit

Permalink
ENH Add parameter return_X_y to make_classification (scikit-learn#3…
Browse files Browse the repository at this point in the history
…0196)

Co-authored-by: Adrin Jalali <[email protected]>
  • Loading branch information
SuccessMoses and adrinjalali authored Jan 3, 2025
1 parent 6c163c6 commit c9aeb15
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- New parameter ``return_X_y`` added to :func:`datasets.make_classification`. The
default value of the parameter does not change how the function behaves.
By :user:`Success Moses <SuccessMoses>` and :user:`Adam Cooper <arc12>`
98 changes: 82 additions & 16 deletions sklearn/datasets/_samples_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import scipy.sparse as sp
from scipy import linalg

from sklearn.utils import Bunch

from ..preprocessing import MultiLabelBinarizer
from ..utils import check_array, check_random_state
from ..utils import shuffle as util_shuffle
Expand Down Expand Up @@ -54,6 +56,7 @@ def _generate_hypercube(samples, dimensions, rng):
"scale": [Interval(Real, 0, None, closed="neither"), "array-like", None],
"shuffle": ["boolean"],
"random_state": ["random_state"],
"return_X_y": ["boolean"],
},
prefer_skip_nested_validation=True,
)
Expand All @@ -74,6 +77,7 @@ def make_classification(
scale=1.0,
shuffle=True,
random_state=None,
return_X_y=True,
):
"""Generate a random n-class classification problem.
Expand Down Expand Up @@ -168,13 +172,32 @@ def make_classification(
for reproducible output across multiple function calls.
See :term:`Glossary <random_state>`.
return_X_y : bool, default=True
If True, a tuple ``(X, y)`` instead of a Bunch object is returned.
.. versionadded:: 1.7
Returns
-------
X : ndarray of shape (n_samples, n_features)
The generated samples.
y : ndarray of shape (n_samples,)
The integer labels for class membership of each sample.
data : :class:`~sklearn.utils.Bunch` if `return_X_y` is `False`.
Dictionary-like object, with the following attributes.
DESCR : str
A description of the function that generated the dataset.
parameter : dict
A dictionary that stores the values of the arguments passed to the
generator function.
feature_info : list of len(n_features)
A description for each generated feature.
X : ndarray of shape (n_samples, n_features)
The generated samples.
y : ndarray of shape (n_samples,)
An integer label for class membership of each sample.
.. versionadded:: 1.7
(X, y) : tuple if ``return_X_y`` is True
A tuple of generated samples and labels.
See Also
--------
Expand Down Expand Up @@ -220,25 +243,28 @@ def make_classification(
)

if weights is not None:
# we define new variable, weight_, instead of modifying user defined parameter.
if len(weights) not in [n_classes, n_classes - 1]:
raise ValueError(
"Weights specified but incompatible with number of classes."
)
if len(weights) == n_classes - 1:
if isinstance(weights, list):
weights = weights + [1.0 - sum(weights)]
weights_ = weights + [1.0 - sum(weights)]
else:
weights = np.resize(weights, n_classes)
weights[-1] = 1.0 - sum(weights[:-1])
weights_ = np.resize(weights, n_classes)
weights_[-1] = 1.0 - sum(weights_[:-1])
else:
weights_ = weights.copy()
else:
weights = [1.0 / n_classes] * n_classes
weights_ = [1.0 / n_classes] * n_classes

n_useless = n_features - n_informative - n_redundant - n_repeated
n_random = n_features - n_informative - n_redundant - n_repeated
n_clusters = n_classes * n_clusters_per_class

# Distribute samples among clusters by weight
n_samples_per_cluster = [
int(n_samples * weights[k % n_classes] / n_clusters_per_class)
int(n_samples * weights_[k % n_classes] / n_clusters_per_class)
for k in range(n_clusters)
]

Expand Down Expand Up @@ -282,14 +308,14 @@ def make_classification(
)

# Repeat some features
n = n_informative + n_redundant
if n_repeated > 0:
n = n_informative + n_redundant
indices = ((n - 1) * generator.uniform(size=n_repeated) + 0.5).astype(np.intp)
X[:, n : n + n_repeated] = X[:, indices]

# Fill useless features
if n_useless > 0:
X[:, -n_useless:] = generator.standard_normal(size=(n_samples, n_useless))
if n_random > 0:
X[:, -n_random:] = generator.standard_normal(size=(n_samples, n_random))

# Randomly replace labels
if flip_y >= 0.0:
Expand All @@ -305,16 +331,56 @@ def make_classification(
scale = 1 + 100 * generator.uniform(size=n_features)
X *= scale

indices = np.arange(n_features)
if shuffle:
# Randomly permute samples
X, y = util_shuffle(X, y, random_state=generator)

# Randomly permute features
indices = np.arange(n_features)
generator.shuffle(indices)
X[:, :] = X[:, indices]

return X, y
if return_X_y:
return X, y

# feat_desc describes features in X
feat_desc = ["random"] * n_features
for i, index in enumerate(indices):
if index < n_informative:
feat_desc[i] = "informative"
elif n_informative <= index < n_informative + n_redundant:
feat_desc[i] = "redundant"
elif n <= index < n + n_repeated:
feat_desc[i] = "repeated"

parameters = {
"n_samples": n_samples,
"n_features": n_features,
"n_informative": n_informative,
"n_redundant": n_redundant,
"n_repeated": n_repeated,
"n_classes": n_classes,
"n_clusters_per_class": n_clusters_per_class,
"weights": weights,
"flip_y": flip_y,
"class_sep": class_sep,
"hypercube": hypercube,
"shift": shift,
"scale": scale,
"shuffle": shuffle,
"random_state": random_state,
"return_X_y": return_X_y,
}

bunch = Bunch(
DESCR=make_classification.__doc__,
parameters=parameters,
feature_info=feat_desc,
X=X,
y=y,
)

return bunch


@validate_params(
Expand Down
51 changes: 51 additions & 0 deletions sklearn/datasets/tests/test_samples_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,57 @@ def test_make_classification_informative_features():
make(n_features=2, n_informative=2, n_classes=3, n_clusters_per_class=2)


def test_make_classification_return_x_y():
"""
Test that make_classification returns a Bunch when return_X_y is False.
Also that bunch.X is the same as X
"""

kwargs = {
"n_samples": 100,
"n_features": 20,
"n_informative": 5,
"n_redundant": 1,
"n_repeated": 1,
"n_classes": 3,
"n_clusters_per_class": 2,
"weights": None,
"flip_y": 0.01,
"class_sep": 1.0,
"hypercube": True,
"shift": 0.0,
"scale": 1.0,
"shuffle": True,
"random_state": 42,
"return_X_y": True,
}

X, y = make_classification(**kwargs)

kwargs["return_X_y"] = False
bunch = make_classification(**kwargs)

assert (
hasattr(bunch, "DESCR")
and hasattr(bunch, "parameters")
and hasattr(bunch, "feature_info")
and hasattr(bunch, "X")
and hasattr(bunch, "y")
)

def count(str_):
return bunch.feature_info.count(str_)

assert np.array_equal(X, bunch.X)
assert np.array_equal(y, bunch.y)
assert bunch.DESCR == make_classification.__doc__
assert bunch.parameters == kwargs
assert count("informative") == kwargs["n_informative"]
assert count("redundant") == kwargs["n_redundant"]
assert count("repeated") == kwargs["n_repeated"]


@pytest.mark.parametrize(
"weights, err_type, err_msg",
[
Expand Down

0 comments on commit c9aeb15

Please sign in to comment.