utils.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from __future__ import annotations
  2. import inspect
  3. from collections.abc import Sequence
  4. from typing import TYPE_CHECKING, Any
  5. import wandb
  6. from wandb.util import get_module
  7. if TYPE_CHECKING:
  8. np_array = get_module("numpy.array")
  9. torch_float_tensor = get_module("torch.FloatTensor")
  10. def chunkify(input_list, chunk_size) -> list:
  11. chunk_size = max(1, chunk_size)
  12. return [
  13. input_list[i : i + chunk_size] for i in range(0, len(input_list), chunk_size)
  14. ]
  15. def get_updated_kwargs(
  16. pipeline: Any, args: Sequence[Any], kwargs: dict[str, Any]
  17. ) -> dict[str, Any]:
  18. pipeline_call_parameters = list(
  19. inspect.signature(pipeline.__call__).parameters.items()
  20. )
  21. for idx, arg in enumerate(args):
  22. kwargs[pipeline_call_parameters[idx][0]] = arg
  23. for pipeline_parameter in pipeline_call_parameters:
  24. if pipeline_parameter[0] not in kwargs:
  25. kwargs[pipeline_parameter[0]] = pipeline_parameter[1].default
  26. if "generator" in kwargs:
  27. generator = kwargs["generator"]
  28. kwargs["generator"] = (
  29. {
  30. "seed": generator.initial_seed(),
  31. "device": generator.device,
  32. "random_state": generator.get_state().cpu().numpy().tolist(),
  33. }
  34. if generator is not None
  35. else None
  36. )
  37. if "ip_adapter_image" in kwargs and kwargs["ip_adapter_image"] is not None:
  38. wandb.log({"IP-Adapter-Image": wandb.Image(kwargs["ip_adapter_image"])})
  39. return kwargs
  40. def postprocess_pils_to_np(image: list) -> np_array:
  41. np = get_module(
  42. "numpy",
  43. required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.",
  44. )
  45. return np.stack(
  46. [np.transpose(np.array(img).astype("uint8"), axes=(2, 0, 1)) for img in image],
  47. axis=0,
  48. )
  49. def postprocess_np_arrays_for_video(
  50. images: list[np_array], normalize: bool | None = False
  51. ) -> np_array:
  52. np = get_module(
  53. "numpy",
  54. required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.",
  55. )
  56. images = [(img * 255).astype("uint8") for img in images] if normalize else images
  57. return np.transpose(np.stack((images), axis=0), axes=(0, 3, 1, 2))
  58. def decode_sdxl_t2i_latents(pipeline: Any, latents: torch_float_tensor) -> list:
  59. """Decode latents generated by [`diffusers.StableDiffusionXLPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#stable-diffusion-xl).
  60. Args:
  61. pipeline: (diffusers.DiffusionPipeline) The Diffusion Pipeline from
  62. [`diffusers`](https://huggingface.co/docs/diffusers).
  63. latents (torch.FloatTensor): The generated latents.
  64. Returns:
  65. List of `PIL` images corresponding to the generated latents.
  66. """
  67. torch = get_module(
  68. "torch",
  69. required="Please ensure PyTorch is installed. You can check out https://pytorch.org/get-started/locally/#start-locally for installation instructions.",
  70. )
  71. with torch.no_grad():
  72. needs_upcasting = (
  73. pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast
  74. )
  75. if needs_upcasting:
  76. pipeline.upcast_vae()
  77. latents = latents.to(
  78. next(iter(pipeline.vae.post_quant_conv.parameters())).dtype
  79. )
  80. images = pipeline.vae.decode(
  81. latents / pipeline.vae.config.scaling_factor, return_dict=False
  82. )[0]
  83. if needs_upcasting:
  84. pipeline.vae.to(dtype=torch.float16)
  85. if pipeline.watermark is not None:
  86. images = pipeline.watermark.apply_watermark(images)
  87. images = pipeline.image_processor.postprocess(images, output_type="pil")
  88. pipeline.maybe_free_model_hooks()
  89. return images