mvp2 /
CullerWhale's picture
Rename app (4).py to
85f5b46 verified
history blame
4.5 kB
import gradio as gr
import PIL.Image
import pandas as pd
import numpy as np
import boto3
from io import BytesIO, StringIO
from import *
def get_x(r): return r['Image Path']
def get_y(r): return r['Survived']
def ProjectReportSplitter(df):
valid_pct = 0.2
unique_reports = df['Project Report'].unique()
valid_reports = np.random.choice(unique_reports, int(len(unique_reports) * valid_pct), replace=False)
valid_idx = df.index[df['Project Report'].isin(valid_reports)].tolist()
train_idx = df.index[~df.index.isin(valid_idx)].tolist()
return train_idx, valid_idx
# Use a function to resolve path
def get_x_transformed(r): return open_image_from_s3(get_x(r))
dblock = DataBlock(
blocks=(ImageBlock(cls=PILImage), CategoryBlock),
item_tfms=Resize(460, method='pad', pad_mode='zeros'),
batch_tfms=aug_transforms(mult=2, do_flip=True, max_rotate=20, max_zoom=1.1, max_warp=0.2)
# Load your model
learn = load_learner("templateClassifierDATAhalfEPOCHoneVISION.pkl")
# Print the vocabulary of the model
print("Model Vocabulary:", learn.dls.vocab)
labels = learn.dls.vocab
def predict(img):
img = PILImage.create(img)
pred,pred_idx,probs = learn.predict(img)
return {labels[i]: float(probs[i]) for i in range(len(labels))}
# def predict(img):
# img = PILImage.create(img)
# pred, pred_idx, probs = learn.predict(img)
# results = {labels[i]: float(probs[i]) for i in range(len(labels))}
# # Adjust results to highlight when 'Survived' meets the 75% threshold
# if results['Survived'] >= 0.75:
# results['Survived'] = 1.0 # Indicating high confidence of survival
# else:
# results['Survived'] = 0.0 # Indicating it did not meet the threshold
# return results
# def predict(img):
# img = PILImage.create(img)
# pred, pred_idx, probs = learn.predict(img)
# results = {labels[i]: float(probs[i]) for i in range(len(labels))}
# # Adjusting to display survival status based on the threshold
# survival_status = 'Survived' if results['Survived'] >= 0.75 else 'Not Survived'
# results['Survival Status'] = survival_status
# return results
# Gradio interface setup
title = "Photo Culling AI"
description = "Upload your photo to check if it survives culling."
article = "This interface uses a model trained to predict whether a photo is relevant for a project report."
gr.Interface(fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=2), title=title, description=description, article=article).launch(share=True,show_error=True)
# import gradio as gr
# import PIL.Image
# import pandas as pd
# import boto3
# from io import BytesIO, StringIO
# from import *
# def get_x(r): return r['Image Path']
# def get_y(r): return r['Survived']
# def ProjectReportSplitter(df):
# valid_pct = 0.2
# unique_reports = df['Project Report'].unique()
# valid_reports = np.random.choice(unique_reports, int(len(unique_reports) * valid_pct), replace=False)
# valid_idx = df.index[df['Project Report'].isin(valid_reports)].tolist()
# train_idx = df.index[~df.index.isin(valid_idx)].tolist()
# return train_idx, valid_idx
# # Use a function to resolve path
# def get_x_transformed(r): return open_image_from_s3(get_x(r))
# dblock = DataBlock(
# blocks=(ImageBlock(cls=PILImage), CategoryBlock),
# splitter=ProjectReportSplitter,
# get_x=get_x_transformed,
# get_y=get_y,
# item_tfms=Resize(460, method='pad', pad_mode='zeros'),
# batch_tfms=aug_transforms(mult=2, do_flip=True, max_rotate=20, max_zoom=1.1, max_warp=0.2)
# )
# # Load your model
# learn = load_learner("templateClassifierDATAhalfEPOCHoneVISION.pkl")
# # Print the vocabulary of the model
# print("Model Vocabulary:", learn.dls.vocab)
# # Update prediction function to directly read from S3
# def predict(img_path):
# pred, pred_idx, probs = learn.predict(img_path)
# return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}
# # Gradio interface setup
# title = "Photo Culling AI"
# description = "Upload your photo to check if it survives culling."
# article = "This interface uses a model trained to predict whether a photo is relevant for a project report."
# gr.Interface(fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=2), title=title, description=description, article=article).launch(share=True)