pipeline_resolver.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from collections.abc import Sequence
  2. from typing import Any
  3. from wandb.sdk.integration_utils.auto_logging import Response
  4. from .resolvers import (
  5. SUPPORTED_MULTIMODAL_PIPELINES,
  6. DiffusersMultiModalPipelineResolver,
  7. )
  8. class DiffusersPipelineResolver:
  9. """Resolver for `DiffusionPipeline` request and responses from [HuggingFace Diffusers](https://huggingface.co/docs/diffusers/index), providing necessary data transformations, formatting, and logging.
  10. This is based off `wandb.sdk.integration_utils.auto_logging.RequestResponseResolver`.
  11. """
  12. def __init__(self) -> None:
  13. self.wandb_table = None
  14. self.pipeline_call_count = 1
  15. def __call__(
  16. self,
  17. args: Sequence[Any],
  18. kwargs: dict[str, Any],
  19. response: Response,
  20. start_time: float,
  21. time_elapsed: float,
  22. ) -> Any:
  23. """Main call method for the `DiffusersPipelineResolver` class.
  24. Args:
  25. args: (Sequence[Any]) List of arguments.
  26. kwargs: (Dict[str, Any]) Dictionary of keyword arguments.
  27. response: (wandb.sdk.integration_utils.auto_logging.Response) The response from
  28. the request.
  29. start_time: (float) Time when request started.
  30. time_elapsed: (float) Time elapsed for the request.
  31. Returns:
  32. Packed data as a dictionary for logging to wandb, None if an exception occurred.
  33. """
  34. pipeline_name = args[0].__class__.__name__
  35. resolver = None
  36. if pipeline_name in SUPPORTED_MULTIMODAL_PIPELINES:
  37. resolver = DiffusersMultiModalPipelineResolver(
  38. pipeline_name, self.pipeline_call_count
  39. )
  40. self.pipeline_call_count += 1
  41. loggable_dict = resolver(args, kwargs, response, start_time, time_elapsed)
  42. return loggable_dict