Files
citypulse/backend/app/services/ai_service.py
Cursor Agent 46dea3304f Refactor: Integrate backend API and normalize data
This commit integrates the backend API for fetching and updating report data. It also includes a normalization function to handle data consistency between the API and local storage.

Co-authored-by: anthonymuncher <anthonymuncher@gmail.com>
2025-09-26 10:27:39 +00:00

139 lines
5.3 KiB
Python

import os
import logging
from typing import Tuple
import torch
from torchvision import transforms, models
from PIL import Image
import cv2
from ultralytics import YOLO
import json
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# ----------------------
# AI Model Manager
# ----------------------
class AIModelManager:
"""Loads and keeps classification and detection models in memory."""
def __init__(self, device: str = None):
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
# Compute relative paths
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
self.class_model_path = os.path.join(BASE_DIR, "models", "classification", "best_model.pth")
self.class_mapping_path = os.path.join(BASE_DIR, "models", "classification", "class_mapping.json")
self.detection_model_path = os.path.join(BASE_DIR, "models", "detection", "best_severity_check.pt")
# Initialize models
self.class_model = None
self.class_names = None
self._load_classification_model()
self.detection_model = None
self._load_detection_model()
# Preprocess for classification
self.preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def _load_classification_model(self):
logger.info("Loading classification model...")
with open(self.class_mapping_path, "r") as f:
class_mapping = json.load(f)
self.class_names = [class_mapping[str(i)] for i in range(len(class_mapping))]
self.class_model = models.resnet18(weights=None)
self.class_model.fc = torch.nn.Linear(self.class_model.fc.in_features, len(self.class_names))
state_dict = torch.load(self.class_model_path, map_location=self.device)
self.class_model.load_state_dict(state_dict)
self.class_model.to(self.device)
self.class_model.eval()
logger.info("Classification model loaded successfully.")
def _load_detection_model(self):
logger.info("Loading YOLO detection model...")
self.detection_model = YOLO(self.detection_model_path)
logger.info("YOLO detection model loaded successfully.")
# ----------------------
# AI Service
# ----------------------
class AIService:
"""Handles classification and detection using preloaded models."""
def __init__(self, model_manager: AIModelManager):
self.models = model_manager
# ----------------------
# Classification
# ----------------------
def classify_category(self, image_path: str) -> str:
image = Image.open(image_path).convert("RGB")
input_tensor = self.models.preprocess(image).unsqueeze(0).to(self.models.device)
with torch.no_grad():
outputs = self.models.class_model(input_tensor)
_, predicted = torch.max(outputs, 1)
category = self.models.class_names[predicted.item()]
logger.info(f"Image '{image_path}' classified as '{category}'.")
return category
# ----------------------
# Detection / Severity
# ----------------------
@staticmethod
def classify_severity(box: Tuple[int, int, int, int], image_height: int) -> str:
x1, y1, x2, y2 = box
area = (x2 - x1) * (y2 - y1)
if area > 50000 or y2 > image_height * 0.75:
return "High"
elif area > 20000 or y2 > image_height * 0.5:
return "Medium"
else:
return "Low"
@staticmethod
def draw_boxes_and_severity(image, results) -> None:
for r in results:
for box in r.boxes.xyxy:
x1, y1, x2, y2 = map(int, box.cpu().numpy())
conf = float(r.boxes.conf[0]) if hasattr(r.boxes, "conf") else 0.0
severity = AIService.classify_severity((x1, y1, x2, y2), image.shape[0])
color = (0, 255, 0) if severity == "Low" else (0, 255, 255) if severity == "Medium" else (0, 0, 255)
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
cv2.putText(image, f"{severity} ({conf:.2f})", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
def detect_pothole_severity(self, image_path: str, output_path: str = None) -> Tuple[str, str]:
image = cv2.imread(image_path)
results = self.models.detection_model(image)
self.draw_boxes_and_severity(image, results)
# Determine highest severity
severities = []
for r in results:
for box in r.boxes.xyxy:
severities.append(self.classify_severity(map(int, box.cpu().numpy()), image.shape[0]))
if severities:
if "High" in severities:
severity = "High"
elif "Medium" in severities:
severity = "Medium"
else:
severity = "Low"
else:
severity = "Unknown"
# Save annotated image
if output_path:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
cv2.imwrite(output_path, image)
else:
output_path = image_path
logger.info(f"Pothole severity: {severity}, output image saved to '{output_path}'.")
return severity, output_path