import sys
import torch
import torch.nn as nn
from yolov5.models.experimental import attempt_load
from yolov5.utils.general import non_max_suppression
from slot_attention.model import SlotAttention_model
import sys
sys.path.insert(0, 'src/yolov5')
[docs]class YOLOPerceptionModule(nn.Module):
"""A perception module using YOLO.
Attrs:
e (int): The maximum number of entities.
d (int): The dimension of the object-centric vector.
device (device): The device where the model and tensors are loaded.
train (bool): The flag if the parameters are trained.
preprocess (tensor->tensor): Reshape the yolo output into the unified format of the perceptiom module.
"""
def __init__(self, e, d, device, train=False):
super().__init__()
self.e = e # num of entities
self.d = d # num of dimension
self.device = device
self.train_ = train # the parameters should be trained or not
self.model = self.load_model(
path='src/weights/yolov5/best.pt', device=device)
# function to transform e * d shape, YOLO returns class labels,
# it should be decomposed into attributes and the probabilities.
self.preprocess = YOLOPreprocess(device)
[docs] def load_model(self, path, device):
print("Loading YOLO model...")
yolo_net = attempt_load(weights=path)
yolo_net.to(device)
if not self.train_:
for param in yolo_net.parameters():
param.requires_grad = False
return yolo_net
[docs] def forward(self, imgs):
pred = self.model(imgs)[0] # yolo model returns tuple
# yolov5.utils.general.non_max_supression returns List[tensors]
# with lengh of batch size
# the number of objects can vary image to iamge
yolo_output = self.pad_result(
non_max_suppression(pred, max_det=self.e))
return self.preprocess(yolo_output)
[docs] def pad_result(self, output):
"""Padding the result by zeros.
(batch, n_obj, 6) -> (batch, n_max_obj, 6)
"""
padded_list = []
for objs in output:
if objs.size(0) < self.e:
diff = self.e - objs.size(0)
zero_tensor = torch.zeros((diff, 6)).to(self.device)
padded = torch.cat([objs, zero_tensor], dim=0)
padded_list.append(padded)
else:
padded_list.append(objs)
return torch.stack(padded_list)
[docs]class SlotAttentionPerceptionModule(nn.Module):
"""A perception module using Slot Attention.
Attrs:
e (int): The maximum number of entities.
d (int): The dimension of the object-centric vector.
device (device): The device where the model and tensors are loaded.
train (bool): The flag if the parameters are trained.
preprocess (tensor->tensor): Reshape the yolo output into the unified format of the perceptiom module.
model: The slot attention model.
"""
def __init__(self, e, d, device, train=False):
super().__init__()
self.e = e # num of entities -> n_slots=10
self.d = d # num of dimension -> encoder_hidden_channels=64
self.device = device
self.train_ = train # the parameters should be trained or not
self.model = self.load_model()
[docs] def load_model(self):
"""Load slot attention network.
"""
if self.device == torch.device('cpu'):
sa_net = SlotAttention_model(n_slots=10, n_iters=3, n_attr=18,
encoder_hidden_channels=64,
attention_hidden_channels=128, device=self.device)
log = torch.load(
"src/weights/slot_attention/best.pt", map_location=torch.device(self.device))
sa_net.load_state_dict(log['weights'], strict=True)
sa_net.to(self.device)
if not self.train_:
for param in sa_net.parameters():
param.requires_grad = False
return sa_net
else:
sa_net = SlotAttention_model(n_slots=10, n_iters=3, n_attr=18,
encoder_hidden_channels=64,
attention_hidden_channels=128, device=self.device)
log = torch.load("src/weights/slot_attention/best.pt")
sa_net.load_state_dict(log['weights'], strict=True)
sa_net.to(self.device)
if not self.train_:
for param in sa_net.parameters():
param.requires_grad = False
return sa_net
[docs] def forward(self, imgs):
return self.model(imgs)
[docs]class YOLOPreprocess(nn.Module):
"""A perception module using Slot Attention.
Attrs:
device (device): The device where the model to be loaded.
img_size (int): The size of the (resized) image to normalize the xy-coordinates.
classes (list(str)): The classes of objects.
colors (tensor(int)): The one-hot encodings of the colors (repeated 3 times).
shapes (tensor(int)): The one-hot encodings of the shapes (repeated 3 times).
"""
def __init__(self, device, img_size=128):
super().__init__()
self.device = device
self.img_size = img_size
self.classes = ['red square', 'red circle', 'red triangle',
'yellow square', 'yellow circle', 'yellow triangle',
'blue square', 'blue circle', 'blue triangle']
self.colors = torch.stack([
torch.tensor([1, 0, 0]).to(device),
torch.tensor([1, 0, 0]).to(device),
torch.tensor([1, 0, 0]).to(device),
torch.tensor([0, 1, 0]).to(device),
torch.tensor([0, 1, 0]).to(device),
torch.tensor([0, 1, 0]).to(device),
torch.tensor([0, 0, 1]).to(device),
torch.tensor([0, 0, 1]).to(device),
torch.tensor([0, 0, 1]).to(device)
])
self.shapes = torch.stack([
torch.tensor([1, 0, 0]).to(device),
torch.tensor([0, 1, 0]).to(device),
torch.tensor([0, 0, 1]).to(device),
torch.tensor([1, 0, 0]).to(device),
torch.tensor([0, 1, 0]).to(device),
torch.tensor([0, 0, 1]).to(device),
torch.tensor([1, 0, 0]).to(device),
torch.tensor([0, 1, 0]).to(device),
torch.tensor([0, 0, 1]).to(device)
])
[docs] def forward(self, x):
"""A preprocess funciton for the YOLO model. The format is: [x1, y1, x2, y2, prob, class].
Args:
x (tensor): The output of the YOLO model. The format is:
Returns:
Z (tensor): The preprocessed object-centric representation Z. The format is: [x1, y1, x2, y2, color1, color2, color3, shape1, shape2, shape3, objectness].
x1,x2,y1,y2 are normalized to [0-1].
The probability for each attribute is obtained by copying the probability of the classification of the YOLO model.
"""
batch_size = x.size(0)
obj_num = x.size(1)
object_list = []
for i in range(obj_num):
zi = x[:, i]
class_id = zi[:, -1].to(torch.int64)
color = self.colors[class_id] * zi[:, -2].unsqueeze(-1)
shape = self.shapes[class_id] * zi[:, -2].unsqueeze(-1)
xyxy = zi[:, 0:4] / self.img_size
prob = zi[:, -2].unsqueeze(-1)
obj = torch.cat([xyxy, color, shape, prob], dim=-1)
object_list.append(obj)
return torch.stack(object_list, dim=1).to(self.device)