shigeru saito commited on
Commit
df52446
·
1 Parent(s): 59876db

複数replicateアカウント対応を修正

Browse files
Files changed (3) hide show
  1. app.py +86 -90
  2. schema.json +1 -1
  3. template.md +1 -1
app.py CHANGED
@@ -20,25 +20,24 @@ from PIL import Image as PIL_Image
20
  from jinja2 import Template
21
 
22
  ENV = os.getenv("ENV")
23
- MODEL = "gpt-3.5-turbo"
24
- # MODEL = "gpt-4"
25
 
26
  load_dotenv()
27
  openai.api_key = os.getenv('OPENAI_API_KEY')
28
- # REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
29
- # REPLICATE_API_TOKEN_LISTをロードし、カンマで分割してリストに変換
30
  REPLICATE_API_TOKEN_LIST = os.getenv("REPLICATE_API_TOKEN_LIST").split(',')
31
- REPLICATE_API_TOKEN_INDEX = 0 # トークンのインデックスを初期化
32
  NUMBER_OF_SCENES = os.getenv("NUMBER_OF_SCENES")
33
 
34
  if ENV == "PRODUCTION":
35
  import replicate
 
36
  else:
37
- from stub import replicate
 
38
 
39
  class Video:
40
- def __init__(self, scene, index, token_controller):
41
- self.token_controller = token_controller
42
  self.scene = scene
43
  self.prompt = "masterpiece, awards, best quality, dramatic-lighting, "
44
  self.prompt = self.prompt + scene.get("visual_prompt_in_en")
@@ -49,40 +48,77 @@ class Video:
49
  self.video_id = uuid.uuid4()
50
  self.file_path = f"assets/thread_{index}_request_{self.video_id}_video.mp4"
51
 
52
- MAX_RETRIES = 2
 
 
53
  def run_replicate(self, retries=0):
54
  try:
55
- self.token = self.token_controller.get_next_token()
56
  start_time = time.time()
57
 
58
- os.environ["REPLICATE_API_TOKEN"] = self.token
59
  #tokenの最初の10文字だけ出力
60
- print(f"Thread {self.index} token: {self.token[:10]}")
61
 
62
- self.output_url = replicate.run(
63
- "lucataco/animate-diff:1531004ee4c98894ab11f8a4ce6206099e732c1da15121987a8eef54828f0663",
 
 
64
  input={
65
  "motion_module": "mm_sd_v14",
66
  "prompt": self.prompt,
67
  "n_prompt": self.nagative_prompt,
68
  "seed": 0,
69
- }
70
  )
 
 
 
 
 
 
 
 
71
 
 
 
 
 
 
 
 
 
 
72
  end_time = time.time()
73
  duration = end_time - start_time
74
 
75
  self.download_and_save(url=self.output_url, file_path=self.file_path)
76
  self.print_thread_info(start_time, end_time, duration)
77
  except replicate.exceptions.ReplicateError as e:
78
- if str(e) == "The requested resource could not be found." and retries < self.MAX_RETRIES:
79
- print("リソースが見つからないエラーが発生しました。2秒後に再試行します。")
80
- time.sleep(2)
81
- self.run_replicate(retries + 1) # 再帰的に関数を呼び出して再試行
82
- elif retries >= self.MAX_RETRIES:
83
- print("最大再試行回数に達しました。スレッドを終了します。")
84
- # 最大再試行回数に達した場合の追加処理
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  else:
 
86
  print("予期しないエラーが発生しました。スレッドを終了します。")
87
  # 予期しないエラーが発生した場合の追加処理
88
  except Exception as e:
@@ -105,24 +141,37 @@ class ThreadController:
105
  scenes = args.get("scenes")
106
  self.videos = []
107
  self.threads = []
108
- self.token_index = 0
109
  self.lock = threading.Lock()
110
- for index, scene in enumerate(scenes):
111
- for _ in REPLICATE_API_TOKEN_LIST:
112
- # token = REPLICATE_API_TOKEN_LIST[self.token_index]
113
- video = Video(scene, index, self)
 
 
 
 
 
 
114
  self.videos.append(video)
115
- self.token_index = (self.token_index + 1) % len(REPLICATE_API_TOKEN_LIST)
 
 
116
 
117
  def run_threads(self):
118
  os.makedirs("assets", exist_ok=True)
119
 
 
120
  for video in self.videos:
 
 
 
 
 
121
  thread = threading.Thread(target=video.run_replicate)
122
  self.threads.append(thread)
123
  thread.start()
124
- # 1秒待ってから実行
125
- # time.sleep(1)
126
 
127
  for thread in self.threads:
128
  thread.join()
@@ -150,12 +199,6 @@ class ThreadController:
150
  for video in self.videos:
151
  print(f"Thread {video.index} prompt: {video.prompt}")
152
 
153
- def get_next_token(self):
154
- with self.lock:
155
- token = REPLICATE_API_TOKEN_LIST[self.token_index]
156
- self.token_index = (self.token_index + 1) % len(REPLICATE_API_TOKEN_LIST)
157
- return token
158
-
159
  def main(args):
160
  thread_controller = ThreadController(args)
161
  thread_controller.run_threads()
@@ -213,34 +256,6 @@ class NajiminoAI:
213
 
214
  def generate_markdown(self, args, generation_time):
215
 
216
- # # lang=args.get("lang")
217
- # title=args.get("title")
218
- # description=args.get("description")
219
- # visual_prompt_in_en=args.get("visual_prompt_in_en")
220
- # scenes = args.get("scenes")
221
-
222
- # prompt_for_visual_expression = \
223
- # visual_prompt_in_en
224
-
225
- # print("prompt_for_visual_expression: "+prompt_for_visual_expression)
226
-
227
- # prompts = []
228
- # if scenes:
229
- # for scene_data in scenes:
230
- # prompt = scene_data.get("visual_prompt_in_en")
231
- # prompt = prompt + ", " + scene_data.get("cinematic_angles")
232
- # prompt = prompt + ", " + scene_data.get("visual_prompt_in_en")
233
- # prompts.append(prompt)
234
- # print("scenes: " + json.dumps(scenes, indent=2))
235
- # if scenes:
236
- # for scene_data in scenes:
237
- # scene = scene_data.get("scene")
238
- # cinematic_angles = scene_data.get("cinematic_angles")
239
- # visual_prompt_in_en = scene_data.get("visual_prompt_in_en")
240
- # print("scene: ", scene)
241
- # print("cinematic_angles: ", cinematic_angles)
242
- # print("visual_prompt_in_en: ", visual_prompt_in_en)
243
-
244
  template_string = get_filetext(filename = "template.md")
245
 
246
  template = Template(template_string)
@@ -282,29 +297,15 @@ class NajiminoAI:
282
 
283
  function_name = message["function_call"]["name"]
284
 
285
- args = json.loads(message["function_call"]["arguments"])
 
 
 
 
 
286
 
287
  print("args: " + json.dumps(args, indent=2))
288
 
289
- # # lang=args.get("lang")
290
- # title=args.get("title")
291
- # description=args.get("description")
292
- # visual_prompt_in_en=args.get("visual_prompt_in_en")
293
- # scenes = args.get("scenes")
294
-
295
- # prompt_for_visual_expression = \
296
- # visual_prompt_in_en
297
-
298
- # print("prompt_for_visual_expression: "+prompt_for_visual_expression)
299
-
300
- # prompts = []
301
- # if scenes:
302
- # for scene_data in scenes:
303
- # prompt = scene_data.get("visual_prompt_in_en")
304
- # prompt = prompt + ", " + scene_data.get("cinematic_angles")
305
- # prompt = prompt + ", " + scene_data.get("visual_prompt_in_en")
306
- # prompts.append(prompt)
307
-
308
  video_path = main(args)
309
 
310
  main_end_time = time.time()
@@ -325,7 +326,6 @@ class NajiminoAI:
325
  )
326
  return [video_path, html]
327
 
328
-
329
  if __name__ == "__main__":
330
  parser = argparse.ArgumentParser(description="Generate videos from text prompts")
331
 
@@ -338,10 +338,6 @@ if __name__ == "__main__":
338
  # main(prompts)
339
  NajiminoAI.generate("伝統工芸と最新技術の融合")
340
  else:
341
- # def create_video(prompt):
342
- # prompts = prompt.strip().split('\n')
343
- # output_path = main(prompts)
344
- # return output_path
345
 
346
  iface = gr.Interface(
347
  fn=NajiminoAI.generate,
@@ -357,7 +353,7 @@ if __name__ == "__main__":
357
  examples=[
358
  ["侍たちは野を超え山を超え、敵軍大将を討ち取り、天下の大将軍となった!"],
359
  ["子どもたちが笑ったり怒ったり泣いたり楽しんだりする"],
360
- ["彼女のダンスは炎のように激しく、風のように自由に、水のように柔軟に、木のように生き生きと、虹のように美しく舞う"],
361
  ],
362
  )
363
  iface.launch()
 
20
  from jinja2 import Template
21
 
22
  ENV = os.getenv("ENV")
23
+ # MODEL = "gpt-3.5-turbo"
24
+ MODEL = "gpt-4"
25
 
26
  load_dotenv()
27
  openai.api_key = os.getenv('OPENAI_API_KEY')
 
 
28
  REPLICATE_API_TOKEN_LIST = os.getenv("REPLICATE_API_TOKEN_LIST").split(',')
 
29
  NUMBER_OF_SCENES = os.getenv("NUMBER_OF_SCENES")
30
 
31
  if ENV == "PRODUCTION":
32
  import replicate
33
+ from replicate.client import Client
34
  else:
35
+ # from stub import replicate
36
+ pass
37
 
38
  class Video:
39
+ def __init__(self, scene, index, client: Client):
40
+ self.client = client
41
  self.scene = scene
42
  self.prompt = "masterpiece, awards, best quality, dramatic-lighting, "
43
  self.prompt = self.prompt + scene.get("visual_prompt_in_en")
 
48
  self.video_id = uuid.uuid4()
49
  self.file_path = f"assets/thread_{index}_request_{self.video_id}_video.mp4"
50
 
51
+ REPLICATE_MODEL_PATH = "lucataco/animate-diff"
52
+ REPLICATE_MODEL_VERSION = "1531004ee4c98894ab11f8a4ce6206099e732c1da15121987a8eef54828f0663"
53
+
54
  def run_replicate(self, retries=0):
55
  try:
56
+ # self.client.api_token = self.client.api_token_controller.get_next_token()
57
  start_time = time.time()
58
 
59
+ # os.environ["REPLICATE_API_TOKEN"] = self.client.api_token
60
  #tokenの最初の10文字だけ出力
61
+ print(f"Thread {self.index} token: {self.client.api_token[:10]}")
62
 
63
+ model = self.client.models.get(self.REPLICATE_MODEL_PATH)
64
+ version = model.versions.get(self.REPLICATE_MODEL_VERSION)
65
+ self.prediction = self.client.predictions.create(
66
+ version=version,
67
  input={
68
  "motion_module": "mm_sd_v14",
69
  "prompt": self.prompt,
70
  "n_prompt": self.nagative_prompt,
71
  "seed": 0,
72
+ },
73
  )
74
+
75
+ self.prediction_id = self.prediction.id
76
+
77
+ # print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction: {self.prediction}")
78
+ print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}")
79
+
80
+ self.prediction.reload()
81
+ print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}")
82
 
83
+ self.prediction.wait()
84
+
85
+ print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}")
86
+ if self.prediction.status == "succeeded":
87
+ self.output_url = self.prediction.output
88
+ print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: {self.prediction.output}")
89
+ else:
90
+ self.output_url = None
91
+
92
  end_time = time.time()
93
  duration = end_time - start_time
94
 
95
  self.download_and_save(url=self.output_url, file_path=self.file_path)
96
  self.print_thread_info(start_time, end_time, duration)
97
  except replicate.exceptions.ReplicateError as e:
98
+ if self.prediction and str(e) == "The requested resource could not be found.":
99
+ predictions = self.client.predictions.list()
100
+ self.prediction = next((p for p in predictions if p.id == self.prediction_id), None)
101
+
102
+ if self.prediction:
103
+ print(f"Found prediction with ID {self.prediction_id}: {self.prediction}")
104
+ else:
105
+ print(f"No prediction found with ID {self.prediction_id}")
106
+ self.prediction.wait()
107
+
108
+ print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}")
109
+ if self.prediction.status == "succeeded":
110
+ self.output_url = self.prediction.output
111
+ print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: {self.prediction.output}")
112
+ else:
113
+ self.output_url = None
114
+
115
+ end_time = time.time()
116
+ duration = end_time - start_time
117
+
118
+ self.download_and_save(url=self.output_url, file_path=self.file_path)
119
+ self.print_thread_info(start_time, end_time, duration)
120
  else:
121
+ print(f"Error in thread {self.index}: {e}")
122
  print("予期しないエラーが発生しました。スレッドを終了します。")
123
  # 予期しないエラーが発生した場合の追加処理
124
  except Exception as e:
 
141
  scenes = args.get("scenes")
142
  self.videos = []
143
  self.threads = []
 
144
  self.lock = threading.Lock()
145
+ self.replicate_client_list = {}
146
+ for token in REPLICATE_API_TOKEN_LIST:
147
+
148
+ client = Client()
149
+ client.api_token = token
150
+ self.replicate_client_list[token] = client
151
+
152
+ for index, scene in enumerate(scenes):
153
+ # token = REPLICATE_API_TOKEN_LIST[self.client.api_token_index]
154
+ video = Video(scene, index, client)
155
  self.videos.append(video)
156
+
157
+ # self.client.api_token_index = (self.client.api_token_index + 1) % len(REPLICATE_API_TOKEN_LIST)
158
+
159
 
160
  def run_threads(self):
161
  os.makedirs("assets", exist_ok=True)
162
 
163
+ token = None
164
  for video in self.videos:
165
+ if token is not None and video.client.api_token != token:
166
+ # tokenが異なる場合、1秒待ってから次を実行
167
+ print(f"Thread {video.index} token changed. Waiting 3 seconds.")
168
+ time.sleep(5)
169
+
170
  thread = threading.Thread(target=video.run_replicate)
171
  self.threads.append(thread)
172
  thread.start()
173
+ token = video.client.api_token
174
+ # time.sleep(5)
175
 
176
  for thread in self.threads:
177
  thread.join()
 
199
  for video in self.videos:
200
  print(f"Thread {video.index} prompt: {video.prompt}")
201
 
 
 
 
 
 
 
202
  def main(args):
203
  thread_controller = ThreadController(args)
204
  thread_controller.run_threads()
 
256
 
257
  def generate_markdown(self, args, generation_time):
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  template_string = get_filetext(filename = "template.md")
260
 
261
  template = Template(template_string)
 
297
 
298
  function_name = message["function_call"]["name"]
299
 
300
+ try:
301
+ args = json.loads(message["function_call"]["arguments"])
302
+ except json.JSONDecodeError as e:
303
+ print(f"JSON decode error at position {e.pos}: {e.msg}")
304
+ print("message: " + json.dumps(message, indent=2))
305
+ raise e
306
 
307
  print("args: " + json.dumps(args, indent=2))
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  video_path = main(args)
310
 
311
  main_end_time = time.time()
 
326
  )
327
  return [video_path, html]
328
 
 
329
  if __name__ == "__main__":
330
  parser = argparse.ArgumentParser(description="Generate videos from text prompts")
331
 
 
338
  # main(prompts)
339
  NajiminoAI.generate("伝統工芸と最新技術の融合")
340
  else:
 
 
 
 
341
 
342
  iface = gr.Interface(
343
  fn=NajiminoAI.generate,
 
353
  examples=[
354
  ["侍たちは野を超え山を超え、敵軍大将を討ち取り、天下の大将軍となった!"],
355
  ["子どもたちが笑ったり怒ったり泣いたり楽しんだりする"],
356
+ ["日は昇り、大地を照らし、日は沈む。闇夜を照らし、陽はまた昇る。 "],
357
  ],
358
  )
359
  iface.launch()
schema.json CHANGED
@@ -20,7 +20,7 @@
20
  },
21
  "story": {
22
  "type": "string",
23
- "description": "映画のあらすじを、起承転結を交えて時系列に詳しく説明する"
24
  },
25
  "visual_style": {
26
  "type": "string",
 
20
  },
21
  "story": {
22
  "type": "string",
23
+ "description": "映画のあらすじを、起承転結を交えて時系列に文学的に詳しく説明する"
24
  },
25
  "visual_style": {
26
  "type": "string",
template.md CHANGED
@@ -12,4 +12,4 @@
12
 
13
  | Scene | visual_prompt_in_en | negative_visual_prompt_in_en | cinematic_angles |
14
  |----:|----|----|----|{% for item in args.scenes %}
15
- |{{ item.scene }}|{{ item.visual_prompt_in_en }}|{{ item.negative_visual_prompt_in_en }}|{{ item.camera_work}}|{% endfor %}
 
12
 
13
  | Scene | visual_prompt_in_en | negative_visual_prompt_in_en | cinematic_angles |
14
  |----:|----|----|----|{% for item in args.scenes %}
15
+ |{{ item.scene }}|{{ item.visual_prompt_in_en }}|{{ item.negative_visual_prompt_in_en }}|{{ item.cinematic_angles}}|{% endfor %}