amongusrickroll68 commited on
Commit
9660022
1 Parent(s): 02ead63

Create class_name

Browse files
Files changed (1) hide show
  1. class_name +31 -0
class_name ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ from scipy.stats import truncnorm
6
+ from skimage.transform import resize
7
+ from transformers import CLIPProcessor, CLIPModel
8
+
9
+ class TextToImageGenerator:
10
+ def __init__(self):
11
+ self.clip = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
12
+ self.processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
13
+ self.generator = tf.keras.models.load_model('path/to/generator/model')
14
+
15
+ def generate_image(self, prompt):
16
+ encoded_prompt = self.processor(prompt, return_tensors="tf").to_dict()
17
+ noise = tf.random.normal([1, 256])
18
+ text_features = self.clip.get_text_features(encoded_prompt)
19
+ image_features = self.generator([text_features, noise], training=False)[0]
20
+ image = self._postprocess_image(image_features)
21
+ return image
22
+
23
+ def _postprocess_image(self, image_features):
24
+ image_features = (image_features + 1) / 2 # scale from [-1, 1] to [0, 1]
25
+ image_features = np.clip(image_features, 0, 1) # clip any values outside of [0, 1]
26
+ image = Image.fromarray(np.uint8(image_features * 255))
27
+ image = image.resize((256, 256))
28
+ image_buffer = BytesIO()
29
+ image.save(image_buffer, format='JPEG')
30
+ image_data = image_buffer.getvalue()
31
+ return image_data