# Avoid multiple imports of the same module. Use this to import the module only once. | |
# Also, ensure that the device and pretrained models folder are consistent across the project. | |
import os | |
import torch | |
global low_vram_mode | |
low_vram_mode = False | |
def use_lower_vram(): | |
global low_vram_mode | |
low_vram_mode = True | |
def get_device(): | |
device = torch.device("cuda") # must use GPU in online demo version | |
return device | |
def set_random_seed(seed: int): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def get_pretrained_models_folder(): | |
return os.path.join(os.path.dirname(__file__), "../pretrained-models") | |
# def download_pretrained_models(): | |
# pretrained_models_folder = get_pretrained_models_folder() | |
# # hard-coded download links | |
# groundingdino_link = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth" | |
# sam_link = "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" | |
# ram_link = "https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth" | |
# groundingdino_ckpt = os.path.join(pretrained_models_folder, "checkpoints/groundingdino_swint_ogc.pth") | |
# sam_ckpt = os.path.join(pretrained_models_folder, "checkpoints/sam_hq_vit_l.pth") | |
# ram_ckpt = os.path.join(pretrained_models_folder, "checkpoints/ram_plus_swin_large_14m.pth") | |
# # download pretrained models if not exists | |
# if not os.path.exists(groundingdino_ckpt): | |
# print(f"Downloading pretrained model: {groundingdino_ckpt}") | |
# os.system(f"wget -O {groundingdino_ckpt} {groundingdino_link} -q") | |
# if not os.path.exists(sam_ckpt): | |
# print(f"Downloading pretrained model: {sam_ckpt}") | |
# os.system(f"wget -O {sam_ckpt} {sam_link} -q") | |
# if not os.path.exists(ram_ckpt): | |
# print(f"Downloading pretrained model: {ram_ckpt}") | |
# os.system(f"wget -O {ram_ckpt} {ram_link} -q") | |
# # download pretrained models when imported | |
# download_pretrained_models() | |