Geek7 commited on
Commit
8be12ab
·
verified ·
1 Parent(s): 3039308

Update myapp.py

Browse files
Files changed (1) hide show
  1. myapp.py +59 -42
myapp.py CHANGED
@@ -1,60 +1,77 @@
1
- import gradio as gr
2
- from random import randint
3
- from all_models import models
4
- from externalmod import gr_Interface_load
5
  import asyncio
 
6
  import os
7
  from threading import RLock
 
 
 
 
 
8
 
9
  lock = RLock()
10
- HF_TOKEN = os.environ.get("HF_TOKEN")
11
 
12
- def load_fn(models):
13
- global models_load
14
- models_load = {}
15
-
16
- for model in models:
17
- if model not in models_load.keys():
18
- try:
19
- m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
20
- except Exception as error:
21
- print(error)
22
- m = gr.Interface(lambda: None, ['text'], ['image'])
23
- models_load.update({model: m})
24
-
25
- load_fn(models)
26
-
27
- num_models = 6
28
- MAX_SEED = 3999999999
29
- default_models = models[:num_models]
30
- inference_timeout = 600
31
-
32
- async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
33
- kwargs = {"seed": seed}
34
- task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
35
  await asyncio.sleep(0)
36
  try:
37
  result = await asyncio.wait_for(task, timeout=timeout)
38
  except (Exception, asyncio.TimeoutError) as e:
39
  print(e)
40
- print(f"Task timed out: {model_str}")
41
- if not task.done():
42
  task.cancel()
43
  result = None
 
44
  if task.done() and result is not None:
45
  with lock:
46
- png_path = "image.png"
47
- result.save(png_path)
48
- return png_path
 
49
  return None
50
 
51
- # Expose Gradio API
52
- def generate_api(model_str, prompt, seed=1):
53
- result = asyncio.run(infer(model_str, prompt, seed))
54
- if result:
55
- return result # Path to generated image
56
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Launch Gradio API without frontend
59
- iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
60
- iface.launch(show_api=True, share=True)
 
1
+ from flask import Flask, request, jsonify, send_file
2
+ from flask_cors import CORS
 
 
3
  import asyncio
4
+ import tempfile
5
  import os
6
  from threading import RLock
7
+ from huggingface_hub import InferenceClient
8
+ from all_models import models # Importing models from all_models
9
+
10
+ myapp = Flask(__name__)
11
+ CORS(myapp) # Enable CORS for all routes
12
 
13
  lock = RLock()
14
+ HF_TOKEN = os.environ.get("HF_TOKEN") # Hugging Face token
15
 
16
+ inference_timeout = 600 # Set timeout for inference
17
+
18
+ # Function to dynamically load models from the "models" list
19
+ def get_model_from_name(model_name):
20
+ return model_name if model_name in models else None
21
+
22
+ # Asynchronous function to perform inference
23
+ async def infer(client, prompt, seed=1, timeout=inference_timeout, model="prompthero/openjourney-v4"):
24
+ task = asyncio.create_task(
25
+ asyncio.to_thread(client.text_to_image, prompt=prompt, seed=seed, model=model)
26
+ )
 
 
 
 
 
 
 
 
 
 
 
 
27
  await asyncio.sleep(0)
28
  try:
29
  result = await asyncio.wait_for(task, timeout=timeout)
30
  except (Exception, asyncio.TimeoutError) as e:
31
  print(e)
32
+ print(f"Task timed out for model: {model}")
33
+ if not task.done():
34
  task.cancel()
35
  result = None
36
+
37
  if task.done() and result is not None:
38
  with lock:
39
+ temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
40
+ with open(temp_image.name, "wb") as f:
41
+ f.write(result) # Save the result image as a temporary file
42
+ return temp_image.name # Return the path to the saved image
43
  return None
44
 
45
+ # Flask route for the API endpoint
46
+ @myapp.route('/generate_api', methods=['POST'])
47
+ def generate_api():
48
+ data = request.get_json()
49
+
50
+ # Extract required fields from the request
51
+ prompt = data.get('prompt', '')
52
+ seed = data.get('seed', 1)
53
+ model_name = data.get('model', 'prompthero/openjourney-v4') # Default to "prompthero/openjourney-v4" if not provided
54
+
55
+ if not prompt:
56
+ return jsonify({"error": "Prompt is required"}), 400
57
+
58
+ # Get the model from all_models
59
+ model = get_model_from_name(model_name)
60
+ if not model:
61
+ return jsonify({"error": f"Model '{model_name}' not found in available models"}), 400
62
+
63
+ try:
64
+ # Create a generic InferenceClient for the model
65
+ client = InferenceClient(token=HF_TOKEN) # Pass Hugging Face token if needed
66
+
67
+ # Call the async inference function
68
+ result_path = asyncio.run(infer(client, prompt, seed, model=model))
69
+ if result_path:
70
+ return send_file(result_path, mimetype='image/png') # Send back the generated image file
71
+ else:
72
+ return jsonify({"error": "Failed to generate image"}), 500
73
+ except Exception as e:
74
+ return jsonify({"error": str(e)}), 500
75
 
76
+ if __name__ == "__main__":
77
+ myapp.run(host='0.0.0.0', port=7860) # Run directly if needed for testing