import argparse
import numpy as np
from sklearn.metrics import accuracy_score, recall_score
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from nsfr_utils import denormalize_kandinsky, get_data_loader, get_prob, get_nsfr_model
from nsfr_utils import save_images_with_captions, to_plot_images_kandinsky, generate_captions
from logic_utils import get_lang
[docs]def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, default=1,
help="Batch size to infer with")
parser.add_argument("--e", type=int, default=4,
help="The maximum number of objects in one image")
parser.add_argument("--m", type=int, default=3)
parser.add_argument("--dataset", choices=["twopairs", "threepairs", "red-triangle", "closeby",
"online", "online-pair", "nine-circles"], help="Use kandinsky patterns dataset")
parser.add_argument("--dataset-type", default="kandinsky",
help="kandinsky or clevr")
parser.add_argument("--small-data", action="store_true", help="Use small training data.")
parser.add_argument('--device', default='cpu',
help='cuda device, i.e. 0 or cpu')
parser.add_argument("--no-cuda", action="store_true",
help="Run on CPU instead of GPU (not recommended)")
parser.add_argument("--num-workers", type=int, default=4,
help="Number of threads for data loader")
parser.add_argument('--gamma', default=0.01, type=float,
help='Smooth parameter in the softor function')
parser.add_argument("--plot", action="store_true",
help="Plot images with captions.")
args = parser.parse_args()
return args
[docs]def predict(NSFR, loader, args, device, writer, th=None, split='train'):
predicted_list = []
target_list = []
count = 0
for i, sample in tqdm(enumerate(loader, start=0)):
if i > 100:
break
# to cuda
imgs, target_set = map(lambda x: x.to(device), sample)
# infer and predict the target probability
V_T = NSFR(imgs)
predicted = get_prob(V_T, NSFR, args)
predicted_list.append(predicted)
target_list.append(target_set)
if args.plot:
imgs = to_plot_images_kandinsky(imgs)
captions = generate_captions(
V_T, NSFR.atoms, NSFR.pm.e, th=0.3)
save_images_with_captions(
imgs, captions, folder='result/kandinsky/' + args.dataset + '/' + split + '/', img_id_start=count, dataset=args.dataset)
count += V_T.size(0) # batch size
predicted = torch.cat(predicted_list, dim=0).detach().cpu().numpy()
target_set = torch.cat(target_list, dim=0).to(
torch.int64).detach().cpu().numpy()
if th == None:
fpr, tpr, thresholds = roc_curve(target_set, predicted, pos_label=1)
accuracy_scores = []
print('ths', thresholds)
for thresh in thresholds:
accuracy_scores.append(accuracy_score(
target_set, [m > thresh for m in predicted]))
accuracies = np.array(accuracy_scores)
max_accuracy = accuracies.max()
max_accuracy_threshold = thresholds[accuracies.argmax()]
rec_score = recall_score(
target_set, [m > thresh for m in predicted], average=None)
print('target_set: ', target_set, target_set.shape)
print('predicted: ', predicted, predicted.shape)
print('accuracy: ', max_accuracy)
print('threshold: ', max_accuracy_threshold)
print('recall: ', rec_score)
return max_accuracy, rec_score, max_accuracy_threshold
else:
accuracy = accuracy_score(target_set, [m > th for m in predicted])
rec_score = recall_score(
target_set, [m > th for m in predicted], average=None)
return accuracy, rec_score, th
[docs]def main():
args = get_args()
print('args ', args)
if args.no_cuda:
device = torch.device('cpu')
elif len(args.device.split(',')) > 1:
# multi gpu
device = torch.device('cuda')
else:
device = torch.device('cuda:' + args.device)
print('device: ', device)
run_name = 'predict/' + args.dataset
writer = SummaryWriter(f"runs/{run_name}", purge_step=0)
# get torch data loader
train_loader, val_loader, test_loader = get_data_loader(args)
# load logical representations
lark_path = 'src/lark/exp.lark'
lang_base_path = 'data/lang/'
lang, clauses, bk_clauses, bk, atoms = get_lang(
lark_path, lang_base_path, args.dataset_type, args.dataset)
print("clauses: ", clauses)
# Neuro-Symbolic Forward Reasoner for clause generation
NSFR = get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device=device)
#lang, clauses, bk, atoms = get_lang(
# lark_path, lang_base_path, args.dataset_type, args.dataset)
# Neuro-Symbolic Forward Reasoner
#NSFR = get_nsfr_model(args, lang, clauses, atoms, bk, device)
# validation split
print("Predicting on validation data set...")
acc_val, rec_val, th_val = predict(
NSFR, val_loader, args, device, writer, th=0.33, split='val')
print("Predicting on training data set...")
# training split
acc, rec, th = predict(
NSFR, train_loader, args, device, writer, th=th_val, split='train')
print("Predicting on test data set...")
# test split
acc_test, rec_test, th_test = predict(
NSFR, test_loader, args, device, writer, th=th_val, split='test')
print("training acc: ", acc, "threashold: ", th, "recall: ", rec)
print("val acc: ", acc_val, "threashold: ", th_val, "recall: ", rec_val)
print("test acc: ", acc_test, "threashold: ", th_test, "recall: ", rec_test)
if __name__ == "__main__":
main()