Using generative ML to design native-looking genes
Once you have designed a protein, you must express it in a host organism in order to test its function. In most bio labs, the way this is done is pretty cool. The central dogma of biology states that DNA is transcribed to RNA (by a polymerase), which is translated (by ribosomes) into proteins. So in order to have the cell make the protein you have designed, you actually have the write the protein sequence as well as expression instructions (such as how much protein to make, or only to make it in response to a signal molecule, to name two common instructions).
The sequence of instructions appears in line with the sequence. So, for example, if you want to make protein A, you might write out some DNA that could be expressed schematically as
instructions coding sequence more instructions
|-----------------|-----------------------|-------------------|
ATCGCGC CGCGCGC ATG GCG AGA CGC GAT TAA TAA GCGCGACGG
where the "coding sequence" is a sequence of codons that translate to the protein of interest. A codon is a sequence of three nucleotide bases, for example ATG or CCC. There are only 4 possible bases in biology: A, T, C, and G. As a biologist, I am obligated to point out that all life on Earth uses the same four bases to write and read its DNA. Every single living thing on earth uses the same code. I wrote the coding sequence above with spaces between the codons to emphasize them, but DNA sequences have no spaces or separators between the different logical parts.
In this schematic, the sequences are extremely short, but in most synthetic biology systems, the first set of instructions (before the coding sequence) can be thousands of bases. The coding sequences themselves tend to be between 300 and 2,000 bases long. And the second set of instructions, at the end, can also be thousands of bases. As such, a complete sequence for expressing a gene can be many thousands of bases long, just a long string of ATCAGGGATCGAGGAAAAAAACTAGCTGATTATG.
Basics of heterologous expression
First, we'll talk about using a bacterial host. Escherichia coli (E. coli), is the commonest host used in protein design research since it is so simple to work with and is so well studied. In order to get a cell to express our gene, we need to do some work first.
What is the core problem in designing genes?
We need to "back-translate" our protein sequence into codons. For each amino acid (of which there are 20), there exist 1–6 different codons that may code for that amino acid in mRNA. Therefore, for each amino acid, we must choose a codon.
What are some easy ways to solve this?
One very simple way to create new genes is the following. First, get a set of existing genes. For each one, validate that the length is a multiple of 3 so it can be split into codons. Usually I'd only keep things that start with ATG. Then break all the genes up into codons and then count the codons. You'll get a dictionary like this with 64 entries, one for each possible codon
{
"ATG": 71032,
"TCT": 68721,
"GCT": 80207,
"GTG": 80750,
....
"TGG": 37691,
"TAG": 2504,
"TAA": 2716,
"TTA": 5531,
"TGA": 1228
}
This happens to be real data from the built-in Yarrowia lipolytica position-independent codon model in my open source codon optimization package, Espresso, which you can find on GitHub.
Then, to design a gene, you normalize the counts to frequencies and use those to sample the valid codons for a particular amino acid.
To normalize the frequencies for example, the amino acid isoleucine (Ile, I) can be coded for by three different codons: ATT (count: 70,891), ATA (count: 6,669), and ATC (count: 76,798). Adding up all those counts and dividing by the total number (which happens to be 154,358), we get the following data for isoleucine
codon count frequency
ATT 70891 0.46
ATA 6669 0.04
ATC 76798 0.50
To sample from this distribution, you can do something like
from random import choices
choices(["ATT", "ATA", "ATC"], weights=[0.46, 0.04, 0.50])
If you're interested in an optimized implementation of this, see my post introducing the Espresso project that follows this one.
How can we frame this as a machine learning problem?
There are a number of ways to frame this problem using ML approaches. For example, we could train a bi-gram model, or more generally a neural network that accepts some input context and chooses a codon for a specific position. One way to think about the task is as a language modeling or natural language translation task, where one sequence of tokens is transformed into another.
We frame the problem as translating a sequence from a series of protein residue tokens (n=20 amino acids) to a series of codon tokens (n=64, the 64 possible permutations of the genetic code ATCG).
Of course, in biology, translation is almost exactly the opposite problem—creating a functional protein from a coding sequence that has been transcribed to mRNA. Thus this problem has both conceptual simplicity in the problem domain and hilarious terminological conflict in the application domain. As someone who often sits at the intersection of computers and biology, I appreciate the irony.
To improve on the current state of the art strategy for gene design, which is offered by all the major vendors such as Twist, GenScript, and IDT, I explored a deep learning approach to this problem. A full description of how I approached the problem can be found in a previous post.
At the same time I was working on this, at least two papers have come out on the subject! While I didn't know about these papers during my research project, I am happy to see that the overall strategy and results are about the same, proving once again how powerful and flexible the transformer architecture is for solving biological problems. The two papers I read most recently are:
- "Deep learning-based codon optimization with large-scale synonymous variant datasets enables generalized tunable protein expression", which is super cool because they train both a base model as well as a fine-tuned model based on expression data!
- "ICOR: Improving codon optimization with recurrent neural networks", which uses a biodirectional LSTM instead of a transformer model
Basically the game is this. We define a vocabulary of the 64 codons (AAA, AAC, …) and a vocabulary of protein letters (A, C, D, ...), and we train a transformer to translate from the protein representation to the codon representation.
Codon transformer model implementation
We can implement this model in a very straightforward way using the Transformer
module in PyTorch 2. I have previously constructed a dataset of codon usage from fungal organisms and uploaded this to Hugging Face. The readme describes the bioinformatics workflows for the dataset construction.
To get started, you can simply use a Google Colab notebook.
import math
from random import shuffle
from typing import Iterable, List
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
from biotite.sequence.io.fasta import FastaFile
from biotite.sequence import NucleotideSequence
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.nn import Transformer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from timeit import default_timer as timer
from rich.progress import track
from rich import print
from datasets import load_dataset
We'll first construct the tokenization for the protein and codon data. Here, we're using the TorchText framework as a nod to the similarity to natural language translation, of which this is simplified case where the input and output sequences are always of the same length. We'll use both beginning and end of sequence tokens here.
protein_vocab = list("ACDEFGHIKLMNPQRSTVWY*") # ["A", "C", "D", "E", ...]
codon_vocab = list("".join(x) for x in product("ATCG", repeat=3)) # ["AAA", "AAC", ...]
SRC_LANGUAGE = "protein"
TGT_LANGUAGE = "codons"
# Placeholders for transforms
token_transform = {}
vocab_transform = {}
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
# first the protein vocab!
vocab_transform[SRC_LANGUAGE] = build_vocab_from_iterator([protein_vocab], min_freq=1, specials=special_symbols, special_first=True)
# now the codon vocab!
vocab_transform[TGT_LANGUAGE] = build_vocab_from_iterator([codon_vocab], min_freq=1, specials=special_symbols, special_first=True)
for lang in [SRC_LANGUAGE, TGT_LANGUAGE]:
vocab_transform[lang].set_default_index(UNK_IDX)
# now the token transforms
token_transform[SRC_LANGUAGE] = lambda x: x.split(" ")
token_transform[TGT_LANGUAGE] = lambda x: x.split(" ")
Playing with a few examples to see if the code behaves as we expect.
# example, transform a protein sequence into tokens
vocab_transform[SRC_LANGUAGE](token_transform[SRC_LANGUAGE]('M S E N *'))
Provides the output [15, 20, 8, 16, 4]
, which makes sense. Similar for codons
# example, get tokens for a codon sequence (coding sequence)
vocab_transform[TGT_LANGUAGE](token_transform[TGT_LANGUAGE]('ATG AGT GGG AAA ATG ATG ATG GCC CGA CCC ATA'))
which outputs [18, 15, 46, 4, 18, 18, 18, 41, 28, 25, 16]
.
So now we can tokenize protein and codon sequences. Let's define a series of transforms for taking in data and providing it to the model.
# helper function to club together sequential operations
def sequential_transforms(*transforms):
def func(txt_input):
for transform in transforms:
txt_input = transform(txt_input)
return txt_input
return func
# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
return torch.cat((torch.tensor([BOS_IDX]),
torch.tensor(token_ids),
torch.tensor([EOS_IDX])))
# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
vocab_transform[ln], #Numericalization
tensor_transform) # Add BOS/EOS and create tensor
# function to collate data samples into batch tensors
def collate_fn(batch):
src_batch, tgt_batch = [], []
for src_sample, tgt_sample in batch:
src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))
src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
return src_batch, tgt_batch
Now we'll turn to the model itself. This is nicely described in a tutorial. "First part is the embedding layer. This layer converts tensor of input indices into corresponding tensor of input embeddings. These embedding are further augmented with positional encodings to provide position information of input tokens to the model. The second part is the actual Transformer
model. Finally, the output of the Transformer model is passed through linear layer that gives unnormalized probabilities for each token in the target language."
Before we get to the model, we'll declare the device Torch is to use and set a seed for reproducibility.
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)
Now for the transformer model:
# functions for generating the autoregressive mask
def generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def create_mask(src, tgt):
src_seq_len = src.shape[0]
tgt_seq_len = tgt.shape[0]
tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)
src_padding_mask = (src == PAD_IDX).transpose(0, 1)
tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
def __init__(self,
emb_size: int,
dropout: float,
maxlen: int = 5000):
super(PositionalEncoding, self).__init__()
den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
pos_embedding = torch.zeros((maxlen, emb_size))
pos_embedding[:, 0::2] = torch.sin(pos * den)
pos_embedding[:, 1::2] = torch.cos(pos * den)
pos_embedding = pos_embedding.unsqueeze(-2)
self.dropout = nn.Dropout(dropout)
self.register_buffer('pos_embedding', pos_embedding)
def forward(self, token_embedding: Tensor):
return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size, emb_size):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size
def forward(self, tokens):
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
def __init__(self,
num_encoder_layers: int,
num_decoder_layers: int,
emb_size: int,
nhead: int,
src_vocab_size: int,
tgt_vocab_size: int,
dim_feedforward: int = 512,
dropout: float = 0.1):
super(Seq2SeqTransformer, self).__init__()
self.transformer = Transformer(d_model=emb_size,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout)
self.generator = nn.Linear(emb_size, tgt_vocab_size)
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
self.positional_encoding = PositionalEncoding(
emb_size, dropout=dropout)
def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
src_emb = self.positional_encoding(self.src_tok_emb(src))
tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
return self.generator(outs)
def encode(self, src, src_mask):
return self.transformer.encoder(self.positional_encoding(self.src_tok_emb(src)), src_mask)
def decode(self, tgt, memory, tgt_mask):
return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)
def sample(self, src, src_mask, max_len, start_symbol, temperature=1.0):
src = src.to(DEVICE)
src_mask = src_mask.to(DEVICE)
memory = self.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
for i in range(max_len-1):
memory = memory.to(DEVICE)
tgt_mask = (generate_square_subsequent_mask(ys.size(0))
.type(torch.bool)).to(DEVICE)
out = self.decode(ys, memory, tgt_mask)
out = out.transpose(0, 1)
logits = self.generator(out[:, -1]) / temperature
#_, next_word = torch.max(prob, dim=1) # for greedy sampling
next_word = torch.multinomial(F.softmax(logits, dim=-1), 1).item()
ys = torch.cat([ys,
torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
if next_word == EOS_IDX:
break
return ys
Before we create the training loop, let's instantiate a model using a reasonable set of starting parameters, to see how big this makes our model.
BATCH_SIZE = 32
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 64
NHEAD = 8
FFN_HID_DIM = 256
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
model = model.to(DEVICE)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.98), eps=1e-9)
n_params = sum(p.numel() for p in model.parameters())
print("model parameters:", n_params)
print("approx. tokens:", 200_000 * 256)
print("tokens per param:", 200_000 * 256 / n_params)
This provides the following output (to which I've added thousands separators)
model parameters: 360_836
approx. tokens: 51_200_000
tokens per param: 141.89
To begin, let's try training this 0.3 M param model on the dataset of 50 million tokens. Note here we're just counting input tokens.
Before we begin training, we need a simple data loader that will transform our data and present it to the model.
class CodonDataset:
"""Dataset for codon dataset"""
def __init__(self, dataset, max_length=510, max_samples=None):
x, y = [], []
n_tokens = 0
n_samples = 0
for item in dataset:
if len(item["protein-sequence"]) < max_length:
n_tokens += len(item["protein-sequence"])
x.append(" ".join(item["protein-sequence"]))
sequence = item["dna-sequence"]
y.append(" ".join(list(sequence[i:i+3] for i in range(0, len(sequence), 3))))
n_samples += 1
if n_samples == max_samples:
break
print("total tokens:", n_tokens)
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
# returns a tuple (protein, codons) space separated
return (self.x[idx], self.y[idx])
We'll use a codon dataset I uploaded to Hugging Face for this tutorial.
# load dataset, downloading if necessary
dataset = load_dataset("alxcarln/codons")
# create dataset and loader objects
train_iter = CodonDataset(dataset["train"], max_samples=None)
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
val_iter = CodonDataset(dataset["validation"], max_samples=None)
val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)