| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- from collections.abc import Sequence
- from typing import Any
- from wandb.sdk.integration_utils.auto_logging import Response
- from .resolvers import (
- SUPPORTED_MULTIMODAL_PIPELINES,
- DiffusersMultiModalPipelineResolver,
- )
- class DiffusersPipelineResolver:
- """Resolver for `DiffusionPipeline` request and responses from [HuggingFace Diffusers](https://huggingface.co/docs/diffusers/index), providing necessary data transformations, formatting, and logging.
- This is based off `wandb.sdk.integration_utils.auto_logging.RequestResponseResolver`.
- """
- def __init__(self) -> None:
- self.wandb_table = None
- self.pipeline_call_count = 1
- def __call__(
- self,
- args: Sequence[Any],
- kwargs: dict[str, Any],
- response: Response,
- start_time: float,
- time_elapsed: float,
- ) -> Any:
- """Main call method for the `DiffusersPipelineResolver` class.
- Args:
- args: (Sequence[Any]) List of arguments.
- kwargs: (Dict[str, Any]) Dictionary of keyword arguments.
- response: (wandb.sdk.integration_utils.auto_logging.Response) The response from
- the request.
- start_time: (float) Time when request started.
- time_elapsed: (float) Time elapsed for the request.
- Returns:
- Packed data as a dictionary for logging to wandb, None if an exception occurred.
- """
- pipeline_name = args[0].__class__.__name__
- resolver = None
- if pipeline_name in SUPPORTED_MULTIMODAL_PIPELINES:
- resolver = DiffusersMultiModalPipelineResolver(
- pipeline_name, self.pipeline_call_count
- )
- self.pipeline_call_count += 1
- loggable_dict = resolver(args, kwargs, response, start_time, time_elapsed)
- return loggable_dict
|