usmanyousaf's picture
Update app.py
88caa86 verified
import streamlit as st
import requests
from PIL import Image
from io import BytesIO
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse
from tensorflow.keras.models import load_model
import numpy as np
import io
import warnings
# Suppress warnings
warnings.filterwarnings('ignore')
# Set Streamlit page configuration
st.set_page_config(
page_title="Sketch to Image using GAN",
layout="centered",
page_icon="πŸ–ŒοΈ",
)
# Title and description
st.markdown("<h1 style='text-align: center; color: #ff6347;'>Sketch to Image using GAN πŸ–ŒοΈ</h1>", unsafe_allow_html=True)
st.markdown("<h2 style='text-align: center; color: #ff6347;'>Upload your sketch to generate an image!</h2>", unsafe_allow_html=True)
# Upload file widget
uploaded_file = st.file_uploader("Upload a sketch (jpg, jpeg, png):", type=["jpg", "jpeg", "png"])
# Model Loading
try:
generator_model = load_model('model.h5') # Update this path to your actual model file
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"Error loading the model: {str(e)}")
# Image processing function
def process_and_generate_image(image_data):
image = Image.open(io.BytesIO(image_data)).convert('RGB')
image = image.resize((256, 256))
# Preprocess image
image_array = np.array(image)
image_array = (image_array - 127.5) / 127.5 # Normalize to [-1, 1]
image_array = np.expand_dims(image_array, axis=0)
# Generate fake image
fake_image = generator_model.predict(image_array)
fake_image = (fake_image + 1) / 2.0 # Rescale to [0, 1]
fake_image = np.squeeze(fake_image)
fake_image = (fake_image * 255).astype(np.uint8)
return Image.fromarray(fake_image)
# Display uploaded image and handle generation
if uploaded_file is not None:
st.image(uploaded_file, caption="Uploaded Sketch", width=300)
if st.button("Generate Image"):
with st.spinner('Generating...'):
try:
# Generate the image
generated_image = process_and_generate_image(uploaded_file.getvalue())
# Display the generated image
st.image(generated_image, caption="Generated Image", width=300)
except Exception as e:
st.error(f"Error generating image: {str(e)}")
# FastAPI app for backend
app = FastAPI()
@app.post("/generate-image/")
async def generate_image(file: UploadFile = File(...)):
contents = await file.read()
generated_image = process_and_generate_image(contents)
img_io = io.BytesIO()
generated_image.save(img_io, 'JPEG')
img_io.seek(0)
return StreamingResponse(img_io, media_type="image/jpeg")
# Running FastAPI app if script is executed directly
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)