Hjgugugjhuhjggg commited on
Commit
ec488c7
·
verified ·
1 Parent(s): 540ad6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -19
app.py CHANGED
@@ -96,26 +96,35 @@ class S3DirectStream:
96
  config = AutoConfig.from_pretrained(model_prefix, **config_dict)
97
  logger.info(f"Config loaded for {model_prefix}.")
98
 
99
- model_files = self._get_model_files(model_prefix) #This line is not used anymore
100
-
101
- state_dict = {}
102
- for model_file in model_files: #This loop is not used anymore
103
- model_path = os.path.join(model_prefix, model_file)
104
- logger.info(f"Loading model file: {model_path}")
105
- model_stream = self.stream_from_s3(model_path)
106
- try:
107
- if model_path.endswith(".safetensors"):
108
- shard_state = safetensors.torch.load_stream(model_stream)
109
- elif model_path.endswith(".bin"):
110
- shard_state = torch.load(model_stream, map_location="cpu")
111
- else:
112
- logger.error(f"Unsupported model file type: {model_path}")
113
- raise ValueError(f"Unsupported model file type: {model_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- state_dict.update(shard_state)
116
- except Exception as e:
117
- logger.exception(f"Error loading model file {model_path}: {e}")
118
- raise
119
 
120
  model = AutoModelForCausalLM.from_config(config)
121
  model.load_state_dict(state_dict)
 
96
  config = AutoConfig.from_pretrained(model_prefix, **config_dict)
97
  logger.info(f"Config loaded for {model_prefix}.")
98
 
99
+ try:
100
+ api = HfApi()
101
+ model_files = api.list_repo_files(model_prefix)
102
+ state_dict = {}
103
+ for file_info in model_files:
104
+ if file_info.rfilename.endswith(('.bin', '.safetensors')):
105
+ file_url = api.download_file(model_prefix, file_info.rfilename)
106
+ model_path = os.path.join(model_prefix, file_info.rfilename)
107
+ logger.info(f"Downloading model file from {file_url} to {model_path} ...")
108
+ with requests.get(file_url, stream=True) as response:
109
+ if response.status_code == 200:
110
+ try:
111
+ model_stream = response.raw
112
+ if model_path.endswith(".safetensors"):
113
+ shard_state = safetensors.torch.load_stream(model_stream)
114
+ elif model_path.endswith(".bin"):
115
+ shard_state = torch.load(model_stream, map_location="cpu")
116
+ state_dict.update(shard_state)
117
+ logger.info(f"Downloaded and loaded model file {model_path}")
118
+ except Exception as e:
119
+ logger.exception(f"Error loading model file {model_path}: {e}")
120
+ raise
121
+ else:
122
+ logger.error(f"Error downloading {file_url} with status code: {response.status_code}")
123
+ raise HTTPException(status_code=500, detail=f"Error downloading model file from Hugging Face")
124
+ except Exception as e:
125
+ logger.exception(f"Error loading model files for {model_prefix} : {e}")
126
+ raise
127
 
 
 
 
 
128
 
129
  model = AutoModelForCausalLM.from_config(config)
130
  model.load_state_dict(state_dict)