diff --git a/predict.py b/predict.py index b74c4608dd..18192015fe 100755 --- a/predict.py +++ b/predict.py @@ -12,6 +12,8 @@ from unet import UNet from utils.utils import plot_img_and_mask + +# for predicting single PIL image def predict_img(net, full_img, device, @@ -33,6 +35,27 @@ def predict_img(net, return mask[0].long().squeeze().numpy() +# for predicting batch of nd.array images +def predict_imgs(net, + image_arrays, + device, + scale_factor=1, + out_threshold=0.5): + net.eval() + img = torch.from_numpy(image_arrays) + img = img.to(device=device, dtype=torch.float32) + + with torch.no_grad(): + output = net(img).cpu() + if net.n_classes > 1: + mask = output.argmax(dim=1) + else: + mask = torch.sigmoid(output) > out_threshold + + return mask.long().squeeze().numpy() + + + def get_args(): parser = argparse.ArgumentParser(description='Predict masks from input images') parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE', @@ -80,38 +103,56 @@ def mask_to_image(mask: np.ndarray, mask_values): args = get_args() logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') - in_files = args.input - out_files = get_output_filenames(args) + image_in_file = False + if image_in_file: + in_files = args.input + out_files = get_output_filenames(args) - net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear) + net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - logging.info(f'Loading model {args.model}') - logging.info(f'Using device {device}') + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logging.info(f'Loading model {args.model}') + logging.info(f'Using device {device}') - net.to(device=device) - state_dict = torch.load(args.model, map_location=device) - mask_values = state_dict.pop('mask_values', [0, 1]) - net.load_state_dict(state_dict) + net.to(device=device) + state_dict = torch.load(args.model, map_location=device) + mask_values = state_dict.pop('mask_values', [0, 1]) + net.load_state_dict(state_dict) - logging.info('Model loaded!') + logging.info('Model loaded!') - for i, filename in enumerate(in_files): - logging.info(f'Predicting image {filename} ...') - img = Image.open(filename) + for i, filename in enumerate(in_files): + logging.info(f'Predicting image {filename} ...') + img = Image.open(filename) - mask = predict_img(net=net, - full_img=img, - scale_factor=args.scale, - out_threshold=args.mask_threshold, - device=device) + mask = predict_img(net=net, + full_img=img, + scale_factor=args.scale, + out_threshold=args.mask_threshold, + device=device) - if not args.no_save: - out_filename = out_files[i] - result = mask_to_image(mask, mask_values) - result.save(out_filename) - logging.info(f'Mask saved to {out_filename}') + if not args.no_save: + out_filename = out_files[i] + result = mask_to_image(mask, mask_values) + result.save(out_filename) + logging.info(f'Mask saved to {out_filename}') - if args.viz: - logging.info(f'Visualizing results for image {filename}, close to continue...') - plot_img_and_mask(img, mask) + if args.viz: + logging.info(f'Visualizing results for image {filename}, close to continue...') + plot_img_and_mask(img, mask) + else: + batch_size = 128 + + net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=1.0) + device = torch.device('cpu' if torch.cuda.is_available() else 'cpu') + logging.info(f'Using device {device}') + net.to(device=device) + mask_values = [0, 1] + + image_arrays = np.random.rand(batch_size, 3, 112, 112) + mask_arrays = predict_imgs(net=net, + image_arrays=image_arrays, + scale_factor=args.scale, + out_threshold=args.mask_threshold, + device=device) + print(mask_arrays.shape)