Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -5,131 +5,82 @@ from io import BytesIO
|
|
5 |
from fastapi import FastAPI, File, UploadFile
|
6 |
from fastapi.responses import StreamingResponse
|
7 |
from tensorflow.keras.models import load_model
|
8 |
-
from tensorflow.keras.layers import Conv2DTranspose # Import Conv2DTranspose
|
9 |
import numpy as np
|
10 |
import io
|
11 |
import warnings
|
12 |
|
|
|
|
|
|
|
13 |
# Set Streamlit page configuration
|
14 |
st.set_page_config(
|
15 |
page_title="Sketch to Image using GAN",
|
16 |
layout="centered",
|
17 |
page_icon="ποΈ",
|
18 |
-
initial_sidebar_state="expanded",
|
19 |
-
)
|
20 |
-
|
21 |
-
# Custom CSS for styling
|
22 |
-
st.markdown(
|
23 |
-
"""
|
24 |
-
<style>
|
25 |
-
body {
|
26 |
-
background-color: #2a2a2a; /* Cool dark background */
|
27 |
-
color:#ffffff; /* White text */
|
28 |
-
font-family: 'Courier New', monospace; /* Cool font */
|
29 |
-
}
|
30 |
-
h1, h2 {
|
31 |
-
color: #ff6347; /* Tomato color for titles */
|
32 |
-
font-weight: bold; /* Bold text */
|
33 |
-
}
|
34 |
-
.stButton>button {
|
35 |
-
color: #ffffff;
|
36 |
-
background-color: #ff6347; /* Tomato color for buttons */
|
37 |
-
border-radius: 10px;
|
38 |
-
border: 2px solid #ffffff;
|
39 |
-
font-weight: bold; /* Bold text */
|
40 |
-
}
|
41 |
-
</style>
|
42 |
-
""",
|
43 |
-
unsafe_allow_html=True,
|
44 |
)
|
45 |
|
46 |
-
# Title
|
47 |
st.markdown("<h1 style='text-align: center; color: #ff6347;'>Sketch to Image using GAN ποΈ</h1>", unsafe_allow_html=True)
|
|
|
48 |
|
49 |
-
#
|
50 |
-
st.
|
51 |
-
|
52 |
-
# Display Logo Image
|
53 |
-
logo_image = Image.open("home1.jpeg") # Make sure to replace with the actual path
|
54 |
-
st.image(logo_image, width=300)
|
55 |
-
|
56 |
-
st.write("The application of Generative Adversarial Networks (GANs) in the Sketch to Image project extends beyond creative endeavors, finding significant utility in various fields. The ability to transform sketches into vibrant and detailed images has far-reaching implications, especially in sectors such as law enforcement, forensic science, and more.")
|
57 |
-
|
58 |
-
# Upload Section
|
59 |
-
st.markdown("<h2 style='text-align: center; color: #ff6347;'>Upload Your Sketch π€</h2>", unsafe_allow_html=True)
|
60 |
-
uploaded_file = st.file_uploader("Choose an image... π€", type=["jpg", "jpeg", "png"])
|
61 |
-
|
62 |
-
if uploaded_file is not None:
|
63 |
-
# Display uploaded image in the center
|
64 |
-
col1, col2, col3 = st.columns([1, 2, 1])
|
65 |
-
with col2:
|
66 |
-
st.image(uploaded_file, caption="Uploaded Image πΌοΈ", width=300)
|
67 |
-
|
68 |
-
# Button to generate the image
|
69 |
-
if st.button('Generate π'):
|
70 |
-
with st.spinner('Wait for it... Generating your image π¨'):
|
71 |
-
try:
|
72 |
-
# Prepare file for sending
|
73 |
-
files = {"file": uploaded_file.getvalue()}
|
74 |
-
|
75 |
-
# Send POST request to FastAPI server
|
76 |
-
response = requests.post("http://127.0.0.1:8000/generate-image/", files=files)
|
77 |
-
|
78 |
-
if response.status_code == 200:
|
79 |
-
# Convert response content to an image
|
80 |
-
generated_image = Image.open(BytesIO(response.content))
|
81 |
-
|
82 |
-
# Display generated image in the center
|
83 |
-
col1, col2, col3 = st.columns([1, 2, 1])
|
84 |
-
with col2:
|
85 |
-
st.image(generated_image, caption="Generated Image β¨", width=300)
|
86 |
-
else:
|
87 |
-
st.error("Error in image generation π’")
|
88 |
-
except requests.ConnectionError:
|
89 |
-
st.error("Unable to connect to the FastAPI server. Please make sure it is running.")
|
90 |
-
|
91 |
-
# FastAPI Section
|
92 |
-
warnings.filterwarnings('ignore')
|
93 |
-
|
94 |
-
# Ensure that Conv2DTranspose is included when loading the model
|
95 |
-
custom_objects = {"Conv2DTranspose": Conv2DTranspose}
|
96 |
|
97 |
-
#
|
98 |
try:
|
99 |
-
generator_model = load_model('model.h5'
|
100 |
-
st.success("Model loaded successfully!
|
101 |
except Exception as e:
|
102 |
-
st.error(f"
|
103 |
|
104 |
-
#
|
105 |
-
|
106 |
-
|
107 |
-
@app.post("/generate-image/")
|
108 |
-
async def generate_image(file: UploadFile = File(...)):
|
109 |
-
contents = await file.read()
|
110 |
-
image = Image.open(io.BytesIO(contents)).convert('RGB')
|
111 |
image = image.resize((256, 256))
|
112 |
|
113 |
# Preprocess image
|
114 |
image_array = np.array(image)
|
115 |
-
image_array = (image_array - 127.5) / 127.5 # Normalize
|
116 |
image_array = np.expand_dims(image_array, axis=0)
|
117 |
|
118 |
# Generate fake image
|
119 |
fake_image = generator_model.predict(image_array)
|
120 |
fake_image = (fake_image + 1) / 2.0 # Rescale to [0, 1]
|
121 |
fake_image = np.squeeze(fake_image)
|
|
|
|
|
|
|
122 |
|
123 |
-
|
124 |
-
|
|
|
125 |
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
img_io = io.BytesIO()
|
128 |
-
|
129 |
img_io.seek(0)
|
130 |
|
131 |
-
return StreamingResponse(img_io, media_type=
|
132 |
|
|
|
133 |
if __name__ == '__main__':
|
134 |
import uvicorn
|
135 |
-
uvicorn.run(app, host=
|
|
|
5 |
from fastapi import FastAPI, File, UploadFile
|
6 |
from fastapi.responses import StreamingResponse
|
7 |
from tensorflow.keras.models import load_model
|
|
|
8 |
import numpy as np
|
9 |
import io
|
10 |
import warnings
|
11 |
|
12 |
+
# Suppress warnings
|
13 |
+
warnings.filterwarnings('ignore')
|
14 |
+
|
15 |
# Set Streamlit page configuration
|
16 |
st.set_page_config(
|
17 |
page_title="Sketch to Image using GAN",
|
18 |
layout="centered",
|
19 |
page_icon="ποΈ",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
)
|
21 |
|
22 |
+
# Title and description
|
23 |
st.markdown("<h1 style='text-align: center; color: #ff6347;'>Sketch to Image using GAN ποΈ</h1>", unsafe_allow_html=True)
|
24 |
+
st.markdown("<h2 style='text-align: center; color: #ff6347;'>Upload your sketch to generate an image!</h2>", unsafe_allow_html=True)
|
25 |
|
26 |
+
# Upload file widget
|
27 |
+
uploaded_file = st.file_uploader("Upload a sketch (jpg, jpeg, png):", type=["jpg", "jpeg", "png"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
# Model Loading
|
30 |
try:
|
31 |
+
generator_model = load_model('model.h5') # Update this path to your actual model file
|
32 |
+
st.success("Model loaded successfully!")
|
33 |
except Exception as e:
|
34 |
+
st.error(f"Error loading the model: {str(e)}")
|
35 |
|
36 |
+
# Image processing function
|
37 |
+
def process_and_generate_image(image_data):
|
38 |
+
image = Image.open(io.BytesIO(image_data)).convert('RGB')
|
|
|
|
|
|
|
|
|
39 |
image = image.resize((256, 256))
|
40 |
|
41 |
# Preprocess image
|
42 |
image_array = np.array(image)
|
43 |
+
image_array = (image_array - 127.5) / 127.5 # Normalize to [-1, 1]
|
44 |
image_array = np.expand_dims(image_array, axis=0)
|
45 |
|
46 |
# Generate fake image
|
47 |
fake_image = generator_model.predict(image_array)
|
48 |
fake_image = (fake_image + 1) / 2.0 # Rescale to [0, 1]
|
49 |
fake_image = np.squeeze(fake_image)
|
50 |
+
fake_image = (fake_image * 255).astype(np.uint8)
|
51 |
+
|
52 |
+
return Image.fromarray(fake_image)
|
53 |
|
54 |
+
# Display uploaded image and handle generation
|
55 |
+
if uploaded_file is not None:
|
56 |
+
st.image(uploaded_file, caption="Uploaded Sketch", width=300)
|
57 |
|
58 |
+
if st.button("Generate Image"):
|
59 |
+
with st.spinner('Generating...'):
|
60 |
+
try:
|
61 |
+
# Generate the image
|
62 |
+
generated_image = process_and_generate_image(uploaded_file.getvalue())
|
63 |
+
|
64 |
+
# Display the generated image
|
65 |
+
st.image(generated_image, caption="Generated Image", width=300)
|
66 |
+
except Exception as e:
|
67 |
+
st.error(f"Error generating image: {str(e)}")
|
68 |
+
|
69 |
+
# FastAPI app for backend
|
70 |
+
app = FastAPI()
|
71 |
+
|
72 |
+
@app.post("/generate-image/")
|
73 |
+
async def generate_image(file: UploadFile = File(...)):
|
74 |
+
contents = await file.read()
|
75 |
+
generated_image = process_and_generate_image(contents)
|
76 |
+
|
77 |
img_io = io.BytesIO()
|
78 |
+
generated_image.save(img_io, 'JPEG')
|
79 |
img_io.seek(0)
|
80 |
|
81 |
+
return StreamingResponse(img_io, media_type="image/jpeg")
|
82 |
|
83 |
+
# Running FastAPI app if script is executed directly
|
84 |
if __name__ == '__main__':
|
85 |
import uvicorn
|
86 |
+
uvicorn.run(app, host="127.0.0.1", port=8000)
|