add local flux schnell
This commit is contained in:
parent
e2f4d271f0
commit
2fca7f89f0
1
.gitignore
vendored
1
.gitignore
vendored
@ -160,3 +160,4 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
FLUX.1-schnell
|
||||
|
52
flux-schnell-local.py
Normal file
52
flux-schnell-local.py
Normal file
@ -0,0 +1,52 @@
|
||||
from diffusers import FluxPipeline
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
def slugify(text):
|
||||
# remove non-word characters and foreign characters
|
||||
text = re.sub(r"[^\w\s]", "", text)
|
||||
text = re.sub(r"\s+", "-", text)
|
||||
return text
|
||||
|
||||
prompt = "the town center of a small futuristic settlement with a fountain and gardens.the buildings are shaped like dodecahedrons, similar to small geodesic domes with a main road leading out to grassy plains sparsely dotted with very broad tall trees"
|
||||
height, width = 720, 1280
|
||||
|
||||
ckpt_id = "./FLUX.1-schnell"
|
||||
|
||||
DIR_NAME="/home/tonydero/remdirs/immich/"
|
||||
dirpath = Path(DIR_NAME)
|
||||
# create parent dir if doesn't exist
|
||||
# dirpath.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# denoising
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
ckpt_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_safetensors=True,
|
||||
)
|
||||
pipe.vae.enable_tiling()
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.enable_sequential_cpu_offload() # offloads modules to CPU on a submodule level (rather than model level)
|
||||
|
||||
output = pipe(
|
||||
prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_images_per_prompt=8,
|
||||
num_inference_steps=4,
|
||||
max_sequence_length=128,
|
||||
guidance_scale=0.0,
|
||||
)
|
||||
# print('Max mem allocated (GB) while denoising:', torch.cuda.max_memory_allocated() / (1024 ** 3))
|
||||
|
||||
# import matplotlib.pyplot as plt
|
||||
# plt.imshow(image)
|
||||
# image.save("./whitehenge.png")
|
||||
# plt.show()
|
||||
for idx, image in enumerate(output.images):
|
||||
timestamp = datetime.now().strftime("%Y%m%d%-H%M%S")
|
||||
image_name = f'{slugify(prompt[:64])}-{idx}-{timestamp}.png'
|
||||
image_path = dirpath / image_name
|
||||
image.save(image_path)
|
Loading…
Reference in New Issue
Block a user