import scipy.optimize
import torch
import torch.nn.functional as F
import numpy as np
import os
[docs]def save_args(args, writer):
# store args as txt file
with open(os.path.join(writer.log_dir, 'args.txt'), 'w') as f:
for arg in vars(args):
f.write(f"\n{arg}: {getattr(args, arg)}")
[docs]def hungarian_matching(attrs, preds_attrs, verbose=0):
"""
Receives unordered predicted set and orders this to match the nearest GT set.
:param attrs:
:param preds_attrs:
:param verbose:
:return:
"""
assert attrs.shape[1] == preds_attrs.shape[1]
assert attrs.shape == preds_attrs.shape
from scipy.optimize import linear_sum_assignment
matched_preds_attrs = preds_attrs.clone()
for sample_id in range(attrs.shape[0]):
# using euclidean distance
cost_matrix = torch.cdist(attrs[sample_id], preds_attrs[sample_id]).detach().cpu()
idx_mapping = linear_sum_assignment(cost_matrix)
# convert to tuples of [(row_id, col_id)] of the cost matrix
idx_mapping = [(idx_mapping[0][i], idx_mapping[1][i]) for i in range(len(idx_mapping[0]))]
for i, (row_id, col_id) in enumerate(idx_mapping):
matched_preds_attrs[sample_id, row_id, :] = preds_attrs[sample_id, col_id, :]
if verbose:
print('GT: {}'.format(attrs[sample_id]))
print('Pred: {}'.format(preds_attrs[sample_id]))
print('Cost Matrix: {}'.format(cost_matrix))
print('idx mapping: {}'.format(idx_mapping))
print('Matched Pred: {}'.format(matched_preds_attrs[sample_id]))
print('\n')
# exit()
return matched_preds_attrs
[docs]def hungarian_loss(predictions, targets, thread_pool):
# permute dimensions for pairwise distance computation between all slots
predictions = predictions.permute(0, 2, 1)
targets = targets.permute(0, 2, 1)
# predictions and targets shape :: (n, c, s)
predictions, targets = outer(predictions, targets)
# squared_error shape :: (n, s, s)
squared_error = F.smooth_l1_loss(predictions, targets.expand_as(predictions), reduction="none").mean(1)
squared_error_np = squared_error.detach().cpu().numpy()
indices = thread_pool.map(hungarian_loss_per_sample, squared_error_np)
losses = [
sample[row_idx, col_idx].mean()
for sample, (row_idx, col_idx) in zip(squared_error, indices)
]
total_loss = torch.mean(torch.stack(list(losses)))
return total_loss
[docs]def hungarian_loss_per_sample(sample_np):
return scipy.optimize.linear_sum_assignment(sample_np)
[docs]def outer(a, b=None):
""" Compute outer product between a and b (or a and a if b is not specified). """
if b is None:
b = a
size_a = tuple(a.size()) + (b.size()[-1],)
size_b = tuple(b.size()) + (a.size()[-1],)
a = a.unsqueeze(dim=-1).expand(*size_a)
b = b.unsqueeze(dim=-2).expand(*size_b)
return a, b