Spaces:
Build error
Build error
File size: 2,328 Bytes
5a6f45f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import argparse
import base64
import os
from pathlib import Path
from io import BytesIO
import time
from flask import Flask, request, jsonify
from flask_cors import CORS, cross_origin
from consts import IMAGES_OUTPUT_DIR
from utils import parse_arg_boolean, parse_arg_dalle_version
from consts import ModelSize
app = Flask(__name__)
CORS(app)
print("--> Starting DALL-E Server. This might take up to two minutes.")
from dalle_model import DalleModel
dalle_model = None
parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
parser.add_argument("--port", type=int, default=8000, help = "backend port")
parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
args = parser.parse_args()
@app.route("/dalle", methods=["POST"])
@cross_origin()
def generate_images_api():
json_data = request.get_json(force=True)
text_prompt = json_data["text"]
num_images = json_data["num_images"]
generated_imgs = dalle_model.generate_images(text_prompt, num_images)
generated_images = []
if args.save_to_disk:
dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
Path(dir_name).mkdir(parents=True, exist_ok=True)
for idx, img in enumerate(generated_imgs):
if args.save_to_disk:
img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG")
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
generated_images.append(img_str)
print(f"Created {num_images} images from text prompt [{text_prompt}]")
return jsonify(generated_images)
@app.route("/", methods=["GET"])
@cross_origin()
def health_check():
return jsonify(success=True)
with app.app_context():
dalle_model = DalleModel(args.model_version)
dalle_model.generate_images("warm-up", 1)
print("--> DALL-E Server is up and running!")
print(f"--> Model selected - DALL-E {args.model_version}")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=args.port, debug=False) |