src.slot_attention package
Submodules
src.slot_attention.data module
src.slot_attention.model module
Slot attention model based on code of tkipf and the corresponding paper Locatello et al. 2020
- class src.slot_attention.model.MLP(hidden_channels)[source]
Bases:
Module
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class src.slot_attention.model.SlotAttention(num_slots, dim, iters=3, eps=1e-08, hidden_dim=128)[source]
Bases:
Module
- forward(inputs, num_slots=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class src.slot_attention.model.SlotAttention_classifier(in_channels, out_channels)[source]
Bases:
Module
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class src.slot_attention.model.SlotAttention_encoder(in_channels, hidden_channels)[source]
Bases:
Module
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class src.slot_attention.model.SlotAttention_model(n_slots, n_iters, n_attr, in_channels=3, encoder_hidden_channels=64, attention_hidden_channels=128, device='cuda')[source]
Bases:
Module
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class src.slot_attention.model.SoftPositionEmbed(hidden_size, resolution, device='cuda')[source]
Bases:
Module
Adds soft positional embedding with learnable projection.
- forward(inputs)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
src.slot_attention.preprocess-images module
src.slot_attention.train module
src.slot_attention.utils module
- src.slot_attention.utils.hungarian_matching(attrs, preds_attrs, verbose=0)[source]
Receives unordered predicted set and orders this to match the nearest GT set. :param attrs: :param preds_attrs: :param verbose: :return: