Amplify Pytorch->Candle Conversion Part 2
Note: this is cross posted from ferritin
Intro
- spin up GPU
- setup remote dev env
- clone HF repo locally
- get intial tests running for Candle and Pytorch
- add print statements in both and watch where the divergence happens
Start#
# roundtrip test
# cargo test -p ferritin-featurizers test_amplify_round_trip -- --exact
# python run.py
#
# sequence MSVVGIDLGFQSCYVAVARAGGIETIANEYSDRCTPACISFGPKNR
# py AMSVVGIDLGTTSCRVAVARAGGIETIANEYSDRKTPACISFGPKNRA
# rust CECVACVMGKRGGVNTSPYQSAATRMKTWKRIRNPHFNCVIVPFISPC
Freqs_Cis#
Rotary encoding.
rust: Precomputed freqs shape - cos: [2048, 32], sin: [2048, 32]
python: Freqs CIS: torch.Size([2048, 32]
python: Freqs CIS sliced: torch.Size([48, 32])
Freqs_Cis_ Sliced#
** issue here **
Begin tracking down….
rust: freqs_cis shape: [48, 32, 2]
python: Freqs CIS sliced: torch.Size([48, 32])
Rotary encoding.
Encoded#
Per residue -> hidden dim.
rust: x shape, freqs_cis shape: [1, 48, 640]
python: Encoded: torch.Size([1, 48, 640])
Attention Block#
# before reshape
# after reshape
## apply rotary embeddings
rust:
AttentionBlock: xq_shape: [1, 48, 640]
Attempting reshape to: [1, 48, 10, 64]
Output shapes - xq: [1, 48, 10, 64], xk: [1, 48, 10, 64]
python:
EncoderBlock_ATT. xq: torch.Size([1, 48, 640]), xk: torch.Size([1, 48, 640]), xv: torch.Size([1, 48, 640])
EncoderBlock_ATT_reshaped. xq: torch.Size([1, 48, 10, 64]), xk: torch.Size([1, 48, 10, 64]), xv: torch.Size([1, 48, 10, 64])
EncoderBlock_after_rotary. xq: torch.Size([1, 48, 10, 64]), xk: torch.Size([1, 48, 10, 64])
# cross attention
rust:
ATTENTION: [1, 48, 10, 64]
ATTENTION_reshaped: [1, 48, 640]
ATTENTION_output: [1, 48, 640]
ATTENTION_output_drop: [1, 48, 640]
python:
attn_weights: None
attention: torch.Size([1, 48, 10, 64])
attention_view: torch.Size([1, 48, 640])
attn_scores: torch.Size([1, 48, 640])
FFN Block#
# rust
FFN_norm shape [1, 48, 640]
FFN_forward shape [1, 48, 640]
FFN_dropout shape [1, 48, 640]
# python
ATT_BLOCK_OUT. attn: torch.Size([1, 48, 640]), contact: None
FFN_NORM. ffn_norm: torch.Size([1, 48, 640])
FFN_BLOCK: x:torch.Size([1, 48, 640])
FFN_BLOCK_FFN: x:torch.Size([1, 48, 640])
FFN_BLOCK_DROP: x:torch.Size([1, 48, 640])
FFN_FINAL.: torch.Size([1, 48, 640])
Spot a potential Issue#
Tranpose!
if x.is_cuda:
# Input and output are of dimension (B, M, H, K) where B is the batch size, M the sequence length,
# H the number of heads, and K the embeding size per head
attn = memory_efficient_attention(
query=xq,
key=xk,
value=xv,
attn_bias=attention_mask,
p=self.config.dropout_prob if self.training else 0,
)
else:
# Input and output are of dimension (B, H, M, K)
attn = scaled_dot_product_attention(
query=xq.transpose(1, 2),
key=xk.transpose(1, 2),
value=xv.transpose(1, 2),
attn_mask=attention_mask,
dropout_p=self.config.dropout_prob if self.training else 0,
).transpose(1, 2) # <----- Transpose!!
# rust
# ATTENTION_pretranspose: [1, 48, 10, 64]
# ATTENTION: [1, 10, 48, 64]
# ATTN CALC: torch.Size([1, 10, 48, 64])
# ATTN CALC TRANSPOSE: torch.Size([1, 48, 10, 64])
It looks like my scaled dot product fn outputs the correct dimensions. Hmm. Lets check the dimentions of the inputs.
print(f"ATTN CALC IN: xq: {xq.shape}")
xq_permute = xq.transpose(1, 2)
print(f"ATTN CALC IN: xq_permute:{xq_permute.shape}")
attn = scaled_dot_product_attention(
query=xq_permute,
key=xk.transpose(1, 2),
value=xv.transpose(1, 2),
attn_mask=attention_mask,
dropout_p=self.config.dropout_prob if self.training else 0,
)
# ATTN CALC IN: xq: torch.Size([1, 48, 10, 64])
# ATTN CALC IN: xq_permute:torch.Size([1, 10, 48, 64])
No permute! lets add those in.
let xq_permute = xq.permute((0, 2, 1, 3))?;
println!("ATTN CALC IN: xq: {:?}", xq.dims());
println!("ATTN CALC IN: xq_permute: {:?}", xq_permute.dims());
let attn = self.scaled_dot_product_attention(
&xq,
&xk,
&xv,
pad_mask.as_ref(),
dropout_prob,
false,
)?;
// ATTN CALC IN: xq: [1, 48, 10, 64]
// ATTN CALC IN: xq_permute: [1, 10, 48, 64]
AAAAAnd … got it! Woo hoo!
Gist https://gist.github.com/zachcp/c731fdf837465aa5a44e6ecaed8e99fa