Files
citypulse/backend/test/Machine_Learning/test_trained_ml.py
Cursor Agent 46dea3304f Refactor: Integrate backend API and normalize data
This commit integrates the backend API for fetching and updating report data. It also includes a normalization function to handle data consistency between the API and local storage.

Co-authored-by: anthonymuncher <anthonymuncher@gmail.com>
2025-09-26 10:27:39 +00:00

40 lines
1.4 KiB
Python

import torch
from torchvision import transforms, models
from PIL import Image
import os
# ---------- CONFIG ----------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 6
CLASS_NAMES = ["broken_streetlight","drainage","garbage", "pothole","signage", "streetlight"]
MODEL_PATH = "best_model.pth"
TEST_IMAGES_DIR = "images" # folder containing test images
# ---------- MODEL ----------
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = torch.nn.Linear(model.fc.in_features, NUM_CLASSES)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model = model.to(DEVICE)
model.eval()
# ---------- IMAGE PREPROCESS ----------
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# ---------- INFERENCE ----------
for image_name in os.listdir(TEST_IMAGES_DIR):
image_path = os.path.join(TEST_IMAGES_DIR, image_name)
if not image_path.lower().endswith(('.png', '.jpg', '.jpeg')):
continue
image = Image.open(image_path).convert("RGB")
input_tensor = preprocess(image).unsqueeze(0).to(DEVICE) # add batch dimension
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
predicted_class = CLASS_NAMES[predicted.item()]
print(f"{image_name} --> Predicted class: {predicted_class}")