File size: 2,247 Bytes
809fb87
 
 
 
 
 
 
ff2b104
809fb87
ff2b104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d43f920
 
 
 
 
809fb87
 
 
 
 
 
ff2b104
809fb87
 
 
ff2b104
809fb87
 
 
 
 
 
 
 
ff2b104
809fb87
 
 
ff2b104
 
 
 
 
809fb87
 
 
 
ff2b104
809fb87
 
 
 
ff2b104
 
809fb87
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import pandas as pd
from pathlib import Path
from Bio import SeqIO
from dscript.pretrained import get_pretrained
from dscript.language_model import lm_embed
from tqdm.auto import tqdm
from uuid import uuid4

model_map = {
    "D-SCRIPT": "human_v1",
    "Topsy-Turvy": "human_v2"
}

def predict(model, sequence_file, pairs_file):
    
    run_id = uuid4()

    gr.Info("Loading model...")
    _ = lm_embed("M")
    
    model = get_pretrained(model_map[model])

    gr.Info("Loading files...")
    try:
        seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta"))
    except ValueError as e:
        gr.Error("Invalid FASTA file - duplicate entry")

    if Path(pairs_file.name).suffix == ".csv":
        pairs = pd.read_csv(pairs_file.name)
    elif Path(pairs_file.name).suffix == ".tsv":
        pairs = pd.read_csv(pairs_file.name, sep="\t")
    pairs.columns = ["protein1", "protein2"]

    gr.Info("Predicting...")
    results = []
    progress = gr.Progress(track_tqdm=True)
    for i, r in tqdm(pairs.iterrows(), total=len(pairs)):
        gr.Info(f"[{i+1}/{len(pairs)}]")
        prot1 = r["protein1"]
        prot2 = r["protein2"]
        seq1 = str(seqs[prot1].seq)
        seq2 = str(seqs[prot2].seq)
        lm1 = lm_embed(seq1)
        lm2 = lm_embed(seq2)
        interaction = model.predict(lm1, lm2).item()
        results.append([prot1, prot2, interaction])
        progress((i, len(pairs)))

    results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"])
    
    file_path = f"/tmp/{run_id}.tsv"
    with open(file_path, "w") as f:
        results.to_csv(f, sep="\t", index=False, header = True)

    return results, file_path

demo = gr.Interface(
    fn=predict,
    inputs = [
        gr.Dropdown(label="Model", choices = ["D-SCRIPT", "Topsy-Turvy"], value = "Topsy-Turvy"),
        gr.File(label="Sequences (.fasta)", file_types = [".fasta"]),
        gr.File(label="Pairs (.csv/.tsv)", file_types = [".csv", ".tsv"])
    ],
    outputs = [
        gr.DataFrame(label='Results', headers=['Protein 1', 'Protein 2', 'Interaction']),
        gr.File(label="Download results", type="file")
    ]
)

if __name__ == "__main__":
    demo.queue(max_size=20)
    demo.launch()