shigeru saito commited on
Commit
7632937
·
1 Parent(s): dc9caf6

Interpolation対応、冗長な処理をリファクタリング

Browse files
Files changed (1) hide show
  1. app.py +77 -46
app.py CHANGED
@@ -45,7 +45,6 @@ class Replicate:
45
  self.REPLICATE_MODEL_PATH = ""
46
  self.REPLICATE_MODEL_VERSION = ""
47
  self.input={}
48
- self.output_url = None
49
  self.response = None
50
  self.prediction_id = None
51
  self.lock = threading.Lock()
@@ -78,52 +77,19 @@ class Replicate:
78
 
79
  print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}")
80
  if self.prediction.status == "succeeded":
81
- self.output_url = self.prediction.output
 
82
  print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: {self.prediction.output}")
83
  else:
84
- self.output_url = None
85
 
86
  self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id)
87
  end_time = time.time()
88
  duration = end_time - start_time
89
 
90
- self.download_and_save(url=self.output_url, file_path=self.file_path)
91
  self.print_thread_info(start_time, end_time, duration)
92
- except replicate.exceptions.ReplicateError as e:
93
- print(f"Error fetching model or version: {e}")
94
- print(f"Model Path: {self.REPLICATE_MODEL_PATH}")
95
- print(f"Model Version: {self.REPLICATE_MODEL_VERSION}")
96
- if self.prediction_id and str(e) == "The requested resource could not be found.":
97
- predictions = self.client.predictions.list()
98
- self.prediction = next((p for p in predictions if p.id == self.prediction_id), None)
99
-
100
- if self.prediction:
101
- print(f"Found prediction with ID {self.prediction_id}: {self.prediction}")
102
- else:
103
- print(f"No prediction found with ID {self.prediction_id}")
104
- self.prediction.wait()
105
-
106
- print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}")
107
- if self.prediction.status == "succeeded":
108
- self.output_url = self.prediction.output
109
- print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: {self.prediction.output}")
110
- else:
111
- self.output_url = None
112
- print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: Error")
113
- print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: {self.prediction.output}")
114
-
115
- self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id)
116
- end_time = time.time()
117
- duration = end_time - start_time
118
-
119
- self.download_and_save(url=self.output_url, file_path=self.file_path)
120
- self.print_thread_info(start_time, end_time, duration)
121
- else:
122
- print(f"Error in thread {self.index}: {e}")
123
- print(traceback.format_exc())
124
- print("予期しないエラーが発生しました。スレッドを終了します。")
125
- # 予期しないエラーが発生した場合の追加処理
126
- raise e
127
  except Exception as e:
128
  print(f"Error in thread {self.index}: {e}")
129
  print(traceback.format_exc())
@@ -135,22 +101,22 @@ class Replicate:
135
  f.write(response.content)
136
 
137
  def print_thread_info(self, start_time, end_time, duration):
138
- print(f"Thread {self.index} output_url: {self.output_url}")
139
  print(f"Thread {self.index} start time: {start_time}")
140
  print(f"Thread {self.index} end time: {end_time}")
141
  print(f"Thread {self.index} duration: {duration}")
142
 
143
- class Video(Replicate):
144
 
145
  def __init__(self, id, client: Client, args, scene, index=None):
146
  super().__init__(id, client, args, index)
147
  self.REPLICATE_MODEL_PATH = "lucataco/animate-diff"
148
- self.REPLICATE_MODEL_VERSION = "1531004ee4c98894ab11f8a4ce6206099e732c1da15121987a8eef54828f0663"
149
  self.scene = scene
150
  self.prompt = "masterpiece, awards, best quality, dramatic-lighting, "
151
  self.prompt = self.prompt + scene.get("visual_prompt_in_en")
152
  self.prompt = self.prompt + ", cinematic-angles-" + scene.get("cinematic_angles")
153
- self.nagative_prompt = "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, nsfw, "
154
  self.file_path_format = "assets/{id}/{class_name}_thread_{index}_request_{prediction_id}.mp4"
155
  self.file_path = None
156
  self.input={
@@ -162,9 +128,74 @@ class Video(Replicate):
162
 
163
  def run_replicate(self, retries=0):
164
  self.response = super().run_replicate()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id)
166
  return self.response
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  class Music(Replicate):
169
 
170
  def __init__(self, id, client: Client, args, duration):
@@ -203,13 +234,12 @@ class Music(Replicate):
203
  }
204
  )
205
  print(output)
206
- self.output_url = output
207
  self.response = output
208
 
209
  self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id)
210
  end_time = time.time()
211
  duration = end_time - start_time
212
- self.download_and_save(url=self.output_url, file_path=self.file_path)
213
  self.print_thread_info(start_time, end_time, duration)
214
 
215
  return self.response
@@ -471,7 +501,8 @@ if __name__ == "__main__":
471
  if args.prompts_file:
472
  prompts = load_prompts(args.prompts_file)
473
  # main(prompts)
474
- NajiminoAI.generate("伝統工芸と最新技術の融合")
 
475
  else:
476
 
477
  description = """
 
45
  self.REPLICATE_MODEL_PATH = ""
46
  self.REPLICATE_MODEL_VERSION = ""
47
  self.input={}
 
48
  self.response = None
49
  self.prediction_id = None
50
  self.lock = threading.Lock()
 
77
 
78
  print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}")
79
  if self.prediction.status == "succeeded":
80
+ self.response = self.prediction.output
81
+ self.response = self.response
82
  print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: {self.prediction.output}")
83
  else:
84
+ self.response = None
85
 
86
  self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id)
87
  end_time = time.time()
88
  duration = end_time - start_time
89
 
 
90
  self.print_thread_info(start_time, end_time, duration)
91
+
92
+ return self.response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  except Exception as e:
94
  print(f"Error in thread {self.index}: {e}")
95
  print(traceback.format_exc())
 
101
  f.write(response.content)
102
 
103
  def print_thread_info(self, start_time, end_time, duration):
104
+ print(f"Thread {self.index} response: {self.response}")
105
  print(f"Thread {self.index} start time: {start_time}")
106
  print(f"Thread {self.index} end time: {end_time}")
107
  print(f"Thread {self.index} duration: {duration}")
108
 
109
+ class LucatacoAnimateDiff(Replicate):
110
 
111
  def __init__(self, id, client: Client, args, scene, index=None):
112
  super().__init__(id, client, args, index)
113
  self.REPLICATE_MODEL_PATH = "lucataco/animate-diff"
114
+ self.REPLICATE_MODEL_VERSION = "beecf59c4aee8d81bf04f0381033dfa10dc16e845b4ae00d281e2fa377e48a9f"
115
  self.scene = scene
116
  self.prompt = "masterpiece, awards, best quality, dramatic-lighting, "
117
  self.prompt = self.prompt + scene.get("visual_prompt_in_en")
118
  self.prompt = self.prompt + ", cinematic-angles-" + scene.get("cinematic_angles")
119
+ self.nagative_prompt = "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, nsfw, deformed iris, deformed pupils, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
120
  self.file_path_format = "assets/{id}/{class_name}_thread_{index}_request_{prediction_id}.mp4"
121
  self.file_path = None
122
  self.input={
 
128
 
129
  def run_replicate(self, retries=0):
130
  self.response = super().run_replicate()
131
+ self.download_and_save(url=self.response, file_path=self.file_path)
132
+ self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id)
133
+ return self.response
134
+
135
+ class ZsxkibAnimateDiff(Replicate):
136
+
137
+ def __init__(self, id, client: Client, args, scene, index=None):
138
+ super().__init__(id, client, args, index)
139
+ self.REPLICATE_MODEL_PATH = "zsxkib/animate-diff"
140
+ self.REPLICATE_MODEL_VERSION = "269a616c8b0c2bbc12fc15fd51bb202b11e94ff0f7786c026aa905305c4ed9fb"
141
+ self.scene = scene
142
+ self.prompt = "masterpiece, awards, best quality, dramatic-lighting, "
143
+ self.prompt = self.prompt + scene.get("visual_prompt_in_en")
144
+ self.prompt = self.prompt + ", cinematic-angles-" + scene.get("cinematic_angles")
145
+ self.nagative_prompt = "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, nsfw, deformed iris, deformed pupils, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
146
+ self.file_path_format = "assets/{id}/{class_name}_thread_{index}_request_{prediction_id}.mp4"
147
+ self.file_path = None
148
+ self.input={
149
+ "prompt": self.prompt,
150
+ "negative_prompt": self.nagative_prompt,
151
+ "base_model": "toonyou_beta3", #Allowed values:realisticVisionV20_v20, lyriel_v16, majicmixRealistic_v5Preview, rcnzCartoon3d_v10, toonyou_beta3
152
+ }
153
+
154
+ def run_replicate(self, retries=0):
155
+ self.response = super().run_replicate()
156
+ self.video = self.response[0]
157
+ self.download_and_save(url=self.video, file_path=self.file_path)
158
  self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id)
159
  return self.response
160
 
161
+ class Interpolator(Replicate):
162
+
163
+ def __init__(self, id, client: Client, args, video, index=None):
164
+ super().__init__(id, client, args, index)
165
+ self.REPLICATE_MODEL_PATH = "zsxkib/st-mfnet"
166
+ self.REPLICATE_MODEL_VERSION = "faa7693430b0a4ac95d1b8e25165673c1d7a7263537a7c4bb9be82a3e2d130fb"
167
+ self.file_path_format = "assets/{id}/{class_name}_thread_{index}_request_{prediction_id}.mp4"
168
+ self.file_path = None
169
+ self.input={
170
+ "mp4": video,
171
+ "framerate_multiplier": 4,
172
+ "keep_original_duration": False,
173
+ "custom_fps": 24,
174
+ }
175
+
176
+ def run_replicate(self, retries=0):
177
+ self.response = super().run_replicate()
178
+ self.download_and_save(url=list(self.response)[-1], file_path=self.file_path)
179
+ self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id)
180
+ return self.response
181
+
182
+ class Video():
183
+
184
+ def __init__(self, id, client: Client, args, scene, index=None):
185
+ self.client = client
186
+ self.index = index
187
+ # self.animatediff = LucatacoAnimateDiff(id, client, args, scene, index)
188
+ self.animatediff = ZsxkibAnimateDiff(id, client, args, scene, index)
189
+ self.prompt = self.animatediff.prompt
190
+ self.interpolator = None
191
+
192
+ def run_replicate(self, retries=0):
193
+ self.animatediff.run_replicate(retries)
194
+ self.interpolator = Interpolator(self.animatediff.id, self.animatediff.client, self.animatediff.args, self.animatediff.video, self.animatediff.index)
195
+ self.response = self.interpolator.run_replicate(retries)
196
+ self.file_path = self.interpolator.file_path
197
+ return self.response
198
+
199
  class Music(Replicate):
200
 
201
  def __init__(self, id, client: Client, args, duration):
 
234
  }
235
  )
236
  print(output)
 
237
  self.response = output
238
 
239
  self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id)
240
  end_time = time.time()
241
  duration = end_time - start_time
242
+ self.download_and_save(url=self.response, file_path=self.file_path)
243
  self.print_thread_info(start_time, end_time, duration)
244
 
245
  return self.response
 
501
  if args.prompts_file:
502
  prompts = load_prompts(args.prompts_file)
503
  # main(prompts)
504
+ NajiminoAI.generate("子どもたちが笑ったり怒ったり泣いたり楽しんだりする")
505
+
506
  else:
507
 
508
  description = """