File size: 1,918 Bytes
3437d14
ce52bcc
 
afba882
3437d14
c4d60f8
 
 
ced1852
c4d60f8
 
 
a79df8e
6d84961
bf6dd26
6d84961
bf6dd26
c4d60f8
ce52bcc
c4d60f8
 
 
 
d1c87c7
afba882
a931815
c4d60f8
 
 
 
 
 
79f0451
b289879
 
 
 
 
 
c4d60f8
b289879
 
 
c4d60f8
a931815
c4d60f8
 
afba882
a79df8e
c4d60f8
 
 
 
a79df8e
afba882
bf6dd26
a79df8e
a931815
 
 
 
ce52bcc
3b2506d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import streamlit as st
from transformers import AutoTokenizer, EsmModel
import torch
import json

def embed(aa_seq, tokenizer, model):
    inputs = tokenizer(aa_seq, return_tensors="pt")
    outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state.detach().numpy().tolist()
    
    return last_hidden_states

# selecing and loading a model
model_name = st.selectbox(
    'Choose a model',
    ["facebook/esm2_t6_8M_UR50D", "facebook/esm2_t48_15B_UR50D"])

#aa_seq_input = st.text_input('Type AA sequance here')

#uploading AA sequences file
uploaded_file = st.file_uploader("Upload JSON with AA sequences", type='json')
if uploaded_file is not None:
    data = json.load(uploaded_file)
    #st.write(data)

def embed_upload_file(upload_dict_dania, tokenizer, model):
    # upload_dict_dania = {
    #                    'uid1': ['aa', 'aan'] 
    #                     }
    # output = {
    #          'uid1': {'aa':[[[0.1298, ....]]], 'aan':[[[0.1298, ....]]]} 
    #          }
    output = {}
    
    # Add a placeholder
    latest_iteration = st.empty()
    bar = st.progress(0)

    for idx, (uid, seqs) in enumerate(upload_dict_dania.items()):
        output[uid] = {}
        # Update the progress bar with each iteration.
        latest_iteration.text(f'Iteration {uid}')
        bar.progress(idx + 1)
        for seq in seqs:
            output[uid][seq] = embed(seq, tokenizer, model)
        
    json_data = json.dumps(output)

    st.download_button(
        label = "Download JSON file",
        data = json_data,
        file_name = "esm-2 last hidden states.json",
        mime = 'application/json'
    )

    
if st.button('Get embedding'):
    st.write('You selected model:', model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = EsmModel.from_pretrained(model_name)
    embed_upload_file(data, tokenizer, model)

st.write('Also, Dania is not gay')