drhead commited on
Commit
0cdffb9
1 Parent(s): accdcf1

clean imports, strip underscores from tags

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -1,21 +1,16 @@
1
  import json
2
 
3
- from PIL import Image
4
  import gradio as gr
 
 
 
 
 
5
  import torch
6
  from torchvision.transforms import transforms
7
  from torchvision.transforms import InterpolationMode
8
  import torchvision.transforms.functional as TF
9
 
10
- import spaces
11
-
12
- import huggingface_hub
13
- import timm
14
- from timm.models import VisionTransformer
15
- import safetensors.torch
16
-
17
-
18
- torch.jit.script = lambda f: f
19
  torch.set_grad_enabled(False)
20
 
21
  class Fit(torch.nn.Module):
@@ -130,6 +125,9 @@ with open("tagger_tags.json", "r") as file:
130
  tags = json.load(file) # type: dict
131
  allowed_tags = list(tags.keys())
132
 
 
 
 
133
  @spaces.GPU(duration=5)
134
  def create_tags(image, threshold):
135
  img = image.convert('RGB')
 
1
  import json
2
 
 
3
  import gradio as gr
4
+ from PIL import Image
5
+ import safetensors.torch
6
+ import spaces
7
+ import timm
8
+ from timm.models import VisionTransformer
9
  import torch
10
  from torchvision.transforms import transforms
11
  from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
13
 
 
 
 
 
 
 
 
 
 
14
  torch.set_grad_enabled(False)
15
 
16
  class Fit(torch.nn.Module):
 
125
  tags = json.load(file) # type: dict
126
  allowed_tags = list(tags.keys())
127
 
128
+ for idx, tag in enumerate(allowed_tags):
129
+ allowed_tags[idx] = tag.replace("_", " ")
130
+
131
  @spaces.GPU(duration=5)
132
  def create_tags(image, threshold):
133
  img = image.convert('RGB')