from fol.logic import *
from nsfr_utils import update_nsfr_clauses, get_prob, get_nsfr_model
#from eval_clause import EvalInferModule
from refinement import RefinementGenerator
from tqdm import tqdm
import torch
import numpy as np
[docs]class ClauseGenerator(object):
"""
clause generator by refinement and beam search
Parameters
----------
ilp_problem : .ilp_problem.ILPProblem
infer_step : int
number of steps in forward inference
max_depth : int
max depth of nests of function symbols
max_body_len : int
max number of atoms in body of clauses
"""
def __init__(self, args, NSFR, lang, pos_data_loader, mode_declarations, bk_clauses, device, no_xil=False):
self.args = args
self.NSFR = NSFR
self.lang = lang
self.mode_declarations = mode_declarations
self.bk_clauses = bk_clauses
self.device = device
self.no_xil = no_xil
self.rgen = RefinementGenerator(lang=lang, mode_declarations=mode_declarations)
self.pos_loader = pos_data_loader
self.bce_loss = torch.nn.BCELoss()
#self.labels = torch.cat([
# torch.ones((len(self.ilp_problem.pos), )),
#], dim=0).to(device)
def _is_valid(self, clause):
obj_num = len([b for b in clause.body if b.pred.name == 'in'])
attr_body = [b for b in clause.body if b.pred.name != 'in']
attr_vars = []
for b in attr_body:
dtypes = b.pred.dtypes
for i, term in enumerate(b.terms):
if dtypes[i].name == 'object' and term.is_var():
attr_vars.append(term)
attr_vars = list(set(attr_vars))
#print(clause, obj_num, attr_vars)
return obj_num == len(attr_vars) #or len(attr_body) == 0
def _cf0(self, clause):
"""Confounded rule for CLEVR-Hans.
not gray
"""
for bi in clause.body:
if bi.pred.name == 'color' and str(bi.terms[-1]) == 'gray':
return True
return False
def _cf1(self, clause):
"""not metal sphere.
"""
for bi in clause.body:
for bj in clause.body:
if bi.pred.name == 'material' and str(bi.terms[-1]) == 'gray':
if bj.pred.name == 'shape' and str(bj.terms[-1]) == 'sphere':
return True
return False
def _is_confounded(self, clause):
if self.no_xil:
return False
if self.args.dataset_type == 'kandinsky':
return False
else:
if self.args.dataset == 'clevr-hans0':
return self._cf0(clause)
elif self.args.dataset == 'clevr-hans1':
return self._cf1(clause)
else:
return False
[docs] def generate(self, C_0, gen_mode='beam', T_beam=7, N_beam=20, N_max=100):
"""
call clause generation function with or without beam-searching
Inputs
------
C_0 : Set[.logic.Clause]
a set of initial clauses
gen_mode : string
a generation mode
'beam' - with beam-searching
'naive' - without beam-searching
T_beam : int
number of steps in beam-searching
N_beam : int
size of the beam
N_max : int
maximum number of clauses to be generated
Returns
-------
C : Set[.logic.Clause]
set of generated clauses
"""
if gen_mode == 'beam':
return self.beam_search(C_0, T_beam=T_beam, N_beam=N_beam, N_max=N_max)
elif gen_mode == 'naive':
return self.naive(C_0, T_beam=T_beam, N_max=N_max)
[docs] def beam_search_clause(self, clause, T_beam=7, N_beam=20, N_max=100, th=0.98):
"""
perform beam-searching from a clause
Inputs
------
clause : Clause
initial clause
T_beam : int
number of steps in beam-searching
N_beam : int
size of the beam
N_max : int
maximum number of clauses to be generated
Returns
-------
C : Set[.logic.Clause]
a set of generated clauses
"""
step = 0
init_step = 0
B = [clause]
C = set()
C_dic = {}
B_ = []
lang = self.lang
while step < T_beam:
# print('Beam step: ', str(step), 'Beam: ', len(B))
B_new = {}
refs = []
for c in B:
refs_i = self.rgen.refinement_clause(c)
# remove invalid clauses
###refs_i = [x for x in refs_i if self._is_valid(x)]
# remove already appeared refs
refs_i = list(set(refs_i).difference(set(B_)))
B_.extend(refs_i)
refs.extend(refs_i)
if self._is_valid(c) and not self._is_confounded(c):
C = C.union(set([c]))
print("Added: ", c)
print('Evaluating ', len(refs), 'generated clauses.')
loss_list = self.eval_clauses(refs)
for i, ref in enumerate(refs):
# check duplication
if not self.is_in_beam(B_new, ref, loss_list[i]):
B_new[ref] = loss_list[i]
C_dic[ref] = loss_list[i]
#if len(C) >= N_max:
# break
B_new_sorted = sorted(B_new.items(), key=lambda x: x[1], reverse=True)
# top N_beam refiements
B_new_sorted = B_new_sorted[:N_beam]
#B_new_sorted = [x for x in B_new_sorted if x[1] > th]
for x in B_new_sorted:
print(x[1], x[0])
B = [x[0] for x in B_new_sorted]
step += 1
if len(B) == 0:
break
#if len(C) >= N_max:
# break
return C
[docs] def is_in_beam(self, B, clause, score):
"""If score is the same, same predicates => duplication
"""
score = score.detach().cpu().numpy()
preds = set([clause.head.pred] + [b.pred for b in clause.body])
y = False
for ci, score_i in B.items():
score_i = score_i.detach().cpu().numpy()
preds_i = set([clause.head.pred] + [b.pred for b in clause.body])
if preds == preds_i and np.abs(score - score_i) < 1e-2:
y = True
#print("duplicated: ", clause, ci)
break
return y
[docs] def beam_search(self, C_0, T_beam=7, N_beam=20, N_max=100):
"""
generate clauses by beam-searching from initial clauses
Inputs
------
C_0 : Set[.logic.Clause]
set of initial clauses
T_beam : int
number of steps in beam-searching
N_beam : int
size of the beam
N_max : int
maximum number of clauses to be generated
Returns
-------
C : Set[.logic.Clause]
a set of generated clauses
"""
C = set()
for clause in C_0:
C = C.union(self.beam_search_clause(
clause, T_beam, N_beam, N_max))
C = sorted(list(C))
print('======= BEAM SEARCHED CLAUSES ======')
for c in C:
print(c)
return C
[docs] def eval_clauses(self, clauses):
C = len(clauses)
print("Eval clauses: ", len(clauses))
# update infer module with new clauses
#NSFR = update_nsfr_clauses(self.NSFR, clauses, self.bk_clauses, self.device)
NSFR = get_nsfr_model(self.args, self.lang, clauses, self.NSFR.atoms, self.NSFR.bk, self.bk_clauses, self.device)
# TODO: Compute loss for validation data , score is bce loss
# N C B G
predicted_list_list = []
score = torch.zeros((C, )).to(self.device)
N_data = 0
# List(C*B*G)
for i, sample in tqdm(enumerate(self.pos_loader, start=0)):
imgs, target_set = map(lambda x: x.to(self.device), sample)
#print(NSFR.clauses)
N_data += imgs.size(0)
B = imgs.size(0)
# C * B * G
V_T_list = NSFR.clause_eval(imgs).detach()
C_score = torch.zeros((C, B)).to(self.device)
for i, V_T in enumerate(V_T_list):
# for each clause
# B
#print(V_T.shape)
predicted = NSFR.predict(v=V_T, predname='kp').detach()
#print("clause: ", clauses[i])
#NSFR.print_valuation_batch(V_T)
#print(predicted)
#predicted = self.bce_loss(predicted, target_set)
#predicted = torch.abs(predicted - target_set)
#print(predicted)
C_score[i] = predicted
# C
# sum over positive prob
C_score = C_score.sum(dim=1)
score += C_score
#return score
#score = 1 - score.detach().cpu().numpy() / N_data
return score