mschuh commited on
Commit
02327b3
β€’
1 Parent(s): 9c26395

Update to ZeroGPU

Browse files
app.py CHANGED
@@ -128,4 +128,4 @@ iface = gr.Interface(
128
  theme=theme
129
  )
130
 
131
- iface.launch()
 
128
  theme=theme
129
  )
130
 
131
+ iface.launch(share=True)
model/barlow_twins.py CHANGED
@@ -9,10 +9,12 @@ import os
9
  import pickle
10
  import inspect
11
  from tqdm.auto import trange
 
12
 
13
  from model.base_model import BaseModel
14
 
15
 
 
16
  class BarlowTwins(BaseModel):
17
  def __init__(
18
  self,
 
9
  import pickle
10
  import inspect
11
  from tqdm.auto import trange
12
+ import spaces
13
 
14
  from model.base_model import BaseModel
15
 
16
 
17
+ @spaces.GPU
18
  class BarlowTwins(BaseModel):
19
  def __init__(
20
  self,
model/base_model.py CHANGED
@@ -2,8 +2,10 @@ from typing import Tuple, Any, Union
2
  import torch
3
  from torch import nn
4
  import numpy as np
 
5
 
6
 
 
7
  class BaseModel(nn.Module):
8
  def __init__(self):
9
  super(BaseModel, self).__init__()
 
2
  import torch
3
  from torch import nn
4
  import numpy as np
5
+ import spaces
6
 
7
 
8
+ @spaces.GPU
9
  class BaseModel(nn.Module):
10
  def __init__(self):
11
  super(BaseModel, self).__init__()
model/model.py CHANGED
@@ -17,6 +17,7 @@ import torch
17
  from typing import *
18
  from rdkit import RDLogger
19
  RDLogger.DisableLog("rdApp.*")
 
20
 
21
  from xgboost import XGBClassifier, DMatrix
22
 
@@ -26,7 +27,7 @@ from model.barlow_twins import BarlowTwins
26
  from utils.sequence import uniprot2sequence, encode_sequences
27
 
28
 
29
-
30
  class DTIModel:
31
  def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"):
32
  self.bt_model = BarlowTwins()
 
17
  from typing import *
18
  from rdkit import RDLogger
19
  RDLogger.DisableLog("rdApp.*")
20
+ import spaces
21
 
22
  from xgboost import XGBClassifier, DMatrix
23
 
 
27
  from utils.sequence import uniprot2sequence, encode_sequences
28
 
29
 
30
+ @spaces.GPU
31
  class DTIModel:
32
  def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"):
33
  self.bt_model = BarlowTwins()
utils/sequence.py CHANGED
@@ -8,6 +8,7 @@ import concurrent.futures
8
  from tqdm.auto import tqdm
9
  import multiprocessing
10
  from multiprocessing import Pool
 
11
 
12
 
13
  ENCODERS = {
@@ -49,6 +50,7 @@ def uniprot2sequence(uniprot_id):
49
  return None
50
 
51
 
 
52
  def encode_sequences(sequences: list, encoder: str):
53
  if encoder not in ENCODERS.keys():
54
  raise ValueError(f"Invalid encoder: {encoder}")
 
8
  from tqdm.auto import tqdm
9
  import multiprocessing
10
  from multiprocessing import Pool
11
+ import spaces
12
 
13
 
14
  ENCODERS = {
 
50
  return None
51
 
52
 
53
+ @spaces.GPU
54
  def encode_sequences(sequences: list, encoder: str):
55
  if encoder not in ENCODERS.keys():
56
  raise ValueError(f"Invalid encoder: {encoder}")