Source code for src.predict_clevr

import argparse

import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from nsfr_utils import denormalize_clevr, get_data_loader, get_prob, get_nsfr_model
from nsfr_utils import save_images_with_captions, to_plot_images_clevr, generate_captions
from logic_utils import get_lang


[docs]def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--batch-size", type=int, default=1, help="Batch size to infer with") parser.add_argument("--e", type=int, default=10, help="The maximum number of objects in one image") parser.add_argument( "--dataset", choices=["clevr-hans3", "clevr-hans7"], help="Use clevr-hans dataset.") parser.add_argument("--dataset-type", default="clevr", help="kandinsky or clevr") parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or cpu') parser.add_argument("--no-cuda", action="store_true", help="Run on CPU instead of GPU (not recommended)") parser.add_argument("--num-workers", type=int, default=4, help="Number of threads for data loader") parser.add_argument('--gamma', default=0.01, type=float, help='Smooth parameter in the softor function') parser.add_argument("--plot", action="store_true", help="Plot images with captions.") parser.add_argument("--plot-cam", action="store_true", help="Plot images cam.") args = parser.parse_args() return args
[docs]def predict(NSFR, loader, args, device, writer, split='train'): predicted_list = [] target_list = [] count = 0 for i, sample in tqdm(enumerate(loader, start=0)): # to cuda imgs, target_set = map(lambda x: x.to(device), sample) # infer and predict the target probability V_T = NSFR(imgs) #print(valuations_to_string(V_T, NSFR.atoms, NSFR.pm.e)) predicted = get_prob(V_T, NSFR, args) predicted_list.extend( list(np.argmax(predicted.detach().cpu().numpy(), axis=1))) target_list.extend( list(np.argmax(target_set.detach().cpu().numpy(), axis=1))) if i < 1: if args.dataset_type == 'clevr': writer.add_images( 'images', denormalize_clevr(imgs).detach().cpu(), 0) else: writer.add_images( 'images', imgs.detach().cpu(), 0) writer.add_text('V_T', NSFR.get_valuation_text(V_T), 0) if args.plot: imgs = to_plot_images_clevr(imgs) captions = generate_captions( V_T, NSFR.atoms, NSFR.pm.e, th=0.33) save_images_with_captions( imgs, captions, folder='result/clevr/' + args.dataset + '/' + split + '/', img_id_start=count, dataset=args.dataset) count += V_T.size(0) # batch size predicted = predicted_list target = target_list return accuracy_score(target, predicted), confusion_matrix(target, predicted)
[docs]def main(): args = get_args() assert args.dataset_type == 'clevr', 'Use clever dataset for this script.' print('args ', args) if args.no_cuda: device = torch.device('cpu') else: device = torch.device('cuda:' + args.device) print('device: ', device) run_name = 'predict/' + args.dataset writer = SummaryWriter(f"runs/{run_name}", purge_step=0) # get torch data loader train_loader, val_loader, test_loader = get_data_loader(args) # load logical representations lark_path = 'src/lark/exp.lark' lang_base_path = 'data/lang/' lang, clauses, bk, atoms = get_lang( lark_path, lang_base_path, args.dataset_type, args.dataset) # Neuro-Symbolic Forward Reasoner NSFR = get_nsfr_model(args, lang, clauses, atoms, bk, device) if len(args.device.split(',')) > 1: NSFR = nn.DataParallel(NSFR) # validation split print("Predicting on validation data set...") acc_val, cmat_val = predict( NSFR, val_loader, args, device, writer, split='val') print("Predicting on training data set...") # training split acc, cmat = predict( NSFR, train_loader, args, device, writer, split='train') print("Predicting on test data set...") # test split acc_test, cmat_test = predict( NSFR, test_loader, args, device, writer, split='test') print("=== ACCURACY ===") print("training acc: ", acc) print("val acc: ", acc_val) print("test acc: ", acc_test) print("=== CONFUSION MATRIX ===") print('training:') print(cmat) print('val:') print(cmat_val) print('test:') print(cmat_test)
if __name__ == "__main__": main()