Amplify Pytorch->Candle Conversion
__ Note: this is cross posted from ferritin __
Intro
In the previous post I left off with the comment that Amplify would be a better first candidate to port to Candle. In this post I am happy to report and initial, but as-of-yet-untested port of the Amplify model from pytorch to Candle. Here are a few of the lessons learned .
# runnable test you can use/modify to start playing
# git clone git@github.com:zachcp/ferritin.git
# cd ferritin
cargo run --example amplify
# Initial Test METVALMETVAL".to_string()
AMPLIFY.forward(): calculating logits
Encoded Logits Dimension: Tensor[dims 1, 14, 27; f32],
indices: [25, 11, 21, 17, 7, 15, 8, 17, 16, 7, 18, 25, 25, 15]
Decoded Values: C E Y K A D G K P A Q C C D
Parts
Weights#
The weights were the easiest part as they were already loaded on HuggingFace and
could be retrieved and used via the candle-hf-hub
crate. Because I was able to load
these weights I was also able to print the names and dimensions of each of the
Tensors in this file which was very helpful later for making sure I had the right model
copied - e.g. it should use and match all of the weights in the Tensor file
// Setup HF API and model info
let model_id = "chandar-lab/AMPLIFY_120M";
let revision = "main";
// Initialize the Hugging Face API client
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
));
// Load and analyze the safetensors file
let weights_path = repo.get("model.safetensors")?;
let weights = std::fs::read(&weights_path)?;
let tensors = SafeTensors::deserialize(&weights)?;
// Print all tensor names and their metadata
println!("Model tensors:");
tensors.names().iter().for_each(|tensor_name| {
if let Ok(tensor_info) = tensors.tensor(tensor_name) {
println!(
"Tensor: {:<44} || Shape: {:?}",
tensor_name,
tensor_info.shape(),
);
}
});
Tokenizer#
From the perspective of tokenizers, using Amplify was a good choice. The Amplify Hugging Face repo
has tokenizer.json
file and uses the python huggingface tokenizer library.
There is a comparable rust version, tokenizers, that
uses the same underlying code and logic. A bit of Claude
magic and i had a functional
Rust tokenizer, ProteinTokenizer
, and was able to retrieve and encode sequences using the
Amplify sequence vocab. This vocabulary will allow us to convert from char
to int
and vice versa.
// examples/amplify/main.rs
// The HF Repo
let model_id = "chandar-lab/AMPLIFY_120M";
let revision = "main";
// Initialize the Hugging Face API client
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
));
// create and use the tokenizer...
let tokenizer = repo.get("tokenizer.json")?;
let protein_tokenizer = ProteinTokenizer::new(tokenizer)?;
println!("Successfully created the tokenizer!");
let pmatrix = protein_tokenizer.encode(&["METVALMETVAL".to_string()], Some(20), true, false)?;
// pmatrix: Tensor[dims 14; i64] <- sequence is now a set of ints.
Model#
In addition to the weights and tokenizer as mentioned above, I also had access to the pytorch source code and model info on github. There are only 3 files and I was able to use Claude to get a decent first pass conversion between pytorch and Candle. The major issues I encouter are related to:
- Hallucinating/inferring pytorch methods that don’t exist in Candle. This led to some spelunking in the Candle source code and my first PR to the Candle repo.
- Differences in idioms - e.g. Candle uses enums like
D::Minus1
to specify some dimensions and never accepts-1
. - Direct python->rust translations that feel or look clunky. These sometimes can be addressed with a few rounds of
Can you write this more idiomatically?
orwhat would this code look like if it used iterators?
or otehr prompts to get Claude to rustize the code.
So my initial test was to get a Load
function so that I could load the weights into an
appropriately shaped Model, as specified by source code and model weights.
AMPLIFY(
(encoder): Embedding(27, 960, padding_idx=0)
(transformer_encoder): ModuleList(
(0-31): 32 x EncoderBlock(
(q): Linear(in_features=960, out_features=960, bias=False)
(k): Linear(in_features=960, out_features=960, bias=False)
(v): Linear(in_features=960, out_features=960, bias=False)
(wo): Linear(in_features=960, out_features=960, bias=False)
(resid_dropout): Dropout(p=0, inplace=False)
(ffn): SwiGLU(
(w12): Linear(in_features=960, out_features=5120, bias=False)
(w3): Linear(in_features=2560, out_features=960, bias=False)
)
(attention_norm): RMSNorm()
(ffn_norm): RMSNorm()
(ffn_dropout): Dropout(p=0, inplace=False)
)
)
(layer_norm_2): RMSNorm()
(decoder): Linear(in_features=960, out_features=27, bias=True)
)
Encoding#
The last major mechanical task once the Model was up and running and loaded was to test if I could run a new, encoded sequence through it and get results. This was the most time-consuming bits as it is where all of the internal connections have to be correctly mapped/mimicked. I won’t get into all of the bits but did run into a bit of trouble with
- Rotational Embeddings. I didn’t know what those were or how they work. And they do some funky things with dimensions. I ended up using Claude heavily to iterate on possible translations but I am not sure I have it right and this is the most likley place I’ve failed to replicate the architecture if my tests against the pytorch model don’t pass.
- Xformers.The Amplify team used xformers in a few places for optimizations. In the feed-forward part of the network, in the SWIGLU activation step, and in a memory efficient cross product block. I tried to translate the intention of the model as best I could to plain Candle. I am SURE that there more optimal ways to implement these.
My goal was to get a working model that is as close to a literal translation as possible and then add tests so I can make sure that any architectural tweaks are not altering the fundamental scientific outputs.
Claude Sonnet.
Once again amazing. I think the updated Sonnet 3.5 got a big logic update. I am going to paste a bit of one of my interactions with sonnet to give a taste of how it helped me reason about a technical field that I am relatively new to. In this bit I have given Claude my Rust code and asked it to trace the path of an encoded sequence through the model while keeping track of the shape of the tensors at each step. This is quite a remarkable response:
Claude’s Response:#
Let’s trace the path of a sequence through the AMPLIFY model, following the dimensions at each step. I’ll use a concrete example to make it clearer:
Let’s say we input a protein sequence of length 8 with a batch size of 1:
-
Initial Input
Input sequence: "METVAL..." (8 tokens) Shape: [1, 8] // [batch_size, seq_length]
-
Embedding Layer (
self.encoder
)Converts tokens to embeddings Input: [1, 8] Output: [1, 8, 640] // [batch, seq_len, hidden_size]
-
Transformer Encoder Blocks (repeated 24 times) Each block contains:
a. Self-Attention
-
Input normalization (RMSNorm)
Input: [1, 8, 640] Output: [1, 8, 640]
-
Query/Key/Value projections
Input: [1, 8, 640] Q,K,V each: [1, 8, 640]
-
Reshape for attention heads
Reshape to: [1, 8, 10, 64] // [batch, seq_len, n_heads, head_dim] // where n_heads=10, head_dim=640/10=64
-
Rotary Position Embeddings
freqs_cis: [8, 32, 2] // [seq_len, head_dim/2, 2] Applied to Q,K maintaining shape: [1, 8, 10, 64]
-
Attention Computation
Q * K^T -> [1, 10, 8, 8] // [batch, n_heads, seq_len, seq_len] After softmax -> [1, 10, 8, 8] After V multiply -> [1, 8, 10, 64]
-
Project back
Reshape: [1, 8, 640] Linear: [1, 8, 640]
b. Feed Forward Network
-
Input normalization
Input: [1, 8, 640] Output: [1, 8, 640]
-
FFN computation (with SwiGLU)
W12: [1, 8, 2560] // expanded internal dimension After SwiGLU: [1, 8, 1280] Final: [1, 8, 640]
-
-
Final Layer Norm
Input: [1, 8, 640] Output: [1, 8, 640]
-
Decoder (Final Linear Layer)
Input: [1, 8, 640] Output: [1, 8, 27] // [batch, seq_len, vocab_size]
The key transformations happen in the attention mechanism where:
- The sequence is first projected into Q,K,V
- These are split into multiple attention heads
- Position information is added via rotary embeddings
- Attention is computed across the sequence
- The results are merged back and processed through a feed-forward network
The rotary embeddings are particularly tricky because they need to:
- Handle complex number rotations (hence the pairs of values)
- Be broadcastable across the batch and head dimensions
- Align properly with the sequence length and head dimensions
This path helps explain why the shapes need to align precisely at each step, and why the rotary embedding shape mismatch is causing issues. Would you like me to elaborate on any particular part of this flow?
Todo
- implement a set of tests based on the pytorch model.
- a style and perf pass
- WASM!