alessandro trinca tornidor commited on
Commit
eec88db
·
1 Parent(s): 7ad428c

feat: adding explicit gpu init in get_model()

Browse files
Files changed (1) hide show
  1. lisa_on_cuda/utils/app_helpers.py +10 -0
lisa_on_cuda/utils/app_helpers.py CHANGED
@@ -169,6 +169,12 @@ def load_model_for_causal_llm_pretrained(
169
  return _model
170
 
171
 
 
 
 
 
 
 
172
  def get_model(args_to_parse, internal_logger: logging = None, inference_decorator: Callable = None, device_map="auto", device="cpu", device2="cuda"):
173
  """Load model and inference function with arguments. Compatible with ZeroGPU (spaces 0.30.2)
174
 
@@ -186,6 +192,10 @@ def get_model(args_to_parse, internal_logger: logging = None, inference_decorato
186
  if internal_logger is None:
187
  internal_logger = app_logger
188
  internal_logger.info(f"starting model preparation, folder creation for path: {args_to_parse.vis_save_path}.")
 
 
 
 
189
  try:
190
  vis_save_path_exists = os.path.isdir(args_to_parse.vis_save_path)
191
  logging.info(f"vis_save_path_exists:{vis_save_path_exists}.")
 
169
  return _model
170
 
171
 
172
+ def gpu_init_zero(internal_logger: logging = None):
173
+ if internal_logger is None:
174
+ internal_logger = app_logger
175
+ internal_logger.info("GPU init...")
176
+
177
+
178
  def get_model(args_to_parse, internal_logger: logging = None, inference_decorator: Callable = None, device_map="auto", device="cpu", device2="cuda"):
179
  """Load model and inference function with arguments. Compatible with ZeroGPU (spaces 0.30.2)
180
 
 
192
  if internal_logger is None:
193
  internal_logger = app_logger
194
  internal_logger.info(f"starting model preparation, folder creation for path: {args_to_parse.vis_save_path}.")
195
+ if inference_decorator:
196
+ internal_logger.info(f"try explicit gpu init with decorator {inference_decorator.__name__}...")
197
+ inference_decorator(gpu_init_zero(internal_logger=internal_logger))
198
+ internal_logger.info(f"gpu explicitly initialized!")
199
  try:
200
  vis_save_path_exists = os.path.isdir(args_to_parse.vis_save_path)
201
  logging.info(f"vis_save_path_exists:{vis_save_path_exists}.")