#%% import pandas as pd import numpy as np import os from tqdm import tqdm import timm import torchvision.transforms as T # from albumentations.pytorch import ToTensorV2 from PIL import Image import torch import torch.nn as nn import json # from transformers import AutoImageProcessor # from create_model import HieraForImageClassification #%% # %% SZ = 224 LABELS = json.load(open("./labels_class_map_rev.json")) ORIGINAL_LABELS = json.load(open("./original_mapping.json")) def is_gpu_available(): """Check if the python package `onnxruntime-gpu` is installed.""" return torch.cuda.is_available() # VALID_AUG = A.Compose([ # A.SmallestMaxSize(max_size=SZ + 16, p=1.0), # A.CenterCrop(height=SZ, width=SZ, p=1.0), # A.Normalize(), # ToTensorV2(), # ]) def get_corn_model(model_name, pretrained=True, **kwargs): model = timm.create_model(model_name, pretrained=pretrained, **kwargs) model = nn.Sequential( model, nn.Dropout(0.15), nn.Linear(model.num_classes, model.num_classes * 2) , nn.Linear(model.num_classes * 2, len(LABELS)) ) return model class PytorchWorker: def __init__(self): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def _load_model(): print("Setting up Pytorch Model") print(f"Using devide: {self.device}") model = get_corn_model("vit_base_patch16_224", pretrained=False) model_ckpt = torch.load("./NB_EXP_V2_008/vit_base_patch16_224_224_bs32_ep16_lr6e05_wd0.05_mixup_cutmix_CV_0.pth", map_location=self.device) model.load_state_dict(model_ckpt) return model.to(self.device) self.transforms = T.Compose([T.Resize((SZ, SZ)), T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) self.model = _load_model() def predict_image(self, image: np.ndarray) -> list(): """Run inference using ONNX runtime. :param image: Input image as numpy array. :return: A list with logits and confidences. """ image_data = self.transforms(image).unsqueeze(0).to(self.device) outputs = self.model(image_data) logits = outputs return logits.tolist() def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"): """Make submission with given """ model = PytorchWorker() predictions = [] for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)): image_path = os.path.join(images_root_path, row.filename) # image_path = row.filename image = Image.open(image_path).convert("RGB") output = model.predict_image(image) string_label_dup = LABELS.get(str(np.argmax(output)), 'Acanthophis antarcticus') prediction_class = ORIGINAL_LABELS.get(string_label_dup, 1) predictions.append(prediction_class) print(predictions) test_metadata["class_id"] = predictions user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first") user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None) #%% if __name__ == "__main__": import zipfile with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref: zip_ref.extractall("/tmp/data") MODEL_PATH = "pytorch_model.bin" MODEL_NAME = "swinv2_tiny_window16_256.ms_in1k" metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv" test_metadata = pd.read_csv(metadata_file_path) # test_metadata = pd.DataFrame() # test_metadata['filename'] = ['../sample.png', '../sample copy.png', '../sample copy 2.png'] # test_metadata['observation_id'] = [1, 2, 3] make_submission( test_metadata=test_metadata, model_path=MODEL_PATH, model_name=MODEL_NAME ) # #%% # import requests # image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) # # %% # image = VALID_AUG(image=np.array(image))['image'] # # %% # model= PytorchWorker() # # %% # output = model.predict_image(image.unsqueeze(dim =0 )) # # %% # output # # %% # import numpy as np # np.argmax(output) # %% # df = pd.DataFrame() # df["filename"] = ['sample.png'] # # %% # make_submission( # test_metadata=df, # model_path="MODEL_PATH", # model_name="MODEL_NAME" # ) # %% # %%