from tqdm import tqdm import network import utils import os import random import argparse import numpy as np import json from torch.utils import data from datasets import VOCSegmentation, Cityscapes from utils import ext_transforms as et from metrics import StreamSegMetrics from torch.utils.tensorboard import SummaryWriter # Added Line import torch import torch.nn as nn from PIL import Image import matplotlib import matplotlib.pyplot as plt def get_argparser(): parser = argparse.ArgumentParser() parser.add_argument("--out_dir", type=str, default="run_0") # Datset Options parser.add_argument("--data_root", type=str, default='', help="path to Dataset") parser.add_argument("--dataset", type=str, default='voc', choices=['voc'], help='Name of dataset') parser.add_argument("--num_classes", type=int, default=None, help="num classes (default: None)") # Deeplab Options parser.add_argument("--model", type=str, default='deeplabv3plus_resnet101', choices=['deeplabv3plus_resnet101', 'deeplabv3plus_resnet50', 'deeplabv3plus_mobilenet', 'deeplabv3plus_xception', 'deeplabv3plus_hrnetv2_48', 'deeplabv3plus_hrnetv2_32', 'deeplabv3_resnet101', 'deeplabv3_resnet50', 'deeplabv3_mobilenet', 'deeplabv3_xception', 'deeplabv3_hrnetv2_48', 'deeplabv3_hrnetv2_32'], help='model name') parser.add_argument("--separable_conv", action='store_true', default=False, help="apply separable conv to decoder and aspp") parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16]) # Enhanced Model Options parser.add_argument("--use_eoaNet", action='store_true', default=True, help="Use Entropy-Optimized Attention Network") parser.add_argument("--no_eoaNet", action='store_false', dest='use_eoaNet', help="Disable Entropy-Optimized Attention Network") parser.add_argument("--msa_scales", nargs='+', type=int, default=[1, 2, 4], help="Scales for Multi-Scale Attention") parser.add_argument("--eog_beta", type=float, default=0.3, help="Entropy threshold for Entropy-Optimized Gating") # Train Options parser.add_argument("--test_only", action='store_true', default=False) parser.add_argument("--save_val_results", action='store_true', default=False, help="save segmentation results to \"./results\"") parser.add_argument("--total_itrs", type=int, default=30e3, help="epoch number (default: 30k 30e3)") parser.add_argument("--lr", type=float, default=0.02, help="learning rate (default: 0.01)") parser.add_argument("--lr_policy", type=str, default='poly', choices=['poly', 'step'], help="learning rate scheduler policy") parser.add_argument("--step_size", type=int, default=10000) parser.add_argument("--crop_val", action='store_true', default=True, help='crop validation (default: False)') parser.add_argument("--batch_size", type=int, default=32, help='batch size (default: 16)') parser.add_argument("--val_batch_size", type=int, default=4, help='batch size for validation (default: 4)') parser.add_argument("--crop_size", type=int, default=513) parser.add_argument("--ckpt", default=None, type=str, help="restore from checkpoint") parser.add_argument("--continue_training", action='store_true', default=False) parser.add_argument("--loss_type", type=str, default='cross_entropy', choices=['cross_entropy', 'focal_loss'], help="loss type (default: False)") parser.add_argument("--gpu_id", type=str, default='0,1', help="GPU ID") parser.add_argument("--weight_decay", type=float, default=1e-4, help='weight decay (default: 1e-4)') parser.add_argument("--random_seed", type=int, default=1, help="random seed (default: 1)") parser.add_argument("--print_interval", type=int, default=10, help="print interval of loss (default: 10)") parser.add_argument("--val_interval", type=int, default=100, help="epoch interval for eval (default: 100)") parser.add_argument("--download", action='store_true', default=False, help="download datasets") # PASCAL VOC Options parser.add_argument("--year", type=str, default='2012_aug', choices=['2012_aug', '2012', '2011', '2009', '2008', '2007'], help='year of VOC') return parser def get_dataset(opts): """ Dataset And Augmentation """ if opts.dataset == 'voc': train_transform = et.ExtCompose([ # et.ExtResize(size=opts.crop_size), et.ExtRandomScale((0.5, 2.0)), et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True), et.ExtRandomHorizontalFlip(), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) if opts.crop_val: val_transform = et.ExtCompose([ et.ExtResize(opts.crop_size), et.ExtCenterCrop(opts.crop_size), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) else: val_transform = et.ExtCompose([ et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='train', download=opts.download, transform=train_transform) val_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='val', download=False, transform=val_transform) if opts.dataset == 'cityscapes': train_transform = et.ExtCompose([ # et.ExtResize( 512 ), et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)), et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), et.ExtRandomHorizontalFlip(), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) val_transform = et.ExtCompose([ # et.ExtResize( 512 ), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_dst = Cityscapes(root=opts.data_root, split='train', transform=train_transform) val_dst = Cityscapes(root=opts.data_root, split='val', transform=val_transform) return train_dst, val_dst def validate(opts, model, loader, device, metrics, ret_samples_ids=None): """Do validation and return specified samples""" metrics.reset() ret_samples = [] if opts.save_val_results: if not os.path.exists('results'): os.mkdir('results') denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) img_id = 0 with torch.no_grad(): for i, (images, labels) in tqdm(enumerate(loader)): images = images.to(device, dtype=torch.float32) labels = labels.to(device, dtype=torch.long) outputs = model(images) preds = outputs.detach().max(dim=1)[1].cpu().numpy() targets = labels.cpu().numpy() metrics.update(targets, preds) if ret_samples_ids is not None and i in ret_samples_ids: # get vis samples ret_samples.append( (images[0].detach().cpu().numpy(), targets[0], preds[0])) if opts.save_val_results: for i in range(len(images)): image = images[i].detach().cpu().numpy() target = targets[i] pred = preds[i] image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8) target = loader.dataset.decode_target(target).astype(np.uint8) pred = loader.dataset.decode_target(pred).astype(np.uint8) Image.fromarray(image).save('results/%d_image.png' % img_id) Image.fromarray(target).save('results/%d_target.png' % img_id) Image.fromarray(pred).save('results/%d_pred.png' % img_id) fig = plt.figure() plt.imshow(image) plt.axis('off') plt.imshow(pred, alpha=0.7) ax = plt.gca() ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator()) ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator()) plt.savefig('results/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0) plt.close() img_id += 1 score = metrics.get_results() return score, ret_samples def main(opts): if opts.dataset.lower() == 'voc': opts.num_classes = 21 elif opts.dataset.lower() == 'cityscapes': opts.num_classes = 19 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Device: %s" % device) # Setup random seed torch.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) # Setup TensorBoard writer writer = SummaryWriter(log_dir='logs') # Line 1071 # Setup dataloader if opts.dataset == 'voc' and not opts.crop_val: opts.val_batch_size = 1 train_dst, val_dst = get_dataset(opts) # Adjust batch size if dataset is smaller than batch size effective_batch_size = min(opts.batch_size, len(train_dst)) effective_val_batch_size = min(opts.val_batch_size, len(val_dst)) if effective_batch_size < opts.batch_size: print(f"Warning: Reducing batch size from {opts.batch_size} to {effective_batch_size} due to small dataset") train_loader = data.DataLoader( train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2, drop_last=False) # drop_last=True to ignore single-image batches. val_loader = data.DataLoader( val_dst, batch_size=effective_val_batch_size, shuffle=True, num_workers=2) print("Dataset: %s, Train set: %d, Val set: %d" % (opts.dataset, len(train_dst), len(val_dst))) # Set up model (all models are 'constructed at network.modeling) model = network.modeling.__dict__[opts.model]( num_classes=opts.num_classes, output_stride=opts.output_stride, use_eoaNet=opts.use_eoaNet, msa_scales=opts.msa_scales, eog_beta=opts.eog_beta ) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(model.backbone, momentum=0.01) # Set up metrics metrics = StreamSegMetrics(opts.num_classes) # Set up optimizer optimizer = torch.optim.SGD(params=[ {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr}, {'params': model.classifier.parameters(), 'lr': opts.lr}, ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor) if opts.lr_policy == 'poly': scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) elif opts.lr_policy == 'step': scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1) # Set up criterion # criterion = utils.get_loss(opts.loss_type) if opts.loss_type == 'focal_loss': criterion = utils.FocalLoss(ignore_index=255, size_average=True) elif opts.loss_type == 'cross_entropy': criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') def save_ckpt(path): """ save current model """ torch.save({ "cur_itrs": cur_itrs, "model_state": model.module.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_score": best_score, }, path) print("Model saved as %s" % path) if not os.path.exists('checkpoints'): os.mkdir('checkpoints') # Restore best_score = 0.0 cur_itrs = 0 cur_epochs = 0 model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if opts.ckpt is not None and os.path.isfile(opts.ckpt): # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state"]) model = nn.DataParallel(model) model.to(device) if opts.continue_training: optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) cur_itrs = checkpoint["cur_itrs"] best_score = checkpoint['best_score'] print("Training state restored from %s" % opts.ckpt) print("Model restored from %s" % opts.ckpt) del checkpoint # free memory else: print("[!] Retrain") model = nn.DataParallel(model) model.to(device) # ========== Train Loop ==========# denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images if opts.test_only: model.eval() val_score, ret_samples = validate( opts=opts, model=model, loader=val_loader, device=device, metrics=metrics) print(metrics.to_str(val_score)) writer.close() # Close writer before returning # Line 1089 return interval_loss = 0 latest_checkpoints = [] if not os.path.exists(f'checkpoints'): os.mkdir(f'checkpoints') while True: # cur_itrs < opts.total_itrs: # ===== Train ===== model.train() cur_epochs += 1 for (images, labels) in train_loader: cur_itrs += 1 images = images.to(device, dtype=torch.float32) labels = labels.to(device, dtype=torch.long) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() np_loss = loss.detach().cpu().numpy() interval_loss += np_loss writer.add_scalar('Loss/train', np_loss, cur_itrs) # Line 1093 if (cur_itrs) % 10 == 0: interval_loss = interval_loss / 10 print("Epoch %d, Itrs %d/%d, Loss=%f" % (cur_epochs, cur_itrs, opts.total_itrs, interval_loss)) interval_loss = 0.0 if (cur_itrs) % opts.val_interval == 0: ckpt_path = f'checkpoints/latest_{cur_itrs}_{opts.model}_{opts.dataset}_os{opts.output_stride}.pth' save_ckpt(ckpt_path) latest_checkpoints.append(ckpt_path) # Keep only the latest 2 checkpoints if len(latest_checkpoints) > 2: # Get the path of the oldest checkpoint to remove oldest_ckpt_path = latest_checkpoints.pop(0) try: # Attempt to remove the file from the filesystem os.remove(oldest_ckpt_path) print(f"Successfully removed old checkpoint: {oldest_ckpt_path}") # Optional: logging/confirmation except FileNotFoundError: # Handle the case where the file might already be gone for some reason print(f"Warning: Could not remove checkpoint because it was not found: {oldest_ckpt_path}") except OSError as e: # Handle other potential errors like permission issues print(f"Error removing checkpoint {oldest_ckpt_path}: {e}") print("validation...") model.eval() val_score, ret_samples = validate( opts=opts, model=model, loader=val_loader, device=device, metrics=metrics) print(metrics.to_str(val_score)) # Log validation metrics to TensorBoard writer.add_scalar('Metrics/Mean_IoU', val_score['Mean IoU'], cur_itrs) # Line 1128 writer.add_scalar('Metrics/Overall_Acc', val_score['Overall Acc'], cur_itrs) # Line 1129 writer.add_scalar('Metrics/Mean_Acc', val_score['Mean Acc'], cur_itrs) # Line 1130 if val_score['Mean IoU'] > best_score: # save best model best_score = val_score['Mean IoU'] save_ckpt(f'checkpoints/best_{opts.model}_{opts.dataset}_os{opts.output_stride}.pth') with open(f'checkpoints/best_score.txt', 'a') as f: f.write(f"iter:{cur_itrs}\n{str(best_score)}\n") with open(f"final_info.json", "w") as f: final_info = { "voc12_aug": { "means": { "mIoU": val_score['Mean IoU'], "OA": val_score['Overall Acc'], "mAcc": val_score['Mean IoU'] } } } json.dump(final_info, f, indent=4) model.train() scheduler.step() if cur_itrs >= opts.total_itrs: writer.close() return if __name__ == '__main__': args = get_argparser().parse_args() try: main(args) except Exception as e: import traceback print("Original error in subprocess:", flush=True) traceback.print_exc(file=open("traceback.log", "w")) raise