Spaces:
Running
Running
import streamlit as st | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import StreamingResponse | |
from tensorflow.keras.models import load_model | |
import numpy as np | |
import io | |
import warnings | |
# Suppress warnings | |
warnings.filterwarnings('ignore') | |
# Set Streamlit page configuration | |
st.set_page_config( | |
page_title="Sketch to Image using GAN", | |
layout="centered", | |
page_icon="ποΈ", | |
) | |
# Title and description | |
st.markdown("<h1 style='text-align: center; color: #ff6347;'>Sketch to Image using GAN ποΈ</h1>", unsafe_allow_html=True) | |
st.markdown("<h2 style='text-align: center; color: #ff6347;'>Upload your sketch to generate an image!</h2>", unsafe_allow_html=True) | |
# Upload file widget | |
uploaded_file = st.file_uploader("Upload a sketch (jpg, jpeg, png):", type=["jpg", "jpeg", "png"]) | |
# Model Loading | |
try: | |
generator_model = load_model('model.h5') # Update this path to your actual model file | |
st.success("Model loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading the model: {str(e)}") | |
# Image processing function | |
def process_and_generate_image(image_data): | |
image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
image = image.resize((256, 256)) | |
# Preprocess image | |
image_array = np.array(image) | |
image_array = (image_array - 127.5) / 127.5 # Normalize to [-1, 1] | |
image_array = np.expand_dims(image_array, axis=0) | |
# Generate fake image | |
fake_image = generator_model.predict(image_array) | |
fake_image = (fake_image + 1) / 2.0 # Rescale to [0, 1] | |
fake_image = np.squeeze(fake_image) | |
fake_image = (fake_image * 255).astype(np.uint8) | |
return Image.fromarray(fake_image) | |
# Display uploaded image and handle generation | |
if uploaded_file is not None: | |
st.image(uploaded_file, caption="Uploaded Sketch", width=300) | |
if st.button("Generate Image"): | |
with st.spinner('Generating...'): | |
try: | |
# Generate the image | |
generated_image = process_and_generate_image(uploaded_file.getvalue()) | |
# Display the generated image | |
st.image(generated_image, caption="Generated Image", width=300) | |
except Exception as e: | |
st.error(f"Error generating image: {str(e)}") | |
# FastAPI app for backend | |
app = FastAPI() | |
async def generate_image(file: UploadFile = File(...)): | |
contents = await file.read() | |
generated_image = process_and_generate_image(contents) | |
img_io = io.BytesIO() | |
generated_image.save(img_io, 'JPEG') | |
img_io.seek(0) | |
return StreamingResponse(img_io, media_type="image/jpeg") | |
# Running FastAPI app if script is executed directly | |
if __name__ == '__main__': | |
import uvicorn | |
uvicorn.run(app, host="127.0.0.1", port=8000) | |