Amplify Pytorch->Candle Conversion Part 2
Cross posted on Ferritin
Intro
In the previous post I had created the Candle version of the Amplify pytorch model. However, when I started to use it It was immedietly clear that my model was wrong. This post briefly recounts the troubleshooting process to get the models to reach parity as measured by identical outputs in the sequence encoding and attention weight outputs.
The Problem My base Rust/Candle model looked like crap. Encode + run + decode should return the same sequence. Something is wrong and needs to be fixed
# encode() + model.forward() + decode()
input MSVVGIDLGFQSCYVAVARAGGIETIANEYSDRCTPACISFGPKNR
output-py AMSVVGIDLGTTSCRVAVARAGGIETIANEYSDRKTPACISFGPKNRA
output-rust CECVACVMGKRGGVNTSPYQSAATRMKTWKRIRNPHFNCVIVPFISPC <--- very wrong!!
The Approach Boot up a GPU machine and clone the python and Rust versions of the model so that we are using the same tokenizer, weights, and input sequence. Then, using print statements, wak through the code until I find a discrepancy.
The TL’DR Within the attention block. the code is transposed, calculated, and then transposed again. Because the model has TWO tranpose options the final shape was correct which is why my first attempt “worked” - the encoded bits travelled through the model but the connections were wrong and thats why it looked random. Once I figured that out I was able to fix the transposes and create a test. I also saved the outputs of the attention layers and saved those so they could also be compared to the Rust/Candle version.
Below are some notes that I was taking during the process that might be of interest to others trying to port a model. The links to the related PRs are at the bottom. The code in the gist may also be helpful to show how I setup my GPU environment.
Setup#
- spin up GPU
- setup remote dev env
- clone HF repo locally
- get intial tests running for Candle and Pytorch
- add print statements in both and watch where the divergence happens
Code Walkthough / Steppin#
Freqs_Cis#
Rotary encoding.
rust: Precomputed freqs shape - cos: [2048, 32], sin: [2048, 32]
python: Freqs CIS: torch.Size([2048, 32]
python: Freqs CIS sliced: torch.Size([48, 32])
Freqs_Cis_ Sliced#
** possible issue here **
Begin tracking down….
rust: freqs_cis shape: [48, 32, 2]
python: Freqs CIS sliced: torch.Size([48, 32])
Claude seems confident this is correct. Lets keep looking.
The key difference is in how the complex numbers are being represented. In the Python version,
torch.polar()
creates a complex tensor directly, while in the Rust version, we’re storing the real (cosine) and imaginary (sine) parts separately in the last dimension. To match the Python version’s output shape exactly, you could modify the Rust code to return just a [2048, 32] tensor where each element is a complex number. However, since Candle might not have direct complex number support, the current approach of keeping real and imaginary parts separate in a [2048, 32, 2] tensor is actually a valid and common alternative representation.
Encoded#
Per residue -> hidden dim.
rust: x shape, freqs_cis shape: [1, 48, 640]
python: Encoded: torch.Size([1, 48, 640])
Attention Block#
# before reshape
# after reshape
# apply rotary embeddings
# rust:
AttentionBlock: xq_shape: [1, 48, 640]
Attempting reshape to: [1, 48, 10, 64]
Output shapes - xq: [1, 48, 10, 64], xk: [1, 48, 10, 64]
# python:
EncoderBlock_ATT. xq: torch.Size([1, 48, 640]), xk: torch.Size([1, 48, 640]), xv: torch.Size([1, 48, 640])
EncoderBlock_ATT_reshaped. xq: torch.Size([1, 48, 10, 64]), xk: torch.Size([1, 48, 10, 64]), xv: torch.Size([1, 48, 10, 64])
EncoderBlock_after_rotary. xq: torch.Size([1, 48, 10, 64]), xk: torch.Size([1, 48, 10, 64])
# cross attention
# rust:
ATTENTION: [1, 48, 10, 64]
ATTENTION_reshaped: [1, 48, 640]
ATTENTION_output: [1, 48, 640]
ATTENTION_output_drop: [1, 48, 640]
# python:
attn_weights: None
attention: torch.Size([1, 48, 10, 64])
attention_view: torch.Size([1, 48, 640])
attn_scores: torch.Size([1, 48, 640])
FFN Block#
# rust
FFN_norm shape [1, 48, 640]
FFN_forward shape [1, 48, 640]
FFN_dropout shape [1, 48, 640]
# python
ATT_BLOCK_OUT. attn: torch.Size([1, 48, 640]), contact: None
FFN_NORM. ffn_norm: torch.Size([1, 48, 640])
FFN_BLOCK: x:torch.Size([1, 48, 640])
FFN_BLOCK_FFN: x:torch.Size([1, 48, 640])
FFN_BLOCK_DROP: x:torch.Size([1, 48, 640])
FFN_FINAL.: torch.Size([1, 48, 640])
I spot an Issue#
Tranpose!
if x.is_cuda:
# Input and output are of dimension (B, M, H, K) where B is the batch size, M the sequence length,
# H the number of heads, and K the embeding size per head
attn = memory_efficient_attention(
query=xq,
key=xk,
value=xv,
attn_bias=attention_mask,
p=self.config.dropout_prob if self.training else 0,
)
else:
# Input and output are of dimension (B, H, M, K)
attn = scaled_dot_product_attention(
query=xq.transpose(1, 2),
key=xk.transpose(1, 2),
value=xv.transpose(1, 2),
attn_mask=attention_mask,
dropout_p=self.config.dropout_prob if self.training else 0,
).transpose(1, 2) # <----- Transpose!!
# rust
# ATTENTION_pretranspose: [1, 48, 10, 64]
# ATTENTION: [1, 10, 48, 64]
# ATTN CALC: torch.Size([1, 10, 48, 64])
# ATTN CALC TRANSPOSE: torch.Size([1, 48, 10, 64])
It looks like my scaled dot product fn outputs the correct dimensions. Hmm. Lets check the dimentions of the inputs.
xq_permute = xq.transpose(1, 2)
attn = scaled_dot_product_attention(
query=xq_permute,
key=xk.transpose(1, 2),
value=xv.transpose(1, 2),
attn_mask=attention_mask,
dropout_p=self.config.dropout_prob if self.training else 0,)
print(f"ATTN CALC IN: xq: {xq.shape}")
print(f"ATTN CALC IN: xq_permute:{xq_permute.shape}")
# ATTN CALC IN: xq: torch.Size([1, 48, 10, 64])
# ATTN CALC IN: xq_permute:torch.Size([1, 10, 48, 64])
No permute! lets add those in.
let xq_permute = xq.permute((0, 2, 1, 3))?;
let attn = self.scaled_dot_product_attention(
&xq, &xk, &xv, pad_mask.as_ref(), dropout_prob, false)?;
println!("ATTN CALC IN: xq: {:?}", xq.dims());
println!("ATTN CALC IN: xq_permute: {:?}", xq_permute.dims());
// ATTN CALC IN: xq: [1, 48, 10, 64]
// ATTN CALC IN: xq_permute: [1, 10, 48, 64]
AAAAnd … got it! Woo hoo!
Links
PRs Since Last Post:
- #39. QC Pass.
- #40. Factor out the Tests.
- #41.PyTorch Compatibility The relevant work is here
- #42. Tensor Test Data and QC Pass Added a
safetensor
testfile from pytorch that can be used to test the Candle outputs.
Troubleshooting Gist: