import numpy as np
import torch.nn as nn
import torch
from logic_utils import get_index_by_predname
[docs]class NSFReasoner(nn.Module):
"""The Neuro-Symbolic Forward Reasoner.
Args:
perception_model (nn.Module): The perception model.
facts_converter (nn.Module): The facts converter module.
infer_module (nn.Module): The differentiable forward-chaining inference module.
atoms (list(atom)): The set of ground atoms (facts).
"""
def __init__(self, perception_module, facts_converter, infer_module, clause_infer_module, atoms, bk, clauses, train=False):
super().__init__()
self.pm = perception_module
self.fc = facts_converter
self.im = infer_module
self.cim = clause_infer_module
self.atoms = atoms
self.bk = bk
self.clauses = clauses
self._train = train
[docs] def get_clauses(self):
clause_ids = [np.argmax(w.detach().cpu().numpy()) for w in self.im.W]
return [self.clauses[ci] for ci in clause_ids]
def _summary(self):
print("facts: ", len(self.atoms))
print("I: ", self.im.I.shape)
[docs] def get_params(self):
return self.im.get_params() # + self.fc.get_params()
[docs] def forward(self, x):
# obtain the object-centric representation
zs = self.pm(x)
# convert to the valuation tensor
V_0 = self.fc(zs, self.atoms, self.bk)
# perform T-step forward-chaining reasoning
V_T = self.im(V_0)
return V_T
[docs] def clause_eval(self, x):
# obtain the object-centric representation
zs = self.pm(x)
# convert to the valuation tensor
V_0 = self.fc(zs, self.atoms, self.bk)
# perform T-step forward-chaining reasoning
V_T = self.cim(V_0)
return V_T
[docs] def predict(self, v, predname):
"""Extracting a value from the valuation tensor using a given predicate.
"""
# v: batch * |atoms|
target_index = get_index_by_predname(
pred_str=predname, atoms=self.atoms)
return v[:, target_index]
[docs] def predict_multi(self, v, prednames):
"""Extracting values from the valuation tensor using given predicates.
prednames = ['kp1', 'kp2', 'kp3']
"""
# v: batch * |atoms|
target_indices = []
for predname in prednames:
target_index = get_index_by_predname(
pred_str=predname, atoms=self.atoms)
target_indices.append(target_index)
prob = torch.cat([v[:, i].unsqueeze(-1)
for i in target_indices], dim=1)
B = v.size(0)
N = len(prednames)
assert prob.size(0) == B and prob.size(
1) == N, 'Invalid shape in the prediction.'
return prob
[docs] def print_program(self):
"""Print asummary of logic programs using continuous weights.
"""
print('====== LEARNED PROGRAM ======')
C = self.clauses
Ws_softmaxed = torch.softmax(self.im.W, 1)
print("Ws_softmaxed: ", np.round(
Ws_softmaxed.detach().cpu().numpy(), 2))
for i, W_ in enumerate(Ws_softmaxed):
max_i = np.argmax(W_.detach().cpu().numpy())
print('C_'+str(i)+': ',
C[max_i], np.round(np.array(W_[max_i].detach().cpu().item()), 2))
[docs] def print_valuation_batch(self, valuation, n=40):
# self.print_program()
for b in range(valuation.size(0)):
print('===== BATCH: ', b, '=====')
v = valuation[b].detach().cpu().numpy()
idxs = np.argsort(-v)
for i in idxs:
if v[i] > 0.1:
print(i, self.atoms[i], ': ', round(v[i], 3))
[docs] def get_valuation_text(self, valuation):
text_batch = '' # texts for each batch
for b in range(valuation.size(0)):
top_atoms = self.get_top_atoms(valuation[b].detach().cpu().numpy())
text = '----BATCH ' + str(b) + '----\n'
text += self.atoms_to_text(top_atoms)
text += '\n'
# texts.append(text)
text_batch += text
return text_batch
[docs] def get_top_atoms(self, v):
top_atoms = []
for i, atom in enumerate(self.atoms):
if v[i] > 0.7:
top_atoms.append(atom)
return top_atoms
[docs] def atoms_to_text(self, atoms):
text = ''
for atom in atoms:
text += str(atom) + ', '
return text