Spaces:
Sleeping
Sleeping
""" | |
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/download.py | |
""" | |
import os | |
import hashlib | |
import requests | |
from tqdm import tqdm | |
def check_sha1(filename, sha1_hash): | |
"""Check whether the sha1 hash of the file content matches the expected hash. | |
Parameters | |
---------- | |
filename : str | |
Path to the file. | |
sha1_hash : str | |
Expected sha1 hash in hexadecimal digits. | |
Returns | |
------- | |
bool | |
Whether the file content matches the expected hash. | |
""" | |
sha1 = hashlib.sha1() | |
with open(filename, 'rb') as f: | |
while True: | |
data = f.read(1048576) | |
if not data: | |
break | |
sha1.update(data) | |
sha1_file = sha1.hexdigest() | |
l = min(len(sha1_file), len(sha1_hash)) | |
return sha1.hexdigest()[0:l] == sha1_hash[0:l] | |
def download_file(url, path=None, overwrite=False, sha1_hash=None): | |
"""Download an given URL | |
Parameters | |
---------- | |
url : str | |
URL to download | |
path : str, optional | |
Destination path to store downloaded file. By default stores to the | |
current directory with same name as in url. | |
overwrite : bool, optional | |
Whether to overwrite destination file if already exists. | |
sha1_hash : str, optional | |
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified | |
but doesn't match. | |
Returns | |
------- | |
str | |
The file path of the downloaded file. | |
""" | |
if path is None: | |
fname = url.split('/')[-1] | |
else: | |
path = os.path.expanduser(path) | |
if os.path.isdir(path): | |
fname = os.path.join(path, url.split('/')[-1]) | |
else: | |
fname = path | |
if overwrite or not os.path.exists(fname) or ( | |
sha1_hash and not check_sha1(fname, sha1_hash)): | |
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) | |
if not os.path.exists(dirname): | |
os.makedirs(dirname) | |
print('Downloading %s from %s...' % (fname, url)) | |
r = requests.get(url, stream=True) | |
if r.status_code != 200: | |
raise RuntimeError("Failed downloading url %s" % url) | |
total_length = r.headers.get('content-length') | |
with open(fname, 'wb') as f: | |
if total_length is None: # no content length header | |
for chunk in r.iter_content(chunk_size=1024): | |
if chunk: # filter out keep-alive new chunks | |
f.write(chunk) | |
else: | |
total_length = int(total_length) | |
for chunk in tqdm(r.iter_content(chunk_size=1024), | |
total=int(total_length / 1024. + 0.5), | |
unit='KB', | |
unit_scale=False, | |
dynamic_ncols=True): | |
f.write(chunk) | |
if sha1_hash and not check_sha1(fname, sha1_hash): | |
raise UserWarning('File {} is downloaded but the content hash does not match. ' \ | |
'The repo may be outdated or download may be incomplete. ' \ | |
'If the "repo_url" is overridden, consider switching to ' \ | |
'the default repo.'.format(fname)) | |
return fname | |