Alex Carlin

Inverse folding unrolled (part 1)

In this series of posts and accompanying Colab notebook, we'll explore a problem in protein AI known as inverse folding. This is the problem of designing a sequence of amino acids to fit a given "fold", or precise 3-D structure of a protein. For this problem, we will be training a transformer model to predict the amino acid identity based on the surrounding protein structure!

Conceptually, this problem is similar to any NLP problem. However, this task differs from the task of training a protein LLM, such as we did in a previous post, in that the structure transformer is provided the protein structure, not just the sequence.

Also, in terms of training data, there are only about 200,000 total known protein structures, so the available training data is severely limited. If that sounds daunting, that's because it is! In this series, we'll explore the use of transformer models that learn to design protein sequences using existing protein structures as training data.

The provided Google Colab notebook contains every bit of code we will need, including a full-featured transformer implementation so you can see that there is nothing scary going on under the hood. Overall, it is a couple hundred lines of basic PyTorch to build the structure transformer model and deal with the training data. Let's dive in!

Part 1: The dataset

We'll begin with the most important part of any machine learning challenge. The dataset. Conceptually, our algorithm is learning to predict a given amino acid buried in a protein by looking at the surrounding structure, which is the "message" that is transformed into a particular amino acid's embedding by the decoder. So what are the proteins that it is learning from? How diverse are they? How do we effectively partition them into training, validation, and test data?

Dataset contents

The RCSB PDB database is the gold standard structural database for biology. Today, in mid-2023, the PDB database contains 209,159 structures, 9,503 of which have been deposited this year.

However, there are some problems with training on the whole PDB. First off, the structural diversity is highly redundant for some protein folds and scarce for others, leading to unnatural biases. But the main problem is going to be: how are we going to tell our model is learning? We will need a set of proteins that is not just sequence diverse, but structurally diverse. In fact, what we want is a set of protein structures that don't appear in the training set at all.

So how do we solve this problem? We use the CATH dataset used by others in the field exploring the inverse folding problem.

Reading and preprocessing

To read in the dataset, we will read in a pre-existing dataset used by Ingraham as well as ESM inverse folding. So this is the real deal, containing all PDB structures split into different categories by structure. You can learn more about CATH on the CATH database site.

Looking at single examples

It's a great idea to have a look at single examples of your training data, and in fact to try to classify some examples yourself. In this case, each example is a dict:

>>> example["seq"]
MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPIL ...

>>> example["coords"].keys()
dict_keys(['N', 'CA', 'C', 'O'])

>>> example["coords"]['CA'][:10]
array([[   nan,    nan,    nan],
       [   nan,    nan,    nan],
       [   nan,    nan,    nan],
       [12.501, 39.048, 28.539],
       [15.552, 39.41 , 26.282],
       [17.791, 39.281, 29.375],
       [16.004, 36.186, 30.742],
       [16.137, 34.313, 27.425],
       [19.794, 35.327, 26.885],
       [20.706, 33.656, 30.141]])

We can see that the structure representation (which will be the thing the model sees), consists of 3-D coordinates for each of 4 atoms in each of L protein residues (where L = len(example["coords"]["CA"]), making 4 arrays of shape (L, 3). We'll see in the code later, we stack these arrays to make batches of dimension (B, L, 4, 3) where B is the number of structures in the batch.

Another good chance to think about what we are asking the model to do. We will provide the model a big stack of coordinates, say from 48 surrounding residues, and ask it to encode those coords into a hidden representation (a vector of our model dimension). Then decode that vector into a vector of probabilities over the 20 amino acids. The model has a very limited vocabulary of tokens to choose from, just the 20 amino acids.

Dataset summary for these experiments

For these experiments, we’ll make use of the dataset used by John Ingraham, who I believe was the first to formulate the protein sequence design problem as a language modeling problem and developed the protein structure graph transformer we’ll be exploring later.

Probably the best thing you could do to improve upon this approach is to take a careful look at this training data. Examine the partitioning, see if there is information about the test set potentially “leaked” in the training set. However, for our purposes we are simply interested in understanding how the whole process works end-to-end, so we’ll use the existing dataset.

In the next post, we'll discuss the feature representation we'll use to present our protein structures to the model