Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -96,26 +96,35 @@ class S3DirectStream:
|
|
96 |
config = AutoConfig.from_pretrained(model_prefix, **config_dict)
|
97 |
logger.info(f"Config loaded for {model_prefix}.")
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
state_dict.update(shard_state)
|
116 |
-
except Exception as e:
|
117 |
-
logger.exception(f"Error loading model file {model_path}: {e}")
|
118 |
-
raise
|
119 |
|
120 |
model = AutoModelForCausalLM.from_config(config)
|
121 |
model.load_state_dict(state_dict)
|
|
|
96 |
config = AutoConfig.from_pretrained(model_prefix, **config_dict)
|
97 |
logger.info(f"Config loaded for {model_prefix}.")
|
98 |
|
99 |
+
try:
|
100 |
+
api = HfApi()
|
101 |
+
model_files = api.list_repo_files(model_prefix)
|
102 |
+
state_dict = {}
|
103 |
+
for file_info in model_files:
|
104 |
+
if file_info.rfilename.endswith(('.bin', '.safetensors')):
|
105 |
+
file_url = api.download_file(model_prefix, file_info.rfilename)
|
106 |
+
model_path = os.path.join(model_prefix, file_info.rfilename)
|
107 |
+
logger.info(f"Downloading model file from {file_url} to {model_path} ...")
|
108 |
+
with requests.get(file_url, stream=True) as response:
|
109 |
+
if response.status_code == 200:
|
110 |
+
try:
|
111 |
+
model_stream = response.raw
|
112 |
+
if model_path.endswith(".safetensors"):
|
113 |
+
shard_state = safetensors.torch.load_stream(model_stream)
|
114 |
+
elif model_path.endswith(".bin"):
|
115 |
+
shard_state = torch.load(model_stream, map_location="cpu")
|
116 |
+
state_dict.update(shard_state)
|
117 |
+
logger.info(f"Downloaded and loaded model file {model_path}")
|
118 |
+
except Exception as e:
|
119 |
+
logger.exception(f"Error loading model file {model_path}: {e}")
|
120 |
+
raise
|
121 |
+
else:
|
122 |
+
logger.error(f"Error downloading {file_url} with status code: {response.status_code}")
|
123 |
+
raise HTTPException(status_code=500, detail=f"Error downloading model file from Hugging Face")
|
124 |
+
except Exception as e:
|
125 |
+
logger.exception(f"Error loading model files for {model_prefix} : {e}")
|
126 |
+
raise
|
127 |
|
|
|
|
|
|
|
|
|
128 |
|
129 |
model = AutoModelForCausalLM.from_config(config)
|
130 |
model.load_state_dict(state_dict)
|