Amplify Pytorch->Candle Conversion Part 2

Note: this is cross posted from ferritin

Intro

  • spin up GPU
  • setup remote dev env
  • clone HF repo locally
  • get intial tests running for Candle and Pytorch
  • add print statements in both and watch where the divergence happens

Start#

# roundtrip test
# cargo test -p ferritin-featurizers test_amplify_round_trip -- --exact
# python run.py
#
# sequence   MSVVGIDLGFQSCYVAVARAGGIETIANEYSDRCTPACISFGPKNR
# py        AMSVVGIDLGTTSCRVAVARAGGIETIANEYSDRKTPACISFGPKNRA
# rust      CECVACVMGKRGGVNTSPYQSAATRMKTWKRIRNPHFNCVIVPFISPC

Freqs_Cis#

Rotary encoding.

rust:   Precomputed freqs shape - cos: [2048, 32], sin: [2048, 32]
python: Freqs CIS: torch.Size([2048, 32]
python: Freqs CIS sliced: torch.Size([48, 32])

Freqs_Cis_ Sliced#

** issue here **

Begin tracking down….

rust:   freqs_cis shape:            [48, 32, 2]
python: Freqs CIS sliced: torch.Size([48, 32])

Rotary encoding.



Encoded#

Per residue -> hidden dim.

rust:   x shape, freqs_cis shape: [1, 48, 640]
python: Encoded: torch.Size([1, 48, 640])

Attention Block#

# before reshape
# after reshape
## apply rotary embeddings

rust:
AttentionBlock: xq_shape: [1, 48, 640]
Attempting reshape to: [1, 48, 10, 64]
Output shapes - xq: [1, 48, 10, 64], xk: [1, 48, 10, 64]

python:
EncoderBlock_ATT. xq: torch.Size([1, 48, 640]), xk: torch.Size([1, 48, 640]), xv: torch.Size([1, 48, 640])
EncoderBlock_ATT_reshaped. xq: torch.Size([1, 48, 10, 64]), xk: torch.Size([1, 48, 10, 64]), xv: torch.Size([1, 48, 10, 64])
EncoderBlock_after_rotary. xq: torch.Size([1, 48, 10, 64]), xk: torch.Size([1, 48, 10, 64])
# cross attention

rust:
ATTENTION: [1, 48, 10, 64]
ATTENTION_reshaped: [1, 48, 640]
ATTENTION_output: [1, 48, 640]
ATTENTION_output_drop: [1, 48, 640]

python:
attn_weights: None
attention: torch.Size([1, 48, 10, 64])
attention_view: torch.Size([1, 48, 640])
attn_scores: torch.Size([1, 48, 640])

FFN Block#

# rust
FFN_norm shape [1, 48, 640]
FFN_forward shape [1, 48, 640]
FFN_dropout shape [1, 48, 640]

# python
ATT_BLOCK_OUT. attn: torch.Size([1, 48, 640]),  contact: None
FFN_NORM. ffn_norm: torch.Size([1, 48, 640])
FFN_BLOCK: x:torch.Size([1, 48, 640])
FFN_BLOCK_FFN: x:torch.Size([1, 48, 640])
FFN_BLOCK_DROP: x:torch.Size([1, 48, 640])
FFN_FINAL.: torch.Size([1, 48, 640])

Spot a potential Issue#

Tranpose!

if x.is_cuda:
    # Input and output are of dimension (B, M, H, K) where B is the batch size, M the sequence length,
    # H the number of heads, and K the embeding size per head
    attn = memory_efficient_attention(
        query=xq,
        key=xk,
        value=xv,
        attn_bias=attention_mask,
        p=self.config.dropout_prob if self.training else 0,
    )
else:
    # Input and output are of dimension (B, H, M, K)
    attn = scaled_dot_product_attention(
        query=xq.transpose(1, 2),
        key=xk.transpose(1, 2),
        value=xv.transpose(1, 2),
        attn_mask=attention_mask,
        dropout_p=self.config.dropout_prob if self.training else 0,
    ).transpose(1, 2)  # <----- Transpose!!

# rust
# ATTENTION_pretranspose: [1, 48, 10, 64]
# ATTENTION: [1, 10, 48, 64]

# ATTN CALC: torch.Size([1, 10, 48, 64])
# ATTN CALC TRANSPOSE: torch.Size([1, 48, 10, 64])

It looks like my scaled dot product fn outputs the correct dimensions. Hmm. Lets check the dimentions of the inputs.

print(f"ATTN CALC IN: xq: {xq.shape}")
xq_permute = xq.transpose(1, 2)
print(f"ATTN CALC IN: xq_permute:{xq_permute.shape}")
attn = scaled_dot_product_attention(
    query=xq_permute,
    key=xk.transpose(1, 2),
    value=xv.transpose(1, 2),
    attn_mask=attention_mask,
    dropout_p=self.config.dropout_prob if self.training else 0,
)

# ATTN CALC IN: xq: torch.Size([1, 48, 10, 64])
# ATTN CALC IN:  xq_permute:torch.Size([1, 10, 48, 64])

No permute! lets add those in.


let xq_permute = xq.permute((0, 2, 1, 3))?;

println!("ATTN CALC IN: xq: {:?}", xq.dims());
println!("ATTN CALC IN: xq_permute: {:?}", xq_permute.dims());

let attn = self.scaled_dot_product_attention(
    &xq,
    &xk,
    &xv,
    pad_mask.as_ref(),
    dropout_prob,
    false,
)?;

// ATTN CALC IN: xq: [1, 48, 10, 64]
// ATTN CALC IN: xq_permute: [1, 10, 48, 64]

AAAAAnd … got it! Woo hoo!

Gist https://gist.github.com/zachcp/c731fdf837465aa5a44e6ecaed8e99fa