Geek7 commited on
Commit
3af9bb2
·
verified ·
1 Parent(s): 57b49bd

Update myapp.py

Browse files
Files changed (1) hide show
  1. myapp.py +70 -5
myapp.py CHANGED
@@ -1,16 +1,81 @@
1
- from flask import Flask, jsonify, request, send_file
2
  from flask_cors import CORS
3
- import io
 
 
 
 
 
4
 
5
- # Initialize the Flask app
6
  myapp = Flask(__name__)
7
- CORS(myapp) # Enable CORS if needed
 
 
 
 
 
8
 
9
  @myapp.route('/')
10
  def home():
11
- return "Welcome to the Image Background Remover!" # Basic home response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
 
13
 
 
 
 
 
 
 
 
 
14
 
15
  # Add this block to make sure your app runs when called
16
  if __name__ == "__main__":
 
1
+ from flask import Flask, request, jsonify, send_file
2
  from flask_cors import CORS
3
+ import asyncio
4
+ import tempfile
5
+ import os
6
+ from threading import RLock
7
+ from huggingface_hub import InferenceClient
8
+ from all_models import models # Importing models from all_models
9
 
 
10
  myapp = Flask(__name__)
11
+ CORS(myapp) # Enable CORS for all routes
12
+
13
+ lock = RLock()
14
+ HF_TOKEN = os.environ.get("HF_TOKEN") # Hugging Face token
15
+
16
+ inference_timeout = 600 # Set timeout for inference
17
 
18
  @myapp.route('/')
19
  def home():
20
+ return "Welcome to the Image Background Remover!"
21
+
22
+ # Function to dynamically load models from the "models" list
23
+ def get_model_from_name(model_name):
24
+ return model_name if model_name in models else None
25
+
26
+ # Asynchronous function to perform inference
27
+ async def infer(client, prompt, seed=1, timeout=inference_timeout, model="prompthero/openjourney-v4"):
28
+ task = asyncio.create_task(
29
+ asyncio.to_thread(client.text_to_image, prompt=prompt, seed=seed, model=model)
30
+ )
31
+ await asyncio.sleep(0)
32
+ try:
33
+ result = await asyncio.wait_for(task, timeout=timeout)
34
+ except (Exception, asyncio.TimeoutError) as e:
35
+ print(e)
36
+ print(f"Task timed out for model: {model}")
37
+ if not task.done():
38
+ task.cancel()
39
+ result = None
40
+
41
+ if task.done() and result is not None:
42
+ with lock:
43
+ temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
44
+ with open(temp_image.name, "wb") as f:
45
+ f.write(result) # Save the result image as a temporary file
46
+ return temp_image.name # Return the path to the saved image
47
+ return None
48
+
49
+ # Flask route for the API endpoint
50
+ @myapp.route('/generate_api', methods=['POST'])
51
+ def generate_api():
52
+ data = request.get_json()
53
+
54
+ # Extract required fields from the request
55
+ prompt = data.get('prompt', '')
56
+ seed = data.get('seed', 1)
57
+ model_name = data.get('model', 'prompthero/openjourney-v4') # Default to "prompthero/openjourney-v4" if not provided
58
+
59
+ if not prompt:
60
+ return jsonify({"error": "Prompt is required"}), 400
61
+
62
+ # Get the model from all_models
63
+ model = get_model_from_name(model_name)
64
+ if not model:
65
+ return jsonify({"error": f"Model '{model_name}' not found in available models"}), 400
66
 
67
+ try:
68
+ # Create a generic InferenceClient for the model
69
+ client = InferenceClient(token=HF_TOKEN) # Pass Hugging Face token if needed
70
 
71
+ # Call the async inference function
72
+ result_path = asyncio.run(infer(client, prompt, seed, model=model))
73
+ if result_path:
74
+ return send_file(result_path, mimetype='image/png') # Send back the generated image file
75
+ else:
76
+ return jsonify({"error": "Failed to generate image"}), 500
77
+ except Exception as e:
78
+ return jsonify({"error": str(e)}), 500
79
 
80
  # Add this block to make sure your app runs when called
81
  if __name__ == "__main__":