src.slot_attention package

Submodules

src.slot_attention.data module

class src.slot_attention.data.CLEVR(base_path, split)[source]

Bases: Dataset

property images_folder
img_db()[source]
load_image(image_id)[source]
object_to_fv(obj, scene_directions)[source]
prepare_scenes(scenes_json)[source]
property scenes_path
src.slot_attention.data.get_loader(dataset, batch_size, num_workers=8, shuffle=True)[source]

src.slot_attention.model module

Slot attention model based on code of tkipf and the corresponding paper Locatello et al. 2020

class src.slot_attention.model.MLP(hidden_channels)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class src.slot_attention.model.SlotAttention(num_slots, dim, iters=3, eps=1e-08, hidden_dim=128)[source]

Bases: Module

forward(inputs, num_slots=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class src.slot_attention.model.SlotAttention_classifier(in_channels, out_channels)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class src.slot_attention.model.SlotAttention_encoder(in_channels, hidden_channels)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class src.slot_attention.model.SlotAttention_model(n_slots, n_iters, n_attr, in_channels=3, encoder_hidden_channels=64, attention_hidden_channels=128, device='cuda')[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class src.slot_attention.model.SoftPositionEmbed(hidden_size, resolution, device='cuda')[source]

Bases: Module

Adds soft positional embedding with learnable projection.

forward(inputs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
src.slot_attention.model.build_grid(resolution)[source]

src.slot_attention.preprocess-images module

src.slot_attention.train module

src.slot_attention.utils module

src.slot_attention.utils.hungarian_loss(predictions, targets, thread_pool)[source]
src.slot_attention.utils.hungarian_loss_per_sample(sample_np)[source]
src.slot_attention.utils.hungarian_matching(attrs, preds_attrs, verbose=0)[source]

Receives unordered predicted set and orders this to match the nearest GT set. :param attrs: :param preds_attrs: :param verbose: :return:

src.slot_attention.utils.outer(a, b=None)[source]

Compute outer product between a and b (or a and a if b is not specified).

src.slot_attention.utils.save_args(args, writer)[source]

Module contents