|
from torch.utils.data import dataset |
|
from tqdm import tqdm |
|
import network |
|
import utils |
|
import os |
|
import random |
|
import argparse |
|
import numpy as np |
|
|
|
from torch.utils import data |
|
from datasets import VOCSegmentation, Cityscapes, cityscapes |
|
from torchvision import transforms as T |
|
from metrics import StreamSegMetrics |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from PIL import Image |
|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
from glob import glob |
|
|
|
def get_argparser(): |
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--input", type=str, required=True, |
|
help="path to a single image or image directory") |
|
parser.add_argument("--dataset", type=str, default='voc', |
|
choices=['voc', 'cityscapes'], help='Name of training set') |
|
|
|
|
|
available_models = sorted(name for name in network.modeling.__dict__ if name.islower() and \ |
|
not (name.startswith("__") or name.startswith('_')) and callable( |
|
network.modeling.__dict__[name]) |
|
) |
|
|
|
parser.add_argument("--model", type=str, default='deeplabv3plus_mobilenet', |
|
choices=available_models, help='model name') |
|
parser.add_argument("--separable_conv", action='store_true', default=False, |
|
help="apply separable conv to decoder and aspp") |
|
parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16]) |
|
|
|
|
|
parser.add_argument("--save_val_results_to", default=None, |
|
help="save segmentation results to the specified dir") |
|
|
|
parser.add_argument("--crop_val", action='store_true', default=False, |
|
help='crop validation (default: False)') |
|
parser.add_argument("--val_batch_size", type=int, default=4, |
|
help='batch size for validation (default: 4)') |
|
parser.add_argument("--crop_size", type=int, default=513) |
|
|
|
|
|
parser.add_argument("--ckpt", default=None, type=str, |
|
help="resume from checkpoint") |
|
parser.add_argument("--gpu_id", type=str, default='0', |
|
help="GPU ID") |
|
return parser |
|
|
|
def main(): |
|
opts = get_argparser().parse_args() |
|
if opts.dataset.lower() == 'voc': |
|
opts.num_classes = 21 |
|
decode_fn = VOCSegmentation.decode_target |
|
elif opts.dataset.lower() == 'cityscapes': |
|
opts.num_classes = 19 |
|
decode_fn = Cityscapes.decode_target |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print("Device: %s" % device) |
|
|
|
|
|
image_files = [] |
|
if os.path.isdir(opts.input): |
|
for ext in ['png', 'jpeg', 'jpg', 'JPEG']: |
|
files = glob(os.path.join(opts.input, '**/*.%s'%(ext)), recursive=True) |
|
if len(files)>0: |
|
image_files.extend(files) |
|
elif os.path.isfile(opts.input): |
|
image_files.append(opts.input) |
|
|
|
|
|
model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) |
|
if opts.separable_conv and 'plus' in opts.model: |
|
network.convert_to_separable_conv(model.classifier) |
|
utils.set_bn_momentum(model.backbone, momentum=0.01) |
|
|
|
if opts.ckpt is not None and os.path.isfile(opts.ckpt): |
|
|
|
checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) |
|
model.load_state_dict(checkpoint["model_state"]) |
|
model = nn.DataParallel(model) |
|
model.to(device) |
|
print("Resume model from %s" % opts.ckpt) |
|
del checkpoint |
|
else: |
|
print("[!] Retrain") |
|
model = nn.DataParallel(model) |
|
model.to(device) |
|
|
|
|
|
|
|
if opts.crop_val: |
|
transform = T.Compose([ |
|
T.Resize(opts.crop_size), |
|
T.CenterCrop(opts.crop_size), |
|
T.ToTensor(), |
|
T.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]), |
|
]) |
|
else: |
|
transform = T.Compose([ |
|
T.ToTensor(), |
|
T.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]), |
|
]) |
|
if opts.save_val_results_to is not None: |
|
os.makedirs(opts.save_val_results_to, exist_ok=True) |
|
with torch.no_grad(): |
|
model = model.eval() |
|
for img_path in tqdm(image_files): |
|
ext = os.path.basename(img_path).split('.')[-1] |
|
img_name = os.path.basename(img_path)[:-len(ext)-1] |
|
img = Image.open(img_path).convert('RGB') |
|
img = transform(img).unsqueeze(0) |
|
img = img.to(device) |
|
|
|
pred = model(img).max(1)[1].cpu().numpy()[0] |
|
colorized_preds = decode_fn(pred).astype('uint8') |
|
colorized_preds = Image.fromarray(colorized_preds) |
|
if opts.save_val_results_to: |
|
colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png')) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|