sky24h's picture
gradio demo for ZeroGPU, HF
a9d25c7
# Avoid multiple imports of the same module. Use this to import the module only once.
# Also, ensure that the device and pretrained models folder are consistent across the project.
import os
import torch
global low_vram_mode
low_vram_mode = False
def use_lower_vram():
global low_vram_mode
low_vram_mode = True
def get_device():
device = torch.device("cuda") # must use GPU in online demo version
return device
def set_random_seed(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_pretrained_models_folder():
return os.path.join(os.path.dirname(__file__), "../pretrained-models")
# def download_pretrained_models():
# pretrained_models_folder = get_pretrained_models_folder()
# # hard-coded download links
# groundingdino_link = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth"
# sam_link = "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth"
# ram_link = "https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"
# groundingdino_ckpt = os.path.join(pretrained_models_folder, "checkpoints/groundingdino_swint_ogc.pth")
# sam_ckpt = os.path.join(pretrained_models_folder, "checkpoints/sam_hq_vit_l.pth")
# ram_ckpt = os.path.join(pretrained_models_folder, "checkpoints/ram_plus_swin_large_14m.pth")
# # download pretrained models if not exists
# if not os.path.exists(groundingdino_ckpt):
# print(f"Downloading pretrained model: {groundingdino_ckpt}")
# os.system(f"wget -O {groundingdino_ckpt} {groundingdino_link} -q")
# if not os.path.exists(sam_ckpt):
# print(f"Downloading pretrained model: {sam_ckpt}")
# os.system(f"wget -O {sam_ckpt} {sam_link} -q")
# if not os.path.exists(ram_ckpt):
# print(f"Downloading pretrained model: {ram_ckpt}")
# os.system(f"wget -O {ram_ckpt} {ram_link} -q")
# # download pretrained models when imported
# download_pretrained_models()