fibmskrt's picture
fix: convert path to string
1f294d6 verified
raw
history blame contribute delete
No virus
4.71 kB
import logging
import pathlib
import gradio as gr
import numpy as np
import pandas as pd
from gt4sd.properties.molecules import MOLECULE_PROPERTY_PREDICTOR_FACTORY
from utils import draw_grid_predict
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
MOLFORMER_VERSIONS = {
"molformer_classification": ["bace", "bbbp", "hiv"],
"molformer_regression": [
"alpha",
"cv",
"g298",
"gap",
"h298",
"homo",
"lipo",
"lumo",
"mu",
"r2",
"u0",
],
"molformer_multitask_classification": ["clintox", "sider", "tox21"],
}
REMOVE = ["docking", "docking_tdc", "molecule_one", "askcos", "plogp"]
REMOVE.extend(["similarity_seed", "activity_against_target", "organtox"])
REMOVE.extend(MOLFORMER_VERSIONS.keys())
MODEL_PROP_DESCRIPTION = {
"Tox21": "NR-AR, NR-AR-LBD, NR-AhR, NR-Aromatase, NR-ER, NR-ER-LBD, NR-PPAR-gamma, SR-ARE, SR-ATAD5, SR-HSE, SR-MMP, SR-p53",
"Sider": "Hepatobiliary disorders,Metabolism and nutrition disorders,Product issues,Eye disorders,Investigations,Musculoskeletal disorders,Gastrointestinal disorders,Social circumstances,Immune system disorders,Reproductive system and breast disorders,Bening & malignant,General disorders,Endocrine disorders,Surgical & medical procedures,Vascular disorders,Blood & lymphatic disorders,Skin & subcutaneous disorders,Congenital & genetic disorders,Infections,Respiratory & thoracic disorders,Psychiatric disorders,Renal & urinary disorders,Pregnancy conditions,Ear disorders,Cardiac disorders,Nervous system disorders,Injury & procedural complications",
"Clintox": "FDA approval, Clinical trial failure",
}
def main(property: str, smiles: str, smiles_file: str):
if "Molformer" in property:
version = property.split(" ")[-1].split("(")[-1].split(")")[0]
property = property.split(" ")[0]
algo, config = MOLECULE_PROPERTY_PREDICTOR_FACTORY[property.lower()]
kwargs = (
{"algorithm_version": "v0"} if property in MODEL_PROP_DESCRIPTION.keys() else {}
)
if property.lower() in MOLFORMER_VERSIONS.keys():
kwargs["algorithm_version"] = version
model = algo(config(**kwargs))
if smiles != "" and smiles_file is not None:
raise ValueError("Pass either smiles or smiles_file, not both.")
elif smiles != "":
smiles = [smiles]
elif smiles_file is not None:
smiles = pd.read_csv(smiles_file.name, header=None, sep="\t")[0].tolist()
props = np.array(list(map(model, smiles))).round(2)
# Expand to 2D array if needed
if len(props.shape) == 1:
props = np.expand_dims(np.array(props), -1)
if property in MODEL_PROP_DESCRIPTION.keys():
property_names = MODEL_PROP_DESCRIPTION[property].split(",")
else:
property_names = [property]
return draw_grid_predict(
smiles, props, property_names=property_names, domain="Molecules"
)
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
properties = list(MOLECULE_PROPERTY_PREDICTOR_FACTORY.keys())[::-1]
for prop in REMOVE:
prop_to_idx = dict(zip(properties, range(len(properties))))
properties.pop(prop_to_idx[prop])
properties = list(map(lambda x: x.capitalize(), properties))
# MolFormer options
for key in MOLFORMER_VERSIONS.keys():
properties.extend(
[f"{key.capitalize()} ({version})" for version in MOLFORMER_VERSIONS[key]]
)
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = [
["Qed", "", str(metadata_root.joinpath("examples.smi"))],
[
"Esol",
"CN1CCN(CCCOc2ccc(N3C(=O)C(=Cc4ccc(Oc5ccc([N+](=O)[O-])cc5)cc4)SC3=S)cc2)CC1",
None,
],
]
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=main,
title="Molecular properties",
inputs=[
gr.Dropdown(properties, label="Property", value="Scscore"),
gr.Textbox(
label="Single SMILES",
placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
lines=1,
),
gr.File(
file_types=[".smi"],
label="Multiple SMILES (tab-separated, `.smi` file)",
),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples,
)
demo.launch(debug=True, show_error=True, share=True)