Goal. Quantify whether a single latent unit cleanly corresponds to a
concept.
Components.
- Feature capacity — accuracy of the best single latent for the concept.
- Local disentanglement — remove that latent and re-evaluate;
monosemantic concepts should drop toward chance.
- Global disentanglement — track marginal gains when adding more latents;
truly monosemantic concepts should not need backup.
Procedure (decision-tree based).
- Train a shallow decision tree on SAE latents to localize the most informative latent
(root node).
- Record accuracy as you include top-k features along the tree path; compute marginal
gains (global).
- Retrain while excluding the top feature(s) to measure the drop (local).
Local disentanglement.
\[
\mathrm{FMS}_{\text{local}@p} \;=\; 2 \times \big(\,\mathrm{accs}_{0} - \mathrm{accs}_{p}\,\big)
\]
Global disentanglement.
\[
A(n) \;=\; \sum_{i=1}^{n}\big(\mathrm{accs\_cum}_{i} - \mathrm{accs}_{0}\big), \qquad
\mathrm{FMS}_{\text{global}} \;=\; 1 - \frac{A(n)}{n}
\]
Overall score.
\[
\mathrm{FMS}@p \;=\; \frac{1}{|C|}\sum_{i=1}^{|C|}
\mathrm{accs}^{c_i}_{0}\;\times\;
\frac{\mathrm{FMS}^{c_i}_{\text{local}@p} + \mathrm{FMS}^{c_i}_{\text{global}}}{2}
\]
Features are ranked by a Gini tree; the root gives accs0, the path
gives accs_cum, and iterative retraining with roots removed estimates locality.
Idea. Reserve a small set of latent indices for labeled concepts and condition them during training
so each index becomes monosemantic by design.
- Encoder activations: Sigmoid(Top‑K) to obtain sparse, interpretable [0,1] latents.
- Conditioning loss: Binary cross-entropy on reserved indices;
if concept \(c\) is present, drive latent \(f_{j(c)}\) toward 1, else toward 0.
- Detection: inspect index \(j(c)\) directly at inference.
- Steering: use the decoder column \(D_i\) as a steering vector; modify the residual stream
\(\hat{\mathbf x} = \mathbf x + \alpha \times \sum\nolimits_{i=0}^c \left(\beta_i \times \gamma_i \times
D_{\cdot,i}\right)\;\) with steering strength \(\alpha\), normalization factor \(\beta\), and balancing term \(\gamma\).
Architecture.
\[
\operatorname{SAE}(x) = D(\sigma(E(x))),\quad
E(x) = W_{\text{enc}}x + b_{\text{enc}} = h,\quad
D(f) = W_{\text{dec}}f + b_{\text{dec}} = \hat x,\quad
\sigma(h) = \mathrm{Sigmoid}(\mathrm{TopK}(h)) = f
\]
Losses. Normalized MSE reconstruction and BCE conditioning on a reserved block
\( f[0{:}c]=(f_0,\dots,f_c) \):
\[
\mathcal{L}_r = \frac{\lVert \hat x - x \rVert^2}{\lVert x \rVert^2}, \qquad
\mathcal{L}_c = \mathrm{BCE}(f[0{:}c], y)
= -\frac{1}{c+1}\sum_{i=0}^{c}\big(y_i \log f_i + (1-y_i)\log(1-f_i)\big), \qquad
\mathcal{L}_{\text{total}} = \mathcal{L}_r + \mathcal{L}_c
\]
Detection & Steering. For concept \(i\), decoder column \(D_{\cdot,i}\in\mathbb{R}^d\) is the
steering direction. Normalize and combine with
\[
\beta_i = \frac{\lVert x \rVert_2}{\lVert D_{\cdot,i} \rVert_2}, \qquad
\gamma_i \in \{1,\, f_i,\, 1-f_i\}, \qquad
\hat x = x + \alpha \sum_{i=0}^{c} \big( \beta_i\, \gamma_i\, D_{\cdot,i} \big)
\]