Amplify Pytorch->Candle Conversion Part 3
Note: this is cross posted from ferritin
Now that we’ve got the model together, I want to start shaping the API. We’d like to make it easy to download/invoke the model, and to do stuff with the outputs. In this code, I’ve taken a piece of the tests as of Nov15 which show:
- an easy loading of the model which returns:
-
- a Tokenizer which can take protein sequences and convert to and from integers
-
- the 120M AMPLIFY model with weights loaded.
-
- a functions on the output that will retrieve the contact maps
::: {.callout-note}
Internally I am using candle_hf_hub
which will download and cache the model weights to
~/.cache/huggingface/hub
. See cache docs
:::
Core API in palce
use ferritin_featurizers::Amplify;
fn test_amplify_full_model() -> Result<(), Box<dyn std::error::Error>> {
// Load the Model adn the Tokenizer
let (tokenizer, amplify) = AMPLIFY::load_from_huggingface()?;
// Test the outputs of the Encoding from the Amplify Test Suite
let AMPLIFY_TEST_SEQ = "MSVVGIDLGFQSCYVAVARAGGIETIANEYSDRCTPACISFGPKNR";
// encode the sequence
let pmatrix = tokenizer.encode(
&[AMPLIFY_TEST_SEQ.to_string()], None, true, false)?;
let pmatrix = pmatrix.unsqueeze(0)?; // [batch, length]
// Run the sequence through the model....
let encoded = amplify.forward(&pmatrix, None, true, true)?;
// There are <VOCABSIZE=27> logits per char in sequence
// ARGMAX will take the highest value of these.
let predictions = &encoded.logits.argmax(D::Minus1)?;
let indices: Vec<u32> = predictions.to_vec2()?[0].to_vec();
let decoded = tokenizer.decode(indices.as_slice(), true)?;
// woo hoo! if this passes we have roundtripped.
// we strongly expect the model to recover the inputs.
assert_eq!(final_seq, decoded.replace(" ", ""));
// What if we want the contact map of the residues?
// for that we will need to retrieve them from the attentions.
if let Some(norm) = &encoded_long.get_contact_map()? {
contact_mapdims = <seqlen, seqlen, 240>
}
...
}
Lots more work to do, of course, but it looks like this is shaping up!