"""
Slot attention model based on code of tkipf and the corresponding paper Locatello et al. 2020
"""
from torch import nn
import torch
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
from torchsummary import summary
[docs]class SlotAttention(nn.Module):
def __init__(self, num_slots, dim, iters=3, eps=1e-8, hidden_dim=128):
super().__init__()
self.num_slots = num_slots
self.iters = iters
self.eps = eps
self.scale = dim ** -0.5
self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))
self.slots_log_sigma = nn.Parameter(torch.randn(1, 1, dim))
self.project_q = nn.Linear(dim, dim)
self.project_k = nn.Linear(dim, dim)
self.project_v = nn.Linear(dim, dim)
self.gru = nn.GRUCell(dim, dim)
hidden_dim = max(dim, hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, dim)
)
self.norm_inputs = nn.LayerNorm(dim, eps=1e-05)
self.norm_slots = nn.LayerNorm(dim, eps=1e-05)
self.norm_mlp = nn.LayerNorm(dim, eps=1e-05)
[docs] def forward(self, inputs, num_slots=None):
b, n, d = inputs.shape
n_s = num_slots if num_slots is not None else self.num_slots
mu = self.slots_mu.expand(b, n_s, -1)
sigma = self.slots_log_sigma.expand(b, n_s, -1)
slots = torch.normal(mu, sigma)
inputs = self.norm_inputs(inputs)
k, v = self.project_k(inputs), self.project_v(inputs)
for _ in range(self.iters):
slots_prev = slots
slots = self.norm_slots(slots)
q = self.project_q(slots)
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
attn = dots.softmax(dim=1) + self.eps
attn = attn / attn.sum(dim=-1, keepdim=True)
updates = torch.einsum('bjd,bij->bid', v, attn)
slots = self.gru(
updates.reshape(-1, d),
slots_prev.reshape(-1, d)
)
slots = slots.reshape(b, -1, d)
slots = slots + self.mlp(self.norm_mlp(slots))
return slots
[docs]class SlotAttention_encoder(nn.Module):
def __init__(self, in_channels, hidden_channels):
super(SlotAttention_encoder, self).__init__()
self.network = nn.Sequential(
nn.Conv2d(in_channels, hidden_channels,
(5, 5), stride=(1, 1), padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_channels, hidden_channels,
(5, 5), stride=(2, 2), padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_channels, hidden_channels,
(5, 5), stride=(2, 2), padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_channels, hidden_channels,
(5, 5), stride=(1, 1), padding=2),
nn.ReLU(inplace=True),
)
[docs] def forward(self, x):
return self.network(x)
[docs]class MLP(nn.Module):
def __init__(self, hidden_channels):
super(MLP, self).__init__()
self.network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.ReLU(inplace=True),
nn.Linear(hidden_channels, hidden_channels),
)
[docs] def forward(self, x):
return self.network(x)
[docs]def build_grid(resolution):
ranges = [np.linspace(0., 1., num=res) for res in resolution]
grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
grid = np.stack(grid, axis=-1)
grid = np.reshape(grid, [resolution[0], resolution[1], -1])
grid = np.expand_dims(grid, axis=0)
grid = grid.astype(np.float32)
return np.concatenate([grid, 1.0 - grid], axis=-1)
[docs]class SoftPositionEmbed(nn.Module):
"""Adds soft positional embedding with learnable projection."""
def __init__(self, hidden_size, resolution, device="cuda"):
"""Builds the soft position embedding layer.
Args:
hidden_size: Size of input feature dimension.
resolution: Tuple of integers specifying width and height of grid.
"""
super().__init__()
self.dense = nn.Linear(4, hidden_size)
self.grid = torch.FloatTensor(build_grid(resolution))
self.grid = self.grid.to(device)
self.resolution = resolution[0]
self.hidden_size = hidden_size
[docs] def forward(self, inputs):
return inputs + self.dense(self.grid).view((-1, self.hidden_size, self.resolution, self.resolution))
[docs]class SlotAttention_classifier(nn.Module):
def __init__(self, in_channels, out_channels):
super(SlotAttention_classifier, self).__init__()
self.network = nn.Sequential(
nn.Linear(in_channels, in_channels),
nn.ReLU(inplace=True),
nn.Linear(in_channels, out_channels),
nn.Sigmoid()
)
[docs] def forward(self, x):
return self.network(x)
[docs]class SlotAttention_model(nn.Module):
def __init__(self, n_slots, n_iters, n_attr,
in_channels=3,
encoder_hidden_channels=64,
attention_hidden_channels=128,
device="cuda"):
super(SlotAttention_model, self).__init__()
self.n_slots = n_slots
self.n_iters = n_iters
self.n_attr = n_attr
self.n_attr = n_attr + 1 # additional slot to indicate if it is a object or empty slot
self.device = device
self.encoder_cnn = SlotAttention_encoder(
in_channels=in_channels, hidden_channels=encoder_hidden_channels)
self.encoder_pos = SoftPositionEmbed(
encoder_hidden_channels, (32, 32), device=device)
self.layer_norm = nn.LayerNorm(encoder_hidden_channels, eps=1e-05)
self.mlp = MLP(hidden_channels=encoder_hidden_channels)
self.slot_attention = SlotAttention(num_slots=n_slots, dim=encoder_hidden_channels, iters=n_iters, eps=1e-8,
hidden_dim=attention_hidden_channels)
self.mlp_classifier = SlotAttention_classifier(
in_channels=encoder_hidden_channels, out_channels=self.n_attr)
[docs] def forward(self, x):
x = self.encoder_cnn(x)
x = self.encoder_pos(x)
x = torch.flatten(x, start_dim=2)
x = x.permute(0, 2, 1)
x = self.layer_norm(x)
x = self.mlp(x)
x = self.slot_attention(x)
x = self.mlp_classifier(x)
return x
if __name__ == "__main__":
x = torch.rand(1, 3, 128, 128)
net = SlotAttention_model(n_slots=10, n_iters=3, n_attr=18,
encoder_hidden_channels=64, attention_hidden_channels=128)
output = net(x)
print(output.shape)
summary(net, (3, 128, 128))