Alex Carlin

Score entropy discrete diffusion models for protein design

I really enjoyed this Stanford course on deep generative models taught by Stefano Ermon. The teaching was super engaging, the material was interesting and deep, and I learned a lot about the field from the super clear explication and focused presentation.

There was a fantastic lecture at the very end by Aaron Lou, a PhD student in Stefano's lab, regarding Aaron's work on diffusion models for discrete data (and award-winning paper). The approach is called score entropy discrete diffusion. It's a principled diffusion model for discrete data like text, using a formulation called the concrete score. A previous paper by Ermon's lab describes the concept of a concrete score:

Representing probability distributions by the gradient of their density functions has proven effective in modeling a wide range of continuous data modalities. However, this representation is not applicable in discrete domains where the gradient is undefined. To this end, we propose an analogous score function called the "Concrete score" [...]. Given a predefined neighborhood structure, the Concrete score of any input is defined by the rate of change of the probabilities with respect to local directional changes of the input. (Emphasis mine.)

Aaron's explanation of the model made me think that it might work a lot better for protein sequences than autoregressive models do. Some of the limitations for autoregressive models are compounded with biases in protein data. Protein sequences aren't "left to right" and "one amino acid at a time" isn't the way to understand a protein. Unlike autoregressive models, SEDD directly models sequences that differ by one position to calculate the concrete score—highly analogous to the way we experimentally score protein sequences that are mutated in one position. This is a very nice parallel with the way we think about protein sequences evolving, and being engineered in the lab.

This made me think that SEDD models have the right kind of inductive biases for modeling protein sequence data, might perform very well in modeling the data distribution of protein sequences, and may be very useful for tasks we care about in protein design. Namely, a few things we care about for a protein sequence model:

You can of course play around with different decoding orders and things, but autoregressive models are fundamentally one-at-a-time models. And we know that autoregressive models perform poorly when asked to assign a relative score to two protein sequences that differ in only one token (even though in natural proteins, an arbitrary single token change can result in a dramatically low likelihood if you mutate away a special amino acid).

Experiments training score entropy discrete diffusion models for protein design

Of course, I never feel like I understand anything until I can create it. I was super curious to try this out on my favorite discrete data: protein sequences! I forked Aaron's implementation of SEDD and implemented some simple data loaders and tokenization for protein sequences and did some experiments training on an A100 on the test dataset from my protein transformers project.

It’s pretty cool, even just leaving the GPT2 tokenizer in place, within just a couple thousand steps, the model is already producing new proteins that are predicted to fold into the correct structure by ESMFold. In contrast, proteins produced by a GPT-like model at this stage of training are not well-predicted.

folded

However, in contrast to the 50,257-token vocabulary used in GPT-2, the “token space” for proteins is just the 20 amino acids plus a few special tokens for sequence start and end. For example, here's the vocab for my protein tokenizer:

{
    "<s>": 0,
    "<pad>": 1,
    "</s>": 2,
    "<unk>": 3,
    "<mask>": 4,
    "A": 5,
    "C": 6,
    "D": 7,
    "E": 8,
    "F": 9,
    "G": 10,
    "H": 11,
    "I": 12,
    "K": 13,
    "L": 14,
    "M": 15,
    "N": 16,
    "P": 17,
    "Q": 18,
    "R": 19,
    "S": 20,
    "T": 21,
    "V": 22,
    "W": 23,
    "Y": 24
}

So we need a few changes from the original implementation that are needed for modeling proteins:

  1. in run_train.py, we update the existing tokenizer for sampling to one for proteins that uses an amino acid vocabulary
  2. in data.py we update the tokenizer used when creating datasets
  3. we need a different method of evaluating generative perplexity and other useful metrics

I chose to implement tokenization for proteins by supplying a modified vocabulary to the existing GPT2TokenizerFast. I also tried creating a tokenizer class from scratch, but using the existing implementation with a modified vocabulary worked best because of the many implementation details you'd have to copy if recreating from scratch.

The new tokenizer is initialized from files vocab.json and merges.txt that are generated when the script runs.

from collections import OrderedDict
from transformers import GPT2TokenizerFast
import json

# Define amino acids and special tokens
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
all_tokens = special_tokens + amino_acids

# Create the vocabulary
vocab = OrderedDict((token, idx) for idx, token in enumerate(all_tokens))

# Save the vocabulary
with open('vocab.json', 'w') as f:
    json.dump(vocab, f)

# Create an empty merges.txt file
with open('merges.txt', 'w') as f:
    f.write('#version: 0.2\n')

# Initialize the tokenizer
tokenizer = GPT2TokenizerFast(
    vocab_file='vocab.json',
    merges_file='merges.txt',
    bos_token='<s>',
    eos_token='</s>',
    unk_token='<unk>',
    pad_token='<pad>',
    mask_token='<mask>'
)

After training the model under several configurations on the AcyP dataset (which can be specified with "acyp" as the dataset name in the config), it seems that the SEDD model has excellent performance at modeling the data distribution for these homologous sequences, and is able to generate highly convincing sequences that fold well as predicted by ESMFold.

folded2

Training on UniRef50

The next step would be to train some SEDD models of different sizes on the UniRef50 dataset and compare to the performance of autogregressive models on the same data. Some good models to compare against: ProGen, RITA, ProtT5. Some model sizes to try: 10M, 100M, 1B, 10B params.

I adapted the existing data loading code to load the UniRef50 dataset, but haven't trained it fully yet, due to the computational cost. From my experiments, it takes about 4 hours to preprocess the dataset using an A100, before training begins. Once training begins, for the tiny model on short sequences of length 128, we get about 1,000 steps of batch size 128 per minute on the A100, which is about 15 million tokens per minute. UniRef50 contains around 40 million sequences with an average length of 256, for a total of 10 billion tokens. I approximate this will take over 10 hours to train on the A100 (for the tiny model).

I've provided the code to load UniRef50 in the repo:

elif name == "uniref50":
    dataset = load_dataset("agemagician/uniref50", cache_dir=cache_dir)

You can train on UniRef50 by providing the dataset name "uniref50" in the Hydra config.

I think the connection to evolution is also much clearer for the SEDD model than autoregressive models. As I understand it, SEDD directly models sequence neighbors one “mutation” away, which is conceptually similar to the underlying data-generating process (evolution). I think Aaron’s score entropy denoising diffusion model will work really well for all kinds of biological sequences, both proteins and DNA, and I'm excited to evaluate it against protein transformer models.