import torch import torchvision.transforms as transforms import torchvision.models as models import gradio as gr import numpy as np import tensorflow as tf from PIL import Image from sklearn.preprocessing import StandardScaler import joblib import os # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load trained ViT model (PyTorch) vit_model = models.vit_b_16(pretrained=False) vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification # Load ViT model weights vit_model_path = "vit_bc.pth" # Update with uploaded model path if os.path.exists(vit_model_path): vit_model.load_state_dict(torch.load(vit_model_path, map_location=device)) vit_model.to(device) vit_model.eval() # Define ViT image transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Class labels class_names = ["Benign", "Malignant"] # Load trained Neural Network model (TensorFlow/Keras) nn_model_path = "my_NN_BC_model.keras" # Update with uploaded model path nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path) else None # Load scaler for feature normalization scaler_path = "nn_bc_scaler.pkl" # Update with uploaded model path scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None # Feature names feature_names = [ "Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness", "Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension", "SE Radius", "SE Texture", "SE Perimeter", "SE Area", "SE Smoothness", "SE Compactness", "SE Concavity", "SE Concave Points", "SE Symmetry", "SE Fractal Dimension", "Worst Radius", "Worst Texture", "Worst Perimeter", "Worst Area", "Worst Smoothness", "Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension" ] def classify(model_choice, image=None, *features): """Classify using ViT (image) or NN (features).""" if model_choice == "ViT": if image is None: return "Please upload an image for ViT classification." image = image.convert("RGB") input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = vit_model(input_tensor) predicted_class = torch.argmax(output, dim=1).item() return class_names[predicted_class] elif model_choice == "Neural Network": if any(f is None for f in features): return "Please enter all 30 numerical features." input_data = np.array(features).reshape(1, -1) input_data_std = scaler.transform(input_data) if scaler else input_data prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]] predicted_class = np.argmax(prediction) return class_names[predicted_class] # Gradio UI model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model") image_input = gr.Image(type="pil", label="Upload Mammogram Image") feature_inputs = [gr.Number(label=feature) for feature in feature_names] iface = gr.Interface( fn=classify, inputs=[model_selector, image_input] + feature_inputs, outputs="text", title="Breast Cancer Classification", description="Choose between ViT (image-based) and Neural Network (feature-based) classification." ) iface.launch()