import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_utils import softor, weight_sum
[docs]def init_identity_weights(X, device):
ones = torch.ones((X.size(0), ), dtype=torch.float32) * 100
return torch.diag(ones).to(device)
[docs]class InferModule(nn.Module):
def __init__(self, I, infer_step, gamma=0.01, device=None, train=False, m=1, I_bk=None):
"""
In the constructor we instantiate two nn.Linear modules and assign them as
member variables.
"""
super(InferModule, self).__init__()
self.I = I
self.I_bk = I_bk
self.infer_step = infer_step
self.m = m
self.C = self.I.size(0)
self.G = self.I.size(1)
self.gamma = gamma
self.device = device
self.train_ = train
self.beta = 0.1 # softmax temperature
if not train:
self.W = init_identity_weights(I, device)
else:
# to learng the clause weights, initialize W as follows:
self.W = nn.Parameter(torch.tensor(
# np.random.normal(size=(m, I.size(0))), requires_grad=True, dtype=torch.float32).to(device))
np.random.rand(m, I.size(0)), requires_grad=True, dtype=torch.float32).to(device))
# clause functions
self.cs = [ClauseFunction(I[i], gamma=gamma)
for i in range(self.I.size(0))]
if not I_bk is None:
self.cs_bk = [ClauseFunction(I_bk[i], gamma=gamma)
for i in range(self.I_bk.size(0))]
self.W_bk = init_identity_weights(I_bk, device)
#print("W: ", self.W.shape)
#print("W_bk: ", self.W_bk)
# assert m == self.C, "Invalid m and C: " + \
# str(m) + ' and ' + str(self.C)
[docs] def get_params(self):
assert self.train_, "Infer module is not in training mode."
return [self.W]
[docs] def forward(self, x):
"""
In the forward function we accept a Tensor of input data and we must return
a Tensor of output data. We can use Modules defined in the constructor as
well as arbitrary operators on Tensors.
"""
R = x
if self.I_bk is None:
for t in range(self.infer_step):
R = softor([R, self.r(R)], dim=1, gamma=self.gamma)
else:
for t in range(self.infer_step):
#R = softor([R, self.r_bk(R)], dim=1, gamma=self.gamma)
R = softor([R, self.r(R), self.r_bk(R)],
dim=1, gamma=self.gamma)
return R
[docs] def r(self, x):
B = x.size(0) # batch size
# apply each clause c_i and stack to a tensor C
# C * B * G
C = torch.stack([self.cs[i](x)
for i in range(self.I.size(0))], 0)
# taking weighted sum using m weights and stack to a tensor H
# m * C
# W_star = torch.softmax(self.W * (1 / self.beta), 1)
W_star = torch.softmax(self.W, 1)
# m * C * B * G
W_tild = W_star.unsqueeze(
dim=-1).unsqueeze(dim=-1).expand(self.m, self.C, B, self.G)
# m * C * B * G
C_tild = C.unsqueeze(dim=0).expand(self.m, self.C, B, self.G)
# m * B * G
H = torch.sum(W_tild * C_tild, dim=1)
# taking soft or to compose a logic program with m clauses
# B * G
R = softor(H, dim=0, gamma=self.gamma)
return R
[docs] def r_bk(self, x):
B = x.size(0) # batch size
# apply each clause c_i and stack to a tensor C
# C * B * G
C = torch.stack([self.cs_bk[i](x)
for i in range(self.I_bk.size(0))], 0)
# B * G
return softor(C, dim=0, gamma=self.gamma)
# taking weighted sum using m weights and stack to a tensor H
# m * C
W_star = torch.softmax(self.W_bk, 1)
# m * C * B * G
W_tild = W_star.unsqueeze(
dim=-1).unsqueeze(dim=-1).expand(self.m, self.C, B, self.G)
# m * C * B * G
C_tild = C.unsqueeze(dim=0).expand(self.m, self.C, B, self.G)
# m * B * G
#H = torch.sum(W_tild * C_tild, dim=1)
# taking soft or to compose a logic program with m clauses
# B * G
R = softor(H, dim=0, gamma=self.gamma)
return R
[docs]class ClauseInferModule(nn.Module):
def __init__(self, I, infer_step, gamma=0.01, device=None, train=False, m=1, I_bk=None):
"""
Infer module using each clause.
The result is not amalgamated in terms of clauses.
"""
super(ClauseInferModule, self).__init__()
self.I = I
self.I_bk = I_bk
self.infer_step = infer_step
self.m = m
self.C = self.I.size(0)
self.G = self.I.size(1)
self.gamma = gamma
self.device = device
self.train_ = train
if not train:
self.W = init_identity_weights(I, device)
else:
# to learng the clause weights, initialize W as follows:
self.W = nn.Parameter(torch.Tensor(
np.random.normal(size=(m, I.size(0)))).to(device))
# clause functions
self.cs = [ClauseFunction(I[i], gamma=gamma)
for i in range(self.I.size(0))]
if not self.I_bk is None:
self.cs_bk = [ClauseFunction(I_bk[i], gamma=gamma)
for i in range(self.I_bk.size(0))]
if not I_bk is None:
self.W_bk = init_identity_weights(I_bk, device)
assert m == self.C, "Invalid m and C: " + \
str(m) + ' and ' + str(self.C)
[docs] def forward(self, x):
"""
In the forward function we accept a Tensor of input data and we must return
a Tensor of output data. We can use Modules defined in the constructor as
well as arbitrary operators on Tensors.
"""
B = x.size(0)
# C * B * G
R = x.unsqueeze(dim=0).expand(self.C, B, self.G)
if self.I_bk is None:
for t in range(self.infer_step):
R = softor([R, self.r(R)], dim=1, gamma=self.gamma)
else:
for t in range(self.infer_step):
# infer by background knowledge
#r_bk = self.r_bk(R[0])
#R_bk = self.r_bk(r_bk).unsqueeze(dim=0).expand(self.C, B, self.G)
#R = R_bk
#print("R: ", R.shape)
#print("r(R): ", self.r(R).shape)
#print("r_bk(R): ", self.r_bk(R).shape)
# shape? dim?
R = softor([R, self.r(R), self.r_bk(R).unsqueeze(
dim=0).expand(self.C, B, self.G)], dim=2, gamma=self.gamma)
return R
[docs] def r(self, x):
# x: C * B * G
B = x.size(1) # batch size
# apply each clause c_i and stack to a tensor C
# C * B * G
# infer from i-th valuation tensor using i-th clause
C = torch.stack([self.cs[i](x[i])
for i in range(self.I.size(0))], 0)
return C
[docs] def r_bk(self, x):
x = x[0]
B = x.size(0) # batch size
# apply each clause c_i and stack to a tensor C
# C * B * G
# just use the first row
C = torch.stack([self.cs_bk[i](x)
for i in range(self.I_bk.size(0))], 0)
# B * G
return softor(C, dim=0, gamma=self.gamma)
# taking weighted sum using m weights and stack to a tensor H
# m * C
W_star = torch.softmax(self.W_bk, 1)
# m * C * B * G
W_tild = W_star.unsqueeze(
dim=-1).unsqueeze(dim=-1).expand(self.m, self.C, B, self.G)
# m * C * B * G
C_tild = C.unsqueeze(dim=0).expand(self.m, self.C, B, self.G)
# m * B * G
H = torch.sum(W_tild * C_tild, dim=1)
# taking soft or to compose a logic program with m clauses
# B * G
R = softor(H, dim=0, gamma=self.gamma)
return R
[docs]class ClauseFunction(nn.Module):
"""
A class of the clause function.
"""
def __init__(self, I_i, gamma=0.01):
super(ClauseFunction, self).__init__()
# self.i = i # clause index
self.I_i = I_i # index tensor C * S * G, S is the number of possible substituions
self.L = I_i.size(-1) # number of body atoms
self.S = I_i.size(-2) # max number of possible substitutions
self.gamma = gamma
[docs] def forward(self, x):
batch_size = x.size(0) # batch size
# B * G
V = x
# G * S * b
#I_i = self.I[self.i, :, :, :]
# B * G -> B * G * S * L
V_tild = V.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, self.S, self.L)
# G * S * L -> B * G * S * L
I_i_tild = self.I_i.repeat(batch_size, 1, 1, 1)
# B * G
C = softor(torch.prod(torch.gather(V_tild, 1, I_i_tild), 3),
dim=2, gamma=self.gamma)
return C