amongusrickroll68's picture
Create class_name
9660022
raw
history blame
No virus
1.39 kB
import tensorflow as tf
import numpy as np
from PIL import Image
from io import BytesIO
from scipy.stats import truncnorm
from skimage.transform import resize
from transformers import CLIPProcessor, CLIPModel
class TextToImageGenerator:
def __init__(self):
self.clip = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
self.processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
self.generator = tf.keras.models.load_model('path/to/generator/model')
def generate_image(self, prompt):
encoded_prompt = self.processor(prompt, return_tensors="tf").to_dict()
noise = tf.random.normal([1, 256])
text_features = self.clip.get_text_features(encoded_prompt)
image_features = self.generator([text_features, noise], training=False)[0]
image = self._postprocess_image(image_features)
return image
def _postprocess_image(self, image_features):
image_features = (image_features + 1) / 2 # scale from [-1, 1] to [0, 1]
image_features = np.clip(image_features, 0, 1) # clip any values outside of [0, 1]
image = Image.fromarray(np.uint8(image_features * 255))
image = image.resize((256, 256))
image_buffer = BytesIO()
image.save(image_buffer, format='JPEG')
image_data = image_buffer.getvalue()
return image_data