| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- from __future__ import annotations
- import inspect
- from collections.abc import Sequence
- from typing import TYPE_CHECKING, Any
- import wandb
- from wandb.util import get_module
- if TYPE_CHECKING:
- np_array = get_module("numpy.array")
- torch_float_tensor = get_module("torch.FloatTensor")
- def chunkify(input_list, chunk_size) -> list:
- chunk_size = max(1, chunk_size)
- return [
- input_list[i : i + chunk_size] for i in range(0, len(input_list), chunk_size)
- ]
- def get_updated_kwargs(
- pipeline: Any, args: Sequence[Any], kwargs: dict[str, Any]
- ) -> dict[str, Any]:
- pipeline_call_parameters = list(
- inspect.signature(pipeline.__call__).parameters.items()
- )
- for idx, arg in enumerate(args):
- kwargs[pipeline_call_parameters[idx][0]] = arg
- for pipeline_parameter in pipeline_call_parameters:
- if pipeline_parameter[0] not in kwargs:
- kwargs[pipeline_parameter[0]] = pipeline_parameter[1].default
- if "generator" in kwargs:
- generator = kwargs["generator"]
- kwargs["generator"] = (
- {
- "seed": generator.initial_seed(),
- "device": generator.device,
- "random_state": generator.get_state().cpu().numpy().tolist(),
- }
- if generator is not None
- else None
- )
- if "ip_adapter_image" in kwargs and kwargs["ip_adapter_image"] is not None:
- wandb.log({"IP-Adapter-Image": wandb.Image(kwargs["ip_adapter_image"])})
- return kwargs
- def postprocess_pils_to_np(image: list) -> np_array:
- np = get_module(
- "numpy",
- required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.",
- )
- return np.stack(
- [np.transpose(np.array(img).astype("uint8"), axes=(2, 0, 1)) for img in image],
- axis=0,
- )
- def postprocess_np_arrays_for_video(
- images: list[np_array], normalize: bool | None = False
- ) -> np_array:
- np = get_module(
- "numpy",
- required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.",
- )
- images = [(img * 255).astype("uint8") for img in images] if normalize else images
- return np.transpose(np.stack((images), axis=0), axes=(0, 3, 1, 2))
- def decode_sdxl_t2i_latents(pipeline: Any, latents: torch_float_tensor) -> list:
- """Decode latents generated by [`diffusers.StableDiffusionXLPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#stable-diffusion-xl).
- Args:
- pipeline: (diffusers.DiffusionPipeline) The Diffusion Pipeline from
- [`diffusers`](https://huggingface.co/docs/diffusers).
- latents (torch.FloatTensor): The generated latents.
- Returns:
- List of `PIL` images corresponding to the generated latents.
- """
- torch = get_module(
- "torch",
- required="Please ensure PyTorch is installed. You can check out https://pytorch.org/get-started/locally/#start-locally for installation instructions.",
- )
- with torch.no_grad():
- needs_upcasting = (
- pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast
- )
- if needs_upcasting:
- pipeline.upcast_vae()
- latents = latents.to(
- next(iter(pipeline.vae.post_quant_conv.parameters())).dtype
- )
- images = pipeline.vae.decode(
- latents / pipeline.vae.config.scaling_factor, return_dict=False
- )[0]
- if needs_upcasting:
- pipeline.vae.to(dtype=torch.float16)
- if pipeline.watermark is not None:
- images = pipeline.watermark.apply_watermark(images)
- images = pipeline.image_processor.postprocess(images, output_type="pil")
- pipeline.maybe_free_model_hooks()
- return images
|