Metal_Defect_Detection / classify.py
prathampatel1's picture
Update classify.py
402a941 verified
raw
history blame contribute delete
890 Bytes
import pandas as pd
from ultralytics import YOLO
class Classification:
def __init__(self) -> None:
self.__cls_model = YOLO('model/cls_best.pt')
def classify_defect(self, image_path) -> pd.DataFrame:
result_cls = self.__cls_model.predict(image_path, stream=False)
# Prepare data for CSV
data1 = []
for result in result_cls:
cnt1 = 0
for i in result_cls[0].probs.top5:
data1.append({
"Image/File Name": result_cls[0].path,
"Detected class by cls": self.__cls_model.names[i],
"Conf score": result_cls[0].probs.top5conf.tolist()[cnt1]
})
cnt1 = cnt1 + 1
# Convert to DataFrame and save as CSV
return pd.DataFrame(data1)
# df1.to_csv('classification_results.csv', index=False)