import json import torch from torchvision import transforms from open_clip import create_model, get_tokenizer import torch.nn.functional as F import numpy as np import collections import heapq import PIL.Image from huggingface_hub import hf_hub_download from typing import Union, List from enum import Enum HF_DATAFILE_REPO = "imageomics/bioclip-demo" HF_DATAFILE_REPO_TYPE = "space" PRED_FILENAME_KEY = "file_name" PRED_CLASSICATION_KEY = "classification" PRED_SCORE_KEY = "score" OPENA_AI_IMAGENET_TEMPLATE = [ lambda c: f"a bad photo of a {c}.", lambda c: f"a photo of many {c}.", lambda c: f"a sculpture of a {c}.", lambda c: f"a photo of the hard to see {c}.", lambda c: f"a low resolution photo of the {c}.", lambda c: f"a rendering of a {c}.", lambda c: f"graffiti of a {c}.", lambda c: f"a bad photo of the {c}.", lambda c: f"a cropped photo of the {c}.", lambda c: f"a tattoo of a {c}.", lambda c: f"the embroidered {c}.", lambda c: f"a photo of a hard to see {c}.", lambda c: f"a bright photo of a {c}.", lambda c: f"a photo of a clean {c}.", lambda c: f"a photo of a dirty {c}.", lambda c: f"a dark photo of the {c}.", lambda c: f"a drawing of a {c}.", lambda c: f"a photo of my {c}.", lambda c: f"the plastic {c}.", lambda c: f"a photo of the cool {c}.", lambda c: f"a close-up photo of a {c}.", lambda c: f"a black and white photo of the {c}.", lambda c: f"a painting of the {c}.", lambda c: f"a painting of a {c}.", lambda c: f"a pixelated photo of the {c}.", lambda c: f"a sculpture of the {c}.", lambda c: f"a bright photo of the {c}.", lambda c: f"a cropped photo of a {c}.", lambda c: f"a plastic {c}.", lambda c: f"a photo of the dirty {c}.", lambda c: f"a jpeg corrupted photo of a {c}.", lambda c: f"a blurry photo of the {c}.", lambda c: f"a photo of the {c}.", lambda c: f"a good photo of the {c}.", lambda c: f"a rendering of the {c}.", lambda c: f"a {c} in a video game.", lambda c: f"a photo of one {c}.", lambda c: f"a doodle of a {c}.", lambda c: f"a close-up photo of the {c}.", lambda c: f"a photo of a {c}.", lambda c: f"the origami {c}.", lambda c: f"the {c} in a video game.", lambda c: f"a sketch of a {c}.", lambda c: f"a doodle of the {c}.", lambda c: f"a origami {c}.", lambda c: f"a low resolution photo of a {c}.", lambda c: f"the toy {c}.", lambda c: f"a rendition of the {c}.", lambda c: f"a photo of the clean {c}.", lambda c: f"a photo of a large {c}.", lambda c: f"a rendition of a {c}.", lambda c: f"a photo of a nice {c}.", lambda c: f"a photo of a weird {c}.", lambda c: f"a blurry photo of a {c}.", lambda c: f"a cartoon {c}.", lambda c: f"art of a {c}.", lambda c: f"a sketch of the {c}.", lambda c: f"a embroidered {c}.", lambda c: f"a pixelated photo of a {c}.", lambda c: f"itap of the {c}.", lambda c: f"a jpeg corrupted photo of the {c}.", lambda c: f"a good photo of a {c}.", lambda c: f"a plushie {c}.", lambda c: f"a photo of the nice {c}.", lambda c: f"a photo of the small {c}.", lambda c: f"a photo of the weird {c}.", lambda c: f"the cartoon {c}.", lambda c: f"art of the {c}.", lambda c: f"a drawing of the {c}.", lambda c: f"a photo of the large {c}.", lambda c: f"a black and white photo of a {c}.", lambda c: f"the plushie {c}.", lambda c: f"a dark photo of a {c}.", lambda c: f"itap of a {c}.", lambda c: f"graffiti of the {c}.", lambda c: f"a toy {c}.", lambda c: f"itap of my {c}.", lambda c: f"a photo of a cool {c}.", lambda c: f"a photo of a small {c}.", lambda c: f"a tattoo of the {c}.", ] def get_cached_datafile(filename:str): return hf_hub_download(repo_id=HF_DATAFILE_REPO, filename=filename, repo_type=HF_DATAFILE_REPO_TYPE) def get_txt_emb(): txt_emb_npy = get_cached_datafile("txt_emb_species.npy") return torch.from_numpy(np.load(txt_emb_npy)) def get_txt_names(): txt_names_json = get_cached_datafile("txt_emb_species.json") with open(txt_names_json) as fd: txt_names = json.load(fd) return txt_names preprocess_img = transforms.Compose( [ transforms.ToTensor(), transforms.Resize((224, 224), antialias=True), transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) class Rank(Enum): KINGDOM = 0 PHYLUM = 1 CLASS = 2 ORDER = 3 FAMILY = 4 GENUS = 5 SPECIES = 6 def get_label(self): return self.name.lower() # The datafile of names ('txt_emb_species.json') contains species epithet. # To create a label for species we concatenate the genus and species epithet. SPECIES_LABEL = Rank.SPECIES.get_label() SPECIES_EPITHET_LABEL = "species_epithet" COMMON_NAME_LABEL = "common_name" def create_bioclip_model(model_str="hf-hub:imageomics/bioclip", device="cuda"): model = create_model(model_str, output_dict=True, require_pretrained=True) model = model.to(device) return torch.compile(model) def create_bioclip_tokenizer(tokenizer_str="ViT-B-16"): return get_tokenizer(tokenizer_str) class CustomLabelsClassifier(object): def __init__(self, device: Union[str, torch.device] = 'cpu'): self.device = device self.model = create_bioclip_model(device=device) self.tokenizer = create_bioclip_tokenizer() def get_txt_features(self, classnames): all_features = [] for classname in classnames: txts = [template(classname) for template in OPENA_AI_IMAGENET_TEMPLATE] txts = self.tokenizer(txts).to(self.device) txt_features = self.model.encode_text(txts) txt_features = F.normalize(txt_features, dim=-1).mean(dim=0) txt_features /= txt_features.norm() all_features.append(txt_features) all_features = torch.stack(all_features, dim=1) return all_features @torch.no_grad() def predict(self, image_path: str, cls_ary: List[str]) -> dict[str, float]: img = PIL.Image.open(image_path) classes = [cls.strip() for cls in cls_ary] txt_features = self.get_txt_features(classes) img = preprocess_img(img).to(self.device) img_features = self.model.encode_image(img.unsqueeze(0)) img_features = F.normalize(img_features, dim=-1) logits = (self.model.logit_scale.exp() * img_features @ txt_features).squeeze() probs = F.softmax(logits, dim=0).to("cpu").tolist() pred_list = [] for cls, prob in zip(classes, probs): pred_list.append({ PRED_FILENAME_KEY: image_path, PRED_CLASSICATION_KEY: cls, PRED_SCORE_KEY: prob }) return pred_list def predict_classifications_from_list(img: Union[PIL.Image.Image, str], cls_ary: List[str], device: Union[str, torch.device] = 'cpu') -> dict[str, float]: classifier = CustomLabelsClassifier(device=device) return classifier.predict(img, cls_ary) def get_tol_classification_labels(rank: Rank) -> List[str]: names = [] for i in range(rank.value + 1): i_rank = Rank(i) if i_rank == Rank.SPECIES: names.append(SPECIES_EPITHET_LABEL) rank_name = i_rank.name.lower() names.append(rank_name) if rank == Rank.SPECIES: names.append(COMMON_NAME_LABEL) return names def create_classification_dict(names: List[List[str]], rank: Rank) -> dict[str, str]: scientific_names = names[0] common_name = names[1] classification_dict = {} for idx, label in enumerate(get_tol_classification_labels(rank=rank)): if label == SPECIES_LABEL: value = scientific_names[-2] + " " + scientific_names[-1] elif label == COMMON_NAME_LABEL: value = common_name else: value = scientific_names[idx] classification_dict[label] = value return classification_dict def join_names(classification_dict: dict[str, str]) -> str: return " ".join(classification_dict.values()) class TreeOfLifeClassifier(object): def __init__(self, device: Union[str, torch.device] = 'cpu'): self.device = device self.model = create_bioclip_model(device=device) self.txt_emb = get_txt_emb().to(device) self.txt_names = get_txt_names() def encode_image(self, img: PIL.Image.Image) -> torch.Tensor: img = preprocess_img(img).to(self.device) img_features = self.model.encode_image(img.unsqueeze(0)) return img_features def predict_species(self, img: PIL.Image.Image) -> torch.Tensor: img_features = self.encode_image(img) img_features = F.normalize(img_features, dim=-1) logits = (self.model.logit_scale.exp() * img_features @ self.txt_emb).squeeze() probs = F.softmax(logits, dim=0) return probs def format_species_probs(self, image_path: str, probs: torch.Tensor, k: int = 5) -> List[dict[str, float]]: topk = probs.topk(k) result = [] for i, prob in zip(topk.indices, topk.values): item = { PRED_FILENAME_KEY: image_path } item.update(create_classification_dict(self.txt_names[i], Rank.SPECIES)) item[PRED_SCORE_KEY] = prob.item() result.append(item) return result def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank, min_prob: float = 1e-9, k: int = 5) -> List[dict[str, float]]: output = collections.defaultdict(float) class_dict_lookup = {} name_to_class_dict = {} for i in torch.nonzero(probs > min_prob).squeeze(): classification_dict = create_classification_dict(self.txt_names[i], rank) name = join_names(classification_dict) class_dict_lookup[name] = classification_dict output[name] += probs[i] name_to_class_dict[name] = classification_dict topk_names = heapq.nlargest(k, output, key=output.get) prediction_ary = [] for name in topk_names: item = { PRED_FILENAME_KEY: image_path } item.update(name_to_class_dict[name]) #item.update(class_dict_lookup) item[PRED_SCORE_KEY] = output[name].item() prediction_ary.append(item) return prediction_ary @torch.no_grad() def predict(self, image_path: str, rank: Rank, min_prob: float = 1e-9, k: int = 5) -> List[dict[str, float]]: img = PIL.Image.open(image_path) probs = self.predict_species(img) if rank == Rank.SPECIES: return self.format_species_probs(image_path, probs, k) return self.format_grouped_probs(image_path, probs, rank, min_prob, k) def predict_classification(img: str, rank: Rank, device: Union[str, torch.device] = 'cpu', min_prob: float = 1e-9, k: int = 5) -> dict[str, float]: """ Predicts from the entire tree of life. If targeting a higher rank than species, then this function predicts among all species, then sums up species-level probabilities for the given rank. """ classifier = TreeOfLifeClassifier(device=device) return classifier.predict(img, rank, min_prob, k)