Alex Carlin

Inverse folding unrolled (part 2)

In this section of our series on inverse folding, we'll take a detailed look at the feature representation that we are using for our protein structures and sequences. This corresponds to Part 2 in the notebook.

Feature representations

We are going to look at several different feature representations. We will use code from a repo by John Ingraham to do this, which helpfully even includes plotting functions!

To begin, we'll use a function featurize that accepts a list of dictionaries following the format of the previous post. Briefly, each training example is a dict with keys like "coords", "seq", "name", and other features of the training set. The function featurize takes a list of examples and produces four tensors.

X, S, mask, lengths = featurize(batch, "cpu")

What are these tensors?

X a tensor of shape (18, 222, 4, 3), representing 18 examples from a batch, where the batch happens to have at least protein of length 222. To get a single example:

x = X[0]

which happens to be of shape (222, 4, 3) (that is, it's the longest protein in the batch). For each of 222 residues, we have 4 sets of 3-D coordinates, one for each of the N, C, CA, and O atoms in the backbone.

S is an integer representation of the sequences, of shape (B, L) where L is the length of the longest sequence in the batch.

S.shape 
# (18, 222)

Batches must be "square" tensors, so we pad the extra sequecnes with a padding character and then mask out those portions we do not wish the model to try to predict.

That brings us to the mask, which is of the same shape as S. Take a look at a single example of the mask, from the third example in the batch

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.])

You can see that some spans are masked, including some in the middle of the protein.

L is a vector of length B (the batch size), listing out the length of each protein. In a well-picked batch, which we attempt to do here, all the proteins are the same length

array([380, 380, 380, 380, 380, 380, 380, 380, 380, 380], dtype=int32)

We can now instantiate a ProteinFeatures object (implemented as a Torch nn.Module here) and use it to generate the actual features from our nice square batch. Let's try it.

features = ProteinFeatures(128, 128, features_type="full")
V, E, E_idx = features(X, lengths, mask) 

V.shape, E.shape, E_idx.shape 
#(torch.Size([10, 380, 128]),
# torch.Size([10, 380, 30, 128]),
# torch.Size([10, 380, 30]))

We can see that we get three tensors back. The node or vertex features V, the edge features E, and a tensor called E_idx.

Where is the number 30 coming from here? It is actually a hyperparameter of this model. It is the number of contacts/neighbors that are used to featurize each residue. Another way of looking at it is that the model is trained to predict a residue identity from the features of the 30 surrounding residues. In ProteinMPNN, this number is set to 48.

The 128 comes from the hidden dimension of our model. So for the tensor V, we have 10 (the batch size) examples of length 380, each of which is represented by 380 vectors of length 128 (the model hidden size).

For the tensor E, we have a much larger tensor. For each of 10 structures in the batch, for each of the 380 residues in each structure, we have 30 vectors of size 128. All of these vectors are used to update the node embedding.

We can see now why ProteinMPNN discards the node embedding (setting it to zeros at the first layer). There is much more information in the backbone angles of the 30 surrounding residues than the backbone angles of the single residue we are trying to predict, and the information in the current node's embedding is redundant with it.