import torch
import torch.nn as nn
from fol.logic import NeuralPredicate
from tqdm import tqdm
[docs]class FactsConverter(nn.Module):
"""
FactsConverter converts the output fromt the perception module to the valuation vector.
"""
def __init__(self, lang, perception_module, valuation_module, device=None):
super(FactsConverter, self).__init__()
self.e = perception_module.e
self.d = perception_module.d
self.lang = lang
self.vm = valuation_module # valuation functions
self.device = device
def __str__(self):
return "FactsConverter(entities={}, dimension={})".format(self.e, self.d)
def __repr__(self):
return "FactsConverter(entities={}, dimension={})".format(self.e, self.d)
[docs] def forward(self, Z, G, B):
return self.convert(Z, G, B)
[docs] def get_params(self):
return self.vm.get_params()
[docs] def init_valuation(self, n, batch_size):
v = torch.zeros((batch_size, n)).to(self.device)
v[:, 1] = 1.0
return v
[docs] def filter_by_datatype():
pass
[docs] def to_vec(self, term, zs):
pass
def __convert(self, Z, G):
# Z: batched output
vs = []
for zs in tqdm(Z):
vs.append(self.convert_i(zs, G))
return torch.stack(vs)
[docs] def convert(self, Z, G, B):
batch_size = Z.size(0)
# V = self.init_valuation(len(G), Z.size(0))
V = torch.zeros((batch_size, len(G))).to(
torch.float32).to(self.device)
for i, atom in enumerate(G):
if type(atom.pred) == NeuralPredicate and i > 1:
V[:, i] = self.vm(Z, atom)
elif atom in B:
# V[:, i] += 1.0
V[:, i] += torch.ones((batch_size, )).to(
torch.float32).to(self.device)
V[:, 1] = torch.ones((batch_size, )).to(
torch.float32).to(self.device)
return V
[docs] def convert_i(self, zs, G):
v = self.init_valuation(len(G))
for i, atom in enumerate(G):
if type(atom.pred) == NeuralPredicate and i > 1:
v[i] = self.vm.eval(atom, zs)
return v
[docs] def call(self, pred):
return pred