ZhangYuhan commited on
Commit
f88246b
โ€ข
1 Parent(s): d0e65e5

update serve

Browse files
constants.py CHANGED
@@ -18,6 +18,9 @@ APPEND_JSON = "append_json"
18
  SAVE_IMAGE = "save_image"
19
  SAVE_LOG = "save_log"
20
 
 
 
 
21
  NUM_SIDES = 2
22
  EVALUATE_DIMS = 5
23
  PROMPT_NUM = 510
 
18
  SAVE_IMAGE = "save_image"
19
  SAVE_LOG = "save_log"
20
 
21
+ GIF_SERVER = os.getenv("GIF_SERVER", f"{LOG_SERVER}/get_gif")
22
+ RGBA_SERVER = os.getenv("RGBA_SERVER", f"{LOG_SERVER}/get_rgba")
23
+
24
  NUM_SIDES = 2
25
  EVALUATE_DIMS = 5
26
  PROMPT_NUM = 510
model/model_worker.py CHANGED
@@ -9,7 +9,7 @@ import subprocess
9
  import requests
10
  from gradio_client import Client
11
  # from .client import Gau2Mesh_client
12
- from constants import REPLICATE_API_TOKEN, LOG_SERVER
13
  # os.environ("REPLICATE_API_TOKEN", "yourKey")
14
 
15
  class BaseModelWorker:
@@ -43,10 +43,10 @@ class BaseModelWorker:
43
  # else:
44
  # return None
45
  galley = "image2shape" if self.i2s_model else "text2shape"
46
- rgb_name = f"{galley}_{self.model_name}_{offline_idx}_rgb"
47
- normal_name = f"{galley}_{self.model_name}_{offline_idx}_normal"
48
- rgb_url = f"{LOG_SERVER}/get_{rgb_name}"
49
- normal_url = f"{LOG_SERVER}/get_{normal_name}"
50
  return {'rgb': rgb_url, 'normal': normal_url}
51
 
52
  def inference(self, prompt):
 
9
  import requests
10
  from gradio_client import Client
11
  # from .client import Gau2Mesh_client
12
+ from constants import REPLICATE_API_TOKEN, LOG_SERVER, GIF_SERVER
13
  # os.environ("REPLICATE_API_TOKEN", "yourKey")
14
 
15
  class BaseModelWorker:
 
43
  # else:
44
  # return None
45
  galley = "image2shape" if self.i2s_model else "text2shape"
46
+ rgb_name = f"{galley}_{self.model_name}_{offline_idx}_rgb.gif"
47
+ normal_name = f"{galley}_{self.model_name}_{offline_idx}_normal.gif"
48
+ rgb_url = f"{GIF_SERVER}/{rgb_name}"
49
+ normal_url = f"{GIF_SERVER}/{normal_name}"
50
  return {'rgb': rgb_url, 'normal': normal_url}
51
 
52
  def inference(self, prompt):
serve/gradio_web_t2s.py CHANGED
@@ -80,8 +80,8 @@ Find out who is the ๐Ÿฅ‡conditional image generation models! More models are goi
80
  value="๐Ÿ‘Ž Both are bad", visible=False, interactive=False
81
  )
82
 
83
- geo_md = gr.Markdown("Geometry Quality: ", visible=False, elem_id="evaldim_markdown")
84
- with gr.Row(elem_id="Geometry Quality"):
85
  geo_leftvote_btn = gr.Button(
86
  value="๐Ÿ‘ˆ A is better", visible=False, interactive=False
87
  )
@@ -164,8 +164,8 @@ Find out who is the ๐Ÿฅ‡conditional image generation models! More models are goi
164
  # value="๐Ÿ‘Ž Both are bad", visible=False, interactive=False
165
  # )
166
 
167
- # with gr.Row(elem_id="Geometry Quality"):
168
- # geo_md = gr.Markdown("Geometry Quality: ", visible=False, elem_id="evaldim_markdown")
169
  # geo_leftvote_btn = gr.Button(
170
  # value="๐Ÿ‘ˆ A is better", visible=False, interactive=False
171
  # )
@@ -434,8 +434,8 @@ Find out who is the ๐Ÿฅ‡conditional image generation models! More models are goi
434
  value="๐Ÿ‘Ž Both are bad", visible=False, interactive=False
435
  )
436
 
437
- geo_md = gr.Markdown("Geometry Quality: ", visible=False, elem_id="evaldim_markdown")
438
- with gr.Row(elem_id="Geometry Quality"):
439
  geo_leftvote_btn = gr.Button(
440
  value="๐Ÿ‘ˆ A is better", visible=False, interactive=False
441
  )
@@ -708,8 +708,8 @@ def build_t2s_ui_single_model(models):
708
  plausive_downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=False)
709
  plausive_flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=False)
710
 
711
- with gr.Row(elem_id="Geometry Quality"):
712
- geo_md = gr.Markdown("Geometry Quality: ", elem_id="evaldim_markdown")
713
  geo_upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=False)
714
  geo_downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=False)
715
  geo_flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=False)
 
80
  value="๐Ÿ‘Ž Both are bad", visible=False, interactive=False
81
  )
82
 
83
+ geo_md = gr.Markdown("Geometry Details: ", visible=False, elem_id="evaldim_markdown")
84
+ with gr.Row(elem_id="Geometry Details"):
85
  geo_leftvote_btn = gr.Button(
86
  value="๐Ÿ‘ˆ A is better", visible=False, interactive=False
87
  )
 
164
  # value="๐Ÿ‘Ž Both are bad", visible=False, interactive=False
165
  # )
166
 
167
+ # with gr.Row(elem_id="Geometry Details"):
168
+ # geo_md = gr.Markdown("Geometry Details: ", visible=False, elem_id="evaldim_markdown")
169
  # geo_leftvote_btn = gr.Button(
170
  # value="๐Ÿ‘ˆ A is better", visible=False, interactive=False
171
  # )
 
434
  value="๐Ÿ‘Ž Both are bad", visible=False, interactive=False
435
  )
436
 
437
+ geo_md = gr.Markdown("Geometry Details: ", visible=False, elem_id="evaldim_markdown")
438
+ with gr.Row(elem_id="Geometry Details"):
439
  geo_leftvote_btn = gr.Button(
440
  value="๐Ÿ‘ˆ A is better", visible=False, interactive=False
441
  )
 
708
  plausive_downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=False)
709
  plausive_flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=False)
710
 
711
+ with gr.Row(elem_id="Geometry Details"):
712
+ geo_md = gr.Markdown("Geometry Details: ", elem_id="evaldim_markdown")
713
  geo_upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=False)
714
  geo_downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=False)
715
  geo_flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=False)
serve/inference.py CHANGED
@@ -4,7 +4,7 @@ import time
4
 
5
  from .utils import *
6
  from .vote_utils import t2s_logger, t2s_multi_logger, i2s_logger, i2s_multi_logger
7
- from constants import IMAGE_DIR, LOG_SERVER, TEXT_PROMPT_PATH, PROMPT_NUM
8
 
9
  with open(TEXT_PROMPT_PATH, 'r') as f:
10
  prompt_list = json.load(f)
@@ -121,7 +121,7 @@ def sample_image(state, model_name):
121
  state = State(model_name)
122
 
123
  idx = random.randint(0, PROMPT_NUM-1)
124
- img_url = f"{LOG_SERVER}/get_rgba_{idx}"
125
 
126
  state.model_name = model_name
127
  state.image = img_url
@@ -137,7 +137,7 @@ def sample_image_side_by_side(state_0, state_1, model_name_0, model_name_1):
137
  state_1 = State(model_name_1)
138
 
139
  idx = random.randint(0, PROMPT_NUM-1)
140
- img_url = f"{LOG_SERVER}/get_rgba_{idx}"
141
 
142
  state_0.i2s_mode, state_1.i2s_mode = True, True
143
  state_0.offline, state_1.offline = True, True
 
4
 
5
  from .utils import *
6
  from .vote_utils import t2s_logger, t2s_multi_logger, i2s_logger, i2s_multi_logger
7
+ from constants import RGBA_SERVER, LOG_SERVER, TEXT_PROMPT_PATH, PROMPT_NUM
8
 
9
  with open(TEXT_PROMPT_PATH, 'r') as f:
10
  prompt_list = json.load(f)
 
121
  state = State(model_name)
122
 
123
  idx = random.randint(0, PROMPT_NUM-1)
124
+ img_url = f"{RGBA_SERVER}/{idx}.png"
125
 
126
  state.model_name = model_name
127
  state.image = img_url
 
137
  state_1 = State(model_name_1)
138
 
139
  idx = random.randint(0, PROMPT_NUM-1)
140
+ img_url = f"{RGBA_SERVER}/{idx}.png"
141
 
142
  state_0.i2s_mode, state_1.i2s_mode = True, True
143
  state_0.offline, state_1.offline = True, True