Skip to content

Commit

Permalink
Add redundancy to OpenML datasets with Figshare (#1218)
Browse files Browse the repository at this point in the history
Co-authored-by: Jerome Dockes <[email protected]>
  • Loading branch information
Vincent-Maladiere and jeromedockes authored Jan 27, 2025
1 parent 9640263 commit 0eb71e7
Show file tree
Hide file tree
Showing 14 changed files with 13,756 additions and 13,880 deletions.
6 changes: 5 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ New features

Changes
-------
* New dataset fetching methods have been added: :func:`fetch_videogame_sales`,
:func:`fetch_bike_sharing`, :func:`fetch_flight_delays`,
:func:`fetch_country_happiness`, and removed :func:`fetch_road_safety`.
:pr:`1218` by :user:`Vincent Maladiere <Vincent-Maladiere>`

Bug fixes
---------
Expand All @@ -30,7 +34,7 @@ Release 0.4.1
Changes
-------
* :class: `TableReport` has `write_html` method
:pr:`1190` by :user: `Mojdeh Rastgoo<mrastgoo>`.
:pr:`1190` by :user:`Mojdeh Rastgoo<mrastgoo>`.

* A new parameter ``verbose`` has been added to the :class:`TableReport` to toggle on or off the
printing of progress information when a report is being generated.
Expand Down
20 changes: 12 additions & 8 deletions doc/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,21 @@ Downloading a dataset
:template: base.rst
:nosignatures:

fetch_bike_sharing
fetch_country_happiness
fetch_credit_fraud
fetch_drug_directory
fetch_employee_salaries
fetch_flight_delays
fetch_ken_embeddings
fetch_ken_table_aliases
fetch_ken_types
fetch_medical_charge
fetch_midwest_survey
fetch_movielens
fetch_open_payments
fetch_road_safety
fetch_toxicity
fetch_traffic_violations
fetch_drug_directory
fetch_world_bank_indicator
fetch_movielens
fetch_credit_fraud
fetch_ken_table_aliases
fetch_ken_types
fetch_ken_embeddings
fetch_videogame_sales
get_data_dir
make_deduplication_data
8 changes: 4 additions & 4 deletions examples/03_datetime_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@

import pandas as pd

data = pd.read_csv(
"https://raw.githubusercontent.com/skrub-data/datasets/master"
"/data/bike-sharing-dataset.csv"
)
from skrub import datasets

data = datasets.fetch_bike_sharing().bike_sharing

# Extract our input data (X) and the target column (y)
y = data["cnt"]
X = data[["date", "holiday", "temp", "hum", "windspeed", "weathersit"]]
Expand Down
25 changes: 8 additions & 17 deletions examples/04_fuzzy_joining.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,10 @@
# --------------------------------
#
# We import the happiness score table first:
import pandas as pd
from skrub import datasets

df = pd.read_csv(
(
"https://raw.githubusercontent.com/skrub-data/datasets/"
"master/data/Happiness_report_2022.csv"
),
thousands=",",
)
df.drop(df.tail(1).index, inplace=True)
happiness_data = datasets.fetch_country_happiness()
df = happiness_data.happiness_report

###############################################################################
# Let's look at the table:
Expand All @@ -66,23 +60,20 @@
# complete our covariates (X table).
#
# Interesting tables can be found on `the World Bank open data platform
# <https://data.worldbank.org/>`_, for which we have a downloading
# function:
from skrub.datasets import fetch_world_bank_indicator

###############################################################################
# <https://data.worldbank.org/>`_, which are also available in the dataset
# We extract the table containing GDP per capita by country:
gdp_per_capita = fetch_world_bank_indicator(indicator_id="NY.GDP.PCAP.CD").X

gdp_per_capita = happiness_data.GDP_per_capita
gdp_per_capita.head(3)

###############################################################################
# Then another table, with life expectancy by country:
life_exp = fetch_world_bank_indicator("SP.DYN.LE00.IN").X
life_exp = happiness_data.life_expectancy
life_exp.head(3)

###############################################################################
# And a table with legal rights strength by country:
legal_rights = fetch_world_bank_indicator("IC.LGL.CRED.XQ").X
legal_rights = happiness_data.legal_rights_index
legal_rights.head(3)

###############################################################################
Expand Down
30 changes: 16 additions & 14 deletions examples/06_ken_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,16 @@
# Let's retrieve the dataset:
import pandas as pd

X = pd.read_csv(
"https://raw.githubusercontent.com/William2064888/vgsales.csv/main/vgsales.csv",
sep=";",
on_bad_lines="skip",
)
# Shuffle the data
X = X.sample(frac=1, random_state=11, ignore_index=True)
from skrub import datasets

data = datasets.fetch_videogame_sales()

X = data.X
X.head(3)

###############################################################################
# Our goal will be to predict the sales amount (y, our target column):
y = X["Global_Sales"]
y = data.y
y


Expand All @@ -81,9 +79,9 @@
# Before moving further, let's carry out some basic preprocessing:

# Get a mask of the rows with missing values in "Publisher" and "Global_Sales"
mask = X.isna()["Publisher"] | X.isna()["Global_Sales"]
mask = X["Publisher"].isna() | y.isna()
# And remove them
X.dropna(subset=["Publisher", "Global_Sales"], inplace=True)
X = X[~mask]
y = y[~mask]

###############################################################################
Expand Down Expand Up @@ -202,14 +200,18 @@

###############################################################################
# The |Pipeline| can now be readily applied to the dataframe for prediction:
from sklearn.model_selection import cross_validate
from sklearn.model_selection import KFold, cross_validate

# We will save the results in a dictionary:
all_r2_scores = dict()
all_rmse_scores = dict()

# The dataset is ordered by rank (most sales first so we need to shuffle before
# splitting into cross-validation folds)
cv = KFold(shuffle=True, random_state=0)

cv_results = cross_validate(
pipeline, X_full, y, scoring=["r2", "neg_root_mean_squared_error"]
pipeline, X_full, y, scoring=["r2", "neg_root_mean_squared_error"], cv=cv
)

all_r2_scores["Base features"] = cv_results["test_r2"]
Expand Down Expand Up @@ -243,7 +245,7 @@
###############################################################################
# Let's look at the results:
cv_results = cross_validate(
pipeline2, X_full, y, scoring=["r2", "neg_root_mean_squared_error"]
pipeline2, X_full, y, scoring=["r2", "neg_root_mean_squared_error"], cv=cv
)

all_r2_scores["KEN features"] = cv_results["test_r2"]
Expand Down Expand Up @@ -287,7 +289,7 @@
###############################################################################
# Let's look at the results:
cv_results = cross_validate(
pipeline3, X_full, y, scoring=["r2", "neg_root_mean_squared_error"]
pipeline3, X_full, y, scoring=["r2", "neg_root_mean_squared_error"], cv=cv
)

all_r2_scores["Base + KEN features"] = cv_results["test_r2"]
Expand Down
11 changes: 6 additions & 5 deletions examples/07_multiple_key_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@

import pandas as pd

from skrub.datasets import fetch_figshare
from skrub.datasets import fetch_flight_delays

dataset = fetch_flight_delays()
seed = 1
flights = fetch_figshare("41771418").X
flights = dataset.flights

# Sampling for faster computation.
flights = flights.sample(5_000, random_state=seed, ignore_index=True)
Expand All @@ -75,7 +76,7 @@
# - The ``airports`` dataset, with information such as their name
# and location (longitude, latitude).

airports = fetch_figshare("41710257").X
airports = dataset.airports
airports.head()

########################################################################
Expand All @@ -85,7 +86,7 @@
# Both tables are from the Global Historical Climatology Network.
# Here, we consider only weather measurements from 2008.

weather = fetch_figshare("41771457").X
weather = dataset.weather
# Sampling for faster computation.
weather = weather.sample(10_000, random_state=seed, ignore_index=True)
weather.head()
Expand All @@ -94,7 +95,7 @@
# - The ``stations`` dataset. Provides location of all the weather
# measurement stations in the US.

stations = fetch_figshare("41710524").X
stations = dataset.stations
stations.head()

###############################################################################
Expand Down
16 changes: 11 additions & 5 deletions examples/09_interpolation_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,18 @@
# stations’ latitude and longitude. We subsample these large tables for the example to
# run faster.

from skrub.datasets import fetch_figshare
import pandas as pd

weather = fetch_figshare("41771457").X
from skrub.datasets import fetch_flight_delays

dataset = fetch_flight_delays()
weather = dataset.weather
weather = weather.sample(100_000, random_state=0, ignore_index=True)
stations = fetch_figshare("41710524").X
stations = dataset.stations
weather = stations.merge(weather, on="ID")[
["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY", "TMAX", "PRCP", "SNOW"]
]
weather["YEAR/MONTH/DAY"] = pd.to_datetime(weather["YEAR/MONTH/DAY"])

######################################################################
# The ``'TMAX'`` is in tenths of degree Celsius -- a ``'TMAX'`` of 297 means the maximum
Expand Down Expand Up @@ -124,9 +128,11 @@
# ``'Origin'`` which refers to the departure airport’s IATA code. We use only a subset
# to speed up the example.

flights = fetch_figshare("41771418").X[["Year_Month_DayofMonth", "Origin", "ArrDelay"]]
flights = dataset.flights
flights["Year_Month_DayofMonth"] = pd.to_datetime(flights["Year_Month_DayofMonth"])
flights = flights[["Year_Month_DayofMonth", "Origin", "ArrDelay"]]
flights = flights.sample(20_000, random_state=0, ignore_index=True)
airports = fetch_figshare("41710257").X[["iata", "airport", "state", "lat", "long"]]
airports = dataset.airports[["iata", "airport", "state", "lat", "long"]]
flights = flights.merge(airports, left_on="Origin", right_on="iata")
# printing the first row is more readable than the head() when we have many columns
flights.iloc[0]
Expand Down
Loading

0 comments on commit 0eb71e7

Please sign in to comment.