Alex Carlin

Simple, hackable transformers for protein design

I was inspired by Andrej Karpathy's makemore to provide a simple, hackable implementation of protein transformer models for education and research.

Transformer models rapidly advanced the state of the art in protein design since their introduction in 2017. A brief overview of some current models can be found in another post. Here, I provide a complete, spelled-out GPT-2 style protein transformer model that is fast and flexible, designed to be easy to hack on and experiment with.

Install

The main repository on GitHub

After cloning the repo, install the environment with Conda

conda create --name proteins python=3.11 
conda activate proteins 
pip install tensorboard torch biotite 

You don't need to install anything to run the transformer models, simply

python train.py --help 

to see all options. To get started training, carry on below.

Training on a homologous sequence family

As an example, we can use sequence homologs from the dihydrofolate reductase family to train a protein transformer capable of designing new enzymes that fold and function the same way as the proteins in the training set. DHFR is a small enzyme that I picked because the sequence is short, it contains a well-studied nucleotide binding motif for its cofactor NADPH, and the family contains many hundreds of thousands of examples in public datasets.

To run the training for the provided DHFR dataset:

python main.py -i data/dhfr.fa -o dhfr

Once the model is trained, and you would like to sample (data will be written in FASTA format):

python main.py -o dhfr --sample-only 

You will be sampling new proteins that are based on those in your training set.

How does it work?

We are using a fairly standard transformer model and training on a next token task. Let's walk through the various parts of the code.

First let's get an idea of the different files

.
โ”œโ”€โ”€ EXPERIMENTS.md
โ”œโ”€โ”€ data.py
โ”œโ”€โ”€ model.py
โ”œโ”€โ”€ train.py
โ””โ”€โ”€ validation_with_experiment.ipynb

In data.py, we have code for loading the dataset and creating the samples for the model. The transformer architecture is completely spelled out in model.py. The script train.py ties the model and data together and also handles the input options.

Data loading

The dataset in this case consists of protein sequences. Protein sequences are strings using the alphabet of the amino acids: 20 characters represented by single letters ACDEFGHIKLMNPQRSTVWY. Protein sequences tend to be about 256 residues in length in nature, across all protein sequences from all organisms (this is just a personal observation I've made from experience), but the protein sequences in our dataset are a bit shorter than that, with 256 actually being the maximum here.

In data.py, we construct a Torch Dataset object, called ProteinDataset, and produce samples where every element of x is shifted over one in y

    def __getitem__(self, idx):
        word = self.proteins[idx]
        ix = self.encode(word)
        x = torch.zeros(self.max_word_length + 1, dtype=torch.long)
        y = torch.zeros(self.max_word_length + 1, dtype=torch.long)
        x[1:1 + len(ix)] = ix
        y[:len(ix)] = ix
        y[len(ix) + 1:] = -1 # index -1 will mask the loss at the inactive locations
        return x, y

We use 0 as the start token for our own convenience, and we don't bother defining tokens for end of sequence or unknown tokens. Perhaps we should! But to keep things simple, we use the above.

We then provide a function create_datasets that will accept a dataset in the form of a FASTA file, and take the necessary steps to turn this into a training set, validation set, and test set that are formatted as above and can be used with the model.

The transformer model

Conceptually, the default model is a decoder-only transformer of the same architecture as GPT-2, trained on next token prediction. This is a nice clean starting point for us, but please note that the best-performing large protein transformer models (ESM family from Meta AI) adopt a BERT-like architecture and masked language modeling objective.

In model.py, we define the Transformer model. The forward pass, defined below, implicitly shows the overall architecture:

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # shape (1, t, n_embd)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return logits, loss