alessandro trinca tornidor
commited on
Commit
·
eec88db
1
Parent(s):
7ad428c
feat: adding explicit gpu init in get_model()
Browse files
lisa_on_cuda/utils/app_helpers.py
CHANGED
@@ -169,6 +169,12 @@ def load_model_for_causal_llm_pretrained(
|
|
169 |
return _model
|
170 |
|
171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
def get_model(args_to_parse, internal_logger: logging = None, inference_decorator: Callable = None, device_map="auto", device="cpu", device2="cuda"):
|
173 |
"""Load model and inference function with arguments. Compatible with ZeroGPU (spaces 0.30.2)
|
174 |
|
@@ -186,6 +192,10 @@ def get_model(args_to_parse, internal_logger: logging = None, inference_decorato
|
|
186 |
if internal_logger is None:
|
187 |
internal_logger = app_logger
|
188 |
internal_logger.info(f"starting model preparation, folder creation for path: {args_to_parse.vis_save_path}.")
|
|
|
|
|
|
|
|
|
189 |
try:
|
190 |
vis_save_path_exists = os.path.isdir(args_to_parse.vis_save_path)
|
191 |
logging.info(f"vis_save_path_exists:{vis_save_path_exists}.")
|
|
|
169 |
return _model
|
170 |
|
171 |
|
172 |
+
def gpu_init_zero(internal_logger: logging = None):
|
173 |
+
if internal_logger is None:
|
174 |
+
internal_logger = app_logger
|
175 |
+
internal_logger.info("GPU init...")
|
176 |
+
|
177 |
+
|
178 |
def get_model(args_to_parse, internal_logger: logging = None, inference_decorator: Callable = None, device_map="auto", device="cpu", device2="cuda"):
|
179 |
"""Load model and inference function with arguments. Compatible with ZeroGPU (spaces 0.30.2)
|
180 |
|
|
|
192 |
if internal_logger is None:
|
193 |
internal_logger = app_logger
|
194 |
internal_logger.info(f"starting model preparation, folder creation for path: {args_to_parse.vis_save_path}.")
|
195 |
+
if inference_decorator:
|
196 |
+
internal_logger.info(f"try explicit gpu init with decorator {inference_decorator.__name__}...")
|
197 |
+
inference_decorator(gpu_init_zero(internal_logger=internal_logger))
|
198 |
+
internal_logger.info(f"gpu explicitly initialized!")
|
199 |
try:
|
200 |
vis_save_path_exists = os.path.isdir(args_to_parse.vis_save_path)
|
201 |
logging.info(f"vis_save_path_exists:{vis_save_path_exists}.")
|