This commit integrates the backend API for fetching and updating report data. It also includes a normalization function to handle data consistency between the API and local storage. Co-authored-by: anthonymuncher <anthonymuncher@gmail.com>
139 lines
5.3 KiB
Python
139 lines
5.3 KiB
Python
import os
|
|
import logging
|
|
from typing import Tuple
|
|
import torch
|
|
from torchvision import transforms, models
|
|
from PIL import Image
|
|
import cv2
|
|
from ultralytics import YOLO
|
|
import json
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
# ----------------------
|
|
# AI Model Manager
|
|
# ----------------------
|
|
class AIModelManager:
|
|
"""Loads and keeps classification and detection models in memory."""
|
|
def __init__(self, device: str = None):
|
|
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
|
|
|
# Compute relative paths
|
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
self.class_model_path = os.path.join(BASE_DIR, "models", "classification", "best_model.pth")
|
|
self.class_mapping_path = os.path.join(BASE_DIR, "models", "classification", "class_mapping.json")
|
|
self.detection_model_path = os.path.join(BASE_DIR, "models", "detection", "best_severity_check.pt")
|
|
|
|
|
|
# Initialize models
|
|
self.class_model = None
|
|
self.class_names = None
|
|
self._load_classification_model()
|
|
self.detection_model = None
|
|
self._load_detection_model()
|
|
|
|
# Preprocess for classification
|
|
self.preprocess = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor()
|
|
])
|
|
|
|
def _load_classification_model(self):
|
|
logger.info("Loading classification model...")
|
|
with open(self.class_mapping_path, "r") as f:
|
|
class_mapping = json.load(f)
|
|
self.class_names = [class_mapping[str(i)] for i in range(len(class_mapping))]
|
|
|
|
self.class_model = models.resnet18(weights=None)
|
|
self.class_model.fc = torch.nn.Linear(self.class_model.fc.in_features, len(self.class_names))
|
|
state_dict = torch.load(self.class_model_path, map_location=self.device)
|
|
self.class_model.load_state_dict(state_dict)
|
|
self.class_model.to(self.device)
|
|
self.class_model.eval()
|
|
logger.info("Classification model loaded successfully.")
|
|
|
|
def _load_detection_model(self):
|
|
logger.info("Loading YOLO detection model...")
|
|
self.detection_model = YOLO(self.detection_model_path)
|
|
logger.info("YOLO detection model loaded successfully.")
|
|
|
|
|
|
# ----------------------
|
|
# AI Service
|
|
# ----------------------
|
|
class AIService:
|
|
"""Handles classification and detection using preloaded models."""
|
|
def __init__(self, model_manager: AIModelManager):
|
|
self.models = model_manager
|
|
|
|
# ----------------------
|
|
# Classification
|
|
# ----------------------
|
|
def classify_category(self, image_path: str) -> str:
|
|
image = Image.open(image_path).convert("RGB")
|
|
input_tensor = self.models.preprocess(image).unsqueeze(0).to(self.models.device)
|
|
with torch.no_grad():
|
|
outputs = self.models.class_model(input_tensor)
|
|
_, predicted = torch.max(outputs, 1)
|
|
category = self.models.class_names[predicted.item()]
|
|
logger.info(f"Image '{image_path}' classified as '{category}'.")
|
|
return category
|
|
|
|
# ----------------------
|
|
# Detection / Severity
|
|
# ----------------------
|
|
@staticmethod
|
|
def classify_severity(box: Tuple[int, int, int, int], image_height: int) -> str:
|
|
x1, y1, x2, y2 = box
|
|
area = (x2 - x1) * (y2 - y1)
|
|
if area > 50000 or y2 > image_height * 0.75:
|
|
return "High"
|
|
elif area > 20000 or y2 > image_height * 0.5:
|
|
return "Medium"
|
|
else:
|
|
return "Low"
|
|
|
|
@staticmethod
|
|
def draw_boxes_and_severity(image, results) -> None:
|
|
for r in results:
|
|
for box in r.boxes.xyxy:
|
|
x1, y1, x2, y2 = map(int, box.cpu().numpy())
|
|
conf = float(r.boxes.conf[0]) if hasattr(r.boxes, "conf") else 0.0
|
|
severity = AIService.classify_severity((x1, y1, x2, y2), image.shape[0])
|
|
color = (0, 255, 0) if severity == "Low" else (0, 255, 255) if severity == "Medium" else (0, 0, 255)
|
|
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
|
cv2.putText(image, f"{severity} ({conf:.2f})", (x1, y1 - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
|
|
|
def detect_pothole_severity(self, image_path: str, output_path: str = None) -> Tuple[str, str]:
|
|
image = cv2.imread(image_path)
|
|
results = self.models.detection_model(image)
|
|
self.draw_boxes_and_severity(image, results)
|
|
|
|
# Determine highest severity
|
|
severities = []
|
|
for r in results:
|
|
for box in r.boxes.xyxy:
|
|
severities.append(self.classify_severity(map(int, box.cpu().numpy()), image.shape[0]))
|
|
|
|
if severities:
|
|
if "High" in severities:
|
|
severity = "High"
|
|
elif "Medium" in severities:
|
|
severity = "Medium"
|
|
else:
|
|
severity = "Low"
|
|
else:
|
|
severity = "Unknown"
|
|
|
|
# Save annotated image
|
|
if output_path:
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
cv2.imwrite(output_path, image)
|
|
else:
|
|
output_path = image_path
|
|
|
|
logger.info(f"Pothole severity: {severity}, output image saved to '{output_path}'.")
|
|
return severity, output_path
|