Source code for src.slot_attention.utils

import scipy.optimize
import torch
import torch.nn.functional as F
import numpy as np
import os


[docs]def save_args(args, writer): # store args as txt file with open(os.path.join(writer.log_dir, 'args.txt'), 'w') as f: for arg in vars(args): f.write(f"\n{arg}: {getattr(args, arg)}")
[docs]def hungarian_matching(attrs, preds_attrs, verbose=0): """ Receives unordered predicted set and orders this to match the nearest GT set. :param attrs: :param preds_attrs: :param verbose: :return: """ assert attrs.shape[1] == preds_attrs.shape[1] assert attrs.shape == preds_attrs.shape from scipy.optimize import linear_sum_assignment matched_preds_attrs = preds_attrs.clone() for sample_id in range(attrs.shape[0]): # using euclidean distance cost_matrix = torch.cdist(attrs[sample_id], preds_attrs[sample_id]).detach().cpu() idx_mapping = linear_sum_assignment(cost_matrix) # convert to tuples of [(row_id, col_id)] of the cost matrix idx_mapping = [(idx_mapping[0][i], idx_mapping[1][i]) for i in range(len(idx_mapping[0]))] for i, (row_id, col_id) in enumerate(idx_mapping): matched_preds_attrs[sample_id, row_id, :] = preds_attrs[sample_id, col_id, :] if verbose: print('GT: {}'.format(attrs[sample_id])) print('Pred: {}'.format(preds_attrs[sample_id])) print('Cost Matrix: {}'.format(cost_matrix)) print('idx mapping: {}'.format(idx_mapping)) print('Matched Pred: {}'.format(matched_preds_attrs[sample_id])) print('\n') # exit() return matched_preds_attrs
[docs]def hungarian_loss(predictions, targets, thread_pool): # permute dimensions for pairwise distance computation between all slots predictions = predictions.permute(0, 2, 1) targets = targets.permute(0, 2, 1) # predictions and targets shape :: (n, c, s) predictions, targets = outer(predictions, targets) # squared_error shape :: (n, s, s) squared_error = F.smooth_l1_loss(predictions, targets.expand_as(predictions), reduction="none").mean(1) squared_error_np = squared_error.detach().cpu().numpy() indices = thread_pool.map(hungarian_loss_per_sample, squared_error_np) losses = [ sample[row_idx, col_idx].mean() for sample, (row_idx, col_idx) in zip(squared_error, indices) ] total_loss = torch.mean(torch.stack(list(losses))) return total_loss
[docs]def hungarian_loss_per_sample(sample_np): return scipy.optimize.linear_sum_assignment(sample_np)
[docs]def outer(a, b=None): """ Compute outer product between a and b (or a and a if b is not specified). """ if b is None: b = a size_a = tuple(a.size()) + (b.size()[-1],) size_b = tuple(b.size()) + (a.size()[-1],) a = a.unsqueeze(dim=-1).expand(*size_a) b = b.unsqueeze(dim=-2).expand(*size_b) return a, b