August 31, 2023
I have a few comments about this paper that tries to justify the use of LayerNorm in transformers using geometry.
The authors claim that LayerNorm maps every key to a \((d-2)\)-sphere, where \(d\) is the model dimension, and that this hypersphere is guaranteed to be orthogonal to the all-ones vector. This is true for single-headed attention. But for multiheaded attention it’s more complicated. And the authors only tested with one head.
With that said, you could get something like this working with multiple heads. You just have to do the LayerNorm after reshaping your key vectors so they include a “head” dimension. That is kind of an awkward place to do that, and I haven’t seen anyone else do it that way. But I did try adding it to my music neural net, and … maybe it kind of helped? I mean, look at this plot of the attention matrices from layer two of an earlier model.
And now look at the model with the per-head LayerNorm.
There was nothing rigorous about this! I made all sorts of changes to the model between taking these pictures! But I’m pretty sure the LayerNorm was at least a major contributor to this change. The authors of the paper talk about “unselectable” keys, and I think they’re kind of on to something there, but with one major caveat. See, a key will be “unselectable” if it’s pointing in the same general direction as other, bigger keys. Those other keys capture attention that would otherwise have gone to the unselectable one. But in that case, you’re still paying attention to really similar keys, and this doesn’t seem like a problem? I don’t think the authors made a strong case here. Their major example is a transformer learning a “majority” task, which just sounds really contrived but also trivial. It’s weird that they didn’t just pick a more conventional task and try tweaking a well-known model. If I were to speculate, I would guess normalizing the keys might help the training process somehow.
Also, I lied earlier. I’m not doing LayerNorm anymore—I just do the scaling part and don’t bother with the centering. I was really unimpressed with the authors’ idea that the “projection” effect of the centering was somehow important. The idea was that the model could align the queries with the ones vector to give uniform attention to every key. But (1) why would want that, and (2) it’s a high-dimensional space! Any two random vectors are going to be almost orthogonal! The authors tested with small model dimensions—like, 8 at most—but I’m currently over 500 with my model. So, yeah, the projection probably isn’t doing much. And I’ve seen other successful models that nix the centering step.
My recommendation is to try normalizing each head separately without centering, and see if it improves your model. But don’t expect any big changes. (In terms of quality, it didn’t feel like my model improved much, but it’s hard to tell.) Don’t give this paper too much credence though. The things it says apply to toy models only.