import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import data_clevr
import data_kandinsky
from percept import SlotAttentionPerceptionModule, YOLOPerceptionModule
from facts_converter import FactsConverter
from nsfr import NSFReasoner
from logic_utils import build_infer_module, build_clause_infer_module
from valuation import SlotAttentionValuationModule, YOLOValuationModule
attrs = ['color', 'shape', 'material', 'size']
[docs]def update_initial_clauses(clauses, obj_num):
assert len(clauses) == 1, "Too many initial clauses."
clause = clauses[0]
clause.body = clause.body[:obj_num]
return [clause]
[docs]def valuation_to_attr_string(v, atoms, e, th=0.5):
"""Generate string explanations of the scene.
"""
st = ''
for i in range(e):
st_i = ''
for j, atom in enumerate(atoms):
#print(atom, [str(term) for term in atom.terms])
if 'obj' + str(i) in [str(term) for term in atom.terms] and atom.pred.name in attrs:
if v[j] > th:
prob = np.round(v[j].detach().cpu().numpy(), 2)
st_i += str(prob) + ':' + str(atom) + ','
if st_i != '':
st_i = st_i[:-1]
st += st_i + '\n'
return st
[docs]def valuation_to_rel_string(v, atoms, th=0.5):
l = 100
st = ''
n = 0
for j, atom in enumerate(atoms):
if v[j] > th and not (atom.pred.name in attrs+['in', '.']):
prob = np.round(v[j].detach().cpu().numpy(), 2)
st += str(prob) + ':' + str(atom) + ','
n += len(str(prob) + ':' + str(atom) + ',')
if n > l:
st += '\n'
n = 0
return st[:-1] + '\n'
[docs]def valuation_to_string(v, atoms, e, th=0.5):
return valuation_to_attr_string(v, atoms, e, th) + valuation_to_rel_string(v, atoms, th)
[docs]def valuations_to_string(V, atoms, e, th=0.5):
"""Generate string explanation of the scenes.
"""
st = ''
for i in range(V.size(0)):
st += 'image ' + str(i) + '\n'
# for each data in the batch
st += valuation_to_string(V[i], atoms, e, th)
return st
[docs]def generate_captions(V, atoms, e, th):
captions = []
for v in V:
# for each data in the batch
captions.append(valuation_to_string(v, atoms, e, th))
return captions
[docs]def save_images_with_captions(imgs, captions, folder, img_id_start, dataset):
if not os.path.exists(folder):
os.makedirs(folder)
if dataset == 'online-pair':
figsize = (15, 15)
elif dataset == 'red-triangle':
figsize = (10, 8)
else:
figsize = (12, 6)
# imgs should be denormalized.
img_id = img_id_start
for i, img in enumerate(imgs):
plt.figure(figsize=figsize, dpi=80)
plt.imshow(img)
plt.xlabel(captions[i])
plt.tight_layout()
plt.savefig(folder+str(img_id)+'.png')
img_id += 1
plt.close()
[docs]def denormalize_clevr(imgs):
"""denormalize clevr images
"""
# normalizing: image = (image - 0.5) * 2.0 # Rescale to [-1, 1].
return (0.5 * imgs) + 0.5
[docs]def denormalize_kandinsky(imgs):
"""denormalize kandinsky images
"""
return imgs
[docs]def to_plot_images_clevr(imgs):
return [img.permute(1, 2, 0).detach().numpy() for img in denormalize_clevr(imgs)]
[docs]def to_plot_images_kandinsky(imgs):
return [img.permute(1, 2, 0).detach().numpy() for img in denormalize_kandinsky(imgs)]
[docs]def get_data_loader(args):
if args.dataset_type == 'kandinsky':
return get_kandinsky_loader(args)
elif args.dataset_type == 'clevr':
return get_clevr_loader(args)
else:
assert 0, 'Invalid dataset type: ' + args.dataset_type
[docs]def get_data_pos_loader(args):
if args.dataset_type == 'kandinsky':
return get_kandinsky_pos_loader(args)
elif args.dataset_type == 'clevr':
return get_clevr_pos_loader(args)
else:
assert 0, 'Invalid dataset type: ' + args.dataset_type
[docs]def get_clevr_loader(args):
dataset_train = data_clevr.CLEVRHans(
args.dataset, 'train'
)
dataset_val = data_clevr.CLEVRHans(
args.dataset, 'val'
)
dataset_test = data_clevr.CLEVRHans(
args.dataset, 'test'
)
train_loader = torch.utils.data.DataLoader(
dataset_train,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
val_loader = torch.utils.data.DataLoader(
dataset_val,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
test_loader = torch.utils.data.DataLoader(
dataset_test,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
return train_loader, val_loader, test_loader
[docs]def get_kandinsky_loader(args, shuffle=False):
dataset_train = data_kandinsky.KANDINSKY(
args.dataset, 'train', small_data=args.small_data
)
dataset_val = data_kandinsky.KANDINSKY(
args.dataset, 'val', small_data=args.small_data
)
dataset_test = data_kandinsky.KANDINSKY(
args.dataset, 'test'
)
train_loader = torch.utils.data.DataLoader(
dataset_train,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
val_loader = torch.utils.data.DataLoader(
dataset_val,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
test_loader = torch.utils.data.DataLoader(
dataset_test,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
return train_loader, val_loader, test_loader
[docs]def get_kandinsky_pos_loader(args, shuffle=False):
dataset_train = data_kandinsky.KANDINSKY_POSITIVE(
args.dataset, 'train', small_data=args.small_data
)
dataset_val = data_kandinsky.KANDINSKY_POSITIVE(
args.dataset, 'val', small_data=args.small_data
)
dataset_test = data_kandinsky.KANDINSKY_POSITIVE(
args.dataset, 'test'
)
train_loader = torch.utils.data.DataLoader(
dataset_train,
shuffle=shuffle,
batch_size=args.batch_size_bs,
num_workers=args.num_workers,
)
val_loader = torch.utils.data.DataLoader(
dataset_val,
shuffle=False,
batch_size=args.batch_size_bs,
num_workers=args.num_workers,
)
test_loader = torch.utils.data.DataLoader(
dataset_test,
shuffle=False,
batch_size=args.batch_size_bs,
num_workers=args.num_workers,
)
return train_loader, val_loader, test_loader
[docs]def get_clevr_pos_loader(args):
dataset_train = data_clevr.CLEVRHans_POSITIVE(
args.dataset, 'train'
)
dataset_val = data_clevr.CLEVRHans_POSITIVE(
args.dataset, 'val'
)
dataset_test = data_clevr.CLEVRHans_POSITIVE(
args.dataset, 'test'
)
train_loader = torch.utils.data.DataLoader(
dataset_train,
shuffle=True,
batch_size=args.batch_size_bs,
num_workers=args.num_workers,
)
val_loader = torch.utils.data.DataLoader(
dataset_val,
shuffle=False,
batch_size=args.batch_size_bs,
num_workers=args.num_workers,
)
test_loader = torch.utils.data.DataLoader(
dataset_test,
shuffle=False,
batch_size=args.batch_size_bs,
num_workers=args.num_workers,
)
return train_loader, val_loader, test_loader
[docs]def get_prob(v_T, NSFR, args):
"""
if args.dataset_type == 'kandinsky':
predicted = NSFR.predict(v=v_T, predname='kp')
elif args.dataset_type == 'clevr':
if args.dataset == 'clevr-hans3':
predicted = NSFR.predict_multi(
v=v_T, prednames=['kp1', 'kp2', 'kp3'])
if args.dataset == 'clevr-hans7':
predicted = NSFR.predict_multi(
v=v_T, prednames=['kp1', 'kp2', 'kp3', 'kp4', 'kp5', 'kp6', 'kp7'])
"""
return NSFR.predict(v=v_T, predname='kp')
# return predicted
[docs]def get_prob_by_prednames(v_T, NSFR, prednames):
if args.dataset_type == 'kandinsky':
predicted = NSFR.predict(v=v_T, predname='kp')
elif args.dataset_type == 'clevr':
if args.dataset == 'clevr-hans3':
predicted = NSFR.predict_multi(
v=v_T, prednames=['kp1', 'kp2', 'kp3'])
if args.dataset == 'clevr-hans7':
predicted = NSFR.predict_multi(
v=v_T, prednames=['kp1', 'kp2', 'kp3', 'kp4', 'kp5', 'kp6', 'kp7'])
return predicted
[docs]def get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False):
if args.dataset_type == 'kandinsky':
PM = YOLOPerceptionModule(e=args.e, d=11, device=device)
VM = YOLOValuationModule(
lang=lang, device=device, dataset=args.dataset)
elif args.dataset_type == 'clevr':
PM = SlotAttentionPerceptionModule(e=10, d=19, device=device)
VM = SlotAttentionValuationModule(lang=lang, device=device)
else:
assert False, "Invalid dataset type: " + str(args.dataset_type)
FC = FactsConverter(lang=lang, perception_module=PM,
valuation_module=VM, device=device)
IM = build_infer_module(clauses, bk_clauses, atoms, lang,
m=args.m, infer_step=2, device=device, train=train)
CIM = build_clause_infer_module(clauses, bk_clauses, atoms, lang,
m=len(clauses), infer_step=2, device=device)
# Neuro-Symbolic Forward Reasoner
NSFR = NSFReasoner(perception_module=PM, facts_converter=FC,
infer_module=IM, clause_infer_module=CIM, atoms=atoms, bk=bk, clauses=clauses)
return NSFR
def __get_nsfr_model_from_nsfr(NSFR, lang, clauses, atoms, bk, bk_clauses, device):
lang = NSFR.lang
PM = YOLOPerceptionModule(e=NSFR.pm.e, d=11, device=device)
VM = YOLOValuationModule(lang=lang, device=device)
# elif args.dataset_type == 'clevr':
# PM = SlotAttentionPerceptionModule(e=10, d=19, device=device)
# VM = SlotAttentionValuationModule(lang=lang, device=device)
FC = FactsConverter(lang=lang, perception_module=PM,
valuation_module=VM, device=device)
IM = build_infer_module(clauses, bk_clauses, atoms, lang,
m=len(clauses), infer_step=4, device=device)
CIM = build_infer_module(clauses, bk_clauses, atoms, lang,
m=len(clauses), infer_step=4, device=device)
# Neuro-Symbolic Forward Reasoner
NSFR = NSFReasoner(perception_module=PM, facts_converter=FC,
infer_module=IM, clause_infer_module=CIM, atoms=atoms, bk=bk, clauses=clauses)
return NSFR
[docs]def update_nsfr_clauses(nsfr, clauses, bk_clauses, device):
CIM = build_clause_infer_module(
clauses, bk_clauses, nsfr.atoms, nsfr.fc.lang, m=len(clauses), device=device)
new_nsfr = NSFReasoner(perception_module=nsfr.pm, facts_converter=nsfr.fc, infer_module=nsfr.im,
clause_infer_module=CIM, atoms=nsfr.atoms, bk=nsfr.bk, clauses=clauses)
new_nsfr._summary()
del nsfr
return new_nsfr
[docs]def get_nsfr_model_train(args, lang, clauses, atoms, bk, device, m):
if args.dataset_type == 'kandinsky':
PM = YOLOPerceptionModule(e=args.e, d=11, device=device)
VM = YOLOValuationModule(
lang=lang, device=device, dataset=args.dataset)
elif args.dataset_type == 'clevr':
PM = SlotAttentionPerceptionModule(e=10, d=19, device=device)
VM = SlotAttentionValuationModule(lang=lang, device=device)
else:
assert False, "Invalid dataset type: " + str(args.dataset_type)
FC = FactsConverter(lang=lang, perception_module=PM,
valuation_module=VM, device=device)
IM = build_infer_module(clauses, atoms, lang,
m=m, infer_step=4, device=device, train=True)
# Neuro-Symbolic Forward Reasoner
NSFR = NSFReasoner(perception_module=PM, facts_converter=FC,
infer_module=IM, atoms=atoms, bk=bk, clauses=clauses)
return NSFR
def __valuation_to_text(v, atoms, e, th=0.5):
st = ''
for i in range(e):
st_i = 'object ' + str(i) + ': '
# list of indices of the atoms about obj_i
atom_indices = []
atoms = []
for j, atom in enumerate(atoms):
terms = atom.terms
if 'obj' + str(j) in [str(term) for term in terms] and atom.pred.name != 'in':
if v[j] > th:
if len(atom.terms) == 2:
st_i += str(atom.terms[1]) + ' '
if len(atom.terms) == 1:
st_i += str(atom.terms[0])
indices.append(j)
atoms.append(atom)
for j in atom_indices:
if v[j] > th:
st_i += ''
st += st_i + '\n'
return st[:-2]