imgen-diffusers/flux-pywebio.py

116 lines
4.1 KiB
Python

from pywebio import start_server
from pywebio.input import *
from pywebio.output import *
import pywebio.pin as pin
import time
from diffusers import FluxPipeline
import torch
from pathlib import Path
import re
from datetime import datetime
default_prompt="A cat wearing an old racing helmet and goggles holding a sign that says zoom zoom"
default_width=1024
default_height=512
default_numimgs=4
ckpt_id = "./FLUX.1-schnell"
# denoising
pipe = FluxPipeline.from_pretrained(
ckpt_id,
torch_dtype=torch.bfloat16,
use_safetensors=True,
)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
pipe.enable_sequential_cpu_offload() # offloads modules to CPU on a submodule level (rather than model level)
def slugify(text):
# remove non-word characters and foreign characters
text = re.sub(r"[^\w\s]", "", text)
text = re.sub(r"\s+", "-", text)
return text
def set_dir():
blazar_path="/home/tonydero/remdirs/immich/"
magnetar_path="/nas/dockerdata/immich/local/"
if Path(blazar_path).exists():
dirpath = Path(blazar_path)
elif Path(magnetar_path).exists():
dirpath = Path(magnetar_path)
else:
dirpath = False
return dirpath
def chk_prompt_len(prompt):
if len(prompt) > 370:
return False
else:
return True
def flux_run(prompt, height, width, num_images_per_prompt, num_inference_steps, max_sequence_length, dirpath):
# FIX: add check to make sure there isn't already one running from another source (e.g. from phone when on computer)
output = pipe(
prompt,
height=height,
width=width,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=num_inference_steps,
max_sequence_length=max_sequence_length,
guidance_scale=0.0,
)
# print('Max mem allocated (GB) while denoising:', torch.cuda.max_memory_allocated() / (1024 ** 3))
for idx, image in enumerate(output.images):
timestamp = datetime.now().strftime("%Y%m%d%-H%M%S")
image_name = f'{slugify(prompt[:64])}-{idx}-{timestamp}.png'
image_path = dirpath / image_name
image.save(image_path)
def btn_click(btn_val):
dirpath = set_dir()
if dirpath:
if btn_val == 'Generate':
if chk_prompt_len(pin.pin.prompt):
with put_loading():
start_time = datetime.now()
put_text("Started generating images at " + start_time.strftime("%H:%M:%S"))
flux_run(pin.pin.prompt, pin.pin.height, pin.pin.width, pin.pin.numimgs, 4, 128, dirpath)
stop_time = datetime.now()
run_time = stop_time - start_time
# rt_min = run_time.minute
# rt_sec = run_time.second
put_success("Images finished generating in " + str(run_time), closable=True)
else:
put_error("Prompt too long.", closable=True)
elif btn_val == 'Portrait':
pin.pin.width = 512
pin.pin.height = 1024
elif btn_val == 'Landscape':
pin.pin.width = 1024
pin.pin.height = 512
elif btn_val == 'Square':
pin.pin.width = 512
pin.pin.height = 512
else:
put_error("Failed to generate. Check directory location.", closable=True)
def main(): # PyWebIO application function
# TODO: figure out how to put the labels in line with the inputs and make the number ones smaller
# put_text("Prompt:", inline=True)
pin.put_input("prompt", label="Prompt:", type="text", value=default_prompt)
# put_text("width:", inline=True)
pin.put_input("width", label="Width:", type="number", value=default_width)
# put_text("height:", inline=True)
pin.put_input("height", label="Height:", type="number", value=default_height)
# put_text("number of images:", inline=True)
pin.put_input("numimgs", label="Number of images:", type="number", value=default_numimgs)
# TODO: add the remaining parameters?
# TODO: add ability to loop for more generation in one run?
put_buttons(['Generate', 'Portrait', 'Landscape', 'Square'], onclick=btn_click)
start_server(main, port=4972, debug=True, auto_open_webbrowser=False)