File size: 4,111 Bytes
0a4e529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from huggingface_hub import HfFolder, hf_hub_url
import os
import requests
import tqdm
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RequestException, Timeout
from tqdm.contrib.concurrent import thread_map
from pathlib import Path
import time

# Save your token in environment variables for better security
HfFolder.save_token(os.getenv('HUGGING_FACE_TOKEN'))  # Set the token in your environment variables

# Define the repository to download from
repo_id = "google/gemma-7b"
repo_type = "model"

# Local path where you want to save the downloaded files
local_folder_path = "./models"

# Variable to specify the file or directory to download
download_target = "model-00001-of-00004.safetensors"  # Change this to the desired file or directory name

print(f"Downloading {download_target} from {repo_id} to {local_folder_path}...")

# Create the local directory if it doesn't exist
os.makedirs(local_folder_path, exist_ok=True)

# Print the URL for debugging
print(f"URL: {hf_hub_url(repo_id, download_target, repo_type=repo_type)}")

def get_session(max_retries=5):
    session = requests.Session()
    if max_retries:
        session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
        session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
    return session

def get_single_file(url, output_folder, start_from_scratch=False, max_retries=7):
    filename = Path(url.rsplit('/', 1)[1])
    output_path = output_folder / filename
    attempt = 0
    while attempt < max_retries:
        attempt += 1
        session = get_session()
        headers = {"Authorization": f"Bearer {HfFolder.get_token()}"}
        mode = 'wb'
        if output_path.exists() and not start_from_scratch:
            # Resume download
            r = session.get(url, headers=headers, stream=True, timeout=20)
            total_size = int(r.headers.get('content-length', 0))
            if output_path.stat().st_size >= total_size:
                return
            headers['Range'] = f'bytes={output_path.stat().st_size}-'
            mode = 'ab'
        try:
            with session.get(url, headers=headers, stream=True, timeout=30) as r:
                r.raise_for_status()
                total_size = int(r.headers.get('content-length', 0))
                block_size = 1024 * 1024  # 1MB
                tqdm_kwargs = {'total': total_size, 'unit': 'iB', 'unit_scale': True, 'bar_format': '{l_bar}{bar}|{n_fmt}/{total_fmt}{rate_fmt}'}
                with open(output_path, mode) as f:
                    with tqdm.tqdm(**tqdm_kwargs) as t:
                        for data in r.iter_content(block_size):
                            f.write(data)
                            t.update(len(data))
                break  # Exit loop if successful
        except (RequestException, ConnectionError, Timeout) as e:
            print(f"Error downloading {filename}: {e}.")
            print(f"That was attempt {attempt}/{max_retries}.", end='')
            if attempt < max_retries:
                print(f"Retry begins in {2**attempt} seconds.")
                time.sleep(2**attempt)
            else:
                print("Failed to download after the maximum number of attempts.")

def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=4):
    thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)

def download_model_files(model, branch, links, output_folder, start_from_scratch=False, threads=4):
    output_folder = Path(output_folder)
    output_folder.mkdir(parents=True, exist_ok=True)
    print(f"Downloading the model to {output_folder}")
    start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)

# Download the specified file or directory
session = get_session()
links = [hf_hub_url(repo_id, download_target, repo_type=repo_type)]

branch = "main"

download_model_files(repo_id, branch, links, local_folder_path)

print("Download complete!")