Source code for src.slot_attention.model

"""
Slot attention model based on code of tkipf and the corresponding paper Locatello et al. 2020
"""
from torch import nn
import torch
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
from torchsummary import summary


[docs]class SlotAttention(nn.Module): def __init__(self, num_slots, dim, iters=3, eps=1e-8, hidden_dim=128): super().__init__() self.num_slots = num_slots self.iters = iters self.eps = eps self.scale = dim ** -0.5 self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) self.slots_log_sigma = nn.Parameter(torch.randn(1, 1, dim)) self.project_q = nn.Linear(dim, dim) self.project_k = nn.Linear(dim, dim) self.project_v = nn.Linear(dim, dim) self.gru = nn.GRUCell(dim, dim) hidden_dim = max(dim, hidden_dim) self.mlp = nn.Sequential( nn.Linear(dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, dim) ) self.norm_inputs = nn.LayerNorm(dim, eps=1e-05) self.norm_slots = nn.LayerNorm(dim, eps=1e-05) self.norm_mlp = nn.LayerNorm(dim, eps=1e-05)
[docs] def forward(self, inputs, num_slots=None): b, n, d = inputs.shape n_s = num_slots if num_slots is not None else self.num_slots mu = self.slots_mu.expand(b, n_s, -1) sigma = self.slots_log_sigma.expand(b, n_s, -1) slots = torch.normal(mu, sigma) inputs = self.norm_inputs(inputs) k, v = self.project_k(inputs), self.project_v(inputs) for _ in range(self.iters): slots_prev = slots slots = self.norm_slots(slots) q = self.project_q(slots) dots = torch.einsum('bid,bjd->bij', q, k) * self.scale attn = dots.softmax(dim=1) + self.eps attn = attn / attn.sum(dim=-1, keepdim=True) updates = torch.einsum('bjd,bij->bid', v, attn) slots = self.gru( updates.reshape(-1, d), slots_prev.reshape(-1, d) ) slots = slots.reshape(b, -1, d) slots = slots + self.mlp(self.norm_mlp(slots)) return slots
[docs]class SlotAttention_encoder(nn.Module): def __init__(self, in_channels, hidden_channels): super(SlotAttention_encoder, self).__init__() self.network = nn.Sequential( nn.Conv2d(in_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), nn.ReLU(inplace=True), nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2), nn.ReLU(inplace=True), nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2), nn.ReLU(inplace=True), nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), nn.ReLU(inplace=True), )
[docs] def forward(self, x): return self.network(x)
[docs]class MLP(nn.Module): def __init__(self, hidden_channels): super(MLP, self).__init__() self.network = nn.Sequential( nn.Linear(hidden_channels, hidden_channels), nn.ReLU(inplace=True), nn.Linear(hidden_channels, hidden_channels), )
[docs] def forward(self, x): return self.network(x)
[docs]def build_grid(resolution): ranges = [np.linspace(0., 1., num=res) for res in resolution] grid = np.meshgrid(*ranges, sparse=False, indexing="ij") grid = np.stack(grid, axis=-1) grid = np.reshape(grid, [resolution[0], resolution[1], -1]) grid = np.expand_dims(grid, axis=0) grid = grid.astype(np.float32) return np.concatenate([grid, 1.0 - grid], axis=-1)
[docs]class SoftPositionEmbed(nn.Module): """Adds soft positional embedding with learnable projection.""" def __init__(self, hidden_size, resolution, device="cuda"): """Builds the soft position embedding layer. Args: hidden_size: Size of input feature dimension. resolution: Tuple of integers specifying width and height of grid. """ super().__init__() self.dense = nn.Linear(4, hidden_size) self.grid = torch.FloatTensor(build_grid(resolution)) self.grid = self.grid.to(device) self.resolution = resolution[0] self.hidden_size = hidden_size
[docs] def forward(self, inputs): return inputs + self.dense(self.grid).view((-1, self.hidden_size, self.resolution, self.resolution))
[docs]class SlotAttention_classifier(nn.Module): def __init__(self, in_channels, out_channels): super(SlotAttention_classifier, self).__init__() self.network = nn.Sequential( nn.Linear(in_channels, in_channels), nn.ReLU(inplace=True), nn.Linear(in_channels, out_channels), nn.Sigmoid() )
[docs] def forward(self, x): return self.network(x)
[docs]class SlotAttention_model(nn.Module): def __init__(self, n_slots, n_iters, n_attr, in_channels=3, encoder_hidden_channels=64, attention_hidden_channels=128, device="cuda"): super(SlotAttention_model, self).__init__() self.n_slots = n_slots self.n_iters = n_iters self.n_attr = n_attr self.n_attr = n_attr + 1 # additional slot to indicate if it is a object or empty slot self.device = device self.encoder_cnn = SlotAttention_encoder( in_channels=in_channels, hidden_channels=encoder_hidden_channels) self.encoder_pos = SoftPositionEmbed( encoder_hidden_channels, (32, 32), device=device) self.layer_norm = nn.LayerNorm(encoder_hidden_channels, eps=1e-05) self.mlp = MLP(hidden_channels=encoder_hidden_channels) self.slot_attention = SlotAttention(num_slots=n_slots, dim=encoder_hidden_channels, iters=n_iters, eps=1e-8, hidden_dim=attention_hidden_channels) self.mlp_classifier = SlotAttention_classifier( in_channels=encoder_hidden_channels, out_channels=self.n_attr)
[docs] def forward(self, x): x = self.encoder_cnn(x) x = self.encoder_pos(x) x = torch.flatten(x, start_dim=2) x = x.permute(0, 2, 1) x = self.layer_norm(x) x = self.mlp(x) x = self.slot_attention(x) x = self.mlp_classifier(x) return x
if __name__ == "__main__": x = torch.rand(1, 3, 128, 128) net = SlotAttention_model(n_slots=10, n_iters=3, n_attr=18, encoder_hidden_channels=64, attention_hidden_channels=128) output = net(x) print(output.shape) summary(net, (3, 128, 128))