modeling_perceiver.py 132 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302
  1. # Copyright 2021 Deepmind and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Perceiver model."""
  15. import abc
  16. import math
  17. from collections.abc import Callable, Mapping
  18. from dataclasses import dataclass
  19. from functools import reduce
  20. from operator import __add__
  21. from typing import Any, Optional
  22. import numpy as np
  23. import torch
  24. from torch import nn
  25. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  26. from ... import initialization as init
  27. from ...activations import ACT2FN
  28. from ...modeling_outputs import BaseModelOutputWithCrossAttentions
  29. from ...modeling_utils import PreTrainedModel
  30. from ...pytorch_utils import apply_chunking_to_forward
  31. from ...utils import ModelOutput, auto_docstring, logging, torch_int
  32. from .configuration_perceiver import PerceiverConfig
  33. ModalitySizeType = Mapping[str, int]
  34. PreprocessorOutputType = tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]
  35. PreprocessorType = Callable[..., PreprocessorOutputType]
  36. PostprocessorType = Callable[..., Any]
  37. logger = logging.get_logger(__name__)
  38. @dataclass
  39. @auto_docstring(
  40. custom_intro="""
  41. Base class for Perceiver base model's outputs, with potential hidden states, attentions and cross-attentions.
  42. """
  43. )
  44. class PerceiverModelOutput(ModelOutput):
  45. r"""
  46. logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
  47. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  48. """
  49. logits: torch.FloatTensor | None = None
  50. last_hidden_state: torch.FloatTensor | None = None
  51. hidden_states: tuple[torch.FloatTensor] | None = None
  52. attentions: tuple[torch.FloatTensor] | None = None
  53. cross_attentions: tuple[torch.FloatTensor] | None = None
  54. @dataclass
  55. @auto_docstring(
  56. custom_intro="""
  57. Base class for Perceiver decoder outputs, with potential cross-attentions.
  58. """
  59. )
  60. class PerceiverDecoderOutput(ModelOutput):
  61. r"""
  62. logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
  63. Output of the basic decoder.
  64. """
  65. logits: torch.FloatTensor | None = None
  66. cross_attentions: tuple[torch.FloatTensor] | None = None
  67. @dataclass
  68. @auto_docstring(
  69. custom_intro="""
  70. Base class for Perceiver's masked language model outputs.
  71. """
  72. )
  73. class PerceiverMaskedLMOutput(ModelOutput):
  74. r"""
  75. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  76. Masked language modeling (MLM) loss.
  77. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  78. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  79. """
  80. loss: torch.FloatTensor | None = None
  81. logits: torch.FloatTensor | None = None
  82. hidden_states: tuple[torch.FloatTensor] | None = None
  83. attentions: tuple[torch.FloatTensor] | None = None
  84. cross_attentions: tuple[torch.FloatTensor] | None = None
  85. @dataclass
  86. @auto_docstring(
  87. custom_intro="""
  88. Base class for Perceiver's outputs of sequence/image classification models, optical flow and multimodal
  89. autoencoding.
  90. """
  91. )
  92. class PerceiverClassifierOutput(ModelOutput):
  93. r"""
  94. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  95. Classification (or regression if config.num_labels==1) loss.
  96. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  97. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  98. """
  99. loss: torch.FloatTensor | None = None
  100. logits: torch.FloatTensor | None = None
  101. hidden_states: tuple[torch.FloatTensor] | None = None
  102. attentions: tuple[torch.FloatTensor] | None = None
  103. cross_attentions: tuple[torch.FloatTensor] | None = None
  104. class PerceiverEmbeddings(nn.Module):
  105. """Construct the latent embeddings."""
  106. def __init__(self, config):
  107. super().__init__()
  108. self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents))
  109. def forward(self, batch_size: int):
  110. return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang
  111. class PerceiverSelfAttention(nn.Module):
  112. """Multi-headed {cross, self}-attention. Can be used both in the encoder as well as in the decoder."""
  113. def __init__(
  114. self,
  115. config,
  116. is_cross_attention=False,
  117. qk_channels=None,
  118. v_channels=None,
  119. num_heads=1,
  120. q_dim=None,
  121. kv_dim=None,
  122. ):
  123. super().__init__()
  124. self.num_heads = num_heads
  125. # Q and K must have the same number of channels.
  126. # Default to preserving Q's input's shape.
  127. if qk_channels is None:
  128. qk_channels = q_dim
  129. # V's num_channels determines the shape of the output of QKV-attention.
  130. # Default to the same number of channels used in the key-query operation.
  131. if v_channels is None:
  132. v_channels = qk_channels
  133. if qk_channels % num_heads != 0:
  134. raise ValueError(f"qk_channels ({qk_channels}) must be divisible by num_heads ({num_heads}).")
  135. if v_channels % num_heads != 0:
  136. raise ValueError(f"v_channels ({v_channels}) must be divisible by num_heads ({num_heads}).")
  137. self.qk_channels = qk_channels
  138. self.v_channels = v_channels
  139. self.qk_channels_per_head = self.qk_channels // num_heads
  140. self.v_channels_per_head = self.v_channels // num_heads
  141. # Layer normalization
  142. self.layernorm1 = nn.LayerNorm(q_dim)
  143. self.layernorm2 = nn.LayerNorm(kv_dim) if is_cross_attention else nn.Identity()
  144. # Projection matrices
  145. self.query = nn.Linear(q_dim, qk_channels)
  146. self.key = nn.Linear(kv_dim, qk_channels)
  147. self.value = nn.Linear(kv_dim, v_channels)
  148. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  149. def transpose_for_scores(self, x, channels_per_head):
  150. new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head)
  151. x = x.view(*new_x_shape)
  152. return x.permute(0, 2, 1, 3)
  153. def forward(
  154. self,
  155. hidden_states: torch.Tensor,
  156. attention_mask: torch.FloatTensor | None = None,
  157. inputs: torch.FloatTensor | None = None,
  158. inputs_mask: torch.FloatTensor | None = None,
  159. output_attentions: bool | None = False,
  160. ) -> tuple[torch.Tensor]:
  161. hidden_states = self.layernorm1(hidden_states)
  162. inputs = self.layernorm2(inputs)
  163. # Project queries, keys and values to a common feature dimension. If this is instantiated as a cross-attention module,
  164. # the keys and values come from the inputs; the attention mask needs to be such that the inputs's non-relevant tokens are not attended to.
  165. is_cross_attention = inputs is not None
  166. queries = self.query(hidden_states)
  167. if is_cross_attention:
  168. keys = self.key(inputs)
  169. values = self.value(inputs)
  170. attention_mask = inputs_mask
  171. else:
  172. keys = self.key(hidden_states)
  173. values = self.value(hidden_states)
  174. # Reshape channels for multi-head attention.
  175. # We reshape from (batch_size, time, channels) to (batch_size, num_heads, time, channels per head)
  176. queries = self.transpose_for_scores(queries, self.qk_channels_per_head)
  177. keys = self.transpose_for_scores(keys, self.qk_channels_per_head)
  178. values = self.transpose_for_scores(values, self.v_channels_per_head)
  179. # Take the dot product between the queries and keys to get the raw attention scores.
  180. attention_scores = torch.matmul(queries, keys.transpose(-1, -2))
  181. batch_size, num_heads, seq_len, q_head_dim = queries.shape
  182. _, _, _, v_head_dim = values.shape
  183. hiddens = self.num_heads * v_head_dim
  184. attention_scores = attention_scores / math.sqrt(q_head_dim)
  185. if attention_mask is not None:
  186. # Apply the attention mask (precomputed for all layers in PerceiverModel forward() function)
  187. attention_scores = attention_scores + attention_mask
  188. # Normalize the attention scores to probabilities.
  189. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  190. # This is actually dropping out entire tokens to attend to, which might
  191. # seem a bit unusual, but is taken from the original Transformer paper.
  192. attention_probs = self.dropout(attention_probs)
  193. context_layer = torch.matmul(attention_probs, values)
  194. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  195. new_context_layer_shape = context_layer.size()[:-2] + (hiddens,)
  196. context_layer = context_layer.view(*new_context_layer_shape)
  197. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  198. return outputs
  199. class PerceiverSelfOutput(nn.Module):
  200. def __init__(self, config, input_channels, output_channels):
  201. super().__init__()
  202. self.dense = nn.Linear(input_channels, output_channels)
  203. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  204. hidden_states = self.dense(hidden_states)
  205. return hidden_states
  206. class PerceiverAttention(nn.Module):
  207. """Attention module, including a dense block."""
  208. def __init__(
  209. self,
  210. config,
  211. is_cross_attention=False,
  212. qk_channels=None,
  213. v_channels=None,
  214. num_heads=1,
  215. q_dim=None,
  216. kv_dim=None,
  217. use_query_residual=True,
  218. ):
  219. super().__init__()
  220. # MultiHead attention
  221. if is_cross_attention and qk_channels is None:
  222. if config.cross_attention_shape_for_attention == "q":
  223. qk_channels = q_dim
  224. elif config.cross_attention_shape_for_attention == "kv":
  225. qk_channels = kv_dim
  226. else:
  227. raise ValueError(
  228. f"Unknown value {config.cross_attention_shape_for_attention} for "
  229. "cross_attention_shape_for_attention."
  230. )
  231. else:
  232. if qk_channels is None:
  233. qk_channels = q_dim
  234. if v_channels is None:
  235. v_channels = qk_channels
  236. self.self = PerceiverSelfAttention(
  237. config,
  238. is_cross_attention=is_cross_attention,
  239. qk_channels=qk_channels,
  240. v_channels=v_channels,
  241. num_heads=num_heads,
  242. q_dim=q_dim,
  243. kv_dim=kv_dim,
  244. )
  245. # dense block
  246. output_channels = None
  247. if is_cross_attention:
  248. output_channels = q_dim
  249. else:
  250. if output_channels is None:
  251. output_channels = v_channels
  252. self.output = PerceiverSelfOutput(config, input_channels=self.self.v_channels, output_channels=output_channels)
  253. self.use_query_residual = use_query_residual
  254. def forward(
  255. self,
  256. hidden_states: torch.Tensor,
  257. attention_mask: torch.FloatTensor | None = None,
  258. inputs: torch.FloatTensor | None = None,
  259. inputs_mask: torch.FloatTensor | None = None,
  260. output_attentions: bool | None = False,
  261. ) -> tuple[torch.Tensor]:
  262. self_outputs = self.self(
  263. hidden_states,
  264. attention_mask,
  265. inputs,
  266. inputs_mask,
  267. output_attentions,
  268. )
  269. # Output projection
  270. attention_output = self.output(self_outputs[0])
  271. # Optionally include a residual to the original queries.
  272. # Consider omitting the residual if the semantics of query and output
  273. # are different, e.g. if queries are positions and outputs are pixels.
  274. if self.use_query_residual:
  275. attention_output = attention_output + hidden_states
  276. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  277. return outputs
  278. class PerceiverMLP(nn.Module):
  279. """A Transformer-style dense module to follow attention."""
  280. def __init__(self, config, input_size, widening_factor):
  281. super().__init__()
  282. self.dense1 = nn.Linear(input_size, widening_factor * input_size)
  283. if isinstance(config.hidden_act, str):
  284. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  285. else:
  286. self.intermediate_act_fn = config.hidden_act
  287. self.dense2 = nn.Linear(widening_factor * input_size, input_size)
  288. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  289. hidden_states = self.dense1(hidden_states)
  290. hidden_states = self.intermediate_act_fn(hidden_states)
  291. hidden_states = self.dense2(hidden_states)
  292. return hidden_states
  293. class PerceiverLayer(nn.Module):
  294. def __init__(
  295. self,
  296. config,
  297. is_cross_attention=False,
  298. qk_channels=None,
  299. v_channels=None,
  300. num_heads=1,
  301. q_dim=None,
  302. kv_dim=None,
  303. widening_factor=4,
  304. use_query_residual=True,
  305. ):
  306. super().__init__()
  307. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  308. self.seq_len_dim = 1
  309. self.attention = PerceiverAttention(
  310. config,
  311. is_cross_attention=is_cross_attention,
  312. qk_channels=qk_channels,
  313. v_channels=v_channels,
  314. num_heads=num_heads,
  315. q_dim=q_dim,
  316. kv_dim=kv_dim,
  317. use_query_residual=use_query_residual,
  318. )
  319. self.layernorm = nn.LayerNorm(q_dim)
  320. self.mlp = PerceiverMLP(config, input_size=q_dim, widening_factor=widening_factor)
  321. def forward(
  322. self,
  323. hidden_states: torch.Tensor,
  324. attention_mask: torch.FloatTensor | None = None,
  325. inputs: torch.FloatTensor | None = None,
  326. inputs_mask: torch.FloatTensor | None = None,
  327. output_attentions: bool | None = False,
  328. ) -> tuple[torch.Tensor]:
  329. attention_outputs = self.attention(
  330. hidden_states,
  331. attention_mask,
  332. inputs,
  333. inputs_mask,
  334. output_attentions,
  335. )
  336. attention_output = attention_outputs[0]
  337. outputs = attention_outputs[1:] # add attentions if we output attention weights
  338. layer_output = apply_chunking_to_forward(
  339. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  340. )
  341. layer_output = layer_output + attention_output # residual connection
  342. outputs = (layer_output,) + outputs
  343. return outputs
  344. def feed_forward_chunk(self, attention_output):
  345. layer_output = self.layernorm(attention_output)
  346. layer_output = self.mlp(layer_output)
  347. return layer_output
  348. class PerceiverEncoder(nn.Module):
  349. """The Perceiver Encoder: a scalable, fully attentional encoder."""
  350. def __init__(self, config, kv_dim=None):
  351. super().__init__()
  352. self.config = config
  353. # Check that we can use multihead-attention with these shapes.
  354. if config.d_latents % config.num_self_attention_heads != 0:
  355. raise ValueError(
  356. f"num_z_channels ({config.d_latents}) must be divisible by"
  357. f" num_self_attend_heads ({config.num_self_attention_heads})."
  358. )
  359. if config.d_latents % config.num_cross_attention_heads != 0:
  360. raise ValueError(
  361. f"num_z_channels ({config.d_latents}) must be divisible by"
  362. f" num_cross_attend_heads ({config.num_cross_attention_heads})."
  363. )
  364. # Construct the cross attention layer.
  365. self.cross_attention = PerceiverLayer(
  366. config,
  367. is_cross_attention=True,
  368. qk_channels=config.qk_channels,
  369. v_channels=config.v_channels,
  370. num_heads=config.num_cross_attention_heads,
  371. q_dim=config.d_latents,
  372. kv_dim=kv_dim,
  373. widening_factor=config.cross_attention_widening_factor,
  374. use_query_residual=config.use_query_residual,
  375. )
  376. # Construct a single block of self-attention layers.
  377. # We get deeper architectures by applying this block more than once.
  378. self_attention_layers = []
  379. for _ in range(config.num_self_attends_per_block):
  380. layer = PerceiverLayer(
  381. config,
  382. is_cross_attention=False,
  383. qk_channels=config.qk_channels,
  384. v_channels=config.v_channels,
  385. num_heads=config.num_self_attention_heads,
  386. q_dim=config.d_latents,
  387. kv_dim=config.d_latents,
  388. widening_factor=config.self_attention_widening_factor,
  389. )
  390. self_attention_layers.append(layer)
  391. self.self_attends = nn.ModuleList(self_attention_layers)
  392. def forward(
  393. self,
  394. hidden_states: torch.Tensor,
  395. attention_mask: torch.FloatTensor | None = None,
  396. inputs: torch.FloatTensor | None = None,
  397. inputs_mask: torch.FloatTensor | None = None,
  398. output_attentions: bool | None = False,
  399. output_hidden_states: bool | None = False,
  400. return_dict: bool | None = True,
  401. ) -> tuple | BaseModelOutputWithCrossAttentions:
  402. all_hidden_states = () if output_hidden_states else None
  403. all_self_attentions = () if output_attentions else None
  404. all_cross_attentions = () if output_attentions else None
  405. # Apply the cross-attention between the latents (hidden_states) and inputs:
  406. layer_outputs = self.cross_attention(
  407. hidden_states,
  408. attention_mask=attention_mask,
  409. inputs=inputs,
  410. inputs_mask=inputs_mask,
  411. output_attentions=output_attentions,
  412. )
  413. hidden_states = layer_outputs[0]
  414. if output_attentions:
  415. all_cross_attentions = all_cross_attentions + (layer_outputs[1],)
  416. # Apply the block of self-attention layers more than once:
  417. for _ in range(self.config.num_blocks):
  418. for i, layer_module in enumerate(self.self_attends):
  419. if output_hidden_states:
  420. all_hidden_states = all_hidden_states + (hidden_states,)
  421. layer_outputs = layer_module(
  422. hidden_states,
  423. attention_mask=attention_mask,
  424. output_attentions=output_attentions,
  425. )
  426. hidden_states = layer_outputs[0]
  427. if output_attentions:
  428. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  429. if output_hidden_states:
  430. all_hidden_states = all_hidden_states + (hidden_states,)
  431. if not return_dict:
  432. return tuple(
  433. v
  434. for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
  435. if v is not None
  436. )
  437. return BaseModelOutputWithCrossAttentions(
  438. last_hidden_state=hidden_states,
  439. hidden_states=all_hidden_states,
  440. attentions=all_self_attentions,
  441. cross_attentions=all_cross_attentions,
  442. )
  443. @auto_docstring
  444. class PerceiverPreTrainedModel(PreTrainedModel):
  445. config: PerceiverConfig
  446. base_model_prefix = "perceiver"
  447. main_input_name = "inputs"
  448. input_modalities = ("image",) # techinically can be anything but HF impl has only image processor
  449. @torch.no_grad()
  450. def _init_weights(self, module):
  451. """Initialize the weights"""
  452. if isinstance(module, (nn.Linear, nn.Conv2d)):
  453. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  454. if module.bias is not None:
  455. init.zeros_(module.bias)
  456. elif hasattr(module, "latents"):
  457. init.normal_(module.latents, mean=0.0, std=self.config.initializer_range)
  458. elif hasattr(module, "position_embeddings") and isinstance(module, PerceiverTrainablePositionEncoding):
  459. init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
  460. elif isinstance(module, nn.ParameterDict):
  461. for modality in module:
  462. init.normal_(module[modality], mean=0.0, std=self.config.initializer_range)
  463. elif isinstance(module, nn.Embedding):
  464. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  465. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  466. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  467. init.zeros_(module.weight[module.padding_idx])
  468. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
  469. init.zeros_(module.bias)
  470. init.ones_(module.weight)
  471. if getattr(module, "running_mean", None) is not None:
  472. init.zeros_(module.running_mean)
  473. init.ones_(module.running_var)
  474. init.zeros_(module.num_batches_tracked)
  475. @auto_docstring(
  476. custom_intro="""
  477. The Perceiver: a scalable, fully attentional architecture.
  478. <Tip>
  479. Note that it's possible to fine-tune Perceiver on higher resolution images than the ones it has been trained on, by
  480. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  481. position embeddings to the higher resolution.
  482. </Tip>
  483. """
  484. )
  485. class PerceiverModel(PerceiverPreTrainedModel):
  486. def __init__(
  487. self,
  488. config,
  489. decoder: Optional["PerceiverAbstractDecoder"] = None,
  490. input_preprocessor: PreprocessorType = None,
  491. output_postprocessor: PostprocessorType = None,
  492. ):
  493. r"""
  494. decoder (`PerceiverDecoder`, *optional*):
  495. Decoder module that transforms latent representations into task predictions.
  496. input_preprocessor (`PreprocessorType`, *optional*):
  497. Preprocessor that encodes raw inputs into tensors for the model.
  498. output_postprocessor (`PostprocessorType`, *optional*):
  499. Postprocessor that transforms model outputs into final predictions.
  500. """
  501. super().__init__(config)
  502. self.config = config
  503. self.input_preprocessor = input_preprocessor
  504. self.output_postprocessor = output_postprocessor
  505. self.embeddings = PerceiverEmbeddings(config)
  506. self.encoder = PerceiverEncoder(
  507. config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model
  508. )
  509. self.decoder = decoder
  510. # Initialize weights and apply final processing
  511. self.post_init()
  512. def get_input_embeddings(self):
  513. return self.embeddings.latents
  514. def set_input_embeddings(self, value):
  515. self.embeddings.latents = value
  516. @auto_docstring
  517. def forward(
  518. self,
  519. inputs: torch.FloatTensor,
  520. attention_mask: torch.FloatTensor | None = None,
  521. subsampled_output_points: dict[str, torch.Tensor] | None = None,
  522. output_attentions: bool | None = None,
  523. output_hidden_states: bool | None = None,
  524. interpolate_pos_encoding: bool = False,
  525. return_dict: bool | None = None,
  526. **kwargs,
  527. ) -> tuple | PerceiverModelOutput:
  528. r"""
  529. inputs (`torch.FloatTensor`):
  530. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  531. subsampled_output_points (`dict[str, torch.Tensor]`, *optional*):
  532. Dictionary of tensors used as queries for the decoder. The decoder maps these queries to the latent
  533. representation of the model. Used for subsampled decoding, e.g. when only decoding certain image patches.
  534. Examples:
  535. ```python
  536. >>> from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverImageProcessor, PerceiverModel
  537. >>> from transformers.models.perceiver.modeling_perceiver import (
  538. ... PerceiverTextPreprocessor,
  539. ... PerceiverImagePreprocessor,
  540. ... PerceiverClassificationDecoder,
  541. ... )
  542. >>> import torch
  543. >>> import httpx
  544. >>> from io import BytesIO
  545. >>> from PIL import Image
  546. >>> # EXAMPLE 1: using the Perceiver to classify texts
  547. >>> # - we define a TextPreprocessor, which can be used to embed tokens
  548. >>> # - we define a ClassificationDecoder, which can be used to decode the
  549. >>> # final hidden states of the latents to classification logits
  550. >>> # using trainable position embeddings
  551. >>> config = PerceiverConfig()
  552. >>> preprocessor = PerceiverTextPreprocessor(config)
  553. >>> decoder = PerceiverClassificationDecoder(
  554. ... config,
  555. ... num_channels=config.d_latents,
  556. ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
  557. ... use_query_residual=True,
  558. ... )
  559. >>> model = PerceiverModel(config, input_preprocessor=preprocessor, decoder=decoder)
  560. >>> # you can then do a forward pass as follows:
  561. >>> tokenizer = PerceiverTokenizer()
  562. >>> text = "hello world"
  563. >>> inputs = tokenizer(text, return_tensors="pt").input_ids
  564. >>> with torch.no_grad():
  565. ... outputs = model(inputs=inputs)
  566. >>> logits = outputs.logits
  567. >>> list(logits.shape)
  568. [1, 2]
  569. >>> # to train, one can train the model using standard cross-entropy:
  570. >>> criterion = torch.nn.CrossEntropyLoss()
  571. >>> labels = torch.tensor([1])
  572. >>> loss = criterion(logits, labels)
  573. >>> # EXAMPLE 2: using the Perceiver to classify images
  574. >>> # - we define an ImagePreprocessor, which can be used to embed images
  575. >>> config = PerceiverConfig(image_size=224)
  576. >>> preprocessor = PerceiverImagePreprocessor(
  577. ... config,
  578. ... prep_type="conv1x1",
  579. ... spatial_downsample=1,
  580. ... out_channels=256,
  581. ... position_encoding_type="trainable",
  582. ... concat_or_add_pos="concat",
  583. ... project_pos_dim=256,
  584. ... trainable_position_encoding_kwargs=dict(
  585. ... num_channels=256,
  586. ... index_dims=config.image_size**2,
  587. ... ),
  588. ... )
  589. >>> model = PerceiverModel(
  590. ... config,
  591. ... input_preprocessor=preprocessor,
  592. ... decoder=PerceiverClassificationDecoder(
  593. ... config,
  594. ... num_channels=config.d_latents,
  595. ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
  596. ... use_query_residual=True,
  597. ... ),
  598. ... )
  599. >>> # you can then do a forward pass as follows:
  600. >>> image_processor = PerceiverImageProcessor()
  601. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  602. >>> with httpx.stream("GET", url) as response:
  603. ... image = Image.open(BytesIO(response.read()))
  604. >>> inputs = image_processor(image, return_tensors="pt").pixel_values
  605. >>> with torch.no_grad():
  606. ... outputs = model(inputs=inputs)
  607. >>> logits = outputs.logits
  608. >>> list(logits.shape)
  609. [1, 2]
  610. >>> # to train, one can train the model using standard cross-entropy:
  611. >>> criterion = torch.nn.CrossEntropyLoss()
  612. >>> labels = torch.tensor([1])
  613. >>> loss = criterion(logits, labels)
  614. ```"""
  615. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  616. output_hidden_states = (
  617. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  618. )
  619. return_dict = return_dict if return_dict is not None else self.config.return_dict
  620. if self.input_preprocessor is not None:
  621. inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(
  622. inputs, interpolate_pos_encoding=interpolate_pos_encoding
  623. )
  624. else:
  625. modality_sizes = None
  626. inputs_without_pos = None
  627. if inputs.size()[-1] != self.config.d_model:
  628. raise ValueError(
  629. f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model:"
  630. f" {self.config.d_model}. Make sure to set config.d_model appropriately."
  631. )
  632. batch_size, seq_length, _ = inputs.size()
  633. device = inputs.device
  634. # If no attention mask is provided, make them all ones
  635. if attention_mask is None:
  636. attention_mask = torch.ones((batch_size, seq_length), device=device)
  637. # Make the attention mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  638. extended_attention_mask = self.invert_attention_mask(attention_mask)
  639. embedding_output = self.embeddings(batch_size=batch_size)
  640. encoder_outputs = self.encoder(
  641. embedding_output,
  642. attention_mask=None,
  643. inputs=inputs,
  644. inputs_mask=extended_attention_mask,
  645. output_attentions=output_attentions,
  646. output_hidden_states=output_hidden_states,
  647. return_dict=return_dict,
  648. )
  649. sequence_output = encoder_outputs[0]
  650. logits = None
  651. if self.decoder:
  652. if subsampled_output_points is not None:
  653. output_modality_sizes = {
  654. "audio": subsampled_output_points["audio"].shape[0],
  655. "image": subsampled_output_points["image"].shape[0],
  656. "label": 1,
  657. }
  658. else:
  659. output_modality_sizes = modality_sizes
  660. decoder_query = self.decoder.decoder_query(
  661. inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points
  662. )
  663. decoder_outputs = self.decoder(
  664. decoder_query,
  665. z=sequence_output,
  666. query_mask=extended_attention_mask,
  667. output_attentions=output_attentions,
  668. )
  669. logits = decoder_outputs.logits
  670. # add cross-attentions of decoder
  671. if output_attentions and decoder_outputs.cross_attentions is not None:
  672. if return_dict:
  673. encoder_outputs.cross_attentions = (
  674. encoder_outputs.cross_attentions + decoder_outputs.cross_attentions
  675. )
  676. else:
  677. encoder_outputs = encoder_outputs + decoder_outputs.cross_attentions
  678. if self.output_postprocessor:
  679. logits = self.output_postprocessor(logits, modality_sizes=output_modality_sizes)
  680. if not return_dict:
  681. if logits is not None:
  682. return (logits, sequence_output) + encoder_outputs[1:]
  683. else:
  684. return (sequence_output,) + encoder_outputs[1:]
  685. return PerceiverModelOutput(
  686. logits=logits,
  687. last_hidden_state=sequence_output,
  688. hidden_states=encoder_outputs.hidden_states,
  689. attentions=encoder_outputs.attentions,
  690. cross_attentions=encoder_outputs.cross_attentions,
  691. )
  692. @auto_docstring(
  693. custom_intro="""
  694. Example use of Perceiver for masked language modeling.
  695. """
  696. )
  697. class PerceiverForMaskedLM(PerceiverPreTrainedModel):
  698. def __init__(self, config: PerceiverConfig):
  699. super().__init__(config)
  700. text_preprocessor = PerceiverTextPreprocessor(config)
  701. trainable_position_encoding_kwargs_decoder = {
  702. "num_channels": text_preprocessor.num_channels,
  703. "index_dims": config.max_position_embeddings,
  704. }
  705. self.perceiver = PerceiverModel(
  706. config,
  707. input_preprocessor=text_preprocessor,
  708. decoder=PerceiverBasicDecoder(
  709. config,
  710. output_num_channels=config.d_latents,
  711. output_index_dims=config.max_position_embeddings, # we need to define the seq_len of the inputs beforehand
  712. num_channels=text_preprocessor.num_channels,
  713. qk_channels=8 * 32,
  714. v_channels=text_preprocessor.num_channels,
  715. num_heads=8,
  716. use_query_residual=False,
  717. final_project=False,
  718. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
  719. ),
  720. )
  721. self.embedding_decoder = PerceiverEmbeddingDecoder(config)
  722. # Initialize weights and apply final processing
  723. self.post_init()
  724. @auto_docstring
  725. def forward(
  726. self,
  727. inputs: torch.Tensor | None = None,
  728. attention_mask: torch.Tensor | None = None,
  729. output_attentions: bool | None = None,
  730. output_hidden_states: bool | None = None,
  731. labels: torch.Tensor | None = None,
  732. return_dict: bool | None = None,
  733. input_ids: torch.Tensor | None = None,
  734. **kwargs,
  735. ) -> tuple | PerceiverMaskedLMOutput:
  736. r"""
  737. inputs (`torch.FloatTensor`):
  738. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  739. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  740. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  741. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  742. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  743. Examples:
  744. ```python
  745. >>> from transformers import AutoTokenizer, PerceiverForMaskedLM
  746. >>> import torch
  747. >>> tokenizer = AutoTokenizer.from_pretrained("deepmind/language-perceiver")
  748. >>> model = PerceiverForMaskedLM.from_pretrained("deepmind/language-perceiver")
  749. >>> # training
  750. >>> text = "This is an incomplete sentence where some words are missing."
  751. >>> inputs = tokenizer(text, padding="max_length", return_tensors="pt")
  752. >>> # mask " missing."
  753. >>> inputs["input_ids"][0, 52:61] = tokenizer.mask_token_id
  754. >>> labels = tokenizer(text, padding="max_length", return_tensors="pt").input_ids
  755. >>> outputs = model(**inputs, labels=labels)
  756. >>> loss = outputs.loss
  757. >>> round(loss.item(), 2)
  758. 19.87
  759. >>> logits = outputs.logits
  760. >>> list(logits.shape)
  761. [1, 2048, 262]
  762. >>> # inference
  763. >>> text = "This is an incomplete sentence where some words are missing."
  764. >>> encoding = tokenizer(text, padding="max_length", return_tensors="pt")
  765. >>> # mask bytes corresponding to " missing.". Note that the model performs much better if the masked span starts with a space.
  766. >>> encoding["input_ids"][0, 52:61] = tokenizer.mask_token_id
  767. >>> # forward pass
  768. >>> with torch.no_grad():
  769. ... outputs = model(**encoding)
  770. >>> logits = outputs.logits
  771. >>> list(logits.shape)
  772. [1, 2048, 262]
  773. >>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist()
  774. >>> tokenizer.decode(masked_tokens_predictions)
  775. ' missing.'
  776. ```"""
  777. if inputs is not None and input_ids is not None:
  778. raise ValueError("You cannot use both `inputs` and `input_ids`")
  779. elif inputs is None and input_ids is not None:
  780. inputs = input_ids
  781. return_dict = return_dict if return_dict is not None else self.config.return_dict
  782. outputs = self.perceiver(
  783. inputs=inputs,
  784. attention_mask=attention_mask,
  785. output_attentions=output_attentions,
  786. output_hidden_states=output_hidden_states,
  787. return_dict=return_dict,
  788. )
  789. logits = self.embedding_decoder(
  790. outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
  791. )
  792. masked_lm_loss = None
  793. if labels is not None:
  794. loss_fct = CrossEntropyLoss() # -100 index = padding token
  795. masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  796. if not return_dict:
  797. output = (logits,) + outputs[2:]
  798. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  799. return PerceiverMaskedLMOutput(
  800. loss=masked_lm_loss,
  801. logits=logits,
  802. hidden_states=outputs.hidden_states,
  803. attentions=outputs.attentions,
  804. cross_attentions=outputs.cross_attentions,
  805. )
  806. @auto_docstring(
  807. custom_intro="""
  808. Example use of Perceiver for text classification.
  809. """
  810. )
  811. class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
  812. def __init__(self, config):
  813. super().__init__(config)
  814. trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
  815. self.num_labels = config.num_labels
  816. self.perceiver = PerceiverModel(
  817. config,
  818. input_preprocessor=PerceiverTextPreprocessor(config),
  819. decoder=PerceiverClassificationDecoder(
  820. config,
  821. num_channels=config.d_latents,
  822. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
  823. use_query_residual=True,
  824. ),
  825. )
  826. # Initialize weights and apply final processing
  827. self.post_init()
  828. @auto_docstring
  829. def forward(
  830. self,
  831. inputs: torch.Tensor | None = None,
  832. attention_mask: torch.Tensor | None = None,
  833. output_attentions: bool | None = None,
  834. output_hidden_states: bool | None = None,
  835. labels: torch.Tensor | None = None,
  836. return_dict: bool | None = None,
  837. input_ids: torch.Tensor | None = None,
  838. **kwargs,
  839. ) -> tuple | PerceiverClassifierOutput:
  840. r"""
  841. inputs (`torch.FloatTensor`):
  842. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  843. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  844. Labels for computing the classification/regression loss. Indices should be in `[0, ..., config.num_labels -
  845. 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels >
  846. 1` a classification loss is computed (Cross-Entropy).
  847. Examples:
  848. ```python
  849. >>> from transformers import AutoTokenizer, PerceiverForSequenceClassification
  850. >>> tokenizer = AutoTokenizer.from_pretrained("deepmind/language-perceiver")
  851. >>> model = PerceiverForSequenceClassification.from_pretrained("deepmind/language-perceiver")
  852. >>> text = "hello world"
  853. >>> inputs = tokenizer(text, return_tensors="pt").input_ids
  854. >>> outputs = model(inputs=inputs)
  855. >>> logits = outputs.logits
  856. >>> list(logits.shape)
  857. [1, 2]
  858. ```"""
  859. if inputs is not None and input_ids is not None:
  860. raise ValueError("You cannot use both `inputs` and `input_ids`")
  861. elif inputs is None and input_ids is not None:
  862. inputs = input_ids
  863. return_dict = return_dict if return_dict is not None else self.config.return_dict
  864. outputs = self.perceiver(
  865. inputs=inputs,
  866. attention_mask=attention_mask,
  867. output_attentions=output_attentions,
  868. output_hidden_states=output_hidden_states,
  869. return_dict=return_dict,
  870. )
  871. logits = outputs.logits if return_dict else outputs[0]
  872. loss = None
  873. if labels is not None:
  874. if self.config.problem_type is None:
  875. if self.num_labels == 1:
  876. self.config.problem_type = "regression"
  877. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  878. self.config.problem_type = "single_label_classification"
  879. else:
  880. self.config.problem_type = "multi_label_classification"
  881. if self.config.problem_type == "regression":
  882. loss_fct = MSELoss()
  883. if self.num_labels == 1:
  884. loss = loss_fct(logits.squeeze(), labels.squeeze())
  885. else:
  886. loss = loss_fct(logits, labels)
  887. elif self.config.problem_type == "single_label_classification":
  888. loss_fct = CrossEntropyLoss()
  889. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  890. elif self.config.problem_type == "multi_label_classification":
  891. loss_fct = BCEWithLogitsLoss()
  892. loss = loss_fct(logits, labels)
  893. if not return_dict:
  894. output = (logits,) + outputs[2:]
  895. return ((loss,) + output) if loss is not None else output
  896. return PerceiverClassifierOutput(
  897. loss=loss,
  898. logits=logits,
  899. hidden_states=outputs.hidden_states,
  900. attentions=outputs.attentions,
  901. cross_attentions=outputs.cross_attentions,
  902. )
  903. @auto_docstring(
  904. custom_intro="""
  905. Example use of Perceiver for image classification, for tasks such as ImageNet.
  906. This model uses learned position embeddings. In other words, this model is not given any privileged information about
  907. the structure of images. As shown in the paper, this model can achieve a top-1 accuracy of 72.7 on ImageNet.
  908. [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
  909. (with `prep_type="conv1x1"`) to preprocess the input images, and
  910. [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
  911. [`PerceiverModel`] into classification logits.
  912. """
  913. )
  914. class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
  915. def __init__(self, config):
  916. super().__init__(config)
  917. trainable_position_encoding_kwargs_preprocessor = {"num_channels": 256, "index_dims": config.image_size**2}
  918. trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
  919. self.num_labels = config.num_labels
  920. self.perceiver = PerceiverModel(
  921. config,
  922. input_preprocessor=PerceiverImagePreprocessor(
  923. config,
  924. prep_type="conv1x1",
  925. spatial_downsample=1,
  926. out_channels=256,
  927. position_encoding_type="trainable",
  928. concat_or_add_pos="concat",
  929. project_pos_dim=256,
  930. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_preprocessor,
  931. ),
  932. decoder=PerceiverClassificationDecoder(
  933. config,
  934. num_channels=config.d_latents,
  935. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
  936. use_query_residual=True,
  937. ),
  938. )
  939. # Initialize weights and apply final processing
  940. self.post_init()
  941. @auto_docstring
  942. def forward(
  943. self,
  944. inputs: torch.Tensor | None = None,
  945. attention_mask: torch.Tensor | None = None,
  946. output_attentions: bool | None = None,
  947. output_hidden_states: bool | None = None,
  948. labels: torch.Tensor | None = None,
  949. interpolate_pos_encoding: bool = False,
  950. return_dict: bool | None = None,
  951. pixel_values: torch.Tensor | None = None,
  952. **kwargs,
  953. ) -> tuple | PerceiverClassifierOutput:
  954. r"""
  955. inputs (`torch.FloatTensor`):
  956. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  957. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  958. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  959. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  960. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  961. Examples:
  962. ```python
  963. >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationLearned
  964. >>> from PIL import Image
  965. >>> import httpx
  966. >>> from io import BytesIO
  967. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  968. >>> with httpx.stream("GET", url) as response:
  969. ... image = Image.open(BytesIO(response.read()))
  970. >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-learned")
  971. >>> model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned")
  972. >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values
  973. >>> outputs = model(inputs=inputs)
  974. >>> logits = outputs.logits
  975. >>> list(logits.shape)
  976. [1, 1000]
  977. >>> # model predicts one of the 1000 ImageNet classes
  978. >>> predicted_class_idx = logits.argmax(-1).item()
  979. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  980. Predicted class: tabby, tabby cat
  981. ```"""
  982. if inputs is not None and pixel_values is not None:
  983. raise ValueError("You cannot use both `inputs` and `pixel_values`")
  984. elif inputs is None and pixel_values is not None:
  985. inputs = pixel_values
  986. return_dict = return_dict if return_dict is not None else self.config.return_dict
  987. outputs = self.perceiver(
  988. inputs=inputs,
  989. attention_mask=attention_mask,
  990. output_attentions=output_attentions,
  991. output_hidden_states=output_hidden_states,
  992. interpolate_pos_encoding=interpolate_pos_encoding,
  993. return_dict=return_dict,
  994. )
  995. logits = outputs.logits if return_dict else outputs[0]
  996. loss = None
  997. if labels is not None:
  998. loss = self.loss_function(labels, logits, self.config)
  999. if not return_dict:
  1000. output = (logits,) + outputs[2:]
  1001. return ((loss,) + output) if loss is not None else output
  1002. return PerceiverClassifierOutput(
  1003. loss=loss,
  1004. logits=logits,
  1005. hidden_states=outputs.hidden_states,
  1006. attentions=outputs.attentions,
  1007. cross_attentions=outputs.cross_attentions,
  1008. )
  1009. @auto_docstring(
  1010. custom_intro="""
  1011. Example use of Perceiver for image classification, for tasks such as ImageNet.
  1012. This model uses fixed 2D Fourier position embeddings. As shown in the paper, this model can achieve a top-1 accuracy of
  1013. 79.0 on ImageNet, and 84.5 when pre-trained on a large-scale dataset (i.e. JFT).
  1014. [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
  1015. (with `prep_type="pixels"`) to preprocess the input images, and
  1016. [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
  1017. [`PerceiverModel`] into classification logits.
  1018. """
  1019. )
  1020. class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
  1021. def __init__(self, config):
  1022. super().__init__(config)
  1023. fourier_position_encoding_kwargs_preprocessor = {
  1024. "concat_pos": True,
  1025. "max_resolution": (224, 224),
  1026. "num_bands": 64,
  1027. "sine_only": False,
  1028. }
  1029. trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
  1030. self.num_labels = config.num_labels
  1031. self.perceiver = PerceiverModel(
  1032. config,
  1033. input_preprocessor=PerceiverImagePreprocessor(
  1034. config,
  1035. prep_type="pixels",
  1036. spatial_downsample=1,
  1037. fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
  1038. ),
  1039. decoder=PerceiverClassificationDecoder(
  1040. config,
  1041. num_channels=config.d_latents,
  1042. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
  1043. use_query_residual=True,
  1044. ),
  1045. )
  1046. # Initialize weights and apply final processing
  1047. self.post_init()
  1048. @auto_docstring
  1049. def forward(
  1050. self,
  1051. inputs: torch.Tensor | None = None,
  1052. attention_mask: torch.Tensor | None = None,
  1053. output_attentions: bool | None = None,
  1054. output_hidden_states: bool | None = None,
  1055. labels: torch.Tensor | None = None,
  1056. return_dict: bool | None = None,
  1057. pixel_values: torch.Tensor | None = None,
  1058. **kwargs,
  1059. ) -> tuple | PerceiverClassifierOutput:
  1060. r"""
  1061. inputs (`torch.FloatTensor`):
  1062. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  1063. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1064. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  1065. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1066. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1067. Examples:
  1068. ```python
  1069. >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationFourier
  1070. >>> from PIL import Image
  1071. >>> import httpx
  1072. >>> from io import BytesIO
  1073. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1074. >>> with httpx.stream("GET", url) as response:
  1075. ... image = Image.open(BytesIO(response.read()))
  1076. >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-fourier")
  1077. >>> model = PerceiverForImageClassificationFourier.from_pretrained("deepmind/vision-perceiver-fourier")
  1078. >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values
  1079. >>> outputs = model(inputs=inputs)
  1080. >>> logits = outputs.logits
  1081. >>> list(logits.shape)
  1082. [1, 1000]
  1083. >>> # model predicts one of the 1000 ImageNet classes
  1084. >>> predicted_class_idx = logits.argmax(-1).item()
  1085. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  1086. Predicted class: tabby, tabby cat
  1087. ```"""
  1088. if inputs is not None and pixel_values is not None:
  1089. raise ValueError("You cannot use both `inputs` and `pixel_values`")
  1090. elif inputs is None and pixel_values is not None:
  1091. inputs = pixel_values
  1092. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1093. outputs = self.perceiver(
  1094. inputs=inputs,
  1095. attention_mask=attention_mask,
  1096. output_attentions=output_attentions,
  1097. output_hidden_states=output_hidden_states,
  1098. return_dict=return_dict,
  1099. )
  1100. logits = outputs.logits if return_dict else outputs[0]
  1101. loss = None
  1102. if labels is not None:
  1103. loss = self.loss_function(labels, logits, self.config)
  1104. if not return_dict:
  1105. output = (logits,) + outputs[2:]
  1106. return ((loss,) + output) if loss is not None else output
  1107. return PerceiverClassifierOutput(
  1108. loss=loss,
  1109. logits=logits,
  1110. hidden_states=outputs.hidden_states,
  1111. attentions=outputs.attentions,
  1112. cross_attentions=outputs.cross_attentions,
  1113. )
  1114. @auto_docstring(
  1115. custom_intro="""
  1116. Example use of Perceiver for image classification, for tasks such as ImageNet.
  1117. This model uses a 2D conv+maxpool preprocessing network. As shown in the paper, this model can achieve a top-1 accuracy
  1118. of 82.1 on ImageNet.
  1119. [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
  1120. (with `prep_type="conv"`) to preprocess the input images, and
  1121. [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
  1122. [`PerceiverModel`] into classification logits.
  1123. """
  1124. )
  1125. class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
  1126. def __init__(self, config):
  1127. super().__init__(config)
  1128. fourier_position_encoding_kwargs_preprocessor = {
  1129. "concat_pos": True,
  1130. "max_resolution": (56, 56),
  1131. "num_bands": 64,
  1132. "sine_only": False,
  1133. }
  1134. trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
  1135. self.num_labels = config.num_labels
  1136. self.perceiver = PerceiverModel(
  1137. config,
  1138. input_preprocessor=PerceiverImagePreprocessor(
  1139. config,
  1140. prep_type="conv",
  1141. spatial_downsample=1,
  1142. position_encoding_type="fourier",
  1143. fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
  1144. ),
  1145. decoder=PerceiverClassificationDecoder(
  1146. config,
  1147. num_channels=config.d_latents,
  1148. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
  1149. use_query_residual=True,
  1150. ),
  1151. )
  1152. # Initialize weights and apply final processing
  1153. self.post_init()
  1154. @auto_docstring
  1155. def forward(
  1156. self,
  1157. inputs: torch.Tensor | None = None,
  1158. attention_mask: torch.Tensor | None = None,
  1159. output_attentions: bool | None = None,
  1160. output_hidden_states: bool | None = None,
  1161. labels: torch.Tensor | None = None,
  1162. return_dict: bool | None = None,
  1163. pixel_values: torch.Tensor | None = None,
  1164. **kwargs,
  1165. ) -> tuple | PerceiverClassifierOutput:
  1166. r"""
  1167. inputs (`torch.FloatTensor`):
  1168. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  1169. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1170. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  1171. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1172. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1173. Examples:
  1174. ```python
  1175. >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationConvProcessing
  1176. >>> from PIL import Image
  1177. >>> import httpx
  1178. >>> from io import BytesIO
  1179. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1180. >>> with httpx.stream("GET", url) as response:
  1181. ... image = Image.open(BytesIO(response.read()))
  1182. >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-conv")
  1183. >>> model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv")
  1184. >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values
  1185. >>> outputs = model(inputs=inputs)
  1186. >>> logits = outputs.logits
  1187. >>> list(logits.shape)
  1188. [1, 1000]
  1189. >>> # model predicts one of the 1000 ImageNet classes
  1190. >>> predicted_class_idx = logits.argmax(-1).item()
  1191. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  1192. Predicted class: tabby, tabby cat
  1193. ```"""
  1194. if inputs is not None and pixel_values is not None:
  1195. raise ValueError("You cannot use both `inputs` and `pixel_values`")
  1196. elif inputs is None and pixel_values is not None:
  1197. inputs = pixel_values
  1198. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1199. outputs = self.perceiver(
  1200. inputs=inputs,
  1201. attention_mask=attention_mask,
  1202. output_attentions=output_attentions,
  1203. output_hidden_states=output_hidden_states,
  1204. return_dict=return_dict,
  1205. )
  1206. logits = outputs.logits if return_dict else outputs[0]
  1207. loss = None
  1208. if labels is not None:
  1209. loss = self.loss_function(labels, logits, self.config)
  1210. if not return_dict:
  1211. output = (logits,) + outputs[2:]
  1212. return ((loss,) + output) if loss is not None else output
  1213. return PerceiverClassifierOutput(
  1214. loss=loss,
  1215. logits=logits,
  1216. hidden_states=outputs.hidden_states,
  1217. attentions=outputs.attentions,
  1218. cross_attentions=outputs.cross_attentions,
  1219. )
  1220. @auto_docstring(
  1221. custom_intro="""
  1222. Example use of Perceiver for optical flow, for tasks such as Sintel and KITTI. [`PerceiverForOpticalFlow`] uses
  1223. [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] (with *prep_type="patches"*) to preprocess the
  1224. input images, and [`~models.perceiver.modeling_perceiver.PerceiverOpticalFlowDecoder`] to decode the latent
  1225. representation of [`PerceiverModel`].
  1226. As input, one concatenates 2 subsequent frames along the channel dimension and extract a 3 x 3 patch around each pixel
  1227. (leading to 3 x 3 x 3 x 2 = 54 values for each pixel). Fixed Fourier position encodings are used to encode the position
  1228. of each pixel in the patch. Next, one applies the Perceiver encoder. To decode, one queries the latent representation
  1229. using the same encoding used for the input.
  1230. """
  1231. )
  1232. class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
  1233. def __init__(self, config):
  1234. super().__init__(config)
  1235. fourier_position_encoding_kwargs_preprocessor = {
  1236. "num_bands": 64,
  1237. "max_resolution": config.train_size,
  1238. "sine_only": False,
  1239. "concat_pos": True,
  1240. }
  1241. fourier_position_encoding_kwargs_decoder = {
  1242. "concat_pos": True,
  1243. "max_resolution": config.train_size,
  1244. "num_bands": 64,
  1245. "sine_only": False,
  1246. }
  1247. image_preprocessor = PerceiverImagePreprocessor(
  1248. config,
  1249. prep_type="patches",
  1250. spatial_downsample=1,
  1251. conv_after_patching=True,
  1252. conv_after_patching_in_channels=54,
  1253. temporal_downsample=2,
  1254. position_encoding_type="fourier",
  1255. # position_encoding_kwargs
  1256. fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
  1257. )
  1258. self.perceiver = PerceiverModel(
  1259. config,
  1260. input_preprocessor=image_preprocessor,
  1261. decoder=PerceiverOpticalFlowDecoder(
  1262. config,
  1263. num_channels=image_preprocessor.num_channels,
  1264. output_image_shape=config.train_size,
  1265. rescale_factor=100.0,
  1266. # decoder kwargs
  1267. use_query_residual=False,
  1268. output_num_channels=2,
  1269. # We query the decoder using the first frame features
  1270. # rather than a standard decoder position encoding.
  1271. position_encoding_type="fourier",
  1272. fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_decoder,
  1273. ),
  1274. )
  1275. # Initialize weights and apply final processing
  1276. self.post_init()
  1277. @auto_docstring
  1278. def forward(
  1279. self,
  1280. inputs: torch.Tensor | None = None,
  1281. attention_mask: torch.Tensor | None = None,
  1282. output_attentions: bool | None = None,
  1283. output_hidden_states: bool | None = None,
  1284. labels: torch.Tensor | None = None,
  1285. return_dict: bool | None = None,
  1286. **kwargs,
  1287. ) -> tuple | PerceiverClassifierOutput:
  1288. r"""
  1289. inputs (`torch.FloatTensor`):
  1290. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  1291. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1292. Labels for computing the optical flow loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1293. Examples:
  1294. ```python
  1295. >>> from transformers import PerceiverForOpticalFlow
  1296. >>> import torch
  1297. >>> model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver")
  1298. >>> # in the Perceiver IO paper, the authors extract a 3 x 3 patch around each pixel,
  1299. >>> # leading to 3 x 3 x 3 = 27 values for each pixel (as each pixel also has 3 color channels)
  1300. >>> # patches have shape (batch_size, num_frames, num_channels, height, width)
  1301. >>> # the authors train on resolutions of 368 x 496
  1302. >>> patches = torch.randn(1, 2, 27, 368, 496)
  1303. >>> outputs = model(inputs=patches)
  1304. >>> logits = outputs.logits
  1305. >>> list(logits.shape)
  1306. [1, 368, 496, 2]
  1307. ```"""
  1308. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1309. loss = None
  1310. if labels is not None:
  1311. raise NotImplementedError("Optical flow training is not yet supported")
  1312. outputs = self.perceiver(
  1313. inputs=inputs,
  1314. attention_mask=attention_mask,
  1315. output_attentions=output_attentions,
  1316. output_hidden_states=output_hidden_states,
  1317. return_dict=return_dict,
  1318. )
  1319. logits = outputs.logits if return_dict else outputs[0]
  1320. if not return_dict:
  1321. output = (logits,) + outputs[2:]
  1322. return ((loss,) + output) if loss is not None else output
  1323. return PerceiverClassifierOutput(
  1324. loss=loss,
  1325. logits=logits,
  1326. hidden_states=outputs.hidden_states,
  1327. attentions=outputs.attentions,
  1328. cross_attentions=outputs.cross_attentions,
  1329. )
  1330. @auto_docstring(
  1331. custom_intro="""
  1332. Example use of Perceiver for multimodal (video) autoencoding, for tasks such as Kinetics-700.
  1333. [`PerceiverForMultimodalAutoencoding`] uses [`~models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor`] to
  1334. preprocess the 3 modalities: images, audio and class labels. This preprocessor uses modality-specific preprocessors to
  1335. preprocess every modality separately, after which they are concatenated. Trainable position embeddings are used to pad
  1336. each modality to the same number of channels to make concatenation along the time dimension possible. Next, one applies
  1337. the Perceiver encoder.
  1338. [`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] is used to decode the latent representation of
  1339. [`PerceiverModel`]. This decoder uses each modality-specific decoder to construct queries. The decoder queries are
  1340. created based on the inputs after preprocessing. However, autoencoding an entire video in a single forward pass is
  1341. computationally infeasible, hence one only uses parts of the decoder queries to do cross-attention with the latent
  1342. representation. This is determined by the subsampled indices for each modality, which can be provided as additional
  1343. input to the forward pass of [`PerceiverForMultimodalAutoencoding`].
  1344. [`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] also pads the decoder queries of the different
  1345. modalities to the same number of channels, in order to concatenate them along the time dimension. Next, cross-attention
  1346. is performed with the latent representation of [`PerceiverModel`].
  1347. Finally, [`~models.perceiver.modeling_perceiver.PerceiverMultiModalPostprocessor`] is used to turn this tensor into an
  1348. actual video. It first splits up the output into the different modalities, and then applies the respective
  1349. postprocessor for each modality.
  1350. Note that, by masking the classification label during evaluation (i.e. simply providing a tensor of zeros for the
  1351. "label" modality), this auto-encoding model becomes a Kinetics 700 video classifier.
  1352. """
  1353. )
  1354. class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
  1355. def __init__(self, config: PerceiverConfig):
  1356. super().__init__(config)
  1357. n_audio_samples = config.num_frames * config.audio_samples_per_frame
  1358. input_preprocessor = PerceiverMultimodalPreprocessor(
  1359. min_padding_size=4,
  1360. modalities={
  1361. "audio": PerceiverAudioPreprocessor(
  1362. config,
  1363. position_encoding_type="fourier",
  1364. fourier_position_encoding_kwargs={
  1365. "num_bands": 192,
  1366. "max_resolution": (n_audio_samples,),
  1367. "sine_only": False,
  1368. "concat_pos": True,
  1369. },
  1370. prep_type="patches",
  1371. samples_per_patch=config.samples_per_patch,
  1372. ),
  1373. "image": PerceiverImagePreprocessor(
  1374. config,
  1375. position_encoding_type="fourier",
  1376. fourier_position_encoding_kwargs={
  1377. "num_bands": 32,
  1378. "max_resolution": (config.num_frames, config.image_size, config.image_size),
  1379. "sine_only": False,
  1380. "concat_pos": True,
  1381. },
  1382. prep_type="patches",
  1383. spatial_downsample=4,
  1384. temporal_downsample=1,
  1385. ),
  1386. "label": PerceiverOneHotPreprocessor(config),
  1387. },
  1388. mask_probs={"image": 0.0, "audio": 0.0, "label": 1.0},
  1389. )
  1390. image_decoder = PerceiverBasicVideoAutoencodingDecoder(
  1391. config,
  1392. # Autoencoding, don't pass inputs to the queries.
  1393. concat_preprocessed_input=False,
  1394. output_shape=config.output_shape,
  1395. output_num_channels=config.output_num_channels,
  1396. use_query_residual=False,
  1397. position_encoding_only=True,
  1398. position_encoding_type="fourier",
  1399. fourier_position_encoding_kwargs={
  1400. "num_bands": 32,
  1401. "max_resolution": (config.num_frames, config.image_size, config.image_size),
  1402. "sine_only": False,
  1403. "concat_pos": True,
  1404. },
  1405. )
  1406. decoder = PerceiverMultimodalDecoder(
  1407. config,
  1408. # Autoencoding, don't pass inputs to the queries.
  1409. concat_preprocessed_input=False,
  1410. # Modality specific decoders are used ONLY to generate queries.
  1411. # All modalties are decoded together using a unified decoder.
  1412. modalities={
  1413. "audio": PerceiverBasicDecoder(
  1414. config,
  1415. # Autoencoding, don't pass inputs to the queries.
  1416. concat_preprocessed_input=False,
  1417. output_index_dims=(n_audio_samples // config.samples_per_patch,),
  1418. output_num_channels=config.output_num_channels,
  1419. use_query_residual=False,
  1420. position_encoding_only=True,
  1421. position_encoding_type="fourier",
  1422. fourier_position_encoding_kwargs={
  1423. "num_bands": 192,
  1424. "max_resolution": (n_audio_samples,),
  1425. "sine_only": False,
  1426. "concat_pos": True,
  1427. },
  1428. ),
  1429. "image": image_decoder,
  1430. "label": PerceiverClassificationDecoder(
  1431. config,
  1432. # Autoencoding, don't pass inputs to the queries.
  1433. concat_preprocessed_input=False,
  1434. use_query_residual=False,
  1435. position_encoding_only=True,
  1436. position_encoding_type="trainable",
  1437. trainable_position_encoding_kwargs={
  1438. "num_channels": config._label_trainable_num_channels,
  1439. "index_dims": 1,
  1440. },
  1441. ),
  1442. },
  1443. num_outputs=None,
  1444. output_num_channels=config.output_num_channels,
  1445. use_query_residual=False,
  1446. )
  1447. output_postprocessor = PerceiverMultimodalPostprocessor(
  1448. modalities={
  1449. "audio": PerceiverAudioPostprocessor(config, in_channels=config.output_num_channels),
  1450. "image": PerceiverProjectionPostprocessor(in_channels=config.output_num_channels, out_channels=3),
  1451. "label": PerceiverClassificationPostprocessor(config, in_channels=config.output_num_channels),
  1452. }
  1453. )
  1454. self.perceiver = PerceiverModel(
  1455. config,
  1456. input_preprocessor=input_preprocessor,
  1457. decoder=decoder,
  1458. output_postprocessor=output_postprocessor,
  1459. )
  1460. # Initialize weights and apply final processing
  1461. self.post_init()
  1462. @auto_docstring
  1463. def forward(
  1464. self,
  1465. inputs: torch.Tensor | None = None,
  1466. attention_mask: torch.Tensor | None = None,
  1467. subsampled_output_points: dict[str, torch.Tensor] | None = None,
  1468. output_attentions: bool | None = None,
  1469. output_hidden_states: bool | None = None,
  1470. labels: torch.Tensor | None = None,
  1471. return_dict: bool | None = None,
  1472. **kwargs,
  1473. ) -> tuple | PerceiverClassifierOutput:
  1474. r"""
  1475. inputs (`torch.FloatTensor`):
  1476. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  1477. subsampled_output_points (`dict[str, torch.Tensor]`, *optional*):
  1478. Dictionary of tensors used as queries for the decoder. The decoder maps these queries to the latent
  1479. representation of the model. Used for subsampled decoding, e.g. when only decoding certain image patches.
  1480. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1481. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  1482. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1483. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1484. Examples:
  1485. ```python
  1486. >>> from transformers import PerceiverForMultimodalAutoencoding
  1487. >>> import torch
  1488. >>> import numpy as np
  1489. >>> # create multimodal inputs
  1490. >>> images = torch.randn((1, 16, 3, 224, 224))
  1491. >>> audio = torch.randn((1, 30720, 1))
  1492. >>> inputs = dict(image=images, audio=audio, label=torch.zeros((images.shape[0], 700)))
  1493. >>> model = PerceiverForMultimodalAutoencoding.from_pretrained("deepmind/multimodal-perceiver")
  1494. >>> # in the Perceiver IO paper, videos are auto-encoded in chunks
  1495. >>> # each chunk subsamples different index dimensions of the image and audio modality decoder queries
  1496. >>> nchunks = 128
  1497. >>> image_chunk_size = np.prod((16, 224, 224)) // nchunks
  1498. >>> audio_chunk_size = audio.shape[1] // model.config.samples_per_patch // nchunks
  1499. >>> # process the first chunk
  1500. >>> chunk_idx = 0
  1501. >>> subsampling = {
  1502. ... "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)),
  1503. ... "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)),
  1504. ... "label": None,
  1505. ... }
  1506. >>> outputs = model(inputs=inputs, subsampled_output_points=subsampling)
  1507. >>> logits = outputs.logits
  1508. >>> list(logits["audio"].shape)
  1509. [1, 240]
  1510. >>> list(logits["image"].shape)
  1511. [1, 6272, 3]
  1512. >>> list(logits["label"].shape)
  1513. [1, 700]
  1514. ```"""
  1515. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1516. loss = None
  1517. if labels is not None:
  1518. raise NotImplementedError("Multimodal autoencoding training is not yet supported")
  1519. outputs = self.perceiver(
  1520. inputs=inputs,
  1521. attention_mask=attention_mask,
  1522. subsampled_output_points=subsampled_output_points,
  1523. output_attentions=output_attentions,
  1524. output_hidden_states=output_hidden_states,
  1525. return_dict=return_dict,
  1526. )
  1527. logits = outputs.logits if return_dict else outputs[0]
  1528. if not return_dict:
  1529. output = (logits,) + outputs[2:]
  1530. return ((loss,) + output) if loss is not None else output
  1531. return PerceiverClassifierOutput(
  1532. loss=loss,
  1533. logits=logits,
  1534. hidden_states=outputs.hidden_states,
  1535. attentions=outputs.attentions,
  1536. cross_attentions=outputs.cross_attentions,
  1537. )
  1538. # Below: position encodings
  1539. def build_position_encoding(
  1540. position_encoding_type,
  1541. out_channels=None,
  1542. project_pos_dim=-1,
  1543. trainable_position_encoding_kwargs=None,
  1544. fourier_position_encoding_kwargs=None,
  1545. ):
  1546. """
  1547. Builds the position encoding.
  1548. Args:
  1549. - out_channels: refers to the number of channels of the position encodings.
  1550. - project_pos_dim: if specified, will project the position encodings to this dimension.
  1551. """
  1552. if position_encoding_type == "trainable":
  1553. if not trainable_position_encoding_kwargs:
  1554. raise ValueError("Make sure to pass trainable_position_encoding_kwargs")
  1555. output_pos_enc = PerceiverTrainablePositionEncoding(**trainable_position_encoding_kwargs)
  1556. elif position_encoding_type == "fourier":
  1557. # We don't use the index_dims argument, as this is only known during the forward pass
  1558. if not fourier_position_encoding_kwargs:
  1559. raise ValueError("Make sure to pass fourier_position_encoding_kwargs")
  1560. output_pos_enc = PerceiverFourierPositionEncoding(**fourier_position_encoding_kwargs)
  1561. else:
  1562. raise ValueError(f"Unknown position encoding type: {position_encoding_type}.")
  1563. # Optionally, project the position encoding to a target dimension:
  1564. positions_projection = nn.Linear(out_channels, project_pos_dim) if project_pos_dim > 0 else nn.Identity()
  1565. return output_pos_enc, positions_projection
  1566. # Below: Perceiver decoders
  1567. class PerceiverAbstractDecoder(nn.Module, metaclass=abc.ABCMeta):
  1568. """Perceiver abstract decoder."""
  1569. @abc.abstractmethod
  1570. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1571. raise NotImplementedError
  1572. @property
  1573. @abc.abstractmethod
  1574. def num_query_channels(self):
  1575. raise NotImplementedError
  1576. @abc.abstractmethod
  1577. def forward(self, query, z, query_mask=None):
  1578. raise NotImplementedError
  1579. class PerceiverProjectionDecoder(PerceiverAbstractDecoder):
  1580. """
  1581. Baseline projection decoder (no cross-attention).
  1582. Args:
  1583. config ([`PerceiverConfig`]):
  1584. Model configuration.
  1585. """
  1586. def __init__(self, config):
  1587. super().__init__()
  1588. self.classifier = nn.Linear(config.d_latents, config.num_labels)
  1589. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1590. return None
  1591. def forward(
  1592. self, query: torch.Tensor, z: torch.FloatTensor, query_mask: torch.FloatTensor | None = None
  1593. ) -> torch.FloatTensor:
  1594. # (batch_size, num_latents, d_latents) -> (batch_size, d_latents)
  1595. z = torch.mean(z, dim=1)
  1596. # (batch_size, d_latents) -> (batch_size, config.num_labels)
  1597. logits = self.classifier(z)
  1598. return logits
  1599. class PerceiverBasicDecoder(PerceiverAbstractDecoder):
  1600. """
  1601. Cross-attention-based decoder. This class can be used to decode the final hidden states of the latents using a
  1602. cross-attention operation, in which the latents produce keys and values.
  1603. The shape of the output of this class depends on how one defines the output queries (also called decoder queries).
  1604. Args:
  1605. config ([*PerceiverConfig*]):
  1606. Model configuration.
  1607. output_num_channels (`int`, *optional*):
  1608. The number of channels in the output. Will only be used in case *final_project* is set to `True`.
  1609. position_encoding_type (`str`, *optional*, defaults to "trainable"):
  1610. The type of position encoding to use. Can be either "trainable", "fourier", or "none".
  1611. output_index_dims (`int`, *optional*):
  1612. The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'.
  1613. num_channels (`int`, *optional*, defaults to 128):
  1614. The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'.
  1615. qk_channels (`int`, *optional*):
  1616. The number of channels of the queries and keys in the cross-attention layer.
  1617. v_channels (`int`, *optional*):
  1618. The number of channels of the values in the cross-attention layer.
  1619. num_heads (`int`, *optional*, defaults to 1):
  1620. The number of attention heads in the cross-attention layer.
  1621. widening_factor (`int`, *optional*, defaults to 1):
  1622. The widening factor of the cross-attention layer.
  1623. use_query_residual (`bool`, *optional*, defaults to `False`):
  1624. Whether to use a residual connection between the query and the output of the cross-attention layer.
  1625. concat_preprocessed_input (`bool`, *optional*, defaults to `False`):
  1626. Whether to concatenate the preprocessed input to the query.
  1627. final_project (`bool`, *optional*, defaults to `True`):
  1628. Whether to project the output of the cross-attention layer to a target dimension.
  1629. position_encoding_only (`bool`, *optional*, defaults to `False`):
  1630. Whether to only use this class to define output queries.
  1631. """
  1632. def __init__(
  1633. self,
  1634. config: PerceiverConfig,
  1635. output_num_channels: int,
  1636. position_encoding_type: str | None = "trainable",
  1637. # The following 2 arguments are ignored if position_encoding_type == 'none':
  1638. output_index_dims: int | None = None,
  1639. num_channels: int | None = 128,
  1640. subsampled_index_dims: int | None = None,
  1641. qk_channels: int | None = None,
  1642. v_channels: int | None = None,
  1643. num_heads: int | None = 1,
  1644. widening_factor: int | None = 1,
  1645. use_query_residual: bool | None = False,
  1646. concat_preprocessed_input: bool | None = False,
  1647. final_project: bool | None = True,
  1648. position_encoding_only: bool | None = False,
  1649. **position_encoding_kwargs,
  1650. ) -> None:
  1651. super().__init__()
  1652. self.output_num_channels = output_num_channels
  1653. # If `none`, the decoder will not construct any position encodings.
  1654. # You should construct your own when querying the decoder.
  1655. self.output_position_encodings = None
  1656. self.position_encoding_type = position_encoding_type
  1657. self.position_encoding_kwargs = position_encoding_kwargs
  1658. if position_encoding_type != "none":
  1659. self.output_position_encodings, self.positions_projection = build_position_encoding(
  1660. position_encoding_type=position_encoding_type, **position_encoding_kwargs
  1661. )
  1662. self.output_index_dims = output_index_dims
  1663. self.num_channels = num_channels
  1664. if subsampled_index_dims is None:
  1665. subsampled_index_dims = output_index_dims
  1666. self.subsampled_index_dims = subsampled_index_dims
  1667. self.concat_preprocessed_input = concat_preprocessed_input
  1668. self.final_project = final_project
  1669. self.position_encoding_only = position_encoding_only
  1670. # for multimodal autoencoding, we don't need the decoder cross-attention and final layer
  1671. # so then we will set position_encoding_only to True
  1672. if not self.position_encoding_only:
  1673. self.decoding_cross_attention = PerceiverLayer(
  1674. config,
  1675. is_cross_attention=True,
  1676. qk_channels=qk_channels,
  1677. v_channels=v_channels,
  1678. num_heads=num_heads,
  1679. q_dim=num_channels,
  1680. kv_dim=config.d_latents,
  1681. widening_factor=widening_factor,
  1682. use_query_residual=use_query_residual,
  1683. )
  1684. self.final_layer = nn.Linear(num_channels, output_num_channels) if final_project else nn.Identity()
  1685. @property
  1686. def num_query_channels(self) -> int:
  1687. if self.position_encoding_type == "none": # Queries come from elsewhere
  1688. raise ValueError(
  1689. "You cannot calculate number of decoder query channels when position_encoding_type is set to none"
  1690. )
  1691. if self.position_encoding_only:
  1692. if "project_pos_dim" in self.position_encoding_kwargs:
  1693. return self.position_encoding_kwargs["project_pos_dim"]
  1694. return self.output_position_encodings.output_size()
  1695. if self.final_project:
  1696. return self.output_num_channels
  1697. return self.num_channels
  1698. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1699. if self.position_encoding_type == "none": # Queries come from elsewhere
  1700. raise ValueError("You cannot construct decoder queries when position_encoding_type is set to none")
  1701. if subsampled_points is not None:
  1702. # subsampled_points are the indices if the inputs would be flattened
  1703. # however, the inputs aren't flattened, that's why we use unravel_index
  1704. # to get the indices for the unflattened array
  1705. # unravel_index returns a tuple (x_idx, y_idx, ...)
  1706. # stack to get the [n, d] tensor of coordinates
  1707. indices = torch.unravel_index(subsampled_points, self.output_index_dims)
  1708. pos = torch.stack(indices, dim=1)
  1709. batch_size = inputs.shape[0]
  1710. # Map these coordinates to [-1, 1]
  1711. pos = -1 + 2 * pos / torch.tensor(self.output_index_dims)[None, :]
  1712. pos = torch.broadcast_to(pos[None], [batch_size, pos.shape[0], pos.shape[1]])
  1713. # Construct the position encoding.
  1714. if self.position_encoding_type == "trainable":
  1715. pos_emb = self.output_position_encodings(batch_size)
  1716. elif self.position_encoding_type == "fourier":
  1717. pos_emb = self.output_position_encodings(
  1718. self.output_index_dims, batch_size=batch_size, device=inputs.device, dtype=inputs.dtype, pos=pos
  1719. )
  1720. # Optionally project them to a target dimension.
  1721. pos_emb = self.positions_projection(pos_emb)
  1722. pos_emb = torch.reshape(pos_emb, [pos_emb.shape[0], -1, pos_emb.shape[-1]])
  1723. else:
  1724. batch_size = inputs.shape[0]
  1725. index_dims = inputs.shape[2:]
  1726. # Construct the position encoding.
  1727. if self.position_encoding_type == "trainable":
  1728. pos_emb = self.output_position_encodings(batch_size)
  1729. elif self.position_encoding_type == "fourier":
  1730. pos_emb = self.output_position_encodings(
  1731. index_dims, batch_size, device=inputs.device, dtype=inputs.dtype
  1732. )
  1733. # Optionally project them to a target dimension.
  1734. pos_emb = self.positions_projection(pos_emb)
  1735. if self.concat_preprocessed_input:
  1736. if inputs_without_pos is None:
  1737. raise ValueError("Value is required for inputs_without_pos if concat_preprocessed_input is True")
  1738. pos_emb = torch.cat([inputs_without_pos, pos_emb], dim=-1)
  1739. return pos_emb
  1740. def forward(
  1741. self,
  1742. query: torch.Tensor,
  1743. z: torch.FloatTensor,
  1744. query_mask: torch.FloatTensor | None = None,
  1745. output_attentions: bool | None = False,
  1746. ) -> PerceiverDecoderOutput:
  1747. # Cross-attention decoding.
  1748. # key, value: B x N x K; query: B x M x K
  1749. # Attention maps -> B x N x M
  1750. # Output -> B x M x K
  1751. cross_attentions = () if output_attentions else None
  1752. layer_outputs = self.decoding_cross_attention(
  1753. query,
  1754. attention_mask=query_mask,
  1755. inputs=z,
  1756. inputs_mask=None,
  1757. output_attentions=output_attentions,
  1758. )
  1759. output = layer_outputs[0]
  1760. if output_attentions:
  1761. cross_attentions = cross_attentions + (layer_outputs[1],)
  1762. logits = self.final_layer(output)
  1763. return PerceiverDecoderOutput(logits=logits, cross_attentions=cross_attentions)
  1764. class PerceiverClassificationDecoder(PerceiverAbstractDecoder):
  1765. """
  1766. Cross-attention based classification decoder. Light-weight wrapper of [`PerceiverBasicDecoder`] for logit output.
  1767. Will turn the output of the Perceiver encoder which is of shape (batch_size, num_latents, d_latents) to a tensor of
  1768. shape (batch_size, num_labels). The queries are of shape (batch_size, 1, num_labels).
  1769. Args:
  1770. config ([`PerceiverConfig`]):
  1771. Model configuration.
  1772. """
  1773. def __init__(self, config, **decoder_kwargs):
  1774. super().__init__()
  1775. self.num_labels = config.num_labels
  1776. self.decoder = PerceiverBasicDecoder(
  1777. config,
  1778. output_num_channels=self.num_labels,
  1779. output_index_dims=1, # Predict a single logit array.
  1780. **decoder_kwargs,
  1781. )
  1782. @property
  1783. def num_query_channels(self) -> int:
  1784. return self.decoder.num_query_channels
  1785. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1786. return self.decoder.decoder_query(
  1787. inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points
  1788. )
  1789. def forward(
  1790. self,
  1791. query: torch.Tensor,
  1792. z: torch.FloatTensor,
  1793. query_mask: torch.FloatTensor | None = None,
  1794. output_attentions: bool | None = False,
  1795. ) -> PerceiverDecoderOutput:
  1796. decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
  1797. # B x 1 x num_classes -> B x num_classes
  1798. logits = decoder_outputs.logits[:, 0, :]
  1799. return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions)
  1800. class PerceiverOpticalFlowDecoder(PerceiverAbstractDecoder):
  1801. """Cross-attention based optical flow decoder."""
  1802. def __init__(self, config, output_image_shape, output_num_channels=2, rescale_factor=100.0, **decoder_kwargs):
  1803. super().__init__()
  1804. self.output_image_shape = output_image_shape
  1805. self.output_num_channels = output_num_channels
  1806. self.rescale_factor = rescale_factor
  1807. self.decoder = PerceiverBasicDecoder(config, output_num_channels=output_num_channels, **decoder_kwargs)
  1808. @property
  1809. def num_query_channels(self) -> int:
  1810. return self.decoder.num_query_channels
  1811. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1812. if subsampled_points is not None:
  1813. raise ValueError("FlowDecoder doesn't support subsampling yet.")
  1814. return inputs
  1815. def forward(
  1816. self,
  1817. query: torch.Tensor,
  1818. z: torch.FloatTensor,
  1819. query_mask: torch.FloatTensor | None = None,
  1820. output_attentions: bool | None = False,
  1821. ) -> PerceiverDecoderOutput:
  1822. decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
  1823. preds = decoder_outputs.logits
  1824. # Output flow and rescale.
  1825. preds /= self.rescale_factor
  1826. preds = preds.reshape([preds.shape[0]] + list(self.output_image_shape) + [preds.shape[-1]])
  1827. return PerceiverDecoderOutput(logits=preds, cross_attentions=decoder_outputs.cross_attentions)
  1828. class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder):
  1829. """
  1830. Cross-attention based video-autoencoding decoder. Light-weight wrapper of [*PerceiverBasicDecoder*] with video
  1831. reshaping logic.
  1832. Args:
  1833. config ([*PerceiverConfig*]):
  1834. Model configuration.
  1835. output_shape (`list[int]`):
  1836. Shape of the output as (batch_size, num_frames, height, width), excluding the channel dimension.
  1837. position_encoding_type (`str`):
  1838. The type of position encoding to use. Can be either "trainable", "fourier", or "none".
  1839. """
  1840. def __init__(
  1841. self, config: PerceiverConfig, output_shape: list[int], position_encoding_type: str, **decoder_kwargs
  1842. ) -> None:
  1843. super().__init__()
  1844. if len(output_shape) != 4: # B, T, H, W
  1845. raise ValueError(f"Expected rank 4 output_shape, got {output_shape}.")
  1846. # Build the decoder components:
  1847. self.output_shape = output_shape
  1848. self.output_num_channels = decoder_kwargs["output_num_channels"]
  1849. self.decoder = PerceiverBasicDecoder(
  1850. config,
  1851. output_index_dims=self.output_shape[1:4], # T*H*W
  1852. position_encoding_type=position_encoding_type,
  1853. **decoder_kwargs,
  1854. )
  1855. @property
  1856. def num_query_channels(self) -> int:
  1857. return self.decoder.num_query_channels
  1858. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1859. return self.decoder.decoder_query(
  1860. inputs,
  1861. modality_sizes=modality_sizes,
  1862. inputs_without_pos=inputs_without_pos,
  1863. subsampled_points=subsampled_points,
  1864. )
  1865. def forward(
  1866. self, query: torch.Tensor, z: torch.FloatTensor, query_mask: torch.FloatTensor | None = None
  1867. ) -> PerceiverDecoderOutput:
  1868. decoder_outputs = self.decoder(query, z)
  1869. logits = decoder_outputs.logits
  1870. logits = torch.reshape(logits, self.output_shape + [logits.shape[-1]])
  1871. return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions)
  1872. def restructure(modality_sizes: ModalitySizeType, inputs: torch.Tensor) -> Mapping[str, torch.Tensor]:
  1873. """
  1874. Partitions a [B, N, C] tensor into tensors for each modality.
  1875. Args:
  1876. modality_sizes
  1877. dict specifying the size of the modality
  1878. inputs:
  1879. input tensor
  1880. Returns:
  1881. dict mapping name of modality to its associated tensor.
  1882. """
  1883. outputs = {}
  1884. index = 0
  1885. # Apply a predictable ordering to the modalities
  1886. for modality in sorted(modality_sizes.keys()):
  1887. size = modality_sizes[modality]
  1888. inp = inputs[:, index : index + size]
  1889. index += size
  1890. outputs[modality] = inp
  1891. return outputs
  1892. class PerceiverMultimodalDecoder(PerceiverAbstractDecoder):
  1893. """
  1894. Multimodal decoding by composing uni-modal decoders. The *modalities* argument of the constructor is a dictionary
  1895. mapping modality name to the decoder of that modality. That decoder will be used to construct queries for that
  1896. modality. Modality-specific queries are padded with trainable modality-specific parameters, after which they are
  1897. concatenated along the time dimension.
  1898. Next, there is a shared cross attention operation across all modalities.
  1899. Args:
  1900. config ([*PerceiverConfig*]):
  1901. Model configuration.
  1902. modalities (`dict[str, PerceiverAbstractDecoder]`):
  1903. Dictionary mapping modality name to the decoder of that modality.
  1904. num_outputs (`int`):
  1905. The number of outputs of the decoder.
  1906. output_num_channels (`int`):
  1907. The number of channels in the output.
  1908. min_padding_size (`int`, *optional*, defaults to 2):
  1909. The minimum padding size for all modalities. The final output will have num_channels equal to the maximum
  1910. channels across all modalities plus min_padding_size.
  1911. subsampled_index_dims (`dict[str, PerceiverAbstractDecoder]`, *optional*):
  1912. Dictionary mapping modality name to the subsampled index dimensions to use for the decoder query of that
  1913. modality.
  1914. """
  1915. def __init__(
  1916. self,
  1917. config: PerceiverConfig,
  1918. modalities: dict[str, PerceiverAbstractDecoder],
  1919. num_outputs: int,
  1920. output_num_channels: int,
  1921. min_padding_size: int | None = 2,
  1922. subsampled_index_dims: dict[str, PerceiverAbstractDecoder] | None = None,
  1923. **decoder_kwargs,
  1924. ) -> None:
  1925. super().__init__()
  1926. self.modalities = nn.ModuleDict(modalities)
  1927. self.subsampled_index_dims = subsampled_index_dims
  1928. self.min_padding_size = min_padding_size
  1929. self.output_num_channels = output_num_channels
  1930. self.num_outputs = num_outputs
  1931. self.decoder = PerceiverBasicDecoder(
  1932. config,
  1933. output_index_dims=(num_outputs,),
  1934. output_num_channels=output_num_channels,
  1935. position_encoding_type="none",
  1936. num_channels=self.num_query_channels,
  1937. **decoder_kwargs,
  1938. )
  1939. self.padding = nn.ParameterDict(
  1940. {
  1941. modality: nn.Parameter(torch.randn(1, self.num_query_channels - decoder.num_query_channels))
  1942. for modality, decoder in modalities.items()
  1943. }
  1944. )
  1945. @property
  1946. def num_query_channels(self) -> int:
  1947. max_channel_size = max(decoder.num_query_channels for _, decoder in self.modalities.items())
  1948. common_channel_size = max_channel_size + self.min_padding_size
  1949. return common_channel_size
  1950. def decoder_query(self, inputs, modality_sizes, inputs_without_pos=None, subsampled_points=None):
  1951. # Partition the flat inputs among the different modalities
  1952. inputs = restructure(modality_sizes, inputs)
  1953. # Obtain modality-specific decoders' queries
  1954. subsampled_points = subsampled_points or {}
  1955. decoder_queries = {}
  1956. for modality, decoder in self.modalities.items():
  1957. # Get input_without_pos for this modality if it exists.
  1958. input_without_pos = None
  1959. if inputs_without_pos is not None:
  1960. input_without_pos = inputs_without_pos.get(modality, None)
  1961. query = decoder.decoder_query(
  1962. inputs=inputs[modality],
  1963. modality_sizes=None,
  1964. inputs_without_pos=input_without_pos,
  1965. subsampled_points=subsampled_points.get(modality, None),
  1966. )
  1967. decoder_queries[modality] = query
  1968. # Pad all queries with trainable position encodings to make them have the same channels
  1969. def embed(modality, x):
  1970. x = torch.reshape(x, [x.shape[0], np.prod(x.shape[1:-1]), x.shape[-1]])
  1971. pos = self.padding[modality]
  1972. pos = torch.broadcast_to(pos, [x.shape[0], x.shape[1], self.num_query_channels - x.shape[2]])
  1973. return torch.cat([x, pos], dim=2)
  1974. # Apply a predictable ordering to the modalities
  1975. return torch.cat(
  1976. [embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1
  1977. )
  1978. def forward(
  1979. self,
  1980. query: torch.Tensor,
  1981. z: torch.FloatTensor,
  1982. query_mask: torch.FloatTensor | None = None,
  1983. output_attentions: bool | None = False,
  1984. ) -> torch.Tensor:
  1985. # B x 1 x num_classes -> B x num_classes
  1986. decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
  1987. return decoder_outputs
  1988. # Below: IO pre- and post-processor classes for Perceiver.
  1989. def space_to_depth(frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1) -> torch.Tensor:
  1990. """
  1991. Space to depth transform. Rearranges blocks of spatial data, into depth.
  1992. This function assumes the channels to be first, but will place the channels last after transformation.
  1993. """
  1994. if len(frames.shape) == 4:
  1995. batch_size, num_channels, height, width = frames.shape
  1996. # split up dimensions (height by spatial_block_size, width by spatial_block_size)
  1997. frames = frames.view(
  1998. batch_size,
  1999. num_channels,
  2000. height // spatial_block_size,
  2001. spatial_block_size,
  2002. width // spatial_block_size,
  2003. spatial_block_size,
  2004. )
  2005. # move blocks to last dimension: (batch_size, H//bs, W//bs, bs, bs, C)
  2006. frames = frames.permute(0, 2, 4, 3, 5, 1).contiguous()
  2007. # concatenate blocks along channel dimension: (batch_size, H//bs, W//bs, bs*bs*C)
  2008. frames = frames.view(
  2009. batch_size,
  2010. height // spatial_block_size,
  2011. width // spatial_block_size,
  2012. (spatial_block_size**2) * num_channels,
  2013. )
  2014. return frames
  2015. elif len(frames.shape) == 5:
  2016. batch_size, time, num_channels, height, width = frames.shape
  2017. # split up dimensions (time by temporal_block_size, height by spatial_block_size, width by spatial_block_size)
  2018. frames = frames.view(
  2019. batch_size,
  2020. time // temporal_block_size,
  2021. temporal_block_size,
  2022. num_channels,
  2023. height // spatial_block_size,
  2024. spatial_block_size,
  2025. width // spatial_block_size,
  2026. spatial_block_size,
  2027. )
  2028. # move blocks to last dimension: (batch_size, T//ts, H//bs, W//bs, ts, bs, bs, C)
  2029. frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
  2030. # concatenate blocks along channel dimension: (batch_size, T//ts, H//bs, W//bs, ts*bs*bs*C)
  2031. frames = frames.view(
  2032. batch_size,
  2033. time // temporal_block_size,
  2034. height // spatial_block_size,
  2035. width // spatial_block_size,
  2036. temporal_block_size * (spatial_block_size**2) * num_channels,
  2037. )
  2038. return frames
  2039. else:
  2040. raise ValueError(
  2041. "Frames should be of rank 4 (batch, channels, height, width)"
  2042. " or rank 5 (batch, time, channels, height, width)"
  2043. )
  2044. class Conv2dSamePadding(nn.Conv2d):
  2045. """
  2046. Conv2d layer with padding="same" support. Source:
  2047. https://gist.github.com/sumanmichael/4de9dee93f972d47c80c4ade8e149ea6
  2048. """
  2049. def __init__(self, *args, **kwargs):
  2050. super().__init__(*args, **kwargs)
  2051. self.zero_pad_2d = nn.ZeroPad2d(
  2052. reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in self.kernel_size[::-1]])
  2053. )
  2054. def forward(self, input):
  2055. return self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias)
  2056. class Conv2DDownsample(nn.Module):
  2057. """Downsamples 4x by applying a 2D convolution and doing max pooling."""
  2058. def __init__(
  2059. self,
  2060. num_layers: int = 1,
  2061. in_channels: int = 3,
  2062. out_channels: int = 64,
  2063. use_batchnorm: bool = True,
  2064. ):
  2065. """
  2066. Constructs a Conv2DDownsample model.
  2067. Args:
  2068. in_channels (`int`, *optional*, defaults to 3):
  2069. The number of input channels.
  2070. out_channels (`int`, *optional*, defaults to 64):
  2071. The number of conv output channels.
  2072. use_batchnorm (`bool`, *optional*, defaults to `True`):
  2073. Whether to use batchnorm.
  2074. """
  2075. super().__init__()
  2076. self.conv = Conv2dSamePadding(
  2077. in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, bias=False
  2078. )
  2079. self.batchnorm = nn.BatchNorm2d(num_features=out_channels) if use_batchnorm else nn.Identity()
  2080. self.relu = nn.ReLU()
  2081. self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)
  2082. def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  2083. out = self.conv(inputs)
  2084. out = self.batchnorm(out)
  2085. out = self.relu(out)
  2086. out = self.max_pool(out)
  2087. return out
  2088. def generate_fourier_features(pos, num_bands, max_resolution=(224, 224), concat_pos=True, sine_only=False):
  2089. """
  2090. Generate a Fourier frequency position encoding with linear spacing.
  2091. Args:
  2092. pos (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`):
  2093. The Tensor containing the position of n points in d dimensional space.
  2094. num_bands (`int`):
  2095. The number of frequency bands (K) to use.
  2096. max_resolution (`tuple[int]`, *optional*, defaults to (224, 224)):
  2097. The maximum resolution (i.e. the number of pixels per dim). A tuple representing resolution for each dimension.
  2098. concat_pos (`bool`, *optional*, defaults to `True`):
  2099. Whether to concatenate the input position encoding to the Fourier features.
  2100. sine_only (`bool`, *optional*, defaults to `False`):
  2101. Whether to use a single phase (sin) or two (sin/cos) for each frequency band.
  2102. Returns:
  2103. `torch.FloatTensor` of shape `(batch_size, sequence_length, n_channels)`: The Fourier position embeddings. If
  2104. `concat_pos` is `True` and `sine_only` is `False`, output dimensions are ordered as: [dim_1, dim_2, ..., dim_d,
  2105. sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ..., sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d), cos(pi*f_1*dim_1),
  2106. ..., cos(pi*f_K*dim_1), ..., cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)], where dim_i is pos[:, i] and f_k is the
  2107. kth frequency band.
  2108. """
  2109. batch_size = pos.shape[0]
  2110. min_freq = 1.0
  2111. # Nyquist frequency at the target resolution:
  2112. freq_bands = torch.stack(
  2113. [torch.linspace(start=min_freq, end=res / 2, steps=num_bands) for res in max_resolution], dim=0
  2114. )
  2115. # Get frequency bands for each spatial dimension.
  2116. # Output is size [n, d * num_bands]
  2117. per_pos_features = pos[0, :, :][:, :, None] * freq_bands[None, :, :]
  2118. per_pos_features = torch.reshape(per_pos_features, [-1, np.prod(per_pos_features.shape[1:])])
  2119. if sine_only:
  2120. # Output is size [n, d * num_bands]
  2121. per_pos_features = torch.sin(np.pi * (per_pos_features))
  2122. else:
  2123. # Output is size [n, 2 * d * num_bands]
  2124. per_pos_features = torch.cat(
  2125. [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1
  2126. )
  2127. # Concatenate the raw input positions.
  2128. if concat_pos:
  2129. # Adds d bands to the encoding.
  2130. per_pos_features = torch.cat([pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1)
  2131. return per_pos_features
  2132. def build_linear_positions(index_dims, output_range=(-1.0, 1.0)):
  2133. """
  2134. Generate an array of position indices for an N-D input array.
  2135. Args:
  2136. index_dims (`list[int]`):
  2137. The shape of the index dimensions of the input array.
  2138. output_range (`tuple[float]`, *optional*, defaults to `(-1.0, 1.0)`):
  2139. The min and max values taken by each input index dimension.
  2140. Returns:
  2141. `torch.FloatTensor` of shape `(index_dims[0], index_dims[1], .., index_dims[-1], N)`.
  2142. """
  2143. def _linspace(n_xels_per_dim):
  2144. return torch.linspace(start=output_range[0], end=output_range[1], steps=n_xels_per_dim, dtype=torch.float32)
  2145. dim_ranges = [_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims]
  2146. array_index_grid = torch.meshgrid(*dim_ranges, indexing="ij")
  2147. return torch.stack(array_index_grid, dim=-1)
  2148. class PerceiverAbstractPositionEncoding(nn.Module, metaclass=abc.ABCMeta):
  2149. """Perceiver abstract position encoding."""
  2150. @property
  2151. @abc.abstractmethod
  2152. def num_dimensions(self) -> int:
  2153. raise NotImplementedError
  2154. @abc.abstractmethod
  2155. def output_size(self, *args, **kwargs) -> int:
  2156. raise NotImplementedError
  2157. @abc.abstractmethod
  2158. def forward(self, batch_size, pos):
  2159. raise NotImplementedError
  2160. class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding):
  2161. """Trainable position encoding."""
  2162. def __init__(self, index_dims, num_channels=128):
  2163. super().__init__()
  2164. self._num_channels = num_channels
  2165. self._index_dims = index_dims
  2166. index_dim = np.prod(index_dims)
  2167. self.position_embeddings = nn.Parameter(torch.randn(index_dim, num_channels))
  2168. @property
  2169. def num_dimensions(self) -> int:
  2170. if isinstance(self._index_dims, int):
  2171. return 1
  2172. return len(self._index_dims)
  2173. def output_size(self, *args, **kwargs) -> int:
  2174. return self._num_channels
  2175. def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  2176. num_positions = position_embeddings.shape[0]
  2177. new_height = new_width = torch_int(num_positions**0.5)
  2178. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  2179. if not torch.jit.is_tracing() and height == new_height and width == new_width:
  2180. return position_embeddings
  2181. position_embeddings = position_embeddings.reshape(1, new_height, new_width, self._num_channels).permute(
  2182. 0, 3, 1, 2
  2183. )
  2184. position_embeddings = nn.functional.interpolate(
  2185. position_embeddings,
  2186. size=(height, width),
  2187. mode="bicubic",
  2188. align_corners=False,
  2189. )
  2190. position_embeddings = position_embeddings.reshape(1, self._num_channels, -1).permute(0, 2, 1).squeeze(0)
  2191. return position_embeddings
  2192. def forward(
  2193. self, batch_size: int, interpolate_pos_encoding: bool = False, input_size: torch.Size | None = None
  2194. ) -> torch.Tensor:
  2195. position_embeddings = self.position_embeddings
  2196. if interpolate_pos_encoding:
  2197. height, width = input_size
  2198. position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width)
  2199. if batch_size is not None:
  2200. position_embeddings = position_embeddings.expand(batch_size, -1, -1)
  2201. return position_embeddings
  2202. def _check_or_build_spatial_positions(pos, index_dims, batch_size):
  2203. """
  2204. Checks or builds spatial position features (x, y, ...).
  2205. Args:
  2206. pos (`torch.FloatTensor`):
  2207. None, or an array of position features. If None, position features are built. Otherwise, their size is checked.
  2208. index_dims (`list[int]`):
  2209. An iterable giving the spatial/index size of the data to be featurized.
  2210. batch_size (`int`):
  2211. The batch size of the data to be featurized.
  2212. Returns:
  2213. `torch.FloatTensor` of shape `(batch_size, prod(index_dims))` an array of position features.
  2214. """
  2215. if pos is None:
  2216. pos = build_linear_positions(index_dims)
  2217. # equivalent to `torch.broadcast_to(pos[None], (batch_size,) + pos.shape)`
  2218. # but `torch.broadcast_to` cannot be converted to ONNX
  2219. pos = pos[None].expand((batch_size,) + pos.shape)
  2220. pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1])
  2221. else:
  2222. # Just a warning label: you probably don't want your spatial features to
  2223. # have a different spatial layout than your pos coordinate system.
  2224. # But feel free to override if you think it'll work!
  2225. if pos.shape[-1] != len(index_dims):
  2226. raise ValueError("Spatial features have the wrong number of dimensions.")
  2227. return pos
  2228. class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding):
  2229. """Fourier (Sinusoidal) position encoding."""
  2230. def __init__(self, num_bands, max_resolution, concat_pos=True, sine_only=False):
  2231. super().__init__()
  2232. self.num_bands = num_bands
  2233. self.max_resolution = max_resolution
  2234. self.concat_pos = concat_pos
  2235. self.sine_only = sine_only
  2236. @property
  2237. def num_dimensions(self) -> int:
  2238. return len(self.max_resolution)
  2239. def output_size(self):
  2240. """Returns size of positional encodings last dimension."""
  2241. num_dims = len(self.max_resolution)
  2242. encoding_size = self.num_bands * num_dims
  2243. if not self.sine_only:
  2244. encoding_size *= 2
  2245. if self.concat_pos:
  2246. encoding_size += self.num_dimensions
  2247. return encoding_size
  2248. def forward(
  2249. self,
  2250. index_dims: list[int],
  2251. batch_size: int,
  2252. device: torch.device,
  2253. dtype: torch.dtype,
  2254. pos: torch.FloatTensor | None = None,
  2255. ) -> torch.FloatTensor:
  2256. pos = _check_or_build_spatial_positions(pos, index_dims, batch_size)
  2257. fourier_pos_enc = generate_fourier_features(
  2258. pos,
  2259. num_bands=self.num_bands,
  2260. max_resolution=self.max_resolution,
  2261. concat_pos=self.concat_pos,
  2262. sine_only=self.sine_only,
  2263. ).to(device=device, dtype=dtype)
  2264. return fourier_pos_enc
  2265. class AbstractPreprocessor(nn.Module):
  2266. @property
  2267. def num_channels(self) -> int:
  2268. """Returns size of preprocessor output."""
  2269. raise NotImplementedError()
  2270. class PerceiverTextPreprocessor(AbstractPreprocessor):
  2271. """
  2272. Text preprocessing for Perceiver Encoder. Can be used to embed `inputs` and add positional encodings.
  2273. The dimensionality of the embeddings is determined by the `d_model` attribute of the configuration.
  2274. Args:
  2275. config ([`PerceiverConfig`]):
  2276. Model configuration.
  2277. """
  2278. def __init__(self, config: PerceiverConfig) -> None:
  2279. super().__init__()
  2280. self.config = config
  2281. self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)
  2282. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
  2283. @property
  2284. def num_channels(self) -> int:
  2285. return self.config.d_model
  2286. def forward(
  2287. self,
  2288. inputs: torch.LongTensor,
  2289. pos: torch.Tensor | None = None,
  2290. network_input_is_1d: bool = True,
  2291. interpolate_pos_encoding: bool = False,
  2292. ):
  2293. embeddings_without_pos = self.embeddings(inputs)
  2294. seq_length = inputs.shape[1]
  2295. position_ids = torch.arange(0, seq_length, device=inputs.device)
  2296. embeddings = embeddings_without_pos + self.position_embeddings(position_ids)
  2297. return embeddings, None, embeddings_without_pos
  2298. class PerceiverEmbeddingDecoder(nn.Module):
  2299. """
  2300. Module to decode embeddings (for masked language modeling).
  2301. Args:
  2302. config ([`PerceiverConfig`]):
  2303. Model configuration.
  2304. """
  2305. def __init__(self, config: PerceiverConfig) -> None:
  2306. super().__init__()
  2307. self.config = config
  2308. self.vocab_size = config.vocab_size
  2309. self.bias = nn.Parameter(torch.zeros(self.vocab_size))
  2310. def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor:
  2311. batch_size, seq_len, d_model = hidden_states.shape
  2312. # Flatten batch dim
  2313. output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.transpose(0, 1))
  2314. output = output + self.bias
  2315. return output.reshape([batch_size, seq_len, self.vocab_size])
  2316. class PerceiverMultimodalPostprocessor(nn.Module):
  2317. """
  2318. Multimodal postprocessing for Perceiver. Can be used to combine modality-specific postprocessors into a single
  2319. postprocessor.
  2320. Args:
  2321. modalities (`Mapping[str, PostprocessorType]`):
  2322. Dictionary mapping modality name to postprocessor class for that modality.
  2323. input_is_dict (`bool`, *optional*, defaults to `False`):
  2324. If True, input is assumed to be dictionary structured, and outputs keep the same dictionary shape. If
  2325. False, input is a tensor which is sliced up during postprocessing by *modality_sizes*.
  2326. """
  2327. def __init__(self, modalities: Mapping[str, PostprocessorType], input_is_dict: bool = False):
  2328. super().__init__()
  2329. self.modalities = nn.ModuleDict(modalities)
  2330. self.input_is_dict = input_is_dict
  2331. def forward(
  2332. self, inputs: torch.Tensor, pos: torch.Tensor | None = None, modality_sizes=None
  2333. ) -> Mapping[str, torch.Tensor]:
  2334. if not self.input_is_dict:
  2335. # Slice up modalities by their sizes.
  2336. if modality_sizes is None:
  2337. raise ValueError("Modality sizes should be specified if input is not a dictionary.")
  2338. inputs = restructure(modality_sizes=modality_sizes, inputs=inputs)
  2339. outputs = {
  2340. modality: postprocessor(inputs[modality], pos=pos, modality_sizes=None)
  2341. for modality, postprocessor in self.modalities.items()
  2342. }
  2343. return outputs
  2344. class PerceiverClassificationPostprocessor(nn.Module):
  2345. """
  2346. Classification postprocessing for Perceiver. Can be used to convert the decoder output to classification logits.
  2347. Args:
  2348. config ([*PerceiverConfig*]):
  2349. Model configuration.
  2350. in_channels (`int`):
  2351. Number of channels in the input.
  2352. """
  2353. def __init__(self, config: PerceiverConfig, in_channels: int) -> None:
  2354. super().__init__()
  2355. self.classifier = nn.Linear(in_channels, config.num_labels)
  2356. def forward(self, inputs, pos: torch.Tensor | None = None, modality_sizes=None) -> torch.Tensor:
  2357. logits = self.classifier(inputs)
  2358. return logits[:, 0, :]
  2359. class PerceiverAudioPostprocessor(nn.Module):
  2360. """
  2361. Audio postprocessing for Perceiver. Can be used to convert the decoder output to audio features.
  2362. Args:
  2363. config ([*PerceiverConfig*]):
  2364. Model configuration.
  2365. in_channels (`int`):
  2366. Number of channels in the input.
  2367. postproc_type (`str`, *optional*, defaults to `"patches"`):
  2368. Postprocessor type to use. Currently, only "patches" is supported.
  2369. """
  2370. def __init__(self, config: PerceiverConfig, in_channels: int, postproc_type: str = "patches") -> None:
  2371. super().__init__()
  2372. if postproc_type != "patches": # to be supported: 'conv', 'patches', 'pixels'
  2373. raise ValueError("Invalid postproc_type!")
  2374. # Architecture parameters:
  2375. self.classifier = nn.Linear(in_channels, config.samples_per_patch)
  2376. def forward(self, inputs: torch.Tensor, pos: torch.Tensor | None = None, modality_sizes=None) -> torch.Tensor:
  2377. logits = self.classifier(inputs)
  2378. return torch.reshape(logits, [inputs.shape[0], -1])
  2379. class PerceiverProjectionPostprocessor(nn.Module):
  2380. """
  2381. Projection postprocessing for Perceiver. Can be used to project the channels of the decoder output to a lower
  2382. dimension.
  2383. Args:
  2384. in_channels (`int`):
  2385. Number of channels in the input.
  2386. out_channels (`int`):
  2387. Number of channels in the output.
  2388. """
  2389. def __init__(self, in_channels: int, out_channels: int) -> None:
  2390. super().__init__()
  2391. self.classifier = nn.Linear(in_channels, out_channels)
  2392. def forward(self, inputs: torch.Tensor, pos: torch.Tensor | None = None, modality_sizes=None) -> torch.Tensor:
  2393. logits = self.classifier(inputs)
  2394. return logits
  2395. class PerceiverImagePreprocessor(AbstractPreprocessor):
  2396. """
  2397. Image preprocessing for Perceiver Encoder.
  2398. Note: the *out_channels* argument refers to the output channels of a convolutional layer, if *prep_type* is set to
  2399. "conv1x1" or "conv". If one adds absolute position embeddings, one must make sure the *num_channels* of the
  2400. position encoding kwargs are set equal to the *out_channels*.
  2401. Args:
  2402. config ([*PerceiverConfig*]):
  2403. Model configuration.
  2404. prep_type (`str`, *optional*, defaults to `"conv"`):
  2405. Preprocessing type. Can be "conv1x1", "conv", "patches", "pixels".
  2406. spatial_downsample (`int`, *optional*, defaults to 4):
  2407. Spatial downsampling factor.
  2408. temporal_downsample (`int`, *optional*, defaults to 1):
  2409. Temporal downsampling factor (only relevant in case a time dimension is present).
  2410. position_encoding_type (`str`, *optional*, defaults to `"fourier"`):
  2411. Position encoding type. Can be "fourier" or "trainable".
  2412. in_channels (`int`, *optional*, defaults to 3):
  2413. Number of channels in the input.
  2414. out_channels (`int`, *optional*, defaults to 64):
  2415. Number of channels in the output.
  2416. conv_after_patching (`bool`, *optional*, defaults to `False`):
  2417. Whether to apply a convolutional layer after patching.
  2418. conv_after_patching_in_channels (`int`, *optional*, defaults to 54):
  2419. Number of channels in the input of the convolutional layer after patching.
  2420. conv2d_use_batchnorm (`bool`, *optional*, defaults to `True`):
  2421. Whether to use batch normalization in the convolutional layer.
  2422. concat_or_add_pos (`str`, *optional*, defaults to `"concat"`):
  2423. How to concatenate the position encoding to the input. Can be "concat" or "add".
  2424. project_pos_dim (`int`, *optional*, defaults to -1):
  2425. Dimension of the position encoding to project to. If -1, no projection is applied.
  2426. **position_encoding_kwargs (`Dict`, *optional*):
  2427. Keyword arguments for the position encoding.
  2428. """
  2429. def __init__(
  2430. self,
  2431. config,
  2432. prep_type="conv",
  2433. spatial_downsample: int = 4,
  2434. temporal_downsample: int = 1,
  2435. position_encoding_type: str = "fourier",
  2436. in_channels: int = 3,
  2437. out_channels: int = 64,
  2438. conv_after_patching: bool = False,
  2439. conv_after_patching_in_channels: int = 54, # only relevant when conv_after_patching = True
  2440. conv2d_use_batchnorm: bool = True,
  2441. concat_or_add_pos: str = "concat",
  2442. project_pos_dim: int = -1,
  2443. **position_encoding_kwargs,
  2444. ):
  2445. super().__init__()
  2446. self.config = config
  2447. if prep_type not in ("conv", "patches", "pixels", "conv1x1"):
  2448. raise ValueError(f"Prep_type {prep_type} is invalid")
  2449. if concat_or_add_pos not in ["concat", "add"]:
  2450. raise ValueError(f"Invalid value {concat_or_add_pos} for concat_or_add_pos.")
  2451. self.in_channels = in_channels
  2452. self.prep_type = prep_type
  2453. self.spatial_downsample = spatial_downsample
  2454. self.temporal_downsample = temporal_downsample
  2455. self.position_encoding_type = position_encoding_type
  2456. self.concat_or_add_pos = concat_or_add_pos
  2457. self.conv_after_patching = conv_after_patching
  2458. self.out_channels = out_channels
  2459. if self.prep_type == "conv":
  2460. # Downsampling with conv is currently restricted
  2461. convnet_num_layers = math.log(spatial_downsample, 4)
  2462. convnet_num_layers_is_int = convnet_num_layers == np.round(convnet_num_layers)
  2463. if not convnet_num_layers_is_int or temporal_downsample != 1:
  2464. raise ValueError(
  2465. "Only powers of 4 expected for spatial and 1 expected for temporal downsampling with conv."
  2466. )
  2467. self.convnet = Conv2DDownsample(
  2468. in_channels=in_channels,
  2469. num_layers=int(convnet_num_layers),
  2470. out_channels=out_channels,
  2471. use_batchnorm=conv2d_use_batchnorm,
  2472. )
  2473. elif self.prep_type == "conv1x1":
  2474. if temporal_downsample != 1:
  2475. raise ValueError("Conv1x1 does not downsample in time.")
  2476. self.convnet_1x1 = nn.Conv2d(
  2477. in_channels=in_channels,
  2478. out_channels=out_channels,
  2479. kernel_size=(1, 1),
  2480. # spatial_downsample is unconstrained for 1x1 convolutions.
  2481. stride=(spatial_downsample, spatial_downsample),
  2482. )
  2483. # Position embeddings
  2484. self.project_pos_dim = project_pos_dim
  2485. self.position_embeddings, self.positions_projection = build_position_encoding(
  2486. position_encoding_type=position_encoding_type,
  2487. out_channels=out_channels,
  2488. project_pos_dim=project_pos_dim,
  2489. **position_encoding_kwargs,
  2490. )
  2491. # Optional convolutional layer after patches.
  2492. self.conv_after_patches = (
  2493. nn.Linear(conv_after_patching_in_channels, self.out_channels) if conv_after_patching else nn.Identity()
  2494. )
  2495. @property
  2496. def num_channels(self) -> int:
  2497. # Let's assume that the number of resolutions (in the context of image preprocessing)
  2498. # of the input data is 2 or 3 depending on whether we are processing image or video respectively.
  2499. # In this case, for convenience, we will declare is_temporal variable,
  2500. # which will show whether the data has a temporal dimension or not.
  2501. is_temporal = self.position_embeddings.num_dimensions > 2
  2502. # position embedding
  2503. if self.project_pos_dim > 0:
  2504. pos_dim = self.project_pos_dim
  2505. else:
  2506. pos_dim = self.position_embeddings.output_size()
  2507. if self.concat_or_add_pos == "add":
  2508. return pos_dim
  2509. # inputs
  2510. if self.conv_after_patching or self.prep_type in ("conv1x1", "conv"):
  2511. inp_dim = self.out_channels
  2512. elif self.prep_type == "pixels":
  2513. inp_dim = self.in_channels
  2514. if not is_temporal:
  2515. inp_dim = math.ceil(inp_dim / self.spatial_downsample)
  2516. elif self.prep_type == "patches":
  2517. if self.conv_after_patching:
  2518. inp_dim = self.out_channels
  2519. else:
  2520. inp_dim = self.in_channels * self.spatial_downsample**2
  2521. if is_temporal:
  2522. inp_dim *= self.temporal_downsample
  2523. return inp_dim + pos_dim
  2524. def _build_network_inputs(
  2525. self, inputs: torch.Tensor, network_input_is_1d: bool = True, interpolate_pos_encoding: bool = False
  2526. ):
  2527. """
  2528. Construct the final input, including position encoding.
  2529. This method expects the inputs to always have channels as last dimension.
  2530. """
  2531. batch_size = inputs.shape[0]
  2532. input_size = inputs.shape[1:3]
  2533. index_dims = inputs.shape[1:-1]
  2534. indices = np.prod(index_dims)
  2535. # Flatten input features to a 1D index dimension if necessary.
  2536. if len(inputs.shape) > 3 and network_input_is_1d:
  2537. inputs = torch.reshape(inputs, [batch_size, indices, -1])
  2538. # Construct the position encoding.
  2539. if self.position_encoding_type == "trainable":
  2540. pos_enc = self.position_embeddings(batch_size, interpolate_pos_encoding, input_size)
  2541. elif self.position_encoding_type == "fourier":
  2542. pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
  2543. # Optionally project them to a target dimension.
  2544. pos_enc = self.positions_projection(pos_enc)
  2545. if not network_input_is_1d:
  2546. # Reshape pos to match the input feature shape
  2547. # if the network takes non-1D inputs
  2548. sh = inputs.shape
  2549. pos_enc = torch.reshape(pos_enc, list(sh)[:-1] + [-1])
  2550. if self.concat_or_add_pos == "concat":
  2551. inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1)
  2552. elif self.concat_or_add_pos == "add":
  2553. inputs_with_pos = inputs + pos_enc
  2554. return inputs_with_pos, inputs
  2555. def forward(
  2556. self,
  2557. inputs: torch.Tensor,
  2558. pos: torch.Tensor | None = None,
  2559. network_input_is_1d: bool = True,
  2560. interpolate_pos_encoding: bool = False,
  2561. ):
  2562. if self.prep_type == "conv":
  2563. # Convnet image featurization.
  2564. # Downsamples spatially by a factor of 4
  2565. inputs = self.convnet(inputs)
  2566. elif self.prep_type == "conv1x1":
  2567. # map inputs to self.out_channels
  2568. inputs = self.convnet_1x1(inputs)
  2569. elif self.prep_type == "pixels":
  2570. # if requested, downsamples in the crudest way
  2571. if inputs.ndim == 4:
  2572. inputs = inputs[:: self.spatial_downsample, :: self.spatial_downsample]
  2573. elif inputs.ndim == 5:
  2574. inputs = inputs[
  2575. :, :: self.temporal_downsample, :, :: self.spatial_downsample, :: self.spatial_downsample
  2576. ]
  2577. else:
  2578. raise ValueError("Unsupported data format for pixels.")
  2579. elif self.prep_type == "patches":
  2580. # Space2depth featurization.
  2581. # Video: B x T x C x H x W
  2582. inputs = space_to_depth(
  2583. inputs, temporal_block_size=self.temporal_downsample, spatial_block_size=self.spatial_downsample
  2584. )
  2585. if inputs.ndim == 5 and inputs.shape[1] == 1:
  2586. # for flow
  2587. inputs = inputs.squeeze(dim=1)
  2588. # Optionally apply conv layer.
  2589. inputs = self.conv_after_patches(inputs)
  2590. if self.prep_type != "patches":
  2591. # move channels to last dimension, as the _build_network_inputs method below expects this
  2592. if inputs.ndim == 4:
  2593. inputs = inputs.permute(0, 2, 3, 1)
  2594. elif inputs.ndim == 5:
  2595. inputs = inputs.permute(0, 1, 3, 4, 2)
  2596. else:
  2597. raise ValueError("Unsupported data format for conv1x1.")
  2598. inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d, interpolate_pos_encoding)
  2599. modality_sizes = None # Size for each modality, only needed for multimodal
  2600. return inputs, modality_sizes, inputs_without_pos
  2601. class PerceiverOneHotPreprocessor(AbstractPreprocessor):
  2602. """
  2603. One-hot preprocessor for Perceiver Encoder. Can be used to add a dummy index dimension to the input.
  2604. Args:
  2605. config ([`PerceiverConfig`]):
  2606. Model configuration.
  2607. """
  2608. def __init__(self, config: PerceiverConfig) -> None:
  2609. super().__init__()
  2610. self.config: PerceiverConfig = config
  2611. @property
  2612. def num_channels(self) -> int:
  2613. return self.config.num_labels
  2614. def forward(self, inputs: torch.Tensor, pos: torch.Tensor | None = None, network_input_is_1d: bool = True):
  2615. # Add a dummy index dimension.
  2616. inputs = inputs[:, None, :]
  2617. # No position encodings, so the 1st (input) and 3rd (inputs_without_pos)
  2618. # outputs are identical.
  2619. return inputs, None, inputs
  2620. class PerceiverAudioPreprocessor(AbstractPreprocessor):
  2621. """
  2622. Audio preprocessing for Perceiver Encoder.
  2623. Args:
  2624. config ([*PerceiverConfig*]):
  2625. Model configuration.
  2626. prep_type (`str`, *optional*, defaults to `"patches"`):
  2627. Preprocessor type to use. Only "patches" is supported.
  2628. samples_per_patch (`int`, *optional*, defaults to 96):
  2629. Number of samples per patch.
  2630. position_encoding_type (`str`, *optional*, defaults to `"fourier"`):
  2631. Type of position encoding to use. Can be "trainable" or "fourier".
  2632. concat_or_add_pos (`str`, *optional*, defaults to `"concat"`):
  2633. How to concatenate the position encoding to the input. Can be "concat" or "add".
  2634. out_channels (`int`, *optional*, defaults to 64):
  2635. Number of channels in the output.
  2636. project_pos_dim (`int`, *optional*, defaults to -1):
  2637. Dimension of the position encoding to project to. If -1, no projection is applied.
  2638. **position_encoding_kwargs (`Dict`, *optional*):
  2639. Keyword arguments for the position encoding.
  2640. """
  2641. def __init__(
  2642. self,
  2643. config,
  2644. prep_type: str = "patches",
  2645. samples_per_patch: int = 96,
  2646. position_encoding_type: str = "fourier",
  2647. concat_or_add_pos: str = "concat",
  2648. out_channels=64,
  2649. project_pos_dim=-1,
  2650. **position_encoding_kwargs,
  2651. ):
  2652. super().__init__()
  2653. self.config = config
  2654. if prep_type != "patches":
  2655. raise ValueError(f"Prep_type {prep_type} is invalid, can only be 'patches'.")
  2656. if concat_or_add_pos not in ["concat", "add"]:
  2657. raise ValueError(f"Concat_or_pos {concat_or_add_pos} is invalid, can only be 'concat' or 'add'.")
  2658. self.samples_per_patch = samples_per_patch
  2659. self.position_encoding_type = position_encoding_type
  2660. self.concat_or_add_pos = concat_or_add_pos
  2661. self.project_pos_dim = project_pos_dim
  2662. # Position embeddings
  2663. self.position_embeddings, self.positions_projection = build_position_encoding(
  2664. position_encoding_type=position_encoding_type,
  2665. out_channels=out_channels,
  2666. project_pos_dim=project_pos_dim,
  2667. **position_encoding_kwargs,
  2668. )
  2669. @property
  2670. def num_channels(self) -> int:
  2671. # position embedding
  2672. if self.project_pos_dim > 0:
  2673. pos_dim = self.project_pos_dim
  2674. else:
  2675. pos_dim = self.position_embeddings.output_size()
  2676. if self.concat_or_add_pos == "add":
  2677. return pos_dim
  2678. return self.samples_per_patch + pos_dim
  2679. def _build_network_inputs(self, inputs):
  2680. """Construct the final input, including position encoding."""
  2681. batch_size = inputs.shape[0]
  2682. index_dims = inputs.shape[1:-1]
  2683. # Construct the position encoding.
  2684. if self.position_encoding_type == "trainable":
  2685. pos_enc = self.position_embeddings(batch_size)
  2686. elif self.position_encoding_type == "fourier":
  2687. pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
  2688. # Optionally project them to a target dimension.
  2689. pos_enc = self.positions_projection(pos_enc)
  2690. if self.concat_or_add_pos == "concat":
  2691. inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1)
  2692. elif self.concat_or_add_pos == "add":
  2693. inputs_with_pos = inputs + pos_enc
  2694. return inputs_with_pos, inputs
  2695. def forward(
  2696. self,
  2697. inputs: torch.Tensor,
  2698. pos: torch.Tensor | None = None,
  2699. network_input_is_1d: bool = True,
  2700. interpolate_pos_encoding: bool = False,
  2701. ):
  2702. inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])
  2703. inputs, inputs_without_pos = self._build_network_inputs(inputs)
  2704. modality_sizes = None # Size for each modality, only needed for multimodal
  2705. return inputs, modality_sizes, inputs_without_pos
  2706. class PerceiverMultimodalPreprocessor(AbstractPreprocessor):
  2707. """
  2708. Multimodal preprocessing for Perceiver Encoder.
  2709. Inputs for each modality are preprocessed, then padded with trainable position embeddings to have the same number
  2710. of channels.
  2711. Args:
  2712. modalities (`Mapping[str, PreprocessorType]`):
  2713. Dict mapping modality name to preprocessor.
  2714. mask_probs (`dict[str, float]`):
  2715. Dict mapping modality name to masking probability of that modality.
  2716. min_padding_size (`int`, *optional*, defaults to 2):
  2717. The minimum padding size for all modalities. The final output will have num_channels equal to the maximum
  2718. channels across all modalities plus min_padding_size.
  2719. """
  2720. def __init__(
  2721. self,
  2722. modalities: Mapping[str, PreprocessorType],
  2723. mask_probs: Mapping[str, float] | None = None,
  2724. min_padding_size: int = 2,
  2725. ):
  2726. super().__init__()
  2727. self.modalities = nn.ModuleDict(modalities)
  2728. self.min_padding_size = min_padding_size
  2729. self.mask_probs = mask_probs if mask_probs is not None else {}
  2730. self.padding = nn.ParameterDict(
  2731. {
  2732. modality: nn.Parameter(torch.randn(1, self.num_channels - preprocessor.num_channels))
  2733. for modality, preprocessor in modalities.items()
  2734. }
  2735. )
  2736. self.mask = nn.ParameterDict(
  2737. {modality: nn.Parameter(torch.randn(1, self.num_channels)) for modality, _ in self.mask_probs.items()}
  2738. )
  2739. @property
  2740. def num_channels(self) -> int:
  2741. max_channel_size = max(processor.num_channels for _, processor in self.modalities.items())
  2742. common_channel_size = max_channel_size + self.min_padding_size
  2743. return common_channel_size
  2744. def forward(
  2745. self,
  2746. inputs: Mapping[str, torch.Tensor],
  2747. pos: torch.Tensor | None = None,
  2748. network_input_is_1d: bool = True,
  2749. interpolate_pos_encoding: bool = False,
  2750. ) -> PreprocessorOutputType:
  2751. padded = {}
  2752. modality_sizes = {}
  2753. inputs_without_pos = {}
  2754. for modality, preprocessor in self.modalities.items():
  2755. # preprocess each modality using the respective preprocessor.
  2756. output, _, inputs_without_pos[modality] = preprocessor(
  2757. inputs[modality], pos=pos, network_input_is_1d=network_input_is_1d
  2758. )
  2759. # pad to the same common_channel_size.
  2760. batch_size, num_samples, num_channels = output.shape
  2761. pos_enc = self.padding[modality].expand(batch_size, -1, -1)
  2762. padding = torch.broadcast_to(
  2763. pos_enc,
  2764. [batch_size, num_samples, self.num_channels - num_channels],
  2765. )
  2766. output_padded = torch.cat([output, padding], dim=2)
  2767. # mask if required
  2768. if modality in self.mask_probs:
  2769. mask_token = self.mask[modality].expand(batch_size, -1, -1)
  2770. mask_prob = self.mask_probs[modality]
  2771. mask = torch.bernoulli(torch.full([batch_size, num_samples], mask_prob))
  2772. mask = torch.unsqueeze(mask, dim=2).to(mask_token.device)
  2773. output_padded = (1 - mask) * output_padded + mask * mask_token
  2774. padded[modality] = output_padded
  2775. modality_sizes[modality] = output_padded.shape[1]
  2776. # Apply a predictable ordering to the modalities
  2777. padded_ls = [padded[k] for k in sorted(padded.keys())]
  2778. # Finally, concatenate along the time dimension
  2779. final_inputs = torch.cat(padded_ls, dim=1)
  2780. return final_inputs, modality_sizes, inputs_without_pos
  2781. __all__ = [
  2782. "PerceiverForImageClassificationConvProcessing",
  2783. "PerceiverForImageClassificationFourier",
  2784. "PerceiverForImageClassificationLearned",
  2785. "PerceiverForMaskedLM",
  2786. "PerceiverForMultimodalAutoencoding",
  2787. "PerceiverForOpticalFlow",
  2788. "PerceiverForSequenceClassification",
  2789. "PerceiverLayer",
  2790. "PerceiverModel",
  2791. "PerceiverPreTrainedModel",
  2792. ]