import itertools
import torch
from tqdm import tqdm
from fol.logic_ops import unify, subs_list, subs
[docs]class TensorEncoder(object):
"""The tensor encoder for differentiable inference.
A class for tensor encoding in differentiable forward-chaining approach.
Args:
lang (language): The language of first-order logic.
facts (list(atom)): The set of ground atoms (facts).
clauses (list(clause)): The set of clauses (rules).
device (torch.device): The device to be used.
Attrs:
lang (language): The language of first-order logic.
facts (list(atom)): The set of ground atoms (facts).
clauses (list(clause)): The set of clauses (rules).
G (int): The number of ground atoms.
C (int): The number of clauses.
L (int): The maximum length of the clauses.
S (int): The maximum number of substitutions for body atoms.
head_unifier_dic ({(atom, atom) -> List[(var, const)]}): The dictionary to save the computed unification results.
fact_index_dic ({atom -> int}): The dictionary that maps an atom to its index.
"""
def __init__(self, lang, facts, clauses, device):
self.lang = lang
self.facts = facts
self.clauses = clauses
self.device = device
self.G = len(facts)
self.C = len(clauses)
# call before computing S and L
#self.head_unifier_dic = self.build_head_unifier_dic()
self.fact_index_dic = self.build_fact_index_dic()
self.S = self.get_max_subs_num(clauses)
self.L = max([len(clause.body)
for clause in clauses] + [1])
[docs] def get_max_subs_num(self, clauses):
"""Compute S (the maximum numebr of substitutions for body atoms) from clauses.
Args:
clauses (list(clause)): A set of clauses.
Returns:
S (int): The maximum number of substitutions for existentially quantified variables in the body atoms.
"""
S_list = []
for clause in clauses:
#print("clause: ", clause)
for fi, fact in enumerate(self.facts):
#if (clause.head, fact) in self.head_unifier_dic:
#theta = self.head_unifier_dic[(clause.head, fact)]
unify_flag, theta = unify([clause.head, fact])
if unify_flag:
clause_ = subs_list(clause, theta)
body = clause_.body
theta_list = self.generate_subs(body)
S_list.append(len(theta_list))
return max(S_list)
[docs] def encode(self):
"""Compute the index tensor for the differentiable inference.
Returns
I (tensor): The index tensor (G, C, S, L).
"""
I = torch.zeros((self.C, self.G, self.S, self.L),
dtype=torch.long).to(self.device)
for ci, clause in enumerate(self.clauses):
#print("CLAUSE: ", clause)
I_c = self.build_I_c(clause)
I[ci, :, :, :] = I_c
return I
[docs] def build_I_c(self, clause):
"""Build index tensor for a given clause.
Args:
clause (clause): A clause.
Returns:
I_c (tensor): The index tensor for the given clause (G, S, L).
"""
# G * S * L
I_c = torch.zeros((self.G, self.S, self.L),
dtype=torch.long).to(self.device)
#print("CLAUSE: ", clause)
for fi, fact in enumerate(self.facts):
#if (clause.head, fact) in self.head_unifier_dic:
unify_flag, theta = unify([clause.head, fact])
if unify_flag:
#theta = self.head_unifier_dic[(clause.head, fact)]
clause_ = subs_list(clause, theta)
# convert body atoms into indices
I_c_b = self.body_to_tensor(clause_.body)
I_c[fi] = I_c_b
return I_c
[docs] def build_fact_index_dic(self):
"""Build dictionary {fact -> index}
Returns:
dic ({atom -> int}): A dictionary to map the atoms to indices.
"""
dic = {}
for i, fact in enumerate(self.facts):
dic[fact] = i
return dic
[docs] def build_head_unifier_dic(self):
"""Build dictionary {(head, fact) -> unifier}.
Returns:
dic ({(atom,atom) -> subtitution}): A dictionary to map the pair of ground atoms to their unifier.
"""
dic = {}
heads = set([c.head for c in self.clauses])
for head in heads:
for fi, fact in enumerate(self.facts):
unify_flag, theta_list = unify([head, fact])
if unify_flag:
dic[(head, fact)] = theta_list
return dic
# taking constant modes to reduce the number of substituions
[docs] def body_to_tensor(self, body):
"""Convert the body atoms into a tensor.
Args:
body (list(atom)): The body atoms.
Returns:
I_c_b (tensor;(S * L)): The tensor representation of the body atoms.
"""
# S * L
I_c_b = torch.zeros(
(self.S, self.L), dtype=torch.long).to(self.device)
# extract all vars in the body atoms
var_list = []
for atom in body:
var_list += atom.all_vars()
var_list = list(set(var_list))
assert len(
var_list) <= 10, 'Too many existentially quantified variables in an atom: ' + str(atom)
if len(var_list) == 0:
# the case of the body atoms are already grounded
x_b = self.facts_to_index(body)
I_c_b[0] = self.pad_by_true(x_b)
for i in range(1, self.S):
I_c_b[i] = torch.zeros(self.L, dtype=torch.long).to(
self.device) # fill by FALSE
else:
# the body has existentially quantified variable!!
# e.g. body atoms: [in(img,O1),shape(O1,square)]
# theta_list: [(O1,obj1), (O1,obj2)]
theta_list = self.generate_subs(body)
n_substs = len(theta_list)
assert n_substs <= self.S, 'Exceeded the maximum number of substitution patterns to existential variables: n_substs is: ' + \
str(n_substs) + ' but max num is: ' + str(self.S)
# compute the grounded clause for each possible substitution, convert to the index tensor, and store it.
for i, theta in enumerate(theta_list):
ground_body = [subs_list(bi, theta) for bi in body]
I_c_b[i] = self.pad_by_true(
self.facts_to_index(ground_body))
# if the number of substitutions is less than the maximum number of substitions (S),
# the rest of the tensor is filled 0, which is the index of FALSE
for i in range(n_substs, self.S):
I_c_b[i] = torch.zeros(
self.L, dtype=torch.long).to(self.device)
return I_c_b
[docs] def pad_by_true(self, x):
"""Fill the tensor by ones for the clause which has less body atoms than the longest clause.
Args:
x (tensor): The tensor.
Return:
x_padded (tensor): The tensor that is padded to the shape of (S, L).
"""
assert x.size(
0) <= self.L, 'x.size(0) exceeds max_body_len: ' + str(self.L)
if x.size(0) == self.L:
return x
else:
diff = self.L - x.size(0)
x_pad = torch.ones(diff, dtype=torch.long).to(self.device)
return torch.cat([x, x_pad])
# taking constant modes to reduce the number of substitutions
[docs] def generate_subs(self, body):
"""Generate substitutions from given body atoms.
Generate the possible substitutions from given list of atoms. If the body contains any variables,
then generate the substitutions by enumerating constants that matches the data type.
!!! ASSUMPTION: The body has variables that have the same data type
e.g. variables O1(object) and Y(color) cannot appear in one clause !!!
Args:
body (list(atom)): The body atoms which may contain existentially quantified variables.
Returns:
theta_list (list(substitution)): The list of substitutions of the given body atoms.
"""
# extract all variables and corresponding data types from given body atoms
var_dtype_list = []
dtypes = []
vars = []
for atom in body:
terms = atom.terms
for i, term in enumerate(terms):
if term.is_var():
v = term
dtype = atom.pred.dtypes[i]
var_dtype_list.append((v, dtype))
dtypes.append(dtype)
vars.append(v)
# in case there is no variables in the body
if len(list(set(dtypes))) == 0:
return []
# check the data type consistency
assert len(list(set(dtypes))) == 1, "Invalid existentially quantified variables. " + \
str(len(list(set(dtypes)))) + " data types in the body: " + str(body) + " dypes: " + str(dtypes)
vars = list(set(vars))
n_vars = len(vars)
consts = self.lang.get_by_dtype(dtypes[0])
# e.g. if the data type is shape, then subs_consts_list = [(red,), (yellow,), (blue,)]
subs_consts_list = itertools.permutations(consts, n_vars)
theta_list = []
# generate substitutions by combining variables to the head of subs_consts_list
for subs_consts in subs_consts_list:
theta = []
for i, const in enumerate(subs_consts):
s = (vars[i], const)
theta.append(s)
theta_list.append(theta)
# e.g. theta_list: [[(Z, red)], [(Z, yellow)], [(Z, blue)]]
#print("theta_list: ", theta_list)
return theta_list
[docs] def facts_to_index(self, atoms):
"""Convert given ground atoms into the indices.
"""
return torch.tensor([self.get_fact_index(nf) for nf in atoms], dtype=torch.long).to(self.device)
[docs] def get_fact_index(self, fact):
"""Convert a fact to the index in the ordered set of all facts.
"""
try:
index = self.fact_index_dic[fact]
except KeyError:
index = 0
return index