LMPNN Featurizers
Note: this is crossposted to Github
Intro
One of the motivating reasons for creating ferritin
was to be able to recreate and run protein language
models in the browser via WASM. This post describes a work-in-progress effort to port
the ligand-mpnn protein model from python/pytorch to Rust/Candle.
My general strategy was to nibble at the model from both ends: begin a naive translation of the Model,
and begin to recreate in Rust the protein processing portion.
Coordinates –> Tensors
Protein language models require converting proteins into compact Tensor
representations of molecular features.
I defined the LMPNNFeatures
trait to define the functions that I need. These functions operate on AtomCollections
and return candle_core::Tensor
s. You can see the function signatures
below as well as the ProteinFeatures
struct which holds a set of
these Tensor
s and mimics the LigandMPNN dictionary.
/// Convert the AtomCollection into a struct that can be passed to a model.
pub trait LMPNNFeatures {
/// convert AtomCollection to a set of Features.
fn featurize(&self, device: &Device) -> Result<ProteinFeatures>;
/// Create a Tensor of dimensions [ <# residues>, 4 <N/CA/C/O> , 3 <xyz>]
fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor>;
/// Create a Tensor of dimensions [ <# of residues>, 37, 3 <xyz>]
fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor>;
/// Create a 3 Tensors of dimensions [ <# heavy atoms>, 3 <xyz>]
/// ( positions , elements, mask )
fn to_numeric_ligand_atoms(&self, device: &Device)
-> Result<(Tensor, Tensor, Tensor)>;
}
/// ProteinFeatures
/// Essential Tensors are Aligned Coordinates of Atoms by Atom Type
pub struct ProteinFeatures {
s: Tensor, // protein amino acids as i32 Tensor
x: Tensor, // protein co-oords by residue [1, #res, 37, 4]
y: Tensor, // ligand coords
y_t: Tensor, // encoded ligand atom names
...
...
}
One nice ergonomic crate I came across was strum.
In order to generate a Tensor
with aligned coordinates I would like to iterate
through the atomtypes and the EnumIter
lets me do that. Below you can see
how that enum is used to build up preallocated vector which we populate and
then reshape into a Tensor. Very nice.
/// AAAtom. This is where the different Atom types are registered.
#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, EnumIter)]
pub enum AAAtom {
N = 0, CA = 1, C = 2, CB = 3, O = 4,
CG = 5, CG1 = 6, CG2 = 7, OG = 8, OG1 = 9,
SG = 10, CD = 11, CD1 = 12, CD2 = 13, ND1 = 14,
ND2 = 15, OD1 = 16, OD2 = 17, SD = 18, CE = 19,
CE1 = 20, CE2 = 21, CE3 = 22, NE = 23, NE1 = 24,
NE2 = 25, OE1 = 26, OE2 = 27, CH2 = 28, NH1 = 29,
NH2 = 30, OH = 31, CZ = 32, CZ2 = 33, CZ3 = 34,
NZ = 35, OXT = 36,
Unknown = -1,
}
// by iterating over the ENUM we can extract all of our coordinate data as follows.
impl LMPNNFeatures for AtomCollection {
fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor> {
let res_count = self.iter_residues_aminoacid().count();
let mut atom37_data = vec![0f32; res_count * 37 * 3];
for residue in self.iter_residues_aminoacid() {
let resid = residue.res_id as usize;
for atom_type in AAAtom::iter().filter(|&a| a != AAAtom::Unknown) {
if let Some(atom) = residue.find_atom_by_name(&atom_type.to_string()) {
let [x, y, z] = atom.coords;
let base_idx = (resid * 37 + atom_type as usize) * 3;
atom37_data[base_idx] = *x;
atom37_data[base_idx + 1] = *y;
atom37_data[base_idx + 2] = *z;
}
}
}
// Create tensor with shape [residues, 37, 3]
Tensor::from_vec(atom37_data, (res_count, 37, 3), &device)
}
}
The LigandMPNN/ProteinMPNN Model
The section above discusses the first part of the problem - converting the proteins to features. This section starts from the other direction which is to take the pytorch code and translate it into Candle code.
As of Nov 5th this is still a work in progress but is not in terrible shape. Below I show a snippet of the high-level model where you can see there are a set of decoder and encoder layers and set of defaults. Internally, there are some tricky bits that are made potentially more tricky than needed due to the many ways that LigandMPNN can be invoked. This results in a lot of branches and conditions within the model and I had some trouble teasing out the model code from the data. But hopefully you can catch a glimpse of what the model should look like.
pub struct ProteinMPNN {
config: ProteinMPNNConfig, // device here ??
decoder_layers: Vec<DecLayer>,
device: Device,
encoder_layers: Vec<EncLayer>,
features: ProteinFeaturesModel, // this needs to be a model with weights etc
w_e: Linear,
w_out: Linear,
w_s: Linear,
}
impl ProteinMPNN {
pub fn new(config: ProteinMPNNConfig, vb: VarBuilder) {}
fn predict(&self) {}
fn train(&mut self) {}
fn encode(&self, features: &ProteinFeatures) {}
fn sample(&self, features: &ProteinFeatures) {}
pub fn score(&self, features: &ProteinFeatures, use_sequence: bool) {}
}
// Model Params Contained here.
// this use of a `Config` struct with
// methods to supply the model-specific values is
// how most of the models in candle_transformers get their
// args.
impl ProteinMPNNConfig {
pub fn proteinmpnn() -> Self {
Self {
atom_context_num: 0,
augment_eps: 0.0,
dropout_ratio: 0.1,
edge_features: 128,
hidden_dim: 128,
k_neighbors: 24,
ligand_mpnn_use_side_chain_context: false,
model_type: ModelTypes::ProteinMPNN,
node_features: 128,
num_decoder_layers: 3,
num_encoder_layers: 3,
num_letters: 48,
num_rbf: 16,
scale_factor: 30.0,
vocab: 48,
}
}
}
CLI and Invocation Configs
There are a LOT of possible configurations that LigandMPNN supports. Based on aggregating
the example CLI options
the LigandMPNN CLI repo, I partitioned them into related structs - e.g. AABias, ResidueControl etc -
and consolidated all of the configs into a common file. The complicated structure of this config is
not great and I am not sure I am going to stick with it. But I do have the beginnings of a testing setup
that can run CLI tests. At the moment there is only one ferritin-featurizer featurize <input> <output>
which will take a PDB or CIF and create a featurized safetensor
file.
//ligandmpnn/config.rs
//! - `ModelTypes` - Enum of supported model architectures
//! - `ProteinMPNNConfig` - Core model parameters
//! - `AABiasConfig` - Amino acid biasing controls
//! - `LigandMPNNConfig` - LigandMPNN specific settings
//! - `MembraneMPNNConfig` - MembraneMPNN specific settings
//! - `MultiPDBConfig` - Multi-PDB mode configuration
//! - `ResidueControl` - Residue-level design controls
//! - `RunConfig` - Runtime execution parameters
pub struct RunConfig {
pub temperature: Option<f32>,
pub verbose: Option<i32>,
pub save_stats: Option<i32>,
pub batch_size: Option<i32>,
pub number_of_batches: Option<i32>,
pub file_ending: Option<String>,
pub zero_indexed: Option<i32>,
pub homo_oligomer: Option<i32>,
pub fasta_seq_separation: Option<String>,
}
Todo
I am happy with the progress so far but the next portion of work needs to tie the pieces together into
a coherent, usable model. Because this is my first foray into Candle
and because Protein-MPNN is a graph
network its not an easy first project to implement. Therefore, I think I am going to take a detour to
implement a simpler model and then, once I have my bearings, to return to finish the job.
I think the best candidate for the first complete model would be Amplify,
a model that is touting its small size relative to incumbents; whose weights are already in safetensor
format
and available on HuggingFace; and whose model architecture looks to be more of a standard transformer-style model
which matches Candle’s strengths.