Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kamo-naoyuki committed Aug 29, 2020
1 parent 2749fec commit 0397328
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
22 changes: 14 additions & 8 deletions espnet_model_zoo/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,14 @@ def get_url(self, name: str = None, version: int = -1, **kwargs: str) -> str:
url = list(urls)[version]
return url

def _get_file_name(self, url):
@staticmethod
def _get_file_name(url):
ma = re.match(r"https://.*/([^/]*)\?download=[0-9]*$", url)
if ma is not None:
# URL e.g.
# https://sandbox.zenodo.org/record/646767/files/asr_train_raw_bpe_valid.acc.best.zip?download=1
return ma.groups()[0]
a = ma.groups()[0]
return a
else:
# If not Zenodo
r = requests.head(url)
Expand Down Expand Up @@ -252,12 +254,16 @@ def download_and_unpack(
if name is not None and is_url(name):
# Specify the downloading link directly. "kwargs" are ignored in this case.
url = name
elif name is not None and Path(name).exists() and len(kwargs) == 0:
return self.unpack_local_file(name)
else:
url = self.get_url(name=name, version=version, **kwargs)

# If the registered url is a file path
try:
url = self.get_url(name=name, version=version, **kwargs)
except RuntimeError:
if name is not None and Path(name).exists() and len(kwargs) == 0:
return self.unpack_local_file(name)
else:
raise

# If the registered url in table.csv is a file path
if not is_url(url) and Path(url).exists():
return self.unpack_local_file(url)

Expand Down Expand Up @@ -317,7 +323,7 @@ def cmd_query(cmd=None):

parser.add_argument(
"condition",
action="append",
nargs="*",
default=[],
help="Given desired condition in form of <key>=<value>. "
"e.g. fs=16000. "
Expand Down
6 changes: 6 additions & 0 deletions test/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def test_download_and_unpack_non_matching():
d.download_and_unpack(task="dummy")


def test_download_and_unpack_local_file():
d = ModelDownloader()
path = d.download("test")
d.download_and_unpack(path)


def test_cmd_download():
cmd_download(["test"])

Expand Down

0 comments on commit 0397328

Please sign in to comment.