File size: 738 Bytes
15a6715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52ffbe2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import gradio as gr
import timm
import torch

nsfw_tf = pipeline("image-classification",
                     model=AutoModelForImageClassification.from_pretrained(
                         "carbon225/vit-base-patch16-224-hentai"),
                     feature_extractor=AutoFeatureExtractor.from_pretrained(
                         "carbon225/vit-base-patch16-224-hentai"))

nsfw_tm = timm.create_model('deepghs/anime_rating', pretrained=True).eval()
tm_config = timm.data.resolve_model_data_config(model)
tm_trans = timm.data.create_transform(**tm_config, is_training=False)

def launch(img):
    tm_output = nsfw_tm(transforms(img).unsqueeze(0))