import torch
import torch.nn as nn
import torch.nn.functional as F
from valuation_func import *
[docs]class YOLOValuationModule(nn.Module):
"""A module to call valuation functions.
Attrs:
lang (language): The language.
device (device): The device.
layers (list(nn.Module)): The list of valuation functions.
vfs (dic(str->nn.Module)): The dictionaty that maps a predicate name to the corresponding valuation function.
attrs (dic(term->tensor)): The dictionary that maps an attribute term to the corresponding one-hot encoding.
dataset (str): The dataset.
"""
def __init__(self, lang, device, dataset):
super().__init__()
self.lang = lang
self.device = device
self.layers, self.vfs = self.init_valuation_functions(device, dataset)
# attr_term -> vector representation dic
self.attrs = self.init_attr_encodings(device)
self.dataset = dataset
[docs] def init_valuation_functions(self, device, dataset=None):
"""
Args:
device (device): The device.
dataset (str): The dataset.
Retunrs:
layers (list(nn.Module)): The list of valuation functions.
vfs (dic(str->nn.Module)): The dictionaty that maps a predicate name to the corresponding valuation function.
"""
layers = []
vfs = {} # a dictionary: pred_name -> valuation function
v_color = YOLOColorValuationFunction()
vfs['color'] = v_color
layers.append(v_color)
v_shape = YOLOShapeValuationFunction()
vfs['shape'] = v_shape
v_in = YOLOInValuationFunction()
vfs['in'] = v_in
layers.append(v_in)
v_closeby = YOLOClosebyValuationFunction(device)
#if dataset in ['closeby', 'red-triangle']:
vfs['closeby'] = v_closeby
vfs['closeby'].load_state_dict(torch.load(
'src/weights/neural_predicates/closeby_pretrain.pt', map_location=device))
vfs['closeby'].eval()
layers.append(v_closeby)
#print('Pretrained neural predicate closeby have been loaded!')
#elif dataset == 'online-pair':
v_online = YOLOOnlineValuationFunction(device)
vfs['online'] = v_online
vfs['online'].load_state_dict(torch.load(
'src/weights/neural_predicates/online_pretrain.pt', map_location=device))
vfs['online'].eval()
layers.append(v_online)
# print('Pretrained neural predicate online have been loaded!')
return nn.ModuleList(layers), vfs
[docs] def init_attr_encodings(self, device):
"""Encode color and shape into one-hot encoding.
Args:
device (device): The device.
Returns:
attrs (dic(term->tensor)): The dictionary that maps an attribute term to the corresponding one-hot encoding.
"""
attr_names = ['color', 'shape']
attrs = {}
for dtype_name in attr_names:
for term in self.lang.get_by_dtype_name(dtype_name):
term_index = self.lang.term_index(term)
num_classes = len(self.lang.get_by_dtype_name(dtype_name))
one_hot = F.one_hot(torch.tensor(
term_index).to(device), num_classes=num_classes)
one_hot.to(device)
attrs[term] = one_hot
return attrs
[docs] def forward(self, zs, atom):
"""Convert the object-centric representation to a valuation tensor.
Args:
zs (tensor): The object-centric representaion (the output of the YOLO model).
atom (atom): The target atom to compute its proability.
Returns:
A batch of the probabilities of the target atom.
"""
if atom.pred.name in self.vfs:
args = [self.ground_to_tensor(term, zs) for term in atom.terms]
# call valuation function
return self.vfs[atom.pred.name](*args)
else:
return torch.zeros((zs.size(0), )).to(
torch.float32).to(self.device)
[docs] def ground_to_tensor(self, term, zs):
"""Ground terms into tensor representations.
Args:
term (term): The term to be grounded.
zs (tensor): The object-centric representation.
Return:
The tensor representation of the input term.
"""
term_index = self.lang.term_index(term)
if term.dtype.name == 'object':
return zs[:, term_index]
elif term.dtype.name == 'color' or term.dtype.name == 'shape':
return self.attrs[term]
elif term.dtype.name == 'image':
return None
else:
assert 0, "Invalid datatype of the given term: " + \
str(term) + ':' + term.dtype.name
[docs]class SlotAttentionValuationModule(nn.Module):
"""A module to call valuation functions.
Attrs:
lang (language): The language.
device (device): The device.
layers (list(nn.Module)): The list of valuation functions.
vfs (dic(str->nn.Module)): The dictionaty that maps a predicate name to the corresponding valuation function.
attrs (dic(term->tensor)): The dictionary that maps an attribute term to the corresponding one-hot encoding.
dataset (str): The dataset.
"""
def __init__(self, lang, device, pretrained=True):
super().__init__()
self.lang = lang
self.device = device
self.colors = ["cyan", "blue", "yellow",
"purple", "red", "green", "gray", "brown"]
self.shapes = ["sphere", "cube", "cylinder"]
self.sizes = ["large", "small"]
self.materials = ["rubber", "metal"]
self.sides = ["left", "right"]
self.layers, self.vfs = self.init_valuation_functions(
device, pretrained)
[docs] def init_valuation_functions(self, device, pretrained):
"""
Args:
device (device): The device.
pretrained (bool): The flag if the neural predicates are pretrained or not.
Retunrs:
layers (list(nn.Module)): The list of valuation functions.
vfs (dic(str->nn.Module)): The dictionaty that maps a predicate name to the corresponding valuation function.
"""
layers = []
vfs = {} # pred name -> valuation function
v_color = SlotAttentionColorValuationFunction(device)
vfs['color'] = v_color
v_shape = SlotAttentionShapeValuationFunction(device)
vfs['shape'] = v_shape
v_in = SlotAttentionInValuationFunction(device)
vfs['in'] = v_in
v_size = SlotAttentionSizeValuationFunction(device)
vfs['size'] = v_size
v_material = SlotAttentionMaterialValuationFunction(device)
vfs['material'] = v_material
v_rightside = SlotAttentionRightSideValuationFunction(device)
vfs['rightside'] = v_rightside
v_leftside = SlotAttentionLeftSideValuationFunction(device)
vfs['leftside'] = v_leftside
v_front = SlotAttentionFrontValuationFunction(device)
vfs['front'] = v_front
if pretrained:
vfs['rightside'].load_state_dict(torch.load(
'src/weights/neural_predicates/rightside_pretrain.pt', map_location=device))
vfs['rightside'].eval()
vfs['leftside'].load_state_dict(torch.load(
'src/weights/neural_predicates/leftside_pretrain.pt', map_location=device))
vfs['leftside'].eval()
vfs['front'].load_state_dict(torch.load(
'src/weights/neural_predicates/front_pretrain.pt', map_location=device))
vfs['front'].eval()
print('Pretrained neural predicates have been loaded!')
return nn.ModuleList([v_color, v_shape, v_in, v_size, v_material, v_rightside, v_leftside, v_front]), vfs
[docs] def forward(self, zs, atom):
"""Convert the object-centric representation to a valuation tensor.
Args:
zs (tensor): The object-centric representaion (the output of the YOLO model).
atom (atom): The target atom to compute its proability.
Returns:
A batch of the probabilities of the target atom.
"""
# term: logical term
# arg: vector representation of the term
# zs = self.preprocess(zs)
args = [self.ground_to_tensor(term, zs) for term in atom.terms]
# call valuation function
return self.vfs[atom.pred.name](*args)
[docs] def ground_to_tensor(self, term, zs):
"""Ground terms into tensor representations.
Args:
term (term): The term to be grounded.
zs (tensor): The object-centric representation.
"""
term_index = self.lang.term_index(term)
if term.dtype.name == 'object':
return zs[:, term_index]
elif term.dtype.name == 'image':
return None
else:
# other attributes
return self.term_to_onehot(term, batch_size=zs.size(0))
[docs] def term_to_onehot(self, term, batch_size):
"""Ground terms into tensor representations.
Args:
term (term): The term to be grounded.
zs (tensor): The object-centric representation.
Return:
The tensor representation of the input term.
"""
if term.dtype.name == 'color':
return self.to_onehot_batch(self.colors.index(term.name), len(self.colors), batch_size)
elif term.dtype.name == 'shape':
return self.to_onehot_batch(self.shapes.index(term.name), len(self.shapes), batch_size)
elif term.dtype.name == 'material':
return self.to_onehot_batch(self.materials.index(term.name), len(self.materials), batch_size)
elif term.dtype.name == 'size':
return self.to_onehot_batch(self.sizes.index(term.name), len(self.sizes), batch_size)
elif term.dtype.name == 'side':
return self.to_onehot_batch(self.sides.index(term.name), len(self.sides), batch_size)
else:
assert True, 'Invalid term: ' + str(term)
[docs] def to_onehot_batch(self, i, length, batch_size):
"""Compute the one-hot encoding that is expanded to the batch size.
"""
onehot = torch.zeros(batch_size, length, ).to(self.device)
onehot[:, i] = 1.0
return onehot