dmolino commited on
Commit
d209127
Β·
verified Β·
1 Parent(s): 9b97177

Create demo_inference

Browse files
Files changed (1) hide show
  1. demo_inference +567 -0
demo_inference ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tifffile
3
+ import pydicom
4
+ from scipy.ndimage import zoom
5
+ import torch
6
+ from core.models.dani_model import dani_model
7
+ import numpy as np
8
+ from PIL import Image
9
+ import base64
10
+ import time
11
+
12
+
13
+ # Funzione per convertire un'immagine in base64
14
+ def image_to_base64(image_path):
15
+ with open(image_path, "rb") as img_file:
16
+ return base64.b64encode(img_file.read()).decode()
17
+
18
+
19
+ st.markdown("""
20
+ <style>
21
+ @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
22
+ /* Apply the font to everything */
23
+ html, body, [class*="st"] {
24
+ font-family: 'Roboto', sans-serif;
25
+ }
26
+ </style>
27
+ """, unsafe_allow_html=True)
28
+
29
+
30
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
+
32
+ # Dati di esempio predefiniti
33
+ esempi = {
34
+ "Frontal βž” Lateral": {'Frontal': 'FtoL.png', 'Lateral': 'LfromF.png'},
35
+ "Frontal βž” Report": {'Frontal': '31d9847f-987fcf63-704f7496-d2b21eb8-63cd973e.tiff', 'Report': 'Small bilateral pleural effusions, left greater than right.'},
36
+ "Frontal βž” Lateral + Report": {'Frontal': '81bca127-0c416084-67f8033c-ecb26476-6d1ecf60.tiff', 'Lateral': 'd52a0c5c-bb7104b0-b1d821a5-959984c3-33c04ccb.tiff', 'Report': 'No acute intrathoracic process. Heart Size is normal. Lungs are clear. No pneumothorax'},
37
+ "Lateral βž” Frontal": {'Lateral': 'LtoF.png', 'Frontal': 'FfromL.png'},
38
+ "Lateral βž” Report": {'Lateral': 'd52a0c5c-bb7104b0-b1d821a5-959984c3-33c04ccb.tiff', 'Report': 'no acute cardiopulmonary process. if concern for injury persists, a dedicated rib series with markers would be necessary to ensure no rib fractures.'},
39
+ "Lateral βž” Frontal + Report": {'Lateral': 'reald52a0c5c-bb7104b0-b1d821a5-959984c3-33c04ccb.tiff', 'Frontal': 'ab37274f-b4c1fc04-e2ff24b4-4a130ba3-cd167968.tiff', 'Report': 'No acute intrathoracic process. If there is strong concern for rib fracture, a dedicated rib series may be performed.'},
40
+ "Report βž” Frontal": {'Report': 'Left lung opacification which may reflect pneumonia superimposed on metastatic disease.', 'Frontal': '02aa804e-bde0afdd-112c0b34-7bc16630-4e384014.tiff'},
41
+ "Report βž” Lateral": {'Report': 'Bilateral pleural effusions, cardiomegaly and mild edema suggest fluid overload.', 'Lateral': '489faba7-a9dc5f1d-fd7241d6-9638d855-eaa952b1.tiff'},
42
+ "Report βž” Frontal + Lateral": {'Report': 'No acute intrathoracic process. The lungs are clean and heart is normal size.', 'Frontal': 'f27ba7cd-44486c2e-29f3e890-f2b9f94e-84110448.tiff', 'Lateral': 'b20c9570-de77944a-b8604ba0-73305a7b-d608a72b.tiff'},
43
+ "Frontal + Lateral βž” Report": {'Frontal': '95856dd1-5878b5b1-9c104817-760c0122-6187946f.tiff', 'Lateral': '3723d912-71940d69-4fef2dd2-27af5a7b-127ba20c.tiff', 'Report': 'Opacities in the right upper or middle lobe, maybe early pneumonia.'},
44
+ "Frontal + Report βž” Lateral": {'Frontal': 'e7f21453-7956d79a-44e44614-fae8ff16-d174d1a0.tiff', 'Report': 'No focal consolidation.', 'Lateral': '8037e6b9-06367464-a4ccd63a-5c5c5a81-ce3e7ffc.tiff'},
45
+ "Lateral + Report βž” Frontal": {'Lateral': '02c66644-b1883a91-54aed0e7-62d25460-398f9865.tiff', 'Report': 'No evidence of acute cardiopulmonary process.', 'Frontal': 'b1f169f1-12177dd5-2fa1c4b1-7b816311-85d769e9.tiff'}
46
+ }
47
+
48
+
49
+ # CSS per personalizzare il tema
50
+ st.markdown("""
51
+ <style>
52
+ /* Sfondo scuro */
53
+ body {
54
+ background-color: #121212;
55
+ color: white;
56
+ }
57
+ /* Personalizzazione del titolo */
58
+ .title {
59
+ font-size: 35px !important;
60
+ font-weight: bold;
61
+ color: #f63366;
62
+ }
63
+ /* Personalizzazione dei sottotitoli e testi principali */
64
+ .stText, .stButton, .stMarkdown {
65
+ font-size: 18px !important;
66
+ }
67
+ </style>
68
+ """, unsafe_allow_html=True)
69
+
70
+
71
+ # Sostituisci questo con il link dell'immagine online
72
+ logo_1_path = "./DEMO/Loghi/Logo_UCBM.png" # Sostituisci con il percorso del primo logo
73
+ logo_2_path = "./DEMO/Loghi/Logo UmU.png" # Sostituisci con il percorso del secondo logo
74
+ logo_3_path = "./DEMO/Loghi/Logo COSBI.png" # Sostituisci con il percorso del terzo logo
75
+ logo_4_path = "./DEMO/Loghi/logo trasparent.png" # Sostituisci con il percorso del quarto logo
76
+ # Converti le immagini in base64
77
+ logo_1_base64 = image_to_base64(logo_1_path)
78
+ logo_2_base64 = image_to_base64(logo_2_path)
79
+ logo_3_base64 = image_to_base64(logo_3_path)
80
+ logo_4_base64 = image_to_base64(logo_4_path)
81
+
82
+ # CSS per posizionare i loghi in basso a destra e renderli piccoli
83
+ st.markdown(f"""
84
+ <style>
85
+ .footer {{
86
+ position: fixed;
87
+ bottom: 20px;
88
+ right: 20px;
89
+ z-index: 100;
90
+ display: flex;
91
+ gap: 10px; /* Spazio tra i loghi */
92
+ }}
93
+ .footer img {{
94
+ height: 60px; /* Altezza dei loghi */
95
+ width: auto; /* Mantiene il rapporto di aspetto originale */
96
+ }}
97
+ </style>
98
+ <div class="footer">
99
+ <img src="data:image/png;base64,{logo_1_base64}" alt="Logo 1">
100
+ <img src="data:image/png;base64,{logo_2_base64}" alt="Logo 2">
101
+ <img src="data:image/png;base64,{logo_3_base64}" alt="Logo 3">
102
+ <img src="data:image/png;base64,{logo_4_base64}" alt="Logo 4">
103
+ </div>
104
+ """, unsafe_allow_html=True)
105
+
106
+ # Inizializzazione dello stato della sessione
107
+ if 'step' not in st.session_state:
108
+ st.session_state['step'] = 1
109
+ if 'selected_option' not in st.session_state:
110
+ st.session_state['selected_option'] = None
111
+ if 'frontal_file' not in st.session_state:
112
+ st.session_state['frontal_file'] = None
113
+ if 'lateral_file' not in st.session_state:
114
+ st.session_state['lateral_file'] = None
115
+ if 'report' not in st.session_state:
116
+ st.session_state['report'] = ""
117
+ if 'inputs' not in st.session_state:
118
+ st.session_state['inputs'] = None
119
+ if 'outputs' not in st.session_state:
120
+ st.session_state['outputs'] = None
121
+ if 'frontal' not in st.session_state:
122
+ st.session_state['frontal'] = None
123
+ if 'lateral' not in st.session_state:
124
+ st.session_state['lateral'] = None
125
+ if 'report' not in st.session_state:
126
+ st.session_state['report'] = ""
127
+ if 'generate' not in st.session_state:
128
+ st.session_state['generate'] = False
129
+
130
+ # Inizializza inference_tester solo una volta
131
+ if 'inference_tester' not in st.session_state:
132
+ model_load_paths = ['CoDi_encoders.pth', 'CoDi_text_diffuser.pth', 'CoDi_video_diffuser_8frames.pth']
133
+ st.session_state['inference_tester'] = dani_model(model='thesis_model',
134
+ data_dir='/mimer/NOBACKUP/groups/snic2022-5-277/dmolino/checkpoints/',
135
+ pth=model_load_paths, load_weights=False)
136
+ inference_tester = st.session_state['inference_tester']
137
+
138
+ # Caricamento dei pesi Clip, Optimus, Frontal, Lateral e Text una sola volta
139
+ if 'weights_loaded' not in st.session_state:
140
+ st.session_state['weights_loaded'] = True # Indica che i pesi sono stati caricati
141
+
142
+ # Usa inference_tester dalla sessione
143
+ inference_tester = st.session_state['inference_tester']
144
+
145
+
146
+ st.markdown('<h1 style="text-align: center" class="title">MedCoDi-M</h1>', unsafe_allow_html=True)
147
+
148
+ if st.session_state['step'] == 1:
149
+ # Breve descrizione del lavoro
150
+ st.markdown("""
151
+ <div style='text-align: justify; font-size: 18px; line-height: 1.6;'>
152
+ This work introduces MedCoDi-M, a novel multi-prompt foundation model for multi-modal medical data generation.
153
+ In this demo, you will be able to perform various generation tasks including frontal and lateral chest X-rays and clinical report generation.
154
+ MedCoDi-M enables flexible, any-to-any generation across different medical data modalities, utilizing contrastive learning and a modular approach for enhanced performance.
155
+ </div>
156
+ """, unsafe_allow_html=True)
157
+
158
+ # lasciamo un po' di spazio
159
+ st.markdown('<br>', unsafe_allow_html=True)
160
+
161
+ # Immagine con didascalia migliorata e con dimensione della caption aumentata
162
+ image_path = "./DEMO/Loghi/model_final.png" # Sostituisci con il percorso della tua immagine
163
+ st.image(image_path, caption='', use_container_width=True)
164
+
165
+ # Caption con dimensione del testo migliorata
166
+ st.markdown("""
167
+ <div style='text-align: center; font-size: 16px; font-style: italic; margin-top: 10px;'>
168
+ Framework of MedCoDi-M: This demo allows you to generate frontal and lateral chest X-rays, as well as medical reports, through the MedCoDi-M model.
169
+ </div>
170
+ """, unsafe_allow_html=True)
171
+
172
+ # lasciamo un po' di spazio
173
+ st.markdown('<br>', unsafe_allow_html=True)
174
+
175
+ # Bottone con testo "Try it out"
176
+ if st.button("Try it out!"):
177
+ st.session_state['step'] = 2
178
+ st.rerun()
179
+
180
+
181
+ # Fase 1: Selezione dell'opzione
182
+ if st.session_state['step'] == 2:
183
+ # Opzioni disponibili
184
+ options = [
185
+ "Frontal βž” Lateral", "Frontal βž” Report", "Frontal βž” Lateral + Report",
186
+ "Lateral βž” Frontal", "Lateral βž” Report", "Lateral βž” Frontal + Report",
187
+ "Report βž” Frontal", "Report βž” Lateral", "Report βž” Frontal + Lateral",
188
+ "Frontal + Lateral βž” Report", "Frontal + Report βž” Lateral", "Lateral + Report βž” Frontal"
189
+ ]
190
+
191
+ # Messaggio di selezione con dimensione aumentata
192
+ st.markdown(
193
+ "<h4 style='text-align: justify'><strong>Select the type of generation you want to perform:</strong></h4>",
194
+ unsafe_allow_html=True)
195
+
196
+ # Aumentare la dimensione di "Please select an option:"
197
+ st.markdown(
198
+ "<h4 style='text-align: justify'><strong>Please select an option:</strong></h4>",
199
+ unsafe_allow_html=True)
200
+
201
+ # Reset esplicito del valore di `selectbox` in caso di reset
202
+ st.session_state['selected_option'] = st.selectbox(
203
+ "", options, key='selectbox_option', index=0) # Rimuoviamo il testo dal selectbox
204
+
205
+ st.markdown('<br>', unsafe_allow_html=True)
206
+
207
+ # Creiamo colonne per i pulsanti
208
+ col1, col2, col3 = st.columns(3)
209
+
210
+ # Pulsante per provare un esempio
211
+ with col1:
212
+ if st.button("Inference"):
213
+ st.session_state['step'] = 3 # Passa al passo 3
214
+ st.rerun()
215
+
216
+ # Pulsante per provare un esempio
217
+ with col2:
218
+ if st.button("Try an example"):
219
+ st.session_state['step'] = 5 # Passa al passo 5
220
+ st.rerun()
221
+
222
+ # Pulsante per tornare all'inizio
223
+ with col3:
224
+ if st.button("Return to the beginning"):
225
+ # Ripristina lo stato della sessione
226
+ st.session_state['step'] = 1
227
+ st.session_state['selected_option'] = None
228
+ st.session_state['selected_option2'] = None
229
+ st.session_state['frontal_file'] = None
230
+ st.session_state['lateral_file'] = None
231
+ st.session_state['report'] = ""
232
+ st.rerun()
233
+
234
+
235
+ # Fase 2: Caricamento file
236
+ if st.session_state['step'] == 3:
237
+ st.markdown(
238
+ f"<h4 style='text-align: justify'><strong>You selected: {st.session_state['selected_option']}. Now, please upload the required files below:</strong></h4>",
239
+ unsafe_allow_html=True)
240
+
241
+ # Carica l'immagine frontale
242
+ if "Frontal" in st.session_state['selected_option'].split(" βž”")[0]:
243
+ st.markdown("<h5 style='font-size: 18px;'>Load the Frontal X-ray in DICOM format</h5>", unsafe_allow_html=True)
244
+ st.session_state['frontal_file'] = st.file_uploader("", type=["dcm"])
245
+
246
+ # Carica l'immagine laterale
247
+ if "Lateral" in st.session_state['selected_option'].split(" βž”")[0]:
248
+ st.markdown("<h5 style='font-size: 18px;'>Load the Lateral X-ray in DICOM format</h5>", unsafe_allow_html=True)
249
+ st.session_state['lateral_file'] = st.file_uploader("", type=["dcm"])
250
+
251
+ # Inserisci il report clinico
252
+ if "Report" in st.session_state['selected_option'].split(" βž”")[0]:
253
+ st.markdown("<h5 style='font-size: 18px;'>Type the clinical report</h5>", unsafe_allow_html=True)
254
+ st.session_state['report'] = st.text_area("", value=st.session_state['report'])
255
+
256
+ # lasciamo un po' di spazio
257
+ st.markdown('<br>', unsafe_allow_html=True)
258
+
259
+ # Creare colonne per allineare i pulsanti in orizzontale
260
+ col1, col2 = st.columns(2)
261
+
262
+ with col1:
263
+ if st.button("Start Generation"):
264
+ frontal = None
265
+ lateral = None
266
+ report = None
267
+ # Dato che questo step Γ¨ velocissimo, prima di procedere mettiamo una finta barra di caricamento di 3 secondi
268
+ with st.spinner("Preprocessing the data..."):
269
+ time.sleep(3)
270
+ # Controllo che i file necessari siano stati caricati
271
+ if "Frontal" in st.session_state['selected_option'].split(" βž”")[0] and not st.session_state['frontal_file']:
272
+ st.error("Load the Frontal image.")
273
+ elif "Lateral" in st.session_state['selected_option'].split(" βž”")[0] and not st.session_state['lateral_file']:
274
+ st.error("Load the Lateral image.")
275
+ elif "Report" in st.session_state['selected_option'].split(" βž”")[0] and not st.session_state['report']:
276
+ st.error("Type the clinical report.")
277
+ else:
278
+ st.write(f"Execution of: {st.session_state['selected_option']}")
279
+
280
+ # Carica l'immagine e avvia l'inferenza
281
+ if st.session_state['frontal_file']:
282
+ dicom = pydicom.dcmread(st.session_state['frontal_file'])
283
+ image = dicom.pixel_array
284
+ if dicom.PhotometricInterpretation == 'MONOCHROME1':
285
+ image = (2 ** dicom.BitsStored - 1) - image
286
+ if dicom.ImagerPixelSpacing != [0.139, 0.139]:
287
+ zoom_factor = [0.139 / dicom.ImagerPixelSpacing[0], 0.139 / dicom.ImagerPixelSpacing[1]]
288
+ image = zoom(image, zoom_factor)
289
+ image = image / (2 ** dicom.BitsStored - 1)
290
+ # Se l'immagine non Γ¨ quadrata, facciamo padding
291
+ if image.shape[0] != image.shape[1]:
292
+ diff = abs(image.shape[0] - image.shape[1])
293
+ pad_size = diff // 2
294
+ if image.shape[0] > image.shape[1]:
295
+ padded_image = np.pad(image, ((0, 0), (pad_size, pad_size)))
296
+ else:
297
+ padded_image = np.pad(image, ((pad_size, pad_size), (0, 0)))
298
+ # Resizing a 256x256 e a 512x512
299
+ zoom_factor = [256 / padded_image.shape[0], 256 / padded_image.shape[1]]
300
+ image_256 = zoom(padded_image, zoom_factor)
301
+ frontal = image_256
302
+ if frontal.dtype != np.uint8:
303
+ frontal2 = (255 * (frontal - frontal.min()) / (frontal.max() - frontal.min())).astype(np.uint8)
304
+ frontal = torch.tensor(frontal, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
305
+ frontal2 = Image.fromarray(frontal2)
306
+ st.write("Frontal Image loaded successfully!")
307
+ # Mostra l'immagine caricata
308
+ st.image(frontal2, caption="Frontal Image Loaded", use_container_width=True)
309
+ if st.session_state['lateral_file']:
310
+ dicom = pydicom.dcmread(st.session_state['lateral_file'])
311
+ image = dicom.pixel_array
312
+ if dicom.PhotometricInterpretation == 'MONOCHROME1':
313
+ image = (2 ** dicom.BitsStored - 1) - image
314
+ if dicom.ImagerPixelSpacing != [0.139, 0.139]:
315
+ zoom_factor = [0.139 / dicom.ImagerPixelSpacing[0], 0.139 / dicom.ImagerPixelSpacing[1]]
316
+ image = zoom(image, zoom_factor)
317
+ image = image / (2 ** dicom.BitsStored - 1)
318
+ # Se l'immagine non Γ¨ quadrata, facciamo padding
319
+ if image.shape[0] != image.shape[1]:
320
+ diff = abs(image.shape[0] - image.shape[1])
321
+ pad_size = diff // 2
322
+ if image.shape[0] > image.shape[1]:
323
+ padded_image = np.pad(image, ((0, 0), (pad_size, pad_size)))
324
+ else:
325
+ padded_image = np.pad(image, ((pad_size, pad_size), (0, 0)))
326
+ # Resizing a 256x256 e a 512x512
327
+ zoom_factor = [256 / padded_image.shape[0], 256 / padded_image.shape[1]]
328
+ image_256 = zoom(padded_image, zoom_factor)
329
+ lateral = image_256
330
+ if lateral.dtype != np.uint8:
331
+ lateral2 = (255 * (lateral - lateral.min()) / (lateral.max() - lateral.min())).astype(np.uint8)
332
+ lateral = torch.tensor(lateral, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
333
+ lateral2 = Image.Frontalmarray(lateral2)
334
+ st.write("Lateral Image loaded successfully!")
335
+ st.image(lateral2, caption="Lateral Image Loaded", use_container_width=True)
336
+ if st.session_state['report']:
337
+ report = st.session_state['report']
338
+ st.write(f"Loaded Report: {report}")
339
+
340
+ inputs = []
341
+ if "Frontal" in st.session_state['selected_option'].split(" βž”")[0]:
342
+ inputs.append('frontal')
343
+ if "Lateral" in st.session_state['selected_option'].split(" βž”")[0]:
344
+ inputs.append('lateral')
345
+ if "Report" in st.session_state['selected_option'].split(" βž”")[0]:
346
+ inputs.append('text')
347
+
348
+ # Ora vediamo cosa c'Γ¨ dopo la freccia
349
+ outputs = []
350
+ if "Frontal" in st.session_state['selected_option'].split(" βž”")[1]:
351
+ outputs.append('frontal')
352
+ if "Lateral" in st.session_state['selected_option'].split(" βž”")[1]:
353
+ outputs.append('lateral')
354
+ if "Report" in st.session_state['selected_option'].split(" βž”")[1]:
355
+ outputs.append('text')
356
+
357
+ # Ultima cosa che va fatta Γ¨ passare allo step 4, prima di farlo perΓ², tutte le variabili che ci servono
358
+ # devono essere salvate nello stato della sessione
359
+ st.session_state['inputs'] = inputs
360
+ st.session_state['outputs'] = outputs
361
+ st.session_state['frontal'] = frontal
362
+ st.session_state['lateral'] = lateral
363
+ st.session_state['report'] = report
364
+ st.session_state['generate'] = True
365
+
366
+ st.session_state['step'] = 4
367
+ st.rerun()
368
+
369
+ with col2:
370
+ if st.button("Return to the beginning"):
371
+ # Ripristina lo stato della sessione
372
+ st.session_state['step'] = 1
373
+ st.session_state['selected_option'] = None
374
+ st.session_state['selected_option2'] = None
375
+ st.session_state['frontal_file'] = None
376
+ st.session_state['lateral_file'] = None
377
+ st.session_state['report'] = ""
378
+ st.rerun()
379
+
380
+ if st.session_state['step'] == 4:
381
+ # Costruzione del prompt
382
+ if st.session_state['generate'] is True:
383
+ conditioning = []
384
+ for inp in st.session_state['inputs']:
385
+ if inp == 'frontal':
386
+ cim = inference_tester.net.clip_encode_vision(st.session_state['frontal'], encode_type='encode_vision').to(device)
387
+ uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['frontal']).to(device),
388
+ encode_type='encode_vision').to(device)
389
+ conditioning.append(torch.cat([uim, cim]))
390
+ elif inp == 'lateral':
391
+ cim = inference_tester.net.clip_encode_vision(st.session_state['lateral'], encode_type='encode_vision').to(device)
392
+ uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['lateral']).to(device),
393
+ encode_type='encode_vision').to(device)
394
+ conditioning.append(torch.cat([uim, cim]))
395
+ elif inp == 'text':
396
+ ctx = inference_tester.net.clip_encode_text(1 * [st.session_state['report']], encode_type='encode_text').to(device)
397
+ utx = inference_tester.net.clip_encode_text(1 * [""], encode_type='encode_text').to(device)
398
+ conditioning.append(torch.cat([utx, ctx]))
399
+
400
+ # Costruzione delle shapes
401
+ shapes = []
402
+ for out in st.session_state['outputs']:
403
+ if out == 'frontal' or out == 'lateral':
404
+ shape = [1, 4, 256 // 8, 256 // 8]
405
+ shapes.append(shape)
406
+ elif out == 'text':
407
+ shape = [1, 768]
408
+ shapes.append(shape)
409
+
410
+ progress_bar = st.progress(0)
411
+
412
+ # Inferenza
413
+ z, _ = inference_tester.sampler.sample(
414
+ steps=50,
415
+ shape=shapes,
416
+ condition=conditioning,
417
+ unconditional_guidance_scale=7.5,
418
+ xtype=st.session_state['outputs'],
419
+ condition_types=st.session_state['inputs'],
420
+ eta=1,
421
+ verbose=False,
422
+ mix_weight={'lateral': 1, 'text': 1, 'frontal': 1},
423
+ progress_bar=progress_bar)
424
+
425
+ # Decoder e visualizzazione dei risultati
426
+ output_cols = st.columns(len(st.session_state['outputs']))
427
+
428
+ # Definire due colonne per le immagini
429
+ col1, col2 = st.columns(2)
430
+
431
+ # Iterare sugli output e assegnare le immagini alle colonne corrispondenti
432
+ for i, out in enumerate(st.session_state['outputs']):
433
+ if out == 'frontal':
434
+ x = inference_tester.net.autokl_decode(z[i])
435
+ x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
436
+ im = x[0].cpu().numpy()
437
+ with col1: # Mostrare la frontal image nella prima colonna
438
+ st.image(im, caption="Generated Frontal Image")
439
+ elif out == 'lateral':
440
+ x = inference_tester.net.autokl_decode(z[i])
441
+ x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
442
+ im = x[0].cpu().numpy()
443
+ with col2: # Mostrare la lateral image nella seconda colonna
444
+ st.image(im, caption="Generated Lateral Image")
445
+ elif out == 'text':
446
+ x = inference_tester.net.optimus_decode(z[i], max_length=100)
447
+ x = [a.tolist() for a in x]
448
+ rec_text = [inference_tester.net.optimus.tokenizer_decoder.decode(a) for a in x]
449
+ rec_text = rec_text[0].replace('<BOS>', '').replace('<EOS>', '')
450
+ st.write(f"Generated Report: {rec_text}")
451
+
452
+ st.write("Generation completed successfully!")
453
+ st.session_state['generate'] = False
454
+
455
+ if st.button("Return to the beginning"):
456
+ # Ripristina lo stato della sessione
457
+ st.session_state['generate'] = False
458
+ st.session_state['step'] = 1
459
+ st.session_state['selected_option'] = None
460
+ st.session_state['frontal_file'] = None
461
+ st.session_state['lateral_file'] = None
462
+ st.session_state['report'] = ""
463
+ st.session_state['inputs'] = None
464
+ st.session_state['outputs'] = None
465
+ st.session_state['frontal'] = None
466
+ st.session_state['lateral'] = None
467
+ st.session_state['report'] = ""
468
+ st.rerun()
469
+
470
+ if st.session_state['step'] == 5:
471
+ st.markdown(
472
+ f"<h4 style='text-align: justify'><strong>You selected: {st.session_state['selected_option']}</strong></h4>",
473
+ unsafe_allow_html=True)
474
+
475
+ inputs = []
476
+ if "Frontal" in st.session_state['selected_option'].split(" βž”")[0]:
477
+ inputs.append('Frontal')
478
+ if "Lateral" in st.session_state['selected_option'].split(" βž”")[0]:
479
+ inputs.append('Lateral')
480
+ if "Report" in st.session_state['selected_option'].split(" βž”")[0]:
481
+ inputs.append('Report')
482
+
483
+ outputs = []
484
+ if "Frontal" in st.session_state['selected_option'].split(" βž”")[1]:
485
+ outputs.append('Frontal')
486
+ if "Lateral" in st.session_state['selected_option'].split(" βž”")[1]:
487
+ outputs.append('Lateral')
488
+ if "Report" in st.session_state['selected_option'].split(" βž”")[1]:
489
+ outputs.append('Report')
490
+
491
+ esempio = esempi[st.session_state['selected_option']]
492
+
493
+ # Mostra i file associati all'esempio
494
+ st.markdown(
495
+ "<h3 style='text-align: center'><strong>INPUT:</strong></h3>",
496
+ unsafe_allow_html=True)
497
+
498
+ # Colonne per gli INPUTS
499
+ input_cols = st.columns(len(inputs))
500
+
501
+ for idx, inp in enumerate(inputs):
502
+ with input_cols[idx]:
503
+ if inp == 'Frontal':
504
+ path = "./DEMO/ESEMPI/" + esempio['Frontal']
505
+ print(path)
506
+ if path.endswith(".tiff"):
507
+ im = tifffile.imread(path)
508
+ im = np.clip(im, 0, 1)
509
+ elif path.endswith(".png"):
510
+ im = Image.open(path)
511
+ st.image(im, caption="Frontal Image")
512
+ elif inp == 'Lateral':
513
+ path = "./DEMO/ESEMPI/" + esempio['Lateral']
514
+ if path.endswith(".tiff"):
515
+ im = tifffile.imread(path)
516
+ im = np.clip(im, 0, 1)
517
+ elif path.endswith(".png"):
518
+ im = Image.open(path)
519
+ st.image(im, caption="Lateral Image")
520
+ elif inp == 'Report':
521
+ st.markdown(
522
+ f"<p style='font-size:20px;'><strong>Report:</strong> {esempio['Report']}</p>",
523
+ unsafe_allow_html=True
524
+ )
525
+ st.markdown(
526
+ "<h3 style='text-align: center'><strong>OUTPUT:</strong></h3>",
527
+ unsafe_allow_html=True)
528
+
529
+ # Colonne per gli OUTPUTS
530
+ output_cols = st.columns(len(outputs))
531
+
532
+ for idx, out in enumerate(outputs):
533
+ with output_cols[idx]:
534
+ if out == 'Frontal':
535
+ path = "./DEMO/ESEMPI/" + esempio['Frontal']
536
+ if path.endswith(".tiff"):
537
+ im = tifffile.imread(path)
538
+ # facciamo clamp tra 0 e 1
539
+ im = np.clip(im, 0, 1)
540
+ elif path.endswith(".png"):
541
+ im = Image.open(path)
542
+ st.image(im, caption="Frontal Image")
543
+ elif out == 'Lateral':
544
+ path = "./DEMO/ESEMPI/" + esempio['Lateral']
545
+ if path.endswith(".tiff"):
546
+ im = tifffile.imread(path)
547
+ # facciamo clamp tra 0 e 1
548
+ im = np.clip(im, 0, 1)
549
+ elif path.endswith(".png"):
550
+ im = Image.open(path)
551
+ st.image(im, caption="Lateral Image")
552
+ elif out == 'Report':
553
+ st.markdown(
554
+ f"<p style='font-size:20px;'><strong>Report:</strong> {esempio['Report']}</p>",
555
+ unsafe_allow_html=True
556
+ )
557
+
558
+ # Pulsante per tornare all'inizio
559
+ if st.button("Return to the beginning"):
560
+ # Ripristina lo stato della sessione
561
+ st.session_state['step'] = 1
562
+ st.session_state['selected_option'] = None
563
+ st.session_state['selected_option2'] = None
564
+ st.session_state['frontal_file'] = None
565
+ st.session_state['lateral_file'] = None
566
+ st.session_state['report'] = ""
567
+ st.rerun()