tobacco / part3.py
lyimo's picture
Create part3.py
fbaa8e3 verified
raw
history blame
6.4 kB
import numpy as np
import cv2
from segment_anything import sam_model_registry, SamPredictor
import matplotlib.pyplot as plt
import io
from PIL import Image
class SAMAnalyzer:
def __init__(self, model_path="sam_vit_h_4b8939.pth"):
self.model_path = model_path
self.sam = None
self.predictor = None
self.initialize_sam()
def initialize_sam(self):
"""Initialize SAM model"""
try:
self.sam = sam_model_registry["vit_h"](checkpoint=self.model_path)
self.predictor = SamPredictor(self.sam)
print("SAM model initialized successfully")
except Exception as e:
print(f"Error initializing SAM model: {e}")
raise
def process_image(self, image_data):
"""Process uploaded image using SAM"""
try:
# Convert uploaded image to numpy array
if isinstance(image_data, (str, bytes)):
if isinstance(image_data, str):
image = cv2.imread(image_data)
else:
nparr = np.frombuffer(image_data, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
image = np.array(Image.open(image_data))
# Segment farmland
print("Segmenting farmland...")
farmland_mask = self.segment_farmland(image)
# Calculate vegetation index
print("Calculating vegetation index...")
veg_index = self.calculate_vegetation_index(image, farmland_mask)
# Analyze health
print("Analyzing crop health...")
health_analysis = self.analyze_crop_health(veg_index, farmland_mask)
# Create visualization
print("Generating visualization...")
viz_plot = self.create_visualization(image, farmland_mask, veg_index)
return veg_index, health_analysis, viz_plot
except Exception as e:
print(f"Error processing image: {e}")
return None, None, None
def segment_farmland(self, image):
"""Segment farmland using SAM2"""
self.predictor.set_image(image)
# Generate automatic mask proposals
h, w = image.shape[:2]
input_point = np.array([[w//2, h//2]])
input_label = np.array([1])
masks, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
# Select best mask
best_mask = masks[scores.argmax()]
return best_mask
def calculate_vegetation_index(self, image, mask):
"""Calculate vegetation index using RGB"""
r, g, b = image[:,:,0], image[:,:,1], image[:,:,2]
numerator = (2 * g.astype(float) - r.astype(float) - b.astype(float))
denominator = (2 * g.astype(float) + r.astype(float) + b.astype(float))
denominator[denominator == 0] = 1e-10
veg_index = numerator / denominator
veg_index = (veg_index + 1) / 2
veg_index = veg_index * mask
return veg_index
def analyze_crop_health(self, veg_index, mask):
"""Analyze crop health based on vegetation index"""
valid_pixels = veg_index[mask > 0]
if len(valid_pixels) == 0:
return {
'average_index': 0,
'health_distribution': {
'low_vegetation': 0,
'moderate_vegetation': 0,
'high_vegetation': 0
},
'overall_health': 'No vegetation detected'
}
avg_index = np.mean(valid_pixels)
health_categories = {
'low_vegetation': np.sum((valid_pixels <= 0.3)) / len(valid_pixels),
'moderate_vegetation': np.sum((valid_pixels > 0.3) & (valid_pixels <= 0.6)) / len(valid_pixels),
'high_vegetation': np.sum((valid_pixels > 0.6)) / len(valid_pixels)
}
return {
'average_index': avg_index,
'health_distribution': health_categories,
'overall_health': 'Healthy' if avg_index > 0.5 else 'Needs attention'
}
def create_visualization(self, image, mask, veg_index):
"""Create visualization of results"""
fig = plt.figure(figsize=(15, 5))
# Original image with mask overlay
plt.subplot(131)
plt.imshow(image)
plt.imshow(mask, alpha=0.3, cmap='gray')
plt.title('Segmented Farmland')
plt.axis('off')
# Vegetation index heatmap
plt.subplot(132)
plt.imshow(veg_index, cmap='RdYlGn')
plt.colorbar(label='Vegetation Index')
plt.title('Vegetation Index')
plt.axis('off')
# Health classification
plt.subplot(133)
health_mask = np.zeros_like(veg_index)
health_mask[veg_index <= 0.3] = 1
health_mask[(veg_index > 0.3) & (veg_index <= 0.6)] = 2
health_mask[veg_index > 0.6] = 3
health_mask = health_mask * mask
plt.imshow(health_mask, cmap='viridis')
plt.colorbar(ticks=[1, 2, 3],
label='Vegetation Levels',
boundaries=np.arange(0.5, 4.5),
values=[1, 2, 3])
plt.title('Vegetation Levels')
plt.axis('off')
plt.tight_layout()
# Save plot to buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
plt.close()
return buf
def format_analysis_text(self, health_analysis):
"""Format health analysis results as text"""
return f"""
🌿 Vegetation Analysis Results:
πŸ“Š Average Vegetation Index: {health_analysis['average_index']:.2f}
🌱 Vegetation Distribution:
β€’ Low Vegetation: {health_analysis['health_distribution']['low_vegetation']*100:.1f}%
β€’ Moderate Vegetation: {health_analysis['health_distribution']['moderate_vegetation']*100:.1f}%
β€’ High Vegetation: {health_analysis['health_distribution']['high_vegetation']*100:.1f}%
πŸ“‹ Overall Health Status: {health_analysis['overall_health']}
Note: Analysis uses SAM2 for farmland segmentation
"""