hysts commited on
Commit
66de195
1 Parent(s): e89b524
Files changed (1) hide show
  1. model.py +4 -0
model.py CHANGED
@@ -107,6 +107,8 @@ class Model:
107
 
108
  def generate_label_image(self, pose_data: torch.Tensor,
109
  shape_text: str) -> np.ndarray:
 
 
110
  self.model.feed_pose_data(pose_data)
111
  shape_attributes = generate_shape_attributes(shape_text)
112
  shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
@@ -118,6 +120,8 @@ class Model:
118
 
119
  def generate_human(self, label_image: np.ndarray, texture_text: str,
120
  sample_steps: int, seed: int) -> np.ndarray:
 
 
121
  mask = label_image.copy()
122
  seg_map = self.process_mask(mask)
123
  self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
 
107
 
108
  def generate_label_image(self, pose_data: torch.Tensor,
109
  shape_text: str) -> np.ndarray:
110
+ if pose_data is None:
111
+ return
112
  self.model.feed_pose_data(pose_data)
113
  shape_attributes = generate_shape_attributes(shape_text)
114
  shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
 
120
 
121
  def generate_human(self, label_image: np.ndarray, texture_text: str,
122
  sample_steps: int, seed: int) -> np.ndarray:
123
+ if label_image is None:
124
+ return
125
  mask = label_image.copy()
126
  seg_map = self.process_mask(mask)
127
  self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(