from pathlib import Path from tqdm import tqdm import numpy as np import argparse import json import wandb import pickle import torch import torch.nn as nn from torch.utils.data import DataLoader import models from models.resnet50 import ResNet from models.mobilenet_v2 import MobileNetV2 from dataset import PlantsDataset from utils import train_transform, test_transform, EMA def parse_args(): parser = argparse.ArgumentParser(description="Train a model on plant dataset") parser.add_argument("--train-root", type=str, default="data/plants/train", help="Path to the training data") parser.add_argument("--test-root", type=str, default="data/plants/test", help="Path to the testing data") parser.add_argument("--load-to-ram", type=bool, default=False, help="Load dataset to RAM") parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing") parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader") parser.add_argument("--num-workers", type=int, default=1, help="Number of workers for DataLoader") parser.add_argument("--num-epochs", type=int, default=10, help="Number of training epochs") parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer") parser.add_argument("--weights-path", type=str, default="weights/mobilenet_v2-b0353104.pth", choices=["weights/resnet50-0676ba61.pth", "weights/mobilenet_v2-b0353104.pth"], help="Path to the pre-trained weights") parser.add_argument("--project-name", type=str, default="plants_classifier", help="WandB project name") parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type") parser.add_argument("--criterion", type=str, default="CrossEntropyLoss", help="Loss function type") parser.add_argument("--labels-path", type=str, default="labels.json", help="Path to the labels json file") parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run the training on") parser.add_argument("--model", type=str, default="mobilenet", choices=["resnet", "mobilenet"], help="Model class name") parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights") parser.add_argument("--logs-dir", type=str, default="resnet-logs", choices=["resnet-logs", "mobilenet-logs"], help="???") return parser.parse_args() def main() -> None: args = parse_args() with open(args.labels_path, "r") as fp: labels = json.load(fp) num_classes = len(labels) logs_dir = Path(args.logs_dir) logs_dir.mkdir(exist_ok=True) wandb.init(project=args.project_name, dir=logs_dir) train_dataset = PlantsDataset(root=args.train_root, load_to_ram=args.load_to_ram, transform=train_transform, labels=labels) test_dataset = PlantsDataset(root=args.test_root, load_to_ram=args.load_to_ram, transform=test_transform, labels=labels) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=args.pin_memory, num_workers=args.num_workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=args.pin_memory, num_workers=args.num_workers) device = torch.device(args.device) if args.model == "resnet": model = ResNet(weights_path=args.weights_path) model.fc = nn.Linear(512 * model.expansion, num_classes) nn.init.xavier_uniform_(model.fc.weight) for name, param in model.named_parameters(): if "layer4" in name or "fc" in name: param.requires_grad = True else: param.requires_grad = False elif args.model == "mobilenet": model = MobileNetV2(weights_path=args.weights_path) num_ftrs = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_ftrs, num_classes) nn.init.xavier_uniform_(model.classifier[1].weight) for name, param in model.named_parameters(): if "classifier" or "features.18" or "features.17" in name: param.requires_grad = True else: param.requires_grad = False model = model.to(device) optimizer_class = getattr(torch.optim, args.optimizer) optimizer = optimizer_class(model.parameters(), lr=args.learning_rate) criterion_class = getattr(nn, args.criterion) criterion = criterion_class() scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs) best_accuracy = 0 train_loss_ema, train_accuracy_ema, grad_norm_ema = EMA(), EMA(), EMA() for epoch in range(1, args.num_epochs + 1): model.train() pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}") for images, labels in pbar: images = images.to(device) labels = labels.to(device) optimizer.zero_grad() logits = model(images) loss = criterion(logits, labels) loss.backward() grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_norm).item() optimizer.step() train_loss = loss.item() train_accuracy = (logits.argmax(dim=1) == labels).sum().item() / logits.shape[0] pbar.set_postfix({"loss": train_loss_ema(train_loss), "accuracy": train_accuracy_ema(train_accuracy), "grad_norm": grad_norm_ema(grad_norm)}) wandb.log( { "train/epoch": epoch, "train/loss": train_loss, "train/accuracy": train_accuracy, "train/learning_rate": optimizer.param_groups[0]["lr"], "train/grad_norm": grad_norm, } ) model.eval() test_loss, test_accuracy = 0.0, 0.0 with torch.no_grad(): pbar = tqdm(test_loader, desc=f"Val epoch {epoch}/{args.num_epochs}") for images, labels in pbar: images = images.to(device) labels = labels.to(device) logits = model(images) loss = criterion(logits, labels) test_loss += loss.item() test_accuracy += (logits.argmax(dim=1) == labels).sum().item() test_loss /= len(test_loader) test_accuracy /= len(test_loader.dataset) print(f"loss: {test_loss:.3f}, accuracy: {test_accuracy:.3f}") wandb.log( { "val/epoch": epoch, "val/test_loss": test_loss, "val/test_accuracy": test_accuracy, } ) scheduler.step() if test_accuracy > best_accuracy: best_accuracy = test_accuracy torch.save(model.state_dict(), logs_dir / f"checkpoint-best-{epoch:09}.pth") elif epoch % args.save_frequency == 0: torch.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth") wandb.finish() if __name__ == "__main__": main()