usmanyousaf commited on
Commit
460f6cf
Β·
verified Β·
1 Parent(s): 93402af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -93
app.py CHANGED
@@ -5,131 +5,82 @@ from io import BytesIO
5
  from fastapi import FastAPI, File, UploadFile
6
  from fastapi.responses import StreamingResponse
7
  from tensorflow.keras.models import load_model
8
- from tensorflow.keras.layers import Conv2DTranspose # Import Conv2DTranspose
9
  import numpy as np
10
  import io
11
  import warnings
12
 
 
 
 
13
  # Set Streamlit page configuration
14
  st.set_page_config(
15
  page_title="Sketch to Image using GAN",
16
  layout="centered",
17
  page_icon="πŸ–ŒοΈ",
18
- initial_sidebar_state="expanded",
19
- )
20
-
21
- # Custom CSS for styling
22
- st.markdown(
23
- """
24
- <style>
25
- body {
26
- background-color: #2a2a2a; /* Cool dark background */
27
- color:#ffffff; /* White text */
28
- font-family: 'Courier New', monospace; /* Cool font */
29
- }
30
- h1, h2 {
31
- color: #ff6347; /* Tomato color for titles */
32
- font-weight: bold; /* Bold text */
33
- }
34
- .stButton>button {
35
- color: #ffffff;
36
- background-color: #ff6347; /* Tomato color for buttons */
37
- border-radius: 10px;
38
- border: 2px solid #ffffff;
39
- font-weight: bold; /* Bold text */
40
- }
41
- </style>
42
- """,
43
- unsafe_allow_html=True,
44
  )
45
 
46
- # Title with colors and emojis
47
  st.markdown("<h1 style='text-align: center; color: #ff6347;'>Sketch to Image using GAN πŸ–ŒοΈ</h1>", unsafe_allow_html=True)
 
48
 
49
- # Description Section
50
- st.markdown("<h2 style='text-align: center; color: #ff6347;'>Empowering Multiple Fields with GANs 🌐</h2>", unsafe_allow_html=True)
51
-
52
- # Display Logo Image
53
- logo_image = Image.open("home1.jpeg") # Make sure to replace with the actual path
54
- st.image(logo_image, width=300)
55
-
56
- st.write("The application of Generative Adversarial Networks (GANs) in the Sketch to Image project extends beyond creative endeavors, finding significant utility in various fields. The ability to transform sketches into vibrant and detailed images has far-reaching implications, especially in sectors such as law enforcement, forensic science, and more.")
57
-
58
- # Upload Section
59
- st.markdown("<h2 style='text-align: center; color: #ff6347;'>Upload Your Sketch πŸ“€</h2>", unsafe_allow_html=True)
60
- uploaded_file = st.file_uploader("Choose an image... πŸ“€", type=["jpg", "jpeg", "png"])
61
-
62
- if uploaded_file is not None:
63
- # Display uploaded image in the center
64
- col1, col2, col3 = st.columns([1, 2, 1])
65
- with col2:
66
- st.image(uploaded_file, caption="Uploaded Image πŸ–ΌοΈ", width=300)
67
-
68
- # Button to generate the image
69
- if st.button('Generate πŸš€'):
70
- with st.spinner('Wait for it... Generating your image 🎨'):
71
- try:
72
- # Prepare file for sending
73
- files = {"file": uploaded_file.getvalue()}
74
-
75
- # Send POST request to FastAPI server
76
- response = requests.post("http://127.0.0.1:8000/generate-image/", files=files)
77
-
78
- if response.status_code == 200:
79
- # Convert response content to an image
80
- generated_image = Image.open(BytesIO(response.content))
81
-
82
- # Display generated image in the center
83
- col1, col2, col3 = st.columns([1, 2, 1])
84
- with col2:
85
- st.image(generated_image, caption="Generated Image ✨", width=300)
86
- else:
87
- st.error("Error in image generation 😒")
88
- except requests.ConnectionError:
89
- st.error("Unable to connect to the FastAPI server. Please make sure it is running.")
90
-
91
- # FastAPI Section
92
- warnings.filterwarnings('ignore')
93
-
94
- # Ensure that Conv2DTranspose is included when loading the model
95
- custom_objects = {"Conv2DTranspose": Conv2DTranspose}
96
 
97
- # Load generator model (you may need to update the model path)
98
  try:
99
- generator_model = load_model('model.h5', custom_objects=custom_objects)
100
- st.success("Model loaded successfully! πŸŽ‰")
101
  except Exception as e:
102
- st.error(f"Failed to load model: {str(e)}")
103
 
104
- # FastAPI application
105
- app = FastAPI()
106
-
107
- @app.post("/generate-image/")
108
- async def generate_image(file: UploadFile = File(...)):
109
- contents = await file.read()
110
- image = Image.open(io.BytesIO(contents)).convert('RGB')
111
  image = image.resize((256, 256))
112
 
113
  # Preprocess image
114
  image_array = np.array(image)
115
- image_array = (image_array - 127.5) / 127.5 # Normalize between -1 and 1
116
  image_array = np.expand_dims(image_array, axis=0)
117
 
118
  # Generate fake image
119
  fake_image = generator_model.predict(image_array)
120
  fake_image = (fake_image + 1) / 2.0 # Rescale to [0, 1]
121
  fake_image = np.squeeze(fake_image)
 
 
 
122
 
123
- fake_image = (fake_image * 255).astype(np.uint8) # Convert to uint8
124
- fake_image = Image.fromarray(fake_image)
 
125
 
126
- # Prepare image for streaming
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  img_io = io.BytesIO()
128
- fake_image.save(img_io, 'JPEG', quality=70)
129
  img_io.seek(0)
130
 
131
- return StreamingResponse(img_io, media_type='image/jpeg')
132
 
 
133
  if __name__ == '__main__':
134
  import uvicorn
135
- uvicorn.run(app, host='127.0.0.1', port=8000)
 
5
  from fastapi import FastAPI, File, UploadFile
6
  from fastapi.responses import StreamingResponse
7
  from tensorflow.keras.models import load_model
 
8
  import numpy as np
9
  import io
10
  import warnings
11
 
12
+ # Suppress warnings
13
+ warnings.filterwarnings('ignore')
14
+
15
  # Set Streamlit page configuration
16
  st.set_page_config(
17
  page_title="Sketch to Image using GAN",
18
  layout="centered",
19
  page_icon="πŸ–ŒοΈ",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  )
21
 
22
+ # Title and description
23
  st.markdown("<h1 style='text-align: center; color: #ff6347;'>Sketch to Image using GAN πŸ–ŒοΈ</h1>", unsafe_allow_html=True)
24
+ st.markdown("<h2 style='text-align: center; color: #ff6347;'>Upload your sketch to generate an image!</h2>", unsafe_allow_html=True)
25
 
26
+ # Upload file widget
27
+ uploaded_file = st.file_uploader("Upload a sketch (jpg, jpeg, png):", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Model Loading
30
  try:
31
+ generator_model = load_model('model.h5') # Update this path to your actual model file
32
+ st.success("Model loaded successfully!")
33
  except Exception as e:
34
+ st.error(f"Error loading the model: {str(e)}")
35
 
36
+ # Image processing function
37
+ def process_and_generate_image(image_data):
38
+ image = Image.open(io.BytesIO(image_data)).convert('RGB')
 
 
 
 
39
  image = image.resize((256, 256))
40
 
41
  # Preprocess image
42
  image_array = np.array(image)
43
+ image_array = (image_array - 127.5) / 127.5 # Normalize to [-1, 1]
44
  image_array = np.expand_dims(image_array, axis=0)
45
 
46
  # Generate fake image
47
  fake_image = generator_model.predict(image_array)
48
  fake_image = (fake_image + 1) / 2.0 # Rescale to [0, 1]
49
  fake_image = np.squeeze(fake_image)
50
+ fake_image = (fake_image * 255).astype(np.uint8)
51
+
52
+ return Image.fromarray(fake_image)
53
 
54
+ # Display uploaded image and handle generation
55
+ if uploaded_file is not None:
56
+ st.image(uploaded_file, caption="Uploaded Sketch", width=300)
57
 
58
+ if st.button("Generate Image"):
59
+ with st.spinner('Generating...'):
60
+ try:
61
+ # Generate the image
62
+ generated_image = process_and_generate_image(uploaded_file.getvalue())
63
+
64
+ # Display the generated image
65
+ st.image(generated_image, caption="Generated Image", width=300)
66
+ except Exception as e:
67
+ st.error(f"Error generating image: {str(e)}")
68
+
69
+ # FastAPI app for backend
70
+ app = FastAPI()
71
+
72
+ @app.post("/generate-image/")
73
+ async def generate_image(file: UploadFile = File(...)):
74
+ contents = await file.read()
75
+ generated_image = process_and_generate_image(contents)
76
+
77
  img_io = io.BytesIO()
78
+ generated_image.save(img_io, 'JPEG')
79
  img_io.seek(0)
80
 
81
+ return StreamingResponse(img_io, media_type="image/jpeg")
82
 
83
+ # Running FastAPI app if script is executed directly
84
  if __name__ == '__main__':
85
  import uvicorn
86
+ uvicorn.run(app, host="127.0.0.1", port=8000)