usmanyousaf commited on
Commit
bed4d70
Β·
verified Β·
1 Parent(s): f83f699

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -20
app.py CHANGED
@@ -5,9 +5,7 @@ from io import BytesIO
5
  from fastapi import FastAPI, File, UploadFile
6
  from fastapi.responses import StreamingResponse
7
  from tensorflow.keras.models import load_model
8
- from tensorflow.keras.utils import get_custom_objects
9
-
10
-
11
  import numpy as np
12
  import io
13
  import warnings
@@ -41,7 +39,7 @@ st.markdown(
41
  font-weight: bold; /* Bold text */
42
  }
43
  </style>
44
- """,
45
  unsafe_allow_html=True,
46
  )
47
 
@@ -51,38 +49,37 @@ st.markdown("<h1 style='text-align: center; color: #ff6347;'>Sketch to Image usi
51
  # Description Section
52
  st.markdown("<h2 style='text-align: center; color: #ff6347;'>Empowering Multiple Fields with GANs 🌐</h2>", unsafe_allow_html=True)
53
 
54
- # Logo Image
55
- logo_image = Image.open("home1.jpeg")
56
  st.image(logo_image, width=300)
57
 
58
  st.write("The application of Generative Adversarial Networks (GANs) in the Sketch to Image project extends beyond creative endeavors, finding significant utility in various fields. The ability to transform sketches into vibrant and detailed images has far-reaching implications, especially in sectors such as law enforcement, forensic science, and more.")
59
 
60
- # Upload Pic Section
61
  st.markdown("<h2 style='text-align: center; color: #ff6347;'>Upload Your Sketch πŸ“€</h2>", unsafe_allow_html=True)
62
  uploaded_file = st.file_uploader("Choose an image... πŸ“€", type=["jpg", "jpeg", "png"])
63
 
64
  if uploaded_file is not None:
65
- # Display the uploaded image in the center
66
  col1, col2, col3 = st.columns([1, 2, 1])
67
  with col2:
68
  st.image(uploaded_file, caption="Uploaded Image πŸ–ΌοΈ", width=300)
69
 
70
- # Button to generate the image with emoji
71
  if st.button('Generate πŸš€'):
72
- # Display a message while generating the image
73
  with st.spinner('Wait for it... Generating your image 🎨'):
74
  try:
75
- # Prepare the file for sending
76
  files = {"file": uploaded_file.getvalue()}
77
-
78
  # Send POST request to FastAPI server
79
  response = requests.post("http://127.0.0.1:8000/generate-image/", files=files)
80
 
81
  if response.status_code == 200:
82
- # Convert the response content to an image
83
  generated_image = Image.open(BytesIO(response.content))
84
-
85
- # Display the generated image in the center
86
  col1, col2, col3 = st.columns([1, 2, 1])
87
  with col2:
88
  st.image(generated_image, caption="Generated Image ✨", width=300)
@@ -93,9 +90,18 @@ if uploaded_file is not None:
93
 
94
  # FastAPI Section
95
  warnings.filterwarnings('ignore')
96
- #generator_model = load_model('model.h5') # Update this with your generator model's path
 
97
  custom_objects = {"Conv2DTranspose": Conv2DTranspose}
98
- generator_model = load_model('model.h5', custom_objects=custom_objects)
 
 
 
 
 
 
 
 
99
  app = FastAPI()
100
 
101
  @app.post("/generate-image/")
@@ -104,17 +110,20 @@ async def generate_image(file: UploadFile = File(...)):
104
  image = Image.open(io.BytesIO(contents)).convert('RGB')
105
  image = image.resize((256, 256))
106
 
 
107
  image_array = np.array(image)
108
- image_array = (image_array - 127.5) / 127.5
109
  image_array = np.expand_dims(image_array, axis=0)
110
 
 
111
  fake_image = generator_model.predict(image_array)
112
- fake_image = (fake_image + 1) / 2.0
113
  fake_image = np.squeeze(fake_image)
114
 
115
- fake_image = (fake_image * 255).astype(np.uint8)
116
  fake_image = Image.fromarray(fake_image)
117
 
 
118
  img_io = io.BytesIO()
119
  fake_image.save(img_io, 'JPEG', quality=70)
120
  img_io.seek(0)
 
5
  from fastapi import FastAPI, File, UploadFile
6
  from fastapi.responses import StreamingResponse
7
  from tensorflow.keras.models import load_model
8
+ from tensorflow.keras.layers import Conv2DTranspose
 
 
9
  import numpy as np
10
  import io
11
  import warnings
 
39
  font-weight: bold; /* Bold text */
40
  }
41
  </style>
42
+ """,
43
  unsafe_allow_html=True,
44
  )
45
 
 
49
  # Description Section
50
  st.markdown("<h2 style='text-align: center; color: #ff6347;'>Empowering Multiple Fields with GANs 🌐</h2>", unsafe_allow_html=True)
51
 
52
+ # Display Logo Image
53
+ logo_image = Image.open("home1.jpeg") # Make sure to replace with the actual path
54
  st.image(logo_image, width=300)
55
 
56
  st.write("The application of Generative Adversarial Networks (GANs) in the Sketch to Image project extends beyond creative endeavors, finding significant utility in various fields. The ability to transform sketches into vibrant and detailed images has far-reaching implications, especially in sectors such as law enforcement, forensic science, and more.")
57
 
58
+ # Upload Section
59
  st.markdown("<h2 style='text-align: center; color: #ff6347;'>Upload Your Sketch πŸ“€</h2>", unsafe_allow_html=True)
60
  uploaded_file = st.file_uploader("Choose an image... πŸ“€", type=["jpg", "jpeg", "png"])
61
 
62
  if uploaded_file is not None:
63
+ # Display uploaded image in the center
64
  col1, col2, col3 = st.columns([1, 2, 1])
65
  with col2:
66
  st.image(uploaded_file, caption="Uploaded Image πŸ–ΌοΈ", width=300)
67
 
68
+ # Button to generate the image
69
  if st.button('Generate πŸš€'):
 
70
  with st.spinner('Wait for it... Generating your image 🎨'):
71
  try:
72
+ # Prepare file for sending
73
  files = {"file": uploaded_file.getvalue()}
74
+
75
  # Send POST request to FastAPI server
76
  response = requests.post("http://127.0.0.1:8000/generate-image/", files=files)
77
 
78
  if response.status_code == 200:
79
+ # Convert response content to an image
80
  generated_image = Image.open(BytesIO(response.content))
81
+
82
+ # Display generated image in the center
83
  col1, col2, col3 = st.columns([1, 2, 1])
84
  with col2:
85
  st.image(generated_image, caption="Generated Image ✨", width=300)
 
90
 
91
  # FastAPI Section
92
  warnings.filterwarnings('ignore')
93
+
94
+ # Ensure that Conv2DTranspose is included when loading the model
95
  custom_objects = {"Conv2DTranspose": Conv2DTranspose}
96
+
97
+ # Load generator model (you may need to update the model path)
98
+ try:
99
+ generator_model = load_model('model.h5', custom_objects=custom_objects)
100
+ st.success("Model loaded successfully! πŸŽ‰")
101
+ except Exception as e:
102
+ st.error(f"Failed to load model: {str(e)}")
103
+
104
+ # FastAPI application
105
  app = FastAPI()
106
 
107
  @app.post("/generate-image/")
 
110
  image = Image.open(io.BytesIO(contents)).convert('RGB')
111
  image = image.resize((256, 256))
112
 
113
+ # Preprocess image
114
  image_array = np.array(image)
115
+ image_array = (image_array - 127.5) / 127.5 # Normalize between -1 and 1
116
  image_array = np.expand_dims(image_array, axis=0)
117
 
118
+ # Generate fake image
119
  fake_image = generator_model.predict(image_array)
120
+ fake_image = (fake_image + 1) / 2.0 # Rescale to [0, 1]
121
  fake_image = np.squeeze(fake_image)
122
 
123
+ fake_image = (fake_image * 255).astype(np.uint8) # Convert to uint8
124
  fake_image = Image.fromarray(fake_image)
125
 
126
+ # Prepare image for streaming
127
  img_io = io.BytesIO()
128
  fake_image.save(img_io, 'JPEG', quality=70)
129
  img_io.seek(0)