Source code for src.slot_attention.data

import os
import math
import random
import json

import torch
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import pil_loader
import torchvision.transforms.functional as T
import h5py
import numpy as np


CLASSES = {
    "shape": ["sphere", "cube", "cylinder"],
    "size": ["large", "small"],
    "material": ["rubber", "metal"],
    "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],
}


[docs]def get_loader(dataset, batch_size, num_workers=8, shuffle=True): return torch.utils.data.DataLoader( dataset, shuffle=shuffle, batch_size=batch_size, pin_memory=True, num_workers=num_workers, drop_last=True, )
[docs]class CLEVR(torch.utils.data.Dataset): def __init__(self, base_path, split): assert split in { "train", "val", "test", } # note: test isn't very useful since it doesn't have ground-truth scene information self.base_path = base_path self.split = split self.max_objects = 10 with self.img_db() as db: ids = db["image_ids"] self.image_id_to_index = {id: i for i, id in enumerate(ids)} self.image_db = None with open(self.scenes_path) as fd: scenes = json.load(fd)["scenes"] self.img_ids, self.scenes = self.prepare_scenes(scenes) self.category_dict = CLASSES
[docs] def object_to_fv(self, obj, scene_directions): # coords = position # Originally the x, y, z positions are in [-3, 3]. # We re-normalize them to [0, 1]. # coords = (obj["3d_coords"] + 3.) / 6. # from slot attention # coords = [(p +3.)/ 6. for p in position] # convert the 3d coords based on camera position # conversion from ns-vqa paper, normalization for slot attention position = [np.dot(obj['3d_coords'], scene_directions['right']), np.dot(obj['3d_coords'], scene_directions['front']), obj['3d_coords'][2]] coords = [(p +4.)/ 8. for p in position] one_hot = lambda key: [obj[key] == x for x in CLASSES[key]] material = one_hot("material") color = one_hot("color") shape = one_hot("shape") size = one_hot("size") assert sum(material) == 1 assert sum(color) == 1 assert sum(shape) == 1 assert sum(size) == 1 # concatenate all the classes # return coords + size + material + shape + color return coords + shape + size + material + color
[docs] def prepare_scenes(self, scenes_json): img_ids = [] scenes = [] for scene in scenes_json: img_idx = scene["image_index"] # different objects depending on bbox version or attribute version of CLEVR sets objects = [self.object_to_fv(obj, scene['directions']) for obj in scene["objects"]] objects = torch.FloatTensor(objects).transpose(0, 1) num_objects = objects.size(1) # pad with 0s if num_objects < self.max_objects: objects = torch.cat( [ objects, torch.zeros(objects.size(0), self.max_objects - num_objects), ], dim=1, ) # fill in masks mask = torch.zeros(self.max_objects) mask[:num_objects] = 1 # concatenate obj indication to end of object list objects = torch.cat((mask.unsqueeze(dim=0), objects), dim=0) img_ids.append(img_idx) # scenes.append((objects, mask)) scenes.append(objects.T) return img_ids, scenes
@property def images_folder(self): return os.path.join(self.base_path, "images", self.split) @property def scenes_path(self): if self.split == "test": raise ValueError("Scenes are not available for test") return os.path.join( self.base_path, "scenes", "CLEVR_{}_scenes.json".format(self.split) )
[docs] def img_db(self): path = os.path.join(self.base_path, "{}-images.h5".format(self.split)) return h5py.File(path, "r")
[docs] def load_image(self, image_id): if self.image_db is None: self.image_db = self.img_db() index = self.image_id_to_index[image_id] image = self.image_db["images"][index] image = (image - 0.5) * 2.0 # Rescale to [-1, 1]. return image
def __getitem__(self, item): image_id = self.img_ids[item] image = self.load_image(image_id) objects = self.scenes[item] return image, objects def __len__(self): return len(self.scenes)