multimodal.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882
  1. import logging
  2. from collections.abc import Sequence
  3. from typing import Any
  4. import wandb
  5. from wandb.sdk.integration_utils.auto_logging import Response
  6. from .utils import (
  7. chunkify,
  8. decode_sdxl_t2i_latents,
  9. get_updated_kwargs,
  10. postprocess_np_arrays_for_video,
  11. postprocess_pils_to_np,
  12. )
  13. logger = logging.getLogger(__name__)
  14. SUPPORTED_MULTIMODAL_PIPELINES = {
  15. "BlipDiffusionPipeline": {
  16. "table-schema": [
  17. "Reference-Image",
  18. "Prompt",
  19. "Negative-Prompt",
  20. "Source-Subject-Category",
  21. "Target-Subject-Category",
  22. "Generated-Image",
  23. ],
  24. "kwarg-logging": [
  25. "reference_image",
  26. "prompt",
  27. "neg_prompt",
  28. "source_subject_category",
  29. "target_subject_category",
  30. ],
  31. "kwarg-actions": [wandb.Image, None, None, None, None],
  32. },
  33. "BlipDiffusionControlNetPipeline": {
  34. "table-schema": [
  35. "Reference-Image",
  36. "Control-Image",
  37. "Prompt",
  38. "Negative-Prompt",
  39. "Source-Subject-Category",
  40. "Target-Subject-Category",
  41. "Generated-Image",
  42. ],
  43. "kwarg-logging": [
  44. "reference_image",
  45. "condtioning_image",
  46. "prompt",
  47. "neg_prompt",
  48. "source_subject_category",
  49. "target_subject_category",
  50. ],
  51. "kwarg-actions": [wandb.Image, wandb.Image, None, None, None, None],
  52. },
  53. "StableDiffusionControlNetPipeline": {
  54. "table-schema": [
  55. "Control-Image",
  56. "Prompt",
  57. "Negative-Prompt",
  58. "Generated-Image",
  59. ],
  60. "kwarg-logging": ["image", "prompt", "negative_prompt"],
  61. "kwarg-actions": [wandb.Image, None, None],
  62. },
  63. "StableDiffusionControlNetImg2ImgPipeline": {
  64. "table-schema": [
  65. "Source-Image",
  66. "Control-Image",
  67. "Prompt",
  68. "Negative-Prompt",
  69. "Generated-Image",
  70. ],
  71. "kwarg-logging": ["image", "control_image", "prompt", "negative_prompt"],
  72. "kwarg-actions": [wandb.Image, wandb.Image, None, None],
  73. },
  74. "StableDiffusionControlNetInpaintPipeline": {
  75. "table-schema": [
  76. "Source-Image",
  77. "Mask-Image",
  78. "Control-Image",
  79. "Prompt",
  80. "Negative-Prompt",
  81. "Generated-Image",
  82. ],
  83. "kwarg-logging": [
  84. "image",
  85. "mask_image",
  86. "control_image",
  87. "prompt",
  88. "negative_prompt",
  89. ],
  90. "kwarg-actions": [wandb.Image, wandb.Image, wandb.Image, None, None],
  91. },
  92. "CycleDiffusionPipeline": {
  93. "table-schema": [
  94. "Source-Image",
  95. "Prompt",
  96. "Source-Prompt",
  97. "Generated-Image",
  98. ],
  99. "kwarg-logging": [
  100. "image",
  101. "prompt",
  102. "source_prompt",
  103. ],
  104. "kwarg-actions": [wandb.Image, None, None],
  105. },
  106. "StableDiffusionInstructPix2PixPipeline": {
  107. "table-schema": [
  108. "Source-Image",
  109. "Prompt",
  110. "Negative-Prompt",
  111. "Generated-Image",
  112. ],
  113. "kwarg-logging": [
  114. "image",
  115. "prompt",
  116. "negative_prompt",
  117. ],
  118. "kwarg-actions": [wandb.Image, None, None],
  119. },
  120. "PaintByExamplePipeline": {
  121. "table-schema": [
  122. "Source-Image",
  123. "Example-Image",
  124. "Mask-Prompt",
  125. "Generated-Image",
  126. ],
  127. "kwarg-logging": [
  128. "image",
  129. "example_image",
  130. "mask_image",
  131. ],
  132. "kwarg-actions": [wandb.Image, wandb.Image, wandb.Image],
  133. },
  134. "RePaintPipeline": {
  135. "table-schema": [
  136. "Source-Image",
  137. "Mask-Prompt",
  138. "Generated-Image",
  139. ],
  140. "kwarg-logging": [
  141. "image",
  142. "mask_image",
  143. ],
  144. "kwarg-actions": [wandb.Image, wandb.Image],
  145. },
  146. "StableDiffusionPipeline": {
  147. "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"],
  148. "kwarg-logging": ["prompt", "negative_prompt"],
  149. "kwarg-actions": [None, None],
  150. },
  151. "KandinskyCombinedPipeline": {
  152. "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"],
  153. "kwarg-logging": ["prompt", "negative_prompt"],
  154. "kwarg-actions": [None, None],
  155. },
  156. "KandinskyV22CombinedPipeline": {
  157. "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"],
  158. "kwarg-logging": ["prompt", "negative_prompt"],
  159. "kwarg-actions": [None, None],
  160. },
  161. "LatentConsistencyModelPipeline": {
  162. "table-schema": ["Prompt", "Generated-Image"],
  163. "kwarg-logging": ["prompt"],
  164. "kwarg-actions": [None],
  165. },
  166. "LDMTextToImagePipeline": {
  167. "table-schema": ["Prompt", "Generated-Image"],
  168. "kwarg-logging": ["prompt"],
  169. "kwarg-actions": [None],
  170. },
  171. "StableDiffusionPanoramaPipeline": {
  172. "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"],
  173. "kwarg-logging": ["prompt", "negative_prompt"],
  174. "kwarg-actions": [None, None],
  175. },
  176. "PixArtAlphaPipeline": {
  177. "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"],
  178. "kwarg-logging": ["prompt", "negative_prompt"],
  179. "kwarg-actions": [None, None],
  180. },
  181. "StableDiffusionSAGPipeline": {
  182. "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"],
  183. "kwarg-logging": ["prompt", "negative_prompt"],
  184. "kwarg-actions": [None, None],
  185. },
  186. "SemanticStableDiffusionPipeline": {
  187. "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"],
  188. "kwarg-logging": ["prompt", "negative_prompt"],
  189. "kwarg-actions": [None, None],
  190. },
  191. "WuerstchenCombinedPipeline": {
  192. "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"],
  193. "kwarg-logging": ["prompt", "negative_prompt"],
  194. "kwarg-actions": [None, None],
  195. },
  196. "IFPipeline": {
  197. "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"],
  198. "kwarg-logging": ["prompt", "negative_prompt"],
  199. "kwarg-actions": [None, None],
  200. },
  201. "AltDiffusionPipeline": {
  202. "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"],
  203. "kwarg-logging": ["prompt", "negative_prompt"],
  204. "kwarg-actions": [None, None],
  205. },
  206. "StableDiffusionAttendAndExcitePipeline": {
  207. "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"],
  208. "kwarg-logging": ["prompt", "negative_prompt"],
  209. "kwarg-actions": [None, None],
  210. },
  211. "KandinskyImg2ImgCombinedPipeline": {
  212. "table-schema": [
  213. "Source-Image",
  214. "Prompt",
  215. "Negative-Prompt",
  216. "Generated-Image",
  217. ],
  218. "kwarg-logging": ["image", "prompt", "negative_prompt"],
  219. "kwarg-actions": [wandb.Image, None, None],
  220. },
  221. "KandinskyInpaintCombinedPipeline": {
  222. "table-schema": [
  223. "Source-Image",
  224. "Prompt",
  225. "Negative-Prompt",
  226. "Generated-Image",
  227. ],
  228. "kwarg-logging": ["image", "prompt", "negative_prompt"],
  229. "kwarg-actions": [wandb.Image, None, None],
  230. },
  231. "KandinskyV22Img2ImgCombinedPipeline": {
  232. "table-schema": [
  233. "Source-Image",
  234. "Prompt",
  235. "Negative-Prompt",
  236. "Generated-Image",
  237. ],
  238. "kwarg-logging": ["image", "prompt", "negative_prompt"],
  239. "kwarg-actions": [wandb.Image, None, None],
  240. },
  241. "KandinskyV22InpaintCombinedPipeline": {
  242. "table-schema": [
  243. "Source-Image",
  244. "Prompt",
  245. "Negative-Prompt",
  246. "Generated-Image",
  247. ],
  248. "kwarg-logging": ["image", "prompt", "negative_prompt"],
  249. "kwarg-actions": [wandb.Image, None, None],
  250. },
  251. "AnimateDiffPipeline": {
  252. "table-schema": [
  253. "Prompt",
  254. "Negative-Prompt",
  255. "Number-of-Frames",
  256. "Generated-Video",
  257. ],
  258. "kwarg-logging": ["prompt", "negative_prompt", "num_frames"],
  259. "kwarg-actions": [None, None, None],
  260. "output-type": "video",
  261. },
  262. "StableVideoDiffusionPipeline": {
  263. "table-schema": [
  264. "Input-Image",
  265. "Frames-Per-Second",
  266. "Generated-Video",
  267. ],
  268. "kwarg-logging": ["image", "fps"],
  269. "kwarg-actions": [wandb.Image, None],
  270. "output-type": "video",
  271. },
  272. "AudioLDMPipeline": {
  273. "table-schema": [
  274. "Prompt",
  275. "Negative-Prompt",
  276. "Audio-Length-in-Seconds",
  277. "Generated-Audio",
  278. ],
  279. "kwarg-logging": ["prompt", "negative_prompt", "audio_length_in_s"],
  280. "kwarg-actions": [None, None, None],
  281. "output-type": "audio",
  282. },
  283. "AudioLDM2Pipeline": {
  284. "table-schema": [
  285. "Prompt",
  286. "Negative-Prompt",
  287. "Audio-Length-in-Seconds",
  288. "Generated-Audio",
  289. ],
  290. "kwarg-logging": ["prompt", "negative_prompt", "audio_length_in_s"],
  291. "kwarg-actions": [None, None, None],
  292. "output-type": "audio",
  293. },
  294. "MusicLDMPipeline": {
  295. "table-schema": [
  296. "Prompt",
  297. "Negative-Prompt",
  298. "Audio-Length-in-Seconds",
  299. "Generated-Audio",
  300. ],
  301. "kwarg-logging": ["prompt", "negative_prompt", "audio_length_in_s"],
  302. "kwarg-actions": [None, None, None],
  303. "output-type": "audio",
  304. },
  305. "StableDiffusionPix2PixZeroPipeline": {
  306. "table-schema": [
  307. "Prompt",
  308. "Negative-Prompt",
  309. "Generated-Image",
  310. ],
  311. "kwarg-logging": ["prompt", "negative_prompt"],
  312. "kwarg-actions": [None, None],
  313. },
  314. "PNDMPipeline": {
  315. "table-schema": [
  316. "Batch-Size",
  317. "Number-of-Inference-Steps",
  318. "Generated-Image",
  319. ],
  320. "kwarg-logging": ["batch_size", "num_inference_steps"],
  321. "kwarg-actions": [None, None],
  322. },
  323. "ShapEPipeline": {
  324. "table-schema": [
  325. "Prompt",
  326. "Generated-Video",
  327. ],
  328. "kwarg-logging": ["prompt"],
  329. "kwarg-actions": [None],
  330. "output-type": "video",
  331. },
  332. "StableDiffusionImg2ImgPipeline": {
  333. "table-schema": [
  334. "Source-Image",
  335. "Prompt",
  336. "Negative-Prompt",
  337. "Generated-Image",
  338. ],
  339. "kwarg-logging": ["image", "prompt", "negative_prompt"],
  340. "kwarg-actions": [wandb.Image, None, None],
  341. },
  342. "StableDiffusionInpaintPipeline": {
  343. "table-schema": [
  344. "Source-Image",
  345. "Mask-Image",
  346. "Prompt",
  347. "Negative-Prompt",
  348. "Generated-Image",
  349. ],
  350. "kwarg-logging": ["image", "mask_image", "prompt", "negative_prompt"],
  351. "kwarg-actions": [wandb.Image, wandb.Image, None, None],
  352. },
  353. "StableDiffusionDepth2ImgPipeline": {
  354. "table-schema": [
  355. "Source-Image",
  356. "Prompt",
  357. "Negative-Prompt",
  358. "Generated-Image",
  359. ],
  360. "kwarg-logging": ["image", "prompt", "negative_prompt"],
  361. "kwarg-actions": [wandb.Image, None, None],
  362. },
  363. "StableDiffusionImageVariationPipeline": {
  364. "table-schema": [
  365. "Source-Image",
  366. "Generated-Image",
  367. ],
  368. "kwarg-logging": [
  369. "image",
  370. ],
  371. "kwarg-actions": [wandb.Image],
  372. },
  373. "StableDiffusionPipelineSafe": {
  374. "table-schema": [
  375. "Prompt",
  376. "Negative-Prompt",
  377. "Generated-Image",
  378. ],
  379. "kwarg-logging": ["prompt", "negative_prompt"],
  380. "kwarg-actions": [None, None],
  381. },
  382. "StableDiffusionUpscalePipeline": {
  383. "table-schema": [
  384. "Source-Image",
  385. "Prompt",
  386. "Negative-Prompt",
  387. "Upscaled-Image",
  388. ],
  389. "kwarg-logging": ["image", "prompt", "negative_prompt"],
  390. "kwarg-actions": [wandb.Image, None, None],
  391. },
  392. "StableDiffusionAdapterPipeline": {
  393. "table-schema": [
  394. "Source-Image",
  395. "Prompt",
  396. "Negative-Prompt",
  397. "Generated-Image",
  398. ],
  399. "kwarg-logging": ["image", "prompt", "negative_prompt"],
  400. "kwarg-actions": [wandb.Image, None, None],
  401. },
  402. "StableDiffusionGLIGENPipeline": {
  403. "table-schema": [
  404. "Prompt",
  405. "GLIGEN-Phrases",
  406. "GLIGEN-Boxes",
  407. "GLIGEN-Inpaint-Image",
  408. "Negative-Prompt",
  409. "Generated-Image",
  410. ],
  411. "kwarg-logging": [
  412. "prompt",
  413. "gligen_phrases",
  414. "gligen_boxes",
  415. "gligen_inpaint_image",
  416. "negative_prompt",
  417. ],
  418. "kwarg-actions": [None, None, None, wandb.Image, None],
  419. },
  420. "VersatileDiffusionTextToImagePipeline": {
  421. "table-schema": [
  422. "Prompt",
  423. "Negative-Prompt",
  424. "Generated-Image",
  425. ],
  426. "kwarg-logging": ["prompt", "negative_prompt"],
  427. "kwarg-actions": [None, None],
  428. },
  429. "VersatileDiffusionImageVariationPipeline": {
  430. "table-schema": [
  431. "Source-Image",
  432. "Negative-Prompt",
  433. "Generated-Image",
  434. ],
  435. "kwarg-logging": ["image", "negative_prompt"],
  436. "kwarg-actions": [wandb.Image, None],
  437. },
  438. "VersatileDiffusionDualGuidedPipeline": {
  439. "table-schema": [
  440. "Source-Image",
  441. "Prompt",
  442. "Negative-Prompt",
  443. "Generated-Image",
  444. ],
  445. "kwarg-logging": ["image", "prompt", "negative_prompt"],
  446. "kwarg-actions": [wandb.Image, None, None],
  447. },
  448. "LDMPipeline": {
  449. "table-schema": [
  450. "Batch-Size",
  451. "Number-of-Inference-Steps",
  452. "Generated-Image",
  453. ],
  454. "kwarg-logging": ["batch_size", "num_inference_steps"],
  455. "kwarg-actions": [None, None],
  456. },
  457. "TextToVideoSDPipeline": {
  458. "table-schema": [
  459. "Prompt",
  460. "Negative-Prompt",
  461. "Number-of-Frames",
  462. "Generated-Video",
  463. ],
  464. "kwarg-logging": ["prompt", "negative_prompt", "num_frames"],
  465. "output-type": "video",
  466. },
  467. "TextToVideoZeroPipeline": {
  468. "table-schema": [
  469. "Prompt",
  470. "Negative-Prompt",
  471. "Number-of-Frames",
  472. "Generated-Video",
  473. ],
  474. "kwarg-logging": ["prompt", "negative_prompt", "video_length"],
  475. },
  476. "AmusedPipeline": {
  477. "table-schema": [
  478. "Prompt",
  479. "Guidance Scale",
  480. "Generated-Image",
  481. ],
  482. "kwarg-logging": [
  483. "prompt",
  484. "guidance_scale",
  485. ],
  486. "kwarg-actions": [None, None],
  487. },
  488. "StableDiffusionXLControlNetPipeline": {
  489. "table-schema": [
  490. "Prompt-1",
  491. "Prompt-2",
  492. "Control-Image",
  493. "Negative-Prompt-1",
  494. "Negative-Prompt-2",
  495. "Generated-Image",
  496. ],
  497. "kwarg-logging": [
  498. "prompt",
  499. "prompt_2",
  500. "image",
  501. "negative_prompt",
  502. "negative_prompt_2",
  503. ],
  504. "kwarg-actions": [None, None, wandb.Image, None, None],
  505. },
  506. "StableDiffusionXLControlNetImg2ImgPipeline": {
  507. "table-schema": [
  508. "Prompt-1",
  509. "Prompt-2",
  510. "Input-Image",
  511. "Control-Image",
  512. "Negative-Prompt-1",
  513. "Negative-Prompt-2",
  514. "Generated-Image",
  515. ],
  516. "kwarg-logging": [
  517. "prompt",
  518. "prompt_2",
  519. "image",
  520. "control_image",
  521. "negative_prompt",
  522. "negative_prompt_2",
  523. ],
  524. "kwarg-actions": [None, None, wandb.Image, wandb.Image, None, None],
  525. },
  526. "Kandinsky3Pipeline": {
  527. "table-schema": [
  528. "Prompt",
  529. "Negative-Prompt",
  530. "Generated-Image",
  531. ],
  532. "kwarg-logging": [
  533. "prompt",
  534. "negative_prompt",
  535. ],
  536. "kwarg-actions": [None, None],
  537. },
  538. "Kandinsky3Img2ImgPipeline": {
  539. "table-schema": [
  540. "Prompt",
  541. "Negative-Prompt",
  542. "Input-Image",
  543. "Generated-Image",
  544. ],
  545. "kwarg-logging": [
  546. "prompt",
  547. "negative_prompt",
  548. "image",
  549. ],
  550. "kwarg-actions": [None, None, wandb.Image],
  551. },
  552. "StableDiffusionXLPipeline": {
  553. "table-schema": [
  554. "Prompt",
  555. "Negative-Prompt",
  556. "Prompt-2",
  557. "Negative-Prompt-2",
  558. "Generated-Image",
  559. ],
  560. "kwarg-logging": [
  561. "prompt",
  562. "negative_prompt",
  563. "prompt_2",
  564. "negative_prompt_2",
  565. ],
  566. "kwarg-actions": [None, None, None, None],
  567. },
  568. "StableDiffusionXLImg2ImgPipeline": {
  569. "table-schema": [
  570. "Prompt",
  571. "Negative-Prompt",
  572. "Prompt-2",
  573. "Negative-Prompt-2",
  574. "Input-Image",
  575. "Generated-Image",
  576. ],
  577. "kwarg-logging": [
  578. "prompt",
  579. "negative_prompt",
  580. "prompt_2",
  581. "negative_prompt_2",
  582. "image",
  583. ],
  584. "kwarg-actions": [None, None, None, None, wandb.Image],
  585. },
  586. }
  587. class DiffusersMultiModalPipelineResolver:
  588. """Resolver for request and responses from [HuggingFace Diffusers](https://huggingface.co/docs/diffusers/index) multi-modal Diffusion Pipelines, providing necessary data transformations, formatting, and logging.
  589. This resolver is internally involved in the
  590. `__call__` for `wandb.integration.diffusers.pipeline_resolver.DiffusersPipelineResolver`.
  591. This is based on `wandb.sdk.integration_utils.auto_logging.RequestResponseResolver`.
  592. Args:
  593. pipeline_name: (str) The name of the Diffusion Pipeline.
  594. """
  595. def __init__(self, pipeline_name: str, pipeline_call_count: int) -> None:
  596. self.pipeline_name = pipeline_name
  597. self.pipeline_call_count = pipeline_call_count
  598. columns = []
  599. if pipeline_name in SUPPORTED_MULTIMODAL_PIPELINES:
  600. columns += SUPPORTED_MULTIMODAL_PIPELINES[pipeline_name]["table-schema"]
  601. else:
  602. wandb.Error("Pipeline not supported for logging")
  603. self.wandb_table = wandb.Table(columns=columns)
  604. def __call__(
  605. self,
  606. args: Sequence[Any],
  607. kwargs: dict[str, Any],
  608. response: Response,
  609. start_time: float,
  610. time_elapsed: float,
  611. ) -> Any:
  612. """Main call method for the `DiffusersPipelineResolver` class.
  613. Args:
  614. args: (Sequence[Any]) List of arguments.
  615. kwargs: (Dict[str, Any]) Dictionary of keyword arguments.
  616. response: (wandb.sdk.integration_utils.auto_logging.Response) The response from
  617. the request.
  618. start_time: (float) Time when request started.
  619. time_elapsed: (float) Time elapsed for the request.
  620. Returns:
  621. Packed data as a dictionary for logging to wandb, None if an exception occurred.
  622. """
  623. try:
  624. # Get the pipeline and the args
  625. pipeline, args = args[0], args[1:]
  626. # Update the Kwargs so that they can be logged easily
  627. kwargs = get_updated_kwargs(pipeline, args, kwargs)
  628. # Get the pipeline configs
  629. pipeline_configs = dict(pipeline.config)
  630. pipeline_configs["pipeline-name"] = self.pipeline_name
  631. if "workflow" not in wandb.config:
  632. wandb.config.update(
  633. {
  634. "workflow": [
  635. {
  636. "pipeline": pipeline_configs,
  637. "params": kwargs,
  638. "stage": f"Pipeline-Call-{self.pipeline_call_count}",
  639. }
  640. ]
  641. }
  642. )
  643. else:
  644. existing_workflow = wandb.config.workflow
  645. updated_workflow = existing_workflow + [
  646. {
  647. "pipeline": pipeline_configs,
  648. "params": kwargs,
  649. "stage": f"Pipeline-Call-{self.pipeline_call_count}",
  650. }
  651. ]
  652. wandb.config.update(
  653. {"workflow": updated_workflow}, allow_val_change=True
  654. )
  655. # Return the WandB loggable dict
  656. return self.prepare_loggable_dict(pipeline, response, kwargs)
  657. except Exception as e:
  658. logger.warning(e)
  659. return None
  660. def get_output_images(self, response: Response) -> list:
  661. """Unpack the generated images, audio, video, etc. from the Diffusion Pipeline's response.
  662. Args:
  663. response: (wandb.sdk.integration_utils.auto_logging.Response) The response from
  664. the request.
  665. Returns:
  666. List of generated images, audio, video, etc.
  667. """
  668. if "output-type" not in SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]:
  669. return response.images
  670. else:
  671. if (
  672. SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"]
  673. == "video"
  674. ):
  675. if self.pipeline_name in ["ShapEPipeline"]:
  676. return response.images
  677. return response.frames
  678. elif (
  679. SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"]
  680. == "audio"
  681. ):
  682. return response.audios
  683. def log_media(self, image: Any, loggable_kwarg_chunks: list, idx: int) -> None:
  684. """Log the generated images, audio, video, etc. from the Diffusion Pipeline's response along with an optional caption to a media panel in the run.
  685. Args:
  686. image: (Any) The generated images, audio, video, etc. from the Diffusion
  687. Pipeline's response.
  688. loggable_kwarg_chunks: (List) Loggable chunks of kwargs.
  689. """
  690. if "output-type" not in SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]:
  691. try:
  692. caption = ""
  693. if self.pipeline_name in [
  694. "StableDiffusionXLPipeline",
  695. "StableDiffusionXLImg2ImgPipeline",
  696. ]:
  697. prompt_index = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][
  698. "kwarg-logging"
  699. ].index("prompt")
  700. prompt2_index = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][
  701. "kwarg-logging"
  702. ].index("prompt_2")
  703. caption = f"Prompt-1: {loggable_kwarg_chunks[prompt_index][idx]}\nPrompt-2: {loggable_kwarg_chunks[prompt2_index][idx]}"
  704. else:
  705. prompt_index = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][
  706. "kwarg-logging"
  707. ].index("prompt")
  708. caption = loggable_kwarg_chunks[prompt_index][idx]
  709. except ValueError:
  710. caption = None
  711. wandb.log(
  712. {
  713. f"Generated-Image/Pipeline-Call-{self.pipeline_call_count}": wandb.Image(
  714. image, caption=caption
  715. )
  716. }
  717. )
  718. else:
  719. if (
  720. SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"]
  721. == "video"
  722. ):
  723. try:
  724. prompt_index = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][
  725. "kwarg-logging"
  726. ].index("prompt")
  727. caption = loggable_kwarg_chunks[prompt_index][idx]
  728. except ValueError:
  729. caption = None
  730. wandb.log(
  731. {
  732. f"Generated-Video/Pipeline-Call-{self.pipeline_call_count}": wandb.Video(
  733. postprocess_pils_to_np(image), fps=4, caption=caption
  734. )
  735. }
  736. )
  737. elif (
  738. SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"]
  739. == "audio"
  740. ):
  741. try:
  742. prompt_index = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][
  743. "kwarg-logging"
  744. ].index("prompt")
  745. caption = loggable_kwarg_chunks[prompt_index][idx]
  746. except ValueError:
  747. caption = None
  748. wandb.log(
  749. {
  750. f"Generated-Audio/Pipeline-Call-{self.pipeline_call_count}": wandb.Audio(
  751. image, sample_rate=16000, caption=caption
  752. )
  753. }
  754. )
  755. def add_data_to_table(
  756. self, image: Any, loggable_kwarg_chunks: list, idx: int
  757. ) -> None:
  758. """Populate the row of the `wandb.Table`.
  759. Args:
  760. image: (Any) The generated images, audio, video, etc. from the Diffusion
  761. Pipeline's response.
  762. loggable_kwarg_chunks: (List) Loggable chunks of kwargs.
  763. idx: (int) Chunk index.
  764. """
  765. table_row = []
  766. kwarg_actions = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][
  767. "kwarg-actions"
  768. ]
  769. for column_idx, loggable_kwarg_chunk in enumerate(loggable_kwarg_chunks):
  770. if kwarg_actions[column_idx] is None:
  771. table_row.append(
  772. loggable_kwarg_chunk[idx]
  773. if loggable_kwarg_chunk[idx] is not None
  774. else ""
  775. )
  776. else:
  777. table_row.append(kwarg_actions[column_idx](loggable_kwarg_chunk[idx]))
  778. if "output-type" not in SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]:
  779. table_row.append(wandb.Image(image))
  780. else:
  781. if (
  782. SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"]
  783. == "video"
  784. ):
  785. table_row.append(wandb.Video(postprocess_pils_to_np(image), fps=4))
  786. elif (
  787. SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"]
  788. == "audio"
  789. ):
  790. table_row.append(wandb.Audio(image, sample_rate=16000))
  791. self.wandb_table.add_data(*table_row)
  792. def prepare_loggable_dict(
  793. self, pipeline: Any, response: Response, kwargs: dict[str, Any]
  794. ) -> dict[str, Any]:
  795. """Prepare the loggable dictionary, which is the packed data as a dictionary for logging to wandb, None if an exception occurred.
  796. Args:
  797. pipeline: (Any) The Diffusion Pipeline.
  798. response: (wandb.sdk.integration_utils.auto_logging.Response) The response from
  799. the request.
  800. kwargs: (Dict[str, Any]) Dictionary of keyword arguments.
  801. Returns:
  802. Packed data as a dictionary for logging to wandb, None if an exception occurred.
  803. """
  804. # Unpack the generated images, audio, video, etc. from the Diffusion Pipeline's response.
  805. images = self.get_output_images(response)
  806. if (
  807. self.pipeline_name == "StableDiffusionXLPipeline"
  808. and kwargs["output_type"] == "latent"
  809. ):
  810. images = decode_sdxl_t2i_latents(pipeline, response.images)
  811. # Account for exception pipelines for text-to-video
  812. if self.pipeline_name in ["TextToVideoSDPipeline", "TextToVideoZeroPipeline"]:
  813. video = postprocess_np_arrays_for_video(
  814. images, normalize=self.pipeline_name == "TextToVideoZeroPipeline"
  815. )
  816. wandb.log(
  817. {
  818. f"Generated-Video/Pipeline-Call-{self.pipeline_call_count}": wandb.Video(
  819. video, fps=4, caption=kwargs["prompt"]
  820. )
  821. }
  822. )
  823. loggable_kwarg_ids = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][
  824. "kwarg-logging"
  825. ]
  826. table_row = [
  827. kwargs[loggable_kwarg_ids[idx]]
  828. for idx in range(len(loggable_kwarg_ids))
  829. ]
  830. table_row.append(wandb.Video(video, fps=4))
  831. self.wandb_table.add_data(*table_row)
  832. else:
  833. loggable_kwarg_ids = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][
  834. "kwarg-logging"
  835. ]
  836. # chunkify loggable kwargs
  837. loggable_kwarg_chunks = []
  838. for loggable_kwarg_id in loggable_kwarg_ids:
  839. loggable_kwarg_chunks.append(
  840. kwargs[loggable_kwarg_id]
  841. if isinstance(kwargs[loggable_kwarg_id], list)
  842. else [kwargs[loggable_kwarg_id]]
  843. )
  844. # chunkify the generated media
  845. images = chunkify(images, len(loggable_kwarg_chunks[0]))
  846. for idx in range(len(loggable_kwarg_chunks[0])):
  847. for image in images[idx]:
  848. # Log media to media panel
  849. self.log_media(image, loggable_kwarg_chunks, idx)
  850. # Populate the row of the wandb_table
  851. self.add_data_to_table(image, loggable_kwarg_chunks, idx)
  852. return {
  853. f"Result-Table/Pipeline-Call-{self.pipeline_call_count}": self.wandb_table
  854. }