generatorImage / generate.py
alfrds's picture
new app
4603c75
import gradio as gr
from gradio.inputs import Textbox
import torch
from diffusers import StableDiffusionPipeline
import boto3
from io import BytesIO
import os
import botocore
from time import sleep
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
S3_BUCKET_NAME = 'pineblogs101145-dev'
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float32)
pipe = pipe.to(device)
s3 = boto3.resource('s3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
s3_client = boto3.client('s3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
bucket_name = 'pineblogs101145-dev'
folder = 'public/mdx/'
def text_to_image(prompt, save_as, key_id):
if AWS_ACCESS_KEY_ID != key_id:
return "not permition"
# Create an instance of the S3 client
s3 = boto3.client('s3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
image_name = '-'.join(save_as.split()) + ".webp"
def save_image_to_s3(image):
# Create a BytesIO object to store the image.
image_buffer = BytesIO()
image.save(image_buffer, format='WEBP')
image_buffer.seek(0)
# Full path of the file in the bucket
s3_key = "public/" + image_name
print('Saving image to s3')
# Upload the image to the S3 bucket
s3.upload_fileobj(image_buffer, S3_BUCKET_NAME, s3_key)
print('Image saved to s3')
def generator_image(prompt):
prompt = prompt
print('Starting to generate the image ...')
try:
image = pipe(prompt).images[0]
except Exception as e:
print('Error: ', e)
print('Image generation completed')
# Save the image in S3
save_image_to_s3(image)
generator_image(prompt)
return image_name
def check_if_exist(bucket_name, key):
try:
s3.Object(bucket_name, key).load()
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == "404":
# The object does not exist.
return False
else:
# Something else has gone wrong.
raise
else:
return True
def list_s3_files(bucket_name, folder):
my_bucket = s3.Bucket(bucket_name)
for objects in my_bucket.objects.filter(Prefix=folder):
print(objects.key)
filename_ext = '%s' % os.path.basename(objects.key)
filename = os.path.splitext(filename_ext)[0]
s3image = 'public/%s.webp' % filename
if check_if_exist(bucket_name, s3image):
print('Image %s already exists!' % s3image)
else:
response = s3_client.head_object(Bucket=bucket_name, Key=objects.key)
metadata = response['Metadata']
print(metadata)
if 'resumen' in metadata:
print('Has resume, ready to create image!')
print('Start creating image.. %s ' % s3image)
resumen = metadata['resumen']
else:
print('There is NOT resume, skipping..')
sleep(500/1000)
text_to_image(resumen, filename, AWS_ACCESS_KEY_ID)
list_s3_files(bucket_name, folder)
iface = gr.Interface(fn=list_s3_files, inputs=[Textbox(label="bucket_name"), Textbox(label="folder")], outputs="text")
iface.launch()