Update main.py
Browse files
main.py
CHANGED
@@ -18,6 +18,7 @@ app.add_middleware( # add the middleware
|
|
18 |
model_id = "runwayml/stable-diffusion-v1-5"
|
19 |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
20 |
pipe = pipe.to("cpu")
|
|
|
21 |
|
22 |
def dummy(images, **kwargs):
|
23 |
return images, False
|
@@ -31,7 +32,9 @@ def hello():
|
|
31 |
|
32 |
@app.get("/gen/{prompt}")
|
33 |
def generate_image(prompt: str):
|
34 |
-
image = pipe(prompt
|
|
|
|
|
35 |
# Save the image
|
36 |
image.save('static/image.png')
|
37 |
|
|
|
18 |
model_id = "runwayml/stable-diffusion-v1-5"
|
19 |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
20 |
pipe = pipe.to("cpu")
|
21 |
+
pipe.enable_attention_slicing()
|
22 |
|
23 |
def dummy(images, **kwargs):
|
24 |
return images, False
|
|
|
32 |
|
33 |
@app.get("/gen/{prompt}")
|
34 |
def generate_image(prompt: str):
|
35 |
+
image = pipe(prompt,
|
36 |
+
guidance_scale=8.5 # how strict to follow the prompt
|
37 |
+
).images[0]
|
38 |
# Save the image
|
39 |
image.save('static/image.png')
|
40 |
|