Source code for src.infer


import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_utils import softor, weight_sum


[docs]def init_identity_weights(X, device): ones = torch.ones((X.size(0), ), dtype=torch.float32) * 100 return torch.diag(ones).to(device)
[docs]class InferModule(nn.Module): def __init__(self, I, infer_step, gamma=0.01, device=None, train=False, m=1, I_bk=None): """ In the constructor we instantiate two nn.Linear modules and assign them as member variables. """ super(InferModule, self).__init__() self.I = I self.I_bk = I_bk self.infer_step = infer_step self.m = m self.C = self.I.size(0) self.G = self.I.size(1) self.gamma = gamma self.device = device self.train_ = train self.beta = 0.1 # softmax temperature if not train: self.W = init_identity_weights(I, device) else: # to learng the clause weights, initialize W as follows: self.W = nn.Parameter(torch.tensor( # np.random.normal(size=(m, I.size(0))), requires_grad=True, dtype=torch.float32).to(device)) np.random.rand(m, I.size(0)), requires_grad=True, dtype=torch.float32).to(device)) # clause functions self.cs = [ClauseFunction(I[i], gamma=gamma) for i in range(self.I.size(0))] if not I_bk is None: self.cs_bk = [ClauseFunction(I_bk[i], gamma=gamma) for i in range(self.I_bk.size(0))] self.W_bk = init_identity_weights(I_bk, device) #print("W: ", self.W.shape) #print("W_bk: ", self.W_bk) # assert m == self.C, "Invalid m and C: " + \ # str(m) + ' and ' + str(self.C)
[docs] def get_params(self): assert self.train_, "Infer module is not in training mode." return [self.W]
[docs] def forward(self, x): """ In the forward function we accept a Tensor of input data and we must return a Tensor of output data. We can use Modules defined in the constructor as well as arbitrary operators on Tensors. """ R = x if self.I_bk is None: for t in range(self.infer_step): R = softor([R, self.r(R)], dim=1, gamma=self.gamma) else: for t in range(self.infer_step): #R = softor([R, self.r_bk(R)], dim=1, gamma=self.gamma) R = softor([R, self.r(R), self.r_bk(R)], dim=1, gamma=self.gamma) return R
[docs] def r(self, x): B = x.size(0) # batch size # apply each clause c_i and stack to a tensor C # C * B * G C = torch.stack([self.cs[i](x) for i in range(self.I.size(0))], 0) # taking weighted sum using m weights and stack to a tensor H # m * C # W_star = torch.softmax(self.W * (1 / self.beta), 1) W_star = torch.softmax(self.W, 1) # m * C * B * G W_tild = W_star.unsqueeze( dim=-1).unsqueeze(dim=-1).expand(self.m, self.C, B, self.G) # m * C * B * G C_tild = C.unsqueeze(dim=0).expand(self.m, self.C, B, self.G) # m * B * G H = torch.sum(W_tild * C_tild, dim=1) # taking soft or to compose a logic program with m clauses # B * G R = softor(H, dim=0, gamma=self.gamma) return R
[docs] def r_bk(self, x): B = x.size(0) # batch size # apply each clause c_i and stack to a tensor C # C * B * G C = torch.stack([self.cs_bk[i](x) for i in range(self.I_bk.size(0))], 0) # B * G return softor(C, dim=0, gamma=self.gamma) # taking weighted sum using m weights and stack to a tensor H # m * C W_star = torch.softmax(self.W_bk, 1) # m * C * B * G W_tild = W_star.unsqueeze( dim=-1).unsqueeze(dim=-1).expand(self.m, self.C, B, self.G) # m * C * B * G C_tild = C.unsqueeze(dim=0).expand(self.m, self.C, B, self.G) # m * B * G #H = torch.sum(W_tild * C_tild, dim=1) # taking soft or to compose a logic program with m clauses # B * G R = softor(H, dim=0, gamma=self.gamma) return R
[docs]class ClauseInferModule(nn.Module): def __init__(self, I, infer_step, gamma=0.01, device=None, train=False, m=1, I_bk=None): """ Infer module using each clause. The result is not amalgamated in terms of clauses. """ super(ClauseInferModule, self).__init__() self.I = I self.I_bk = I_bk self.infer_step = infer_step self.m = m self.C = self.I.size(0) self.G = self.I.size(1) self.gamma = gamma self.device = device self.train_ = train if not train: self.W = init_identity_weights(I, device) else: # to learng the clause weights, initialize W as follows: self.W = nn.Parameter(torch.Tensor( np.random.normal(size=(m, I.size(0)))).to(device)) # clause functions self.cs = [ClauseFunction(I[i], gamma=gamma) for i in range(self.I.size(0))] if not self.I_bk is None: self.cs_bk = [ClauseFunction(I_bk[i], gamma=gamma) for i in range(self.I_bk.size(0))] if not I_bk is None: self.W_bk = init_identity_weights(I_bk, device) assert m == self.C, "Invalid m and C: " + \ str(m) + ' and ' + str(self.C)
[docs] def forward(self, x): """ In the forward function we accept a Tensor of input data and we must return a Tensor of output data. We can use Modules defined in the constructor as well as arbitrary operators on Tensors. """ B = x.size(0) # C * B * G R = x.unsqueeze(dim=0).expand(self.C, B, self.G) if self.I_bk is None: for t in range(self.infer_step): R = softor([R, self.r(R)], dim=1, gamma=self.gamma) else: for t in range(self.infer_step): # infer by background knowledge #r_bk = self.r_bk(R[0]) #R_bk = self.r_bk(r_bk).unsqueeze(dim=0).expand(self.C, B, self.G) #R = R_bk #print("R: ", R.shape) #print("r(R): ", self.r(R).shape) #print("r_bk(R): ", self.r_bk(R).shape) # shape? dim? R = softor([R, self.r(R), self.r_bk(R).unsqueeze( dim=0).expand(self.C, B, self.G)], dim=2, gamma=self.gamma) return R
[docs] def r(self, x): # x: C * B * G B = x.size(1) # batch size # apply each clause c_i and stack to a tensor C # C * B * G # infer from i-th valuation tensor using i-th clause C = torch.stack([self.cs[i](x[i]) for i in range(self.I.size(0))], 0) return C
[docs] def r_bk(self, x): x = x[0] B = x.size(0) # batch size # apply each clause c_i and stack to a tensor C # C * B * G # just use the first row C = torch.stack([self.cs_bk[i](x) for i in range(self.I_bk.size(0))], 0) # B * G return softor(C, dim=0, gamma=self.gamma) # taking weighted sum using m weights and stack to a tensor H # m * C W_star = torch.softmax(self.W_bk, 1) # m * C * B * G W_tild = W_star.unsqueeze( dim=-1).unsqueeze(dim=-1).expand(self.m, self.C, B, self.G) # m * C * B * G C_tild = C.unsqueeze(dim=0).expand(self.m, self.C, B, self.G) # m * B * G H = torch.sum(W_tild * C_tild, dim=1) # taking soft or to compose a logic program with m clauses # B * G R = softor(H, dim=0, gamma=self.gamma) return R
[docs]class ClauseFunction(nn.Module): """ A class of the clause function. """ def __init__(self, I_i, gamma=0.01): super(ClauseFunction, self).__init__() # self.i = i # clause index self.I_i = I_i # index tensor C * S * G, S is the number of possible substituions self.L = I_i.size(-1) # number of body atoms self.S = I_i.size(-2) # max number of possible substitutions self.gamma = gamma
[docs] def forward(self, x): batch_size = x.size(0) # batch size # B * G V = x # G * S * b #I_i = self.I[self.i, :, :, :] # B * G -> B * G * S * L V_tild = V.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, self.S, self.L) # G * S * L -> B * G * S * L I_i_tild = self.I_i.repeat(batch_size, 1, 1, 1) # B * G C = softor(torch.prod(torch.gather(V_tild, 1, I_i_tild), 3), dim=2, gamma=self.gamma) return C