Skip to content

Commit

Permalink
Search dataset through tags until valid url
Browse files Browse the repository at this point in the history
  • Loading branch information
jmorgadov committed Jul 22, 2024
1 parent 5187b8a commit 37ca9b5
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
1 change: 1 addition & 0 deletions .github/workflows/set-version.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
version:
type: string
required: true
description: Version to set

jobs:
upgrade:
Expand Down
33 changes: 19 additions & 14 deletions pactus/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Callable, List, Tuple, Union

import numpy as np
import requests
from git.cmd import Git
from sklearn.model_selection import train_test_split
from yupi import Trajectory
Expand Down Expand Up @@ -85,7 +86,7 @@ def take(
) -> Data:
"""
Takes a subset of the dataset.
Parameters
----------
size : Union[float, int]
Expand All @@ -101,7 +102,7 @@ def take(
by default True.
random_state : Union[int, None], optional
Random state to be used, by default None.
Returns
-------
Data
Expand All @@ -128,7 +129,7 @@ def cut(self, size: Union[float, int]) -> Tuple[Data, Data]:
as the proportion of the dataset to be taken. If int, it should be
between 0 and the dataset size and it will be interpreted as the
number of trajectories to be taken.
Returns
-------
Tuple[Data, Data]
Expand Down Expand Up @@ -264,8 +265,6 @@ class Dataset(Data):
Dataset version.
"""

_last_tag: Union[str, None] = None

def __init__(
self,
name: str,
Expand Down Expand Up @@ -296,14 +295,15 @@ def _from_json(name: str, data: dict) -> Dataset:

@staticmethod
def _get_dataset_url(name: str) -> str:
tag = Dataset._last_tag
if tag is None:
g_cmd = Git()
output = g_cmd.ls_remote(REPO_URL, sort="-v:refname", tags=True)
tag = output.split("\n")[0].split("/")[-1]
Dataset._last_tag = tag
assert tag is not None, "Could not find the last tag"
return f"{REPO_URL}/releases/download/{tag}/{name}.zip"
g_cmd = Git()
output = g_cmd.ls_remote(REPO_URL, sort="-v:refname", tags=True)
tags = output.split("\n")[0].split("/")[-1]
tags = [ref.split("/")[-1] for ref in output.split("\n")]
for tag in tags:
url = f"{REPO_URL}/releases/download/{tag}/{name}.zip"
if requests.head(url).status_code == 302:
return url
assert False, "Could not find the given dataset"

@staticmethod
def _from_url(name: str, force: bool = False) -> Dataset:
Expand Down Expand Up @@ -374,11 +374,16 @@ def uci_characters(redownload: bool = False) -> Dataset:
"""Loads the uci_characters dataset"""
return Dataset.get("uci_characters", redownload=redownload)

@staticmethod
def traffic(redownload: bool = False) -> Dataset:
"""Loads the traffic dataset"""
return Dataset.get("traffic", redownload=redownload)

@staticmethod
def diffusive_particles(redownload: bool = False) -> Dataset:
"""Loads the diffusive particles dataset"""
return Dataset.get("diffusive_particles", redownload=redownload)

@staticmethod
def get(dataset_name: str, redownload: bool = False) -> Dataset:
"""Loads a dataset from the trajectory-dataset repository"""
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ dependencies = [
"tensorflow >= 2.12.0",
"scikit-learn >= 1.1.1",
"xgboost >= 1.7.4",
"GitPython >= 3.1.29"
"GitPython >= 3.1.29",
"requests >= 2.32.3"
]
requires-python = ">=3.8"

Expand Down

0 comments on commit 37ca9b5

Please sign in to comment.