| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302 |
- # Copyright 2021 Deepmind and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch Perceiver model."""
- import abc
- import math
- from collections.abc import Callable, Mapping
- from dataclasses import dataclass
- from functools import reduce
- from operator import __add__
- from typing import Any, Optional
- import numpy as np
- import torch
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...modeling_outputs import BaseModelOutputWithCrossAttentions
- from ...modeling_utils import PreTrainedModel
- from ...pytorch_utils import apply_chunking_to_forward
- from ...utils import ModelOutput, auto_docstring, logging, torch_int
- from .configuration_perceiver import PerceiverConfig
- ModalitySizeType = Mapping[str, int]
- PreprocessorOutputType = tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]
- PreprocessorType = Callable[..., PreprocessorOutputType]
- PostprocessorType = Callable[..., Any]
- logger = logging.get_logger(__name__)
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Perceiver base model's outputs, with potential hidden states, attentions and cross-attentions.
- """
- )
- class PerceiverModelOutput(ModelOutput):
- r"""
- logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
- Classification (or regression if config.num_labels==1) scores (before SoftMax).
- """
- logits: torch.FloatTensor | None = None
- last_hidden_state: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- cross_attentions: tuple[torch.FloatTensor] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Perceiver decoder outputs, with potential cross-attentions.
- """
- )
- class PerceiverDecoderOutput(ModelOutput):
- r"""
- logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
- Output of the basic decoder.
- """
- logits: torch.FloatTensor | None = None
- cross_attentions: tuple[torch.FloatTensor] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Perceiver's masked language model outputs.
- """
- )
- class PerceiverMaskedLMOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Masked language modeling (MLM) loss.
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- cross_attentions: tuple[torch.FloatTensor] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Perceiver's outputs of sequence/image classification models, optical flow and multimodal
- autoencoding.
- """
- )
- class PerceiverClassifierOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification (or regression if config.num_labels==1) loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
- Classification (or regression if config.num_labels==1) scores (before SoftMax).
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- cross_attentions: tuple[torch.FloatTensor] | None = None
- class PerceiverEmbeddings(nn.Module):
- """Construct the latent embeddings."""
- def __init__(self, config):
- super().__init__()
- self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents))
- def forward(self, batch_size: int):
- return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang
- class PerceiverSelfAttention(nn.Module):
- """Multi-headed {cross, self}-attention. Can be used both in the encoder as well as in the decoder."""
- def __init__(
- self,
- config,
- is_cross_attention=False,
- qk_channels=None,
- v_channels=None,
- num_heads=1,
- q_dim=None,
- kv_dim=None,
- ):
- super().__init__()
- self.num_heads = num_heads
- # Q and K must have the same number of channels.
- # Default to preserving Q's input's shape.
- if qk_channels is None:
- qk_channels = q_dim
- # V's num_channels determines the shape of the output of QKV-attention.
- # Default to the same number of channels used in the key-query operation.
- if v_channels is None:
- v_channels = qk_channels
- if qk_channels % num_heads != 0:
- raise ValueError(f"qk_channels ({qk_channels}) must be divisible by num_heads ({num_heads}).")
- if v_channels % num_heads != 0:
- raise ValueError(f"v_channels ({v_channels}) must be divisible by num_heads ({num_heads}).")
- self.qk_channels = qk_channels
- self.v_channels = v_channels
- self.qk_channels_per_head = self.qk_channels // num_heads
- self.v_channels_per_head = self.v_channels // num_heads
- # Layer normalization
- self.layernorm1 = nn.LayerNorm(q_dim)
- self.layernorm2 = nn.LayerNorm(kv_dim) if is_cross_attention else nn.Identity()
- # Projection matrices
- self.query = nn.Linear(q_dim, qk_channels)
- self.key = nn.Linear(kv_dim, qk_channels)
- self.value = nn.Linear(kv_dim, v_channels)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- def transpose_for_scores(self, x, channels_per_head):
- new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head)
- x = x.view(*new_x_shape)
- return x.permute(0, 2, 1, 3)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- inputs: torch.FloatTensor | None = None,
- inputs_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> tuple[torch.Tensor]:
- hidden_states = self.layernorm1(hidden_states)
- inputs = self.layernorm2(inputs)
- # Project queries, keys and values to a common feature dimension. If this is instantiated as a cross-attention module,
- # the keys and values come from the inputs; the attention mask needs to be such that the inputs's non-relevant tokens are not attended to.
- is_cross_attention = inputs is not None
- queries = self.query(hidden_states)
- if is_cross_attention:
- keys = self.key(inputs)
- values = self.value(inputs)
- attention_mask = inputs_mask
- else:
- keys = self.key(hidden_states)
- values = self.value(hidden_states)
- # Reshape channels for multi-head attention.
- # We reshape from (batch_size, time, channels) to (batch_size, num_heads, time, channels per head)
- queries = self.transpose_for_scores(queries, self.qk_channels_per_head)
- keys = self.transpose_for_scores(keys, self.qk_channels_per_head)
- values = self.transpose_for_scores(values, self.v_channels_per_head)
- # Take the dot product between the queries and keys to get the raw attention scores.
- attention_scores = torch.matmul(queries, keys.transpose(-1, -2))
- batch_size, num_heads, seq_len, q_head_dim = queries.shape
- _, _, _, v_head_dim = values.shape
- hiddens = self.num_heads * v_head_dim
- attention_scores = attention_scores / math.sqrt(q_head_dim)
- if attention_mask is not None:
- # Apply the attention mask (precomputed for all layers in PerceiverModel forward() function)
- attention_scores = attention_scores + attention_mask
- # Normalize the attention scores to probabilities.
- attention_probs = nn.Softmax(dim=-1)(attention_scores)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- context_layer = torch.matmul(attention_probs, values)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (hiddens,)
- context_layer = context_layer.view(*new_context_layer_shape)
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
- return outputs
- class PerceiverSelfOutput(nn.Module):
- def __init__(self, config, input_channels, output_channels):
- super().__init__()
- self.dense = nn.Linear(input_channels, output_channels)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- return hidden_states
- class PerceiverAttention(nn.Module):
- """Attention module, including a dense block."""
- def __init__(
- self,
- config,
- is_cross_attention=False,
- qk_channels=None,
- v_channels=None,
- num_heads=1,
- q_dim=None,
- kv_dim=None,
- use_query_residual=True,
- ):
- super().__init__()
- # MultiHead attention
- if is_cross_attention and qk_channels is None:
- if config.cross_attention_shape_for_attention == "q":
- qk_channels = q_dim
- elif config.cross_attention_shape_for_attention == "kv":
- qk_channels = kv_dim
- else:
- raise ValueError(
- f"Unknown value {config.cross_attention_shape_for_attention} for "
- "cross_attention_shape_for_attention."
- )
- else:
- if qk_channels is None:
- qk_channels = q_dim
- if v_channels is None:
- v_channels = qk_channels
- self.self = PerceiverSelfAttention(
- config,
- is_cross_attention=is_cross_attention,
- qk_channels=qk_channels,
- v_channels=v_channels,
- num_heads=num_heads,
- q_dim=q_dim,
- kv_dim=kv_dim,
- )
- # dense block
- output_channels = None
- if is_cross_attention:
- output_channels = q_dim
- else:
- if output_channels is None:
- output_channels = v_channels
- self.output = PerceiverSelfOutput(config, input_channels=self.self.v_channels, output_channels=output_channels)
- self.use_query_residual = use_query_residual
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- inputs: torch.FloatTensor | None = None,
- inputs_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> tuple[torch.Tensor]:
- self_outputs = self.self(
- hidden_states,
- attention_mask,
- inputs,
- inputs_mask,
- output_attentions,
- )
- # Output projection
- attention_output = self.output(self_outputs[0])
- # Optionally include a residual to the original queries.
- # Consider omitting the residual if the semantics of query and output
- # are different, e.g. if queries are positions and outputs are pixels.
- if self.use_query_residual:
- attention_output = attention_output + hidden_states
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
- return outputs
- class PerceiverMLP(nn.Module):
- """A Transformer-style dense module to follow attention."""
- def __init__(self, config, input_size, widening_factor):
- super().__init__()
- self.dense1 = nn.Linear(input_size, widening_factor * input_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- self.dense2 = nn.Linear(widening_factor * input_size, input_size)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense1(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- hidden_states = self.dense2(hidden_states)
- return hidden_states
- class PerceiverLayer(nn.Module):
- def __init__(
- self,
- config,
- is_cross_attention=False,
- qk_channels=None,
- v_channels=None,
- num_heads=1,
- q_dim=None,
- kv_dim=None,
- widening_factor=4,
- use_query_residual=True,
- ):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = PerceiverAttention(
- config,
- is_cross_attention=is_cross_attention,
- qk_channels=qk_channels,
- v_channels=v_channels,
- num_heads=num_heads,
- q_dim=q_dim,
- kv_dim=kv_dim,
- use_query_residual=use_query_residual,
- )
- self.layernorm = nn.LayerNorm(q_dim)
- self.mlp = PerceiverMLP(config, input_size=q_dim, widening_factor=widening_factor)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- inputs: torch.FloatTensor | None = None,
- inputs_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> tuple[torch.Tensor]:
- attention_outputs = self.attention(
- hidden_states,
- attention_mask,
- inputs,
- inputs_mask,
- output_attentions,
- )
- attention_output = attention_outputs[0]
- outputs = attention_outputs[1:] # add attentions if we output attention weights
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
- )
- layer_output = layer_output + attention_output # residual connection
- outputs = (layer_output,) + outputs
- return outputs
- def feed_forward_chunk(self, attention_output):
- layer_output = self.layernorm(attention_output)
- layer_output = self.mlp(layer_output)
- return layer_output
- class PerceiverEncoder(nn.Module):
- """The Perceiver Encoder: a scalable, fully attentional encoder."""
- def __init__(self, config, kv_dim=None):
- super().__init__()
- self.config = config
- # Check that we can use multihead-attention with these shapes.
- if config.d_latents % config.num_self_attention_heads != 0:
- raise ValueError(
- f"num_z_channels ({config.d_latents}) must be divisible by"
- f" num_self_attend_heads ({config.num_self_attention_heads})."
- )
- if config.d_latents % config.num_cross_attention_heads != 0:
- raise ValueError(
- f"num_z_channels ({config.d_latents}) must be divisible by"
- f" num_cross_attend_heads ({config.num_cross_attention_heads})."
- )
- # Construct the cross attention layer.
- self.cross_attention = PerceiverLayer(
- config,
- is_cross_attention=True,
- qk_channels=config.qk_channels,
- v_channels=config.v_channels,
- num_heads=config.num_cross_attention_heads,
- q_dim=config.d_latents,
- kv_dim=kv_dim,
- widening_factor=config.cross_attention_widening_factor,
- use_query_residual=config.use_query_residual,
- )
- # Construct a single block of self-attention layers.
- # We get deeper architectures by applying this block more than once.
- self_attention_layers = []
- for _ in range(config.num_self_attends_per_block):
- layer = PerceiverLayer(
- config,
- is_cross_attention=False,
- qk_channels=config.qk_channels,
- v_channels=config.v_channels,
- num_heads=config.num_self_attention_heads,
- q_dim=config.d_latents,
- kv_dim=config.d_latents,
- widening_factor=config.self_attention_widening_factor,
- )
- self_attention_layers.append(layer)
- self.self_attends = nn.ModuleList(self_attention_layers)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- inputs: torch.FloatTensor | None = None,
- inputs_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- output_hidden_states: bool | None = False,
- return_dict: bool | None = True,
- ) -> tuple | BaseModelOutputWithCrossAttentions:
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- all_cross_attentions = () if output_attentions else None
- # Apply the cross-attention between the latents (hidden_states) and inputs:
- layer_outputs = self.cross_attention(
- hidden_states,
- attention_mask=attention_mask,
- inputs=inputs,
- inputs_mask=inputs_mask,
- output_attentions=output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_cross_attentions = all_cross_attentions + (layer_outputs[1],)
- # Apply the block of self-attention layers more than once:
- for _ in range(self.config.num_blocks):
- for i, layer_module in enumerate(self.self_attends):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_outputs = layer_module(
- hidden_states,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
- if v is not None
- )
- return BaseModelOutputWithCrossAttentions(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- cross_attentions=all_cross_attentions,
- )
- @auto_docstring
- class PerceiverPreTrainedModel(PreTrainedModel):
- config: PerceiverConfig
- base_model_prefix = "perceiver"
- main_input_name = "inputs"
- input_modalities = ("image",) # techinically can be anything but HF impl has only image processor
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif hasattr(module, "latents"):
- init.normal_(module.latents, mean=0.0, std=self.config.initializer_range)
- elif hasattr(module, "position_embeddings") and isinstance(module, PerceiverTrainablePositionEncoding):
- init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
- elif isinstance(module, nn.ParameterDict):
- for modality in module:
- init.normal_(module[modality], mean=0.0, std=self.config.initializer_range)
- elif isinstance(module, nn.Embedding):
- init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
- if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
- init.zeros_(module.weight[module.padding_idx])
- elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- if getattr(module, "running_mean", None) is not None:
- init.zeros_(module.running_mean)
- init.ones_(module.running_var)
- init.zeros_(module.num_batches_tracked)
- @auto_docstring(
- custom_intro="""
- The Perceiver: a scalable, fully attentional architecture.
- <Tip>
- Note that it's possible to fine-tune Perceiver on higher resolution images than the ones it has been trained on, by
- setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
- position embeddings to the higher resolution.
- </Tip>
- """
- )
- class PerceiverModel(PerceiverPreTrainedModel):
- def __init__(
- self,
- config,
- decoder: Optional["PerceiverAbstractDecoder"] = None,
- input_preprocessor: PreprocessorType = None,
- output_postprocessor: PostprocessorType = None,
- ):
- r"""
- decoder (`PerceiverDecoder`, *optional*):
- Decoder module that transforms latent representations into task predictions.
- input_preprocessor (`PreprocessorType`, *optional*):
- Preprocessor that encodes raw inputs into tensors for the model.
- output_postprocessor (`PostprocessorType`, *optional*):
- Postprocessor that transforms model outputs into final predictions.
- """
- super().__init__(config)
- self.config = config
- self.input_preprocessor = input_preprocessor
- self.output_postprocessor = output_postprocessor
- self.embeddings = PerceiverEmbeddings(config)
- self.encoder = PerceiverEncoder(
- config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model
- )
- self.decoder = decoder
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.latents
- def set_input_embeddings(self, value):
- self.embeddings.latents = value
- @auto_docstring
- def forward(
- self,
- inputs: torch.FloatTensor,
- attention_mask: torch.FloatTensor | None = None,
- subsampled_output_points: dict[str, torch.Tensor] | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- interpolate_pos_encoding: bool = False,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | PerceiverModelOutput:
- r"""
- inputs (`torch.FloatTensor`):
- Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
- subsampled_output_points (`dict[str, torch.Tensor]`, *optional*):
- Dictionary of tensors used as queries for the decoder. The decoder maps these queries to the latent
- representation of the model. Used for subsampled decoding, e.g. when only decoding certain image patches.
- Examples:
- ```python
- >>> from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverImageProcessor, PerceiverModel
- >>> from transformers.models.perceiver.modeling_perceiver import (
- ... PerceiverTextPreprocessor,
- ... PerceiverImagePreprocessor,
- ... PerceiverClassificationDecoder,
- ... )
- >>> import torch
- >>> import httpx
- >>> from io import BytesIO
- >>> from PIL import Image
- >>> # EXAMPLE 1: using the Perceiver to classify texts
- >>> # - we define a TextPreprocessor, which can be used to embed tokens
- >>> # - we define a ClassificationDecoder, which can be used to decode the
- >>> # final hidden states of the latents to classification logits
- >>> # using trainable position embeddings
- >>> config = PerceiverConfig()
- >>> preprocessor = PerceiverTextPreprocessor(config)
- >>> decoder = PerceiverClassificationDecoder(
- ... config,
- ... num_channels=config.d_latents,
- ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
- ... use_query_residual=True,
- ... )
- >>> model = PerceiverModel(config, input_preprocessor=preprocessor, decoder=decoder)
- >>> # you can then do a forward pass as follows:
- >>> tokenizer = PerceiverTokenizer()
- >>> text = "hello world"
- >>> inputs = tokenizer(text, return_tensors="pt").input_ids
- >>> with torch.no_grad():
- ... outputs = model(inputs=inputs)
- >>> logits = outputs.logits
- >>> list(logits.shape)
- [1, 2]
- >>> # to train, one can train the model using standard cross-entropy:
- >>> criterion = torch.nn.CrossEntropyLoss()
- >>> labels = torch.tensor([1])
- >>> loss = criterion(logits, labels)
- >>> # EXAMPLE 2: using the Perceiver to classify images
- >>> # - we define an ImagePreprocessor, which can be used to embed images
- >>> config = PerceiverConfig(image_size=224)
- >>> preprocessor = PerceiverImagePreprocessor(
- ... config,
- ... prep_type="conv1x1",
- ... spatial_downsample=1,
- ... out_channels=256,
- ... position_encoding_type="trainable",
- ... concat_or_add_pos="concat",
- ... project_pos_dim=256,
- ... trainable_position_encoding_kwargs=dict(
- ... num_channels=256,
- ... index_dims=config.image_size**2,
- ... ),
- ... )
- >>> model = PerceiverModel(
- ... config,
- ... input_preprocessor=preprocessor,
- ... decoder=PerceiverClassificationDecoder(
- ... config,
- ... num_channels=config.d_latents,
- ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
- ... use_query_residual=True,
- ... ),
- ... )
- >>> # you can then do a forward pass as follows:
- >>> image_processor = PerceiverImageProcessor()
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> inputs = image_processor(image, return_tensors="pt").pixel_values
- >>> with torch.no_grad():
- ... outputs = model(inputs=inputs)
- >>> logits = outputs.logits
- >>> list(logits.shape)
- [1, 2]
- >>> # to train, one can train the model using standard cross-entropy:
- >>> criterion = torch.nn.CrossEntropyLoss()
- >>> labels = torch.tensor([1])
- >>> loss = criterion(logits, labels)
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if self.input_preprocessor is not None:
- inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(
- inputs, interpolate_pos_encoding=interpolate_pos_encoding
- )
- else:
- modality_sizes = None
- inputs_without_pos = None
- if inputs.size()[-1] != self.config.d_model:
- raise ValueError(
- f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model:"
- f" {self.config.d_model}. Make sure to set config.d_model appropriately."
- )
- batch_size, seq_length, _ = inputs.size()
- device = inputs.device
- # If no attention mask is provided, make them all ones
- if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_length), device=device)
- # Make the attention mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
- extended_attention_mask = self.invert_attention_mask(attention_mask)
- embedding_output = self.embeddings(batch_size=batch_size)
- encoder_outputs = self.encoder(
- embedding_output,
- attention_mask=None,
- inputs=inputs,
- inputs_mask=extended_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = encoder_outputs[0]
- logits = None
- if self.decoder:
- if subsampled_output_points is not None:
- output_modality_sizes = {
- "audio": subsampled_output_points["audio"].shape[0],
- "image": subsampled_output_points["image"].shape[0],
- "label": 1,
- }
- else:
- output_modality_sizes = modality_sizes
- decoder_query = self.decoder.decoder_query(
- inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points
- )
- decoder_outputs = self.decoder(
- decoder_query,
- z=sequence_output,
- query_mask=extended_attention_mask,
- output_attentions=output_attentions,
- )
- logits = decoder_outputs.logits
- # add cross-attentions of decoder
- if output_attentions and decoder_outputs.cross_attentions is not None:
- if return_dict:
- encoder_outputs.cross_attentions = (
- encoder_outputs.cross_attentions + decoder_outputs.cross_attentions
- )
- else:
- encoder_outputs = encoder_outputs + decoder_outputs.cross_attentions
- if self.output_postprocessor:
- logits = self.output_postprocessor(logits, modality_sizes=output_modality_sizes)
- if not return_dict:
- if logits is not None:
- return (logits, sequence_output) + encoder_outputs[1:]
- else:
- return (sequence_output,) + encoder_outputs[1:]
- return PerceiverModelOutput(
- logits=logits,
- last_hidden_state=sequence_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- cross_attentions=encoder_outputs.cross_attentions,
- )
- @auto_docstring(
- custom_intro="""
- Example use of Perceiver for masked language modeling.
- """
- )
- class PerceiverForMaskedLM(PerceiverPreTrainedModel):
- def __init__(self, config: PerceiverConfig):
- super().__init__(config)
- text_preprocessor = PerceiverTextPreprocessor(config)
- trainable_position_encoding_kwargs_decoder = {
- "num_channels": text_preprocessor.num_channels,
- "index_dims": config.max_position_embeddings,
- }
- self.perceiver = PerceiverModel(
- config,
- input_preprocessor=text_preprocessor,
- decoder=PerceiverBasicDecoder(
- config,
- output_num_channels=config.d_latents,
- output_index_dims=config.max_position_embeddings, # we need to define the seq_len of the inputs beforehand
- num_channels=text_preprocessor.num_channels,
- qk_channels=8 * 32,
- v_channels=text_preprocessor.num_channels,
- num_heads=8,
- use_query_residual=False,
- final_project=False,
- trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
- ),
- )
- self.embedding_decoder = PerceiverEmbeddingDecoder(config)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- inputs: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- labels: torch.Tensor | None = None,
- return_dict: bool | None = None,
- input_ids: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | PerceiverMaskedLMOutput:
- r"""
- inputs (`torch.FloatTensor`):
- Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, PerceiverForMaskedLM
- >>> import torch
- >>> tokenizer = AutoTokenizer.from_pretrained("deepmind/language-perceiver")
- >>> model = PerceiverForMaskedLM.from_pretrained("deepmind/language-perceiver")
- >>> # training
- >>> text = "This is an incomplete sentence where some words are missing."
- >>> inputs = tokenizer(text, padding="max_length", return_tensors="pt")
- >>> # mask " missing."
- >>> inputs["input_ids"][0, 52:61] = tokenizer.mask_token_id
- >>> labels = tokenizer(text, padding="max_length", return_tensors="pt").input_ids
- >>> outputs = model(**inputs, labels=labels)
- >>> loss = outputs.loss
- >>> round(loss.item(), 2)
- 19.87
- >>> logits = outputs.logits
- >>> list(logits.shape)
- [1, 2048, 262]
- >>> # inference
- >>> text = "This is an incomplete sentence where some words are missing."
- >>> encoding = tokenizer(text, padding="max_length", return_tensors="pt")
- >>> # mask bytes corresponding to " missing.". Note that the model performs much better if the masked span starts with a space.
- >>> encoding["input_ids"][0, 52:61] = tokenizer.mask_token_id
- >>> # forward pass
- >>> with torch.no_grad():
- ... outputs = model(**encoding)
- >>> logits = outputs.logits
- >>> list(logits.shape)
- [1, 2048, 262]
- >>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist()
- >>> tokenizer.decode(masked_tokens_predictions)
- ' missing.'
- ```"""
- if inputs is not None and input_ids is not None:
- raise ValueError("You cannot use both `inputs` and `input_ids`")
- elif inputs is None and input_ids is not None:
- inputs = input_ids
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.perceiver(
- inputs=inputs,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- logits = self.embedding_decoder(
- outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
- )
- masked_lm_loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss() # -100 index = padding token
- masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
- return PerceiverMaskedLMOutput(
- loss=masked_lm_loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- @auto_docstring(
- custom_intro="""
- Example use of Perceiver for text classification.
- """
- )
- class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
- self.num_labels = config.num_labels
- self.perceiver = PerceiverModel(
- config,
- input_preprocessor=PerceiverTextPreprocessor(config),
- decoder=PerceiverClassificationDecoder(
- config,
- num_channels=config.d_latents,
- trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
- use_query_residual=True,
- ),
- )
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- inputs: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- labels: torch.Tensor | None = None,
- return_dict: bool | None = None,
- input_ids: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | PerceiverClassifierOutput:
- r"""
- inputs (`torch.FloatTensor`):
- Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the classification/regression loss. Indices should be in `[0, ..., config.num_labels -
- 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels >
- 1` a classification loss is computed (Cross-Entropy).
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, PerceiverForSequenceClassification
- >>> tokenizer = AutoTokenizer.from_pretrained("deepmind/language-perceiver")
- >>> model = PerceiverForSequenceClassification.from_pretrained("deepmind/language-perceiver")
- >>> text = "hello world"
- >>> inputs = tokenizer(text, return_tensors="pt").input_ids
- >>> outputs = model(inputs=inputs)
- >>> logits = outputs.logits
- >>> list(logits.shape)
- [1, 2]
- ```"""
- if inputs is not None and input_ids is not None:
- raise ValueError("You cannot use both `inputs` and `input_ids`")
- elif inputs is None and input_ids is not None:
- inputs = input_ids
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.perceiver(
- inputs=inputs,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- logits = outputs.logits if return_dict else outputs[0]
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return PerceiverClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- @auto_docstring(
- custom_intro="""
- Example use of Perceiver for image classification, for tasks such as ImageNet.
- This model uses learned position embeddings. In other words, this model is not given any privileged information about
- the structure of images. As shown in the paper, this model can achieve a top-1 accuracy of 72.7 on ImageNet.
- [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
- (with `prep_type="conv1x1"`) to preprocess the input images, and
- [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
- [`PerceiverModel`] into classification logits.
- """
- )
- class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- trainable_position_encoding_kwargs_preprocessor = {"num_channels": 256, "index_dims": config.image_size**2}
- trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
- self.num_labels = config.num_labels
- self.perceiver = PerceiverModel(
- config,
- input_preprocessor=PerceiverImagePreprocessor(
- config,
- prep_type="conv1x1",
- spatial_downsample=1,
- out_channels=256,
- position_encoding_type="trainable",
- concat_or_add_pos="concat",
- project_pos_dim=256,
- trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_preprocessor,
- ),
- decoder=PerceiverClassificationDecoder(
- config,
- num_channels=config.d_latents,
- trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
- use_query_residual=True,
- ),
- )
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- inputs: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- labels: torch.Tensor | None = None,
- interpolate_pos_encoding: bool = False,
- return_dict: bool | None = None,
- pixel_values: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | PerceiverClassifierOutput:
- r"""
- inputs (`torch.FloatTensor`):
- Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationLearned
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-learned")
- >>> model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned")
- >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values
- >>> outputs = model(inputs=inputs)
- >>> logits = outputs.logits
- >>> list(logits.shape)
- [1, 1000]
- >>> # model predicts one of the 1000 ImageNet classes
- >>> predicted_class_idx = logits.argmax(-1).item()
- >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
- Predicted class: tabby, tabby cat
- ```"""
- if inputs is not None and pixel_values is not None:
- raise ValueError("You cannot use both `inputs` and `pixel_values`")
- elif inputs is None and pixel_values is not None:
- inputs = pixel_values
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.perceiver(
- inputs=inputs,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- )
- logits = outputs.logits if return_dict else outputs[0]
- loss = None
- if labels is not None:
- loss = self.loss_function(labels, logits, self.config)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return PerceiverClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- @auto_docstring(
- custom_intro="""
- Example use of Perceiver for image classification, for tasks such as ImageNet.
- This model uses fixed 2D Fourier position embeddings. As shown in the paper, this model can achieve a top-1 accuracy of
- 79.0 on ImageNet, and 84.5 when pre-trained on a large-scale dataset (i.e. JFT).
- [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
- (with `prep_type="pixels"`) to preprocess the input images, and
- [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
- [`PerceiverModel`] into classification logits.
- """
- )
- class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- fourier_position_encoding_kwargs_preprocessor = {
- "concat_pos": True,
- "max_resolution": (224, 224),
- "num_bands": 64,
- "sine_only": False,
- }
- trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
- self.num_labels = config.num_labels
- self.perceiver = PerceiverModel(
- config,
- input_preprocessor=PerceiverImagePreprocessor(
- config,
- prep_type="pixels",
- spatial_downsample=1,
- fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
- ),
- decoder=PerceiverClassificationDecoder(
- config,
- num_channels=config.d_latents,
- trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
- use_query_residual=True,
- ),
- )
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- inputs: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- labels: torch.Tensor | None = None,
- return_dict: bool | None = None,
- pixel_values: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | PerceiverClassifierOutput:
- r"""
- inputs (`torch.FloatTensor`):
- Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationFourier
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-fourier")
- >>> model = PerceiverForImageClassificationFourier.from_pretrained("deepmind/vision-perceiver-fourier")
- >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values
- >>> outputs = model(inputs=inputs)
- >>> logits = outputs.logits
- >>> list(logits.shape)
- [1, 1000]
- >>> # model predicts one of the 1000 ImageNet classes
- >>> predicted_class_idx = logits.argmax(-1).item()
- >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
- Predicted class: tabby, tabby cat
- ```"""
- if inputs is not None and pixel_values is not None:
- raise ValueError("You cannot use both `inputs` and `pixel_values`")
- elif inputs is None and pixel_values is not None:
- inputs = pixel_values
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.perceiver(
- inputs=inputs,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- logits = outputs.logits if return_dict else outputs[0]
- loss = None
- if labels is not None:
- loss = self.loss_function(labels, logits, self.config)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return PerceiverClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- @auto_docstring(
- custom_intro="""
- Example use of Perceiver for image classification, for tasks such as ImageNet.
- This model uses a 2D conv+maxpool preprocessing network. As shown in the paper, this model can achieve a top-1 accuracy
- of 82.1 on ImageNet.
- [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
- (with `prep_type="conv"`) to preprocess the input images, and
- [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
- [`PerceiverModel`] into classification logits.
- """
- )
- class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- fourier_position_encoding_kwargs_preprocessor = {
- "concat_pos": True,
- "max_resolution": (56, 56),
- "num_bands": 64,
- "sine_only": False,
- }
- trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
- self.num_labels = config.num_labels
- self.perceiver = PerceiverModel(
- config,
- input_preprocessor=PerceiverImagePreprocessor(
- config,
- prep_type="conv",
- spatial_downsample=1,
- position_encoding_type="fourier",
- fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
- ),
- decoder=PerceiverClassificationDecoder(
- config,
- num_channels=config.d_latents,
- trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
- use_query_residual=True,
- ),
- )
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- inputs: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- labels: torch.Tensor | None = None,
- return_dict: bool | None = None,
- pixel_values: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | PerceiverClassifierOutput:
- r"""
- inputs (`torch.FloatTensor`):
- Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationConvProcessing
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-conv")
- >>> model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv")
- >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values
- >>> outputs = model(inputs=inputs)
- >>> logits = outputs.logits
- >>> list(logits.shape)
- [1, 1000]
- >>> # model predicts one of the 1000 ImageNet classes
- >>> predicted_class_idx = logits.argmax(-1).item()
- >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
- Predicted class: tabby, tabby cat
- ```"""
- if inputs is not None and pixel_values is not None:
- raise ValueError("You cannot use both `inputs` and `pixel_values`")
- elif inputs is None and pixel_values is not None:
- inputs = pixel_values
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.perceiver(
- inputs=inputs,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- logits = outputs.logits if return_dict else outputs[0]
- loss = None
- if labels is not None:
- loss = self.loss_function(labels, logits, self.config)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return PerceiverClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- @auto_docstring(
- custom_intro="""
- Example use of Perceiver for optical flow, for tasks such as Sintel and KITTI. [`PerceiverForOpticalFlow`] uses
- [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] (with *prep_type="patches"*) to preprocess the
- input images, and [`~models.perceiver.modeling_perceiver.PerceiverOpticalFlowDecoder`] to decode the latent
- representation of [`PerceiverModel`].
- As input, one concatenates 2 subsequent frames along the channel dimension and extract a 3 x 3 patch around each pixel
- (leading to 3 x 3 x 3 x 2 = 54 values for each pixel). Fixed Fourier position encodings are used to encode the position
- of each pixel in the patch. Next, one applies the Perceiver encoder. To decode, one queries the latent representation
- using the same encoding used for the input.
- """
- )
- class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- fourier_position_encoding_kwargs_preprocessor = {
- "num_bands": 64,
- "max_resolution": config.train_size,
- "sine_only": False,
- "concat_pos": True,
- }
- fourier_position_encoding_kwargs_decoder = {
- "concat_pos": True,
- "max_resolution": config.train_size,
- "num_bands": 64,
- "sine_only": False,
- }
- image_preprocessor = PerceiverImagePreprocessor(
- config,
- prep_type="patches",
- spatial_downsample=1,
- conv_after_patching=True,
- conv_after_patching_in_channels=54,
- temporal_downsample=2,
- position_encoding_type="fourier",
- # position_encoding_kwargs
- fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
- )
- self.perceiver = PerceiverModel(
- config,
- input_preprocessor=image_preprocessor,
- decoder=PerceiverOpticalFlowDecoder(
- config,
- num_channels=image_preprocessor.num_channels,
- output_image_shape=config.train_size,
- rescale_factor=100.0,
- # decoder kwargs
- use_query_residual=False,
- output_num_channels=2,
- # We query the decoder using the first frame features
- # rather than a standard decoder position encoding.
- position_encoding_type="fourier",
- fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_decoder,
- ),
- )
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- inputs: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- labels: torch.Tensor | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | PerceiverClassifierOutput:
- r"""
- inputs (`torch.FloatTensor`):
- Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the optical flow loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- Examples:
- ```python
- >>> from transformers import PerceiverForOpticalFlow
- >>> import torch
- >>> model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver")
- >>> # in the Perceiver IO paper, the authors extract a 3 x 3 patch around each pixel,
- >>> # leading to 3 x 3 x 3 = 27 values for each pixel (as each pixel also has 3 color channels)
- >>> # patches have shape (batch_size, num_frames, num_channels, height, width)
- >>> # the authors train on resolutions of 368 x 496
- >>> patches = torch.randn(1, 2, 27, 368, 496)
- >>> outputs = model(inputs=patches)
- >>> logits = outputs.logits
- >>> list(logits.shape)
- [1, 368, 496, 2]
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- loss = None
- if labels is not None:
- raise NotImplementedError("Optical flow training is not yet supported")
- outputs = self.perceiver(
- inputs=inputs,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- logits = outputs.logits if return_dict else outputs[0]
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return PerceiverClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- @auto_docstring(
- custom_intro="""
- Example use of Perceiver for multimodal (video) autoencoding, for tasks such as Kinetics-700.
- [`PerceiverForMultimodalAutoencoding`] uses [`~models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor`] to
- preprocess the 3 modalities: images, audio and class labels. This preprocessor uses modality-specific preprocessors to
- preprocess every modality separately, after which they are concatenated. Trainable position embeddings are used to pad
- each modality to the same number of channels to make concatenation along the time dimension possible. Next, one applies
- the Perceiver encoder.
- [`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] is used to decode the latent representation of
- [`PerceiverModel`]. This decoder uses each modality-specific decoder to construct queries. The decoder queries are
- created based on the inputs after preprocessing. However, autoencoding an entire video in a single forward pass is
- computationally infeasible, hence one only uses parts of the decoder queries to do cross-attention with the latent
- representation. This is determined by the subsampled indices for each modality, which can be provided as additional
- input to the forward pass of [`PerceiverForMultimodalAutoencoding`].
- [`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] also pads the decoder queries of the different
- modalities to the same number of channels, in order to concatenate them along the time dimension. Next, cross-attention
- is performed with the latent representation of [`PerceiverModel`].
- Finally, [`~models.perceiver.modeling_perceiver.PerceiverMultiModalPostprocessor`] is used to turn this tensor into an
- actual video. It first splits up the output into the different modalities, and then applies the respective
- postprocessor for each modality.
- Note that, by masking the classification label during evaluation (i.e. simply providing a tensor of zeros for the
- "label" modality), this auto-encoding model becomes a Kinetics 700 video classifier.
- """
- )
- class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
- def __init__(self, config: PerceiverConfig):
- super().__init__(config)
- n_audio_samples = config.num_frames * config.audio_samples_per_frame
- input_preprocessor = PerceiverMultimodalPreprocessor(
- min_padding_size=4,
- modalities={
- "audio": PerceiverAudioPreprocessor(
- config,
- position_encoding_type="fourier",
- fourier_position_encoding_kwargs={
- "num_bands": 192,
- "max_resolution": (n_audio_samples,),
- "sine_only": False,
- "concat_pos": True,
- },
- prep_type="patches",
- samples_per_patch=config.samples_per_patch,
- ),
- "image": PerceiverImagePreprocessor(
- config,
- position_encoding_type="fourier",
- fourier_position_encoding_kwargs={
- "num_bands": 32,
- "max_resolution": (config.num_frames, config.image_size, config.image_size),
- "sine_only": False,
- "concat_pos": True,
- },
- prep_type="patches",
- spatial_downsample=4,
- temporal_downsample=1,
- ),
- "label": PerceiverOneHotPreprocessor(config),
- },
- mask_probs={"image": 0.0, "audio": 0.0, "label": 1.0},
- )
- image_decoder = PerceiverBasicVideoAutoencodingDecoder(
- config,
- # Autoencoding, don't pass inputs to the queries.
- concat_preprocessed_input=False,
- output_shape=config.output_shape,
- output_num_channels=config.output_num_channels,
- use_query_residual=False,
- position_encoding_only=True,
- position_encoding_type="fourier",
- fourier_position_encoding_kwargs={
- "num_bands": 32,
- "max_resolution": (config.num_frames, config.image_size, config.image_size),
- "sine_only": False,
- "concat_pos": True,
- },
- )
- decoder = PerceiverMultimodalDecoder(
- config,
- # Autoencoding, don't pass inputs to the queries.
- concat_preprocessed_input=False,
- # Modality specific decoders are used ONLY to generate queries.
- # All modalties are decoded together using a unified decoder.
- modalities={
- "audio": PerceiverBasicDecoder(
- config,
- # Autoencoding, don't pass inputs to the queries.
- concat_preprocessed_input=False,
- output_index_dims=(n_audio_samples // config.samples_per_patch,),
- output_num_channels=config.output_num_channels,
- use_query_residual=False,
- position_encoding_only=True,
- position_encoding_type="fourier",
- fourier_position_encoding_kwargs={
- "num_bands": 192,
- "max_resolution": (n_audio_samples,),
- "sine_only": False,
- "concat_pos": True,
- },
- ),
- "image": image_decoder,
- "label": PerceiverClassificationDecoder(
- config,
- # Autoencoding, don't pass inputs to the queries.
- concat_preprocessed_input=False,
- use_query_residual=False,
- position_encoding_only=True,
- position_encoding_type="trainable",
- trainable_position_encoding_kwargs={
- "num_channels": config._label_trainable_num_channels,
- "index_dims": 1,
- },
- ),
- },
- num_outputs=None,
- output_num_channels=config.output_num_channels,
- use_query_residual=False,
- )
- output_postprocessor = PerceiverMultimodalPostprocessor(
- modalities={
- "audio": PerceiverAudioPostprocessor(config, in_channels=config.output_num_channels),
- "image": PerceiverProjectionPostprocessor(in_channels=config.output_num_channels, out_channels=3),
- "label": PerceiverClassificationPostprocessor(config, in_channels=config.output_num_channels),
- }
- )
- self.perceiver = PerceiverModel(
- config,
- input_preprocessor=input_preprocessor,
- decoder=decoder,
- output_postprocessor=output_postprocessor,
- )
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- inputs: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- subsampled_output_points: dict[str, torch.Tensor] | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- labels: torch.Tensor | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | PerceiverClassifierOutput:
- r"""
- inputs (`torch.FloatTensor`):
- Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
- subsampled_output_points (`dict[str, torch.Tensor]`, *optional*):
- Dictionary of tensors used as queries for the decoder. The decoder maps these queries to the latent
- representation of the model. Used for subsampled decoding, e.g. when only decoding certain image patches.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- Examples:
- ```python
- >>> from transformers import PerceiverForMultimodalAutoencoding
- >>> import torch
- >>> import numpy as np
- >>> # create multimodal inputs
- >>> images = torch.randn((1, 16, 3, 224, 224))
- >>> audio = torch.randn((1, 30720, 1))
- >>> inputs = dict(image=images, audio=audio, label=torch.zeros((images.shape[0], 700)))
- >>> model = PerceiverForMultimodalAutoencoding.from_pretrained("deepmind/multimodal-perceiver")
- >>> # in the Perceiver IO paper, videos are auto-encoded in chunks
- >>> # each chunk subsamples different index dimensions of the image and audio modality decoder queries
- >>> nchunks = 128
- >>> image_chunk_size = np.prod((16, 224, 224)) // nchunks
- >>> audio_chunk_size = audio.shape[1] // model.config.samples_per_patch // nchunks
- >>> # process the first chunk
- >>> chunk_idx = 0
- >>> subsampling = {
- ... "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)),
- ... "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)),
- ... "label": None,
- ... }
- >>> outputs = model(inputs=inputs, subsampled_output_points=subsampling)
- >>> logits = outputs.logits
- >>> list(logits["audio"].shape)
- [1, 240]
- >>> list(logits["image"].shape)
- [1, 6272, 3]
- >>> list(logits["label"].shape)
- [1, 700]
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- loss = None
- if labels is not None:
- raise NotImplementedError("Multimodal autoencoding training is not yet supported")
- outputs = self.perceiver(
- inputs=inputs,
- attention_mask=attention_mask,
- subsampled_output_points=subsampled_output_points,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- logits = outputs.logits if return_dict else outputs[0]
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return PerceiverClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- # Below: position encodings
- def build_position_encoding(
- position_encoding_type,
- out_channels=None,
- project_pos_dim=-1,
- trainable_position_encoding_kwargs=None,
- fourier_position_encoding_kwargs=None,
- ):
- """
- Builds the position encoding.
- Args:
- - out_channels: refers to the number of channels of the position encodings.
- - project_pos_dim: if specified, will project the position encodings to this dimension.
- """
- if position_encoding_type == "trainable":
- if not trainable_position_encoding_kwargs:
- raise ValueError("Make sure to pass trainable_position_encoding_kwargs")
- output_pos_enc = PerceiverTrainablePositionEncoding(**trainable_position_encoding_kwargs)
- elif position_encoding_type == "fourier":
- # We don't use the index_dims argument, as this is only known during the forward pass
- if not fourier_position_encoding_kwargs:
- raise ValueError("Make sure to pass fourier_position_encoding_kwargs")
- output_pos_enc = PerceiverFourierPositionEncoding(**fourier_position_encoding_kwargs)
- else:
- raise ValueError(f"Unknown position encoding type: {position_encoding_type}.")
- # Optionally, project the position encoding to a target dimension:
- positions_projection = nn.Linear(out_channels, project_pos_dim) if project_pos_dim > 0 else nn.Identity()
- return output_pos_enc, positions_projection
- # Below: Perceiver decoders
- class PerceiverAbstractDecoder(nn.Module, metaclass=abc.ABCMeta):
- """Perceiver abstract decoder."""
- @abc.abstractmethod
- def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
- raise NotImplementedError
- @property
- @abc.abstractmethod
- def num_query_channels(self):
- raise NotImplementedError
- @abc.abstractmethod
- def forward(self, query, z, query_mask=None):
- raise NotImplementedError
- class PerceiverProjectionDecoder(PerceiverAbstractDecoder):
- """
- Baseline projection decoder (no cross-attention).
- Args:
- config ([`PerceiverConfig`]):
- Model configuration.
- """
- def __init__(self, config):
- super().__init__()
- self.classifier = nn.Linear(config.d_latents, config.num_labels)
- def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
- return None
- def forward(
- self, query: torch.Tensor, z: torch.FloatTensor, query_mask: torch.FloatTensor | None = None
- ) -> torch.FloatTensor:
- # (batch_size, num_latents, d_latents) -> (batch_size, d_latents)
- z = torch.mean(z, dim=1)
- # (batch_size, d_latents) -> (batch_size, config.num_labels)
- logits = self.classifier(z)
- return logits
- class PerceiverBasicDecoder(PerceiverAbstractDecoder):
- """
- Cross-attention-based decoder. This class can be used to decode the final hidden states of the latents using a
- cross-attention operation, in which the latents produce keys and values.
- The shape of the output of this class depends on how one defines the output queries (also called decoder queries).
- Args:
- config ([*PerceiverConfig*]):
- Model configuration.
- output_num_channels (`int`, *optional*):
- The number of channels in the output. Will only be used in case *final_project* is set to `True`.
- position_encoding_type (`str`, *optional*, defaults to "trainable"):
- The type of position encoding to use. Can be either "trainable", "fourier", or "none".
- output_index_dims (`int`, *optional*):
- The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'.
- num_channels (`int`, *optional*, defaults to 128):
- The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'.
- qk_channels (`int`, *optional*):
- The number of channels of the queries and keys in the cross-attention layer.
- v_channels (`int`, *optional*):
- The number of channels of the values in the cross-attention layer.
- num_heads (`int`, *optional*, defaults to 1):
- The number of attention heads in the cross-attention layer.
- widening_factor (`int`, *optional*, defaults to 1):
- The widening factor of the cross-attention layer.
- use_query_residual (`bool`, *optional*, defaults to `False`):
- Whether to use a residual connection between the query and the output of the cross-attention layer.
- concat_preprocessed_input (`bool`, *optional*, defaults to `False`):
- Whether to concatenate the preprocessed input to the query.
- final_project (`bool`, *optional*, defaults to `True`):
- Whether to project the output of the cross-attention layer to a target dimension.
- position_encoding_only (`bool`, *optional*, defaults to `False`):
- Whether to only use this class to define output queries.
- """
- def __init__(
- self,
- config: PerceiverConfig,
- output_num_channels: int,
- position_encoding_type: str | None = "trainable",
- # The following 2 arguments are ignored if position_encoding_type == 'none':
- output_index_dims: int | None = None,
- num_channels: int | None = 128,
- subsampled_index_dims: int | None = None,
- qk_channels: int | None = None,
- v_channels: int | None = None,
- num_heads: int | None = 1,
- widening_factor: int | None = 1,
- use_query_residual: bool | None = False,
- concat_preprocessed_input: bool | None = False,
- final_project: bool | None = True,
- position_encoding_only: bool | None = False,
- **position_encoding_kwargs,
- ) -> None:
- super().__init__()
- self.output_num_channels = output_num_channels
- # If `none`, the decoder will not construct any position encodings.
- # You should construct your own when querying the decoder.
- self.output_position_encodings = None
- self.position_encoding_type = position_encoding_type
- self.position_encoding_kwargs = position_encoding_kwargs
- if position_encoding_type != "none":
- self.output_position_encodings, self.positions_projection = build_position_encoding(
- position_encoding_type=position_encoding_type, **position_encoding_kwargs
- )
- self.output_index_dims = output_index_dims
- self.num_channels = num_channels
- if subsampled_index_dims is None:
- subsampled_index_dims = output_index_dims
- self.subsampled_index_dims = subsampled_index_dims
- self.concat_preprocessed_input = concat_preprocessed_input
- self.final_project = final_project
- self.position_encoding_only = position_encoding_only
- # for multimodal autoencoding, we don't need the decoder cross-attention and final layer
- # so then we will set position_encoding_only to True
- if not self.position_encoding_only:
- self.decoding_cross_attention = PerceiverLayer(
- config,
- is_cross_attention=True,
- qk_channels=qk_channels,
- v_channels=v_channels,
- num_heads=num_heads,
- q_dim=num_channels,
- kv_dim=config.d_latents,
- widening_factor=widening_factor,
- use_query_residual=use_query_residual,
- )
- self.final_layer = nn.Linear(num_channels, output_num_channels) if final_project else nn.Identity()
- @property
- def num_query_channels(self) -> int:
- if self.position_encoding_type == "none": # Queries come from elsewhere
- raise ValueError(
- "You cannot calculate number of decoder query channels when position_encoding_type is set to none"
- )
- if self.position_encoding_only:
- if "project_pos_dim" in self.position_encoding_kwargs:
- return self.position_encoding_kwargs["project_pos_dim"]
- return self.output_position_encodings.output_size()
- if self.final_project:
- return self.output_num_channels
- return self.num_channels
- def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
- if self.position_encoding_type == "none": # Queries come from elsewhere
- raise ValueError("You cannot construct decoder queries when position_encoding_type is set to none")
- if subsampled_points is not None:
- # subsampled_points are the indices if the inputs would be flattened
- # however, the inputs aren't flattened, that's why we use unravel_index
- # to get the indices for the unflattened array
- # unravel_index returns a tuple (x_idx, y_idx, ...)
- # stack to get the [n, d] tensor of coordinates
- indices = torch.unravel_index(subsampled_points, self.output_index_dims)
- pos = torch.stack(indices, dim=1)
- batch_size = inputs.shape[0]
- # Map these coordinates to [-1, 1]
- pos = -1 + 2 * pos / torch.tensor(self.output_index_dims)[None, :]
- pos = torch.broadcast_to(pos[None], [batch_size, pos.shape[0], pos.shape[1]])
- # Construct the position encoding.
- if self.position_encoding_type == "trainable":
- pos_emb = self.output_position_encodings(batch_size)
- elif self.position_encoding_type == "fourier":
- pos_emb = self.output_position_encodings(
- self.output_index_dims, batch_size=batch_size, device=inputs.device, dtype=inputs.dtype, pos=pos
- )
- # Optionally project them to a target dimension.
- pos_emb = self.positions_projection(pos_emb)
- pos_emb = torch.reshape(pos_emb, [pos_emb.shape[0], -1, pos_emb.shape[-1]])
- else:
- batch_size = inputs.shape[0]
- index_dims = inputs.shape[2:]
- # Construct the position encoding.
- if self.position_encoding_type == "trainable":
- pos_emb = self.output_position_encodings(batch_size)
- elif self.position_encoding_type == "fourier":
- pos_emb = self.output_position_encodings(
- index_dims, batch_size, device=inputs.device, dtype=inputs.dtype
- )
- # Optionally project them to a target dimension.
- pos_emb = self.positions_projection(pos_emb)
- if self.concat_preprocessed_input:
- if inputs_without_pos is None:
- raise ValueError("Value is required for inputs_without_pos if concat_preprocessed_input is True")
- pos_emb = torch.cat([inputs_without_pos, pos_emb], dim=-1)
- return pos_emb
- def forward(
- self,
- query: torch.Tensor,
- z: torch.FloatTensor,
- query_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> PerceiverDecoderOutput:
- # Cross-attention decoding.
- # key, value: B x N x K; query: B x M x K
- # Attention maps -> B x N x M
- # Output -> B x M x K
- cross_attentions = () if output_attentions else None
- layer_outputs = self.decoding_cross_attention(
- query,
- attention_mask=query_mask,
- inputs=z,
- inputs_mask=None,
- output_attentions=output_attentions,
- )
- output = layer_outputs[0]
- if output_attentions:
- cross_attentions = cross_attentions + (layer_outputs[1],)
- logits = self.final_layer(output)
- return PerceiverDecoderOutput(logits=logits, cross_attentions=cross_attentions)
- class PerceiverClassificationDecoder(PerceiverAbstractDecoder):
- """
- Cross-attention based classification decoder. Light-weight wrapper of [`PerceiverBasicDecoder`] for logit output.
- Will turn the output of the Perceiver encoder which is of shape (batch_size, num_latents, d_latents) to a tensor of
- shape (batch_size, num_labels). The queries are of shape (batch_size, 1, num_labels).
- Args:
- config ([`PerceiverConfig`]):
- Model configuration.
- """
- def __init__(self, config, **decoder_kwargs):
- super().__init__()
- self.num_labels = config.num_labels
- self.decoder = PerceiverBasicDecoder(
- config,
- output_num_channels=self.num_labels,
- output_index_dims=1, # Predict a single logit array.
- **decoder_kwargs,
- )
- @property
- def num_query_channels(self) -> int:
- return self.decoder.num_query_channels
- def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
- return self.decoder.decoder_query(
- inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points
- )
- def forward(
- self,
- query: torch.Tensor,
- z: torch.FloatTensor,
- query_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> PerceiverDecoderOutput:
- decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
- # B x 1 x num_classes -> B x num_classes
- logits = decoder_outputs.logits[:, 0, :]
- return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions)
- class PerceiverOpticalFlowDecoder(PerceiverAbstractDecoder):
- """Cross-attention based optical flow decoder."""
- def __init__(self, config, output_image_shape, output_num_channels=2, rescale_factor=100.0, **decoder_kwargs):
- super().__init__()
- self.output_image_shape = output_image_shape
- self.output_num_channels = output_num_channels
- self.rescale_factor = rescale_factor
- self.decoder = PerceiverBasicDecoder(config, output_num_channels=output_num_channels, **decoder_kwargs)
- @property
- def num_query_channels(self) -> int:
- return self.decoder.num_query_channels
- def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
- if subsampled_points is not None:
- raise ValueError("FlowDecoder doesn't support subsampling yet.")
- return inputs
- def forward(
- self,
- query: torch.Tensor,
- z: torch.FloatTensor,
- query_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> PerceiverDecoderOutput:
- decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
- preds = decoder_outputs.logits
- # Output flow and rescale.
- preds /= self.rescale_factor
- preds = preds.reshape([preds.shape[0]] + list(self.output_image_shape) + [preds.shape[-1]])
- return PerceiverDecoderOutput(logits=preds, cross_attentions=decoder_outputs.cross_attentions)
- class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder):
- """
- Cross-attention based video-autoencoding decoder. Light-weight wrapper of [*PerceiverBasicDecoder*] with video
- reshaping logic.
- Args:
- config ([*PerceiverConfig*]):
- Model configuration.
- output_shape (`list[int]`):
- Shape of the output as (batch_size, num_frames, height, width), excluding the channel dimension.
- position_encoding_type (`str`):
- The type of position encoding to use. Can be either "trainable", "fourier", or "none".
- """
- def __init__(
- self, config: PerceiverConfig, output_shape: list[int], position_encoding_type: str, **decoder_kwargs
- ) -> None:
- super().__init__()
- if len(output_shape) != 4: # B, T, H, W
- raise ValueError(f"Expected rank 4 output_shape, got {output_shape}.")
- # Build the decoder components:
- self.output_shape = output_shape
- self.output_num_channels = decoder_kwargs["output_num_channels"]
- self.decoder = PerceiverBasicDecoder(
- config,
- output_index_dims=self.output_shape[1:4], # T*H*W
- position_encoding_type=position_encoding_type,
- **decoder_kwargs,
- )
- @property
- def num_query_channels(self) -> int:
- return self.decoder.num_query_channels
- def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
- return self.decoder.decoder_query(
- inputs,
- modality_sizes=modality_sizes,
- inputs_without_pos=inputs_without_pos,
- subsampled_points=subsampled_points,
- )
- def forward(
- self, query: torch.Tensor, z: torch.FloatTensor, query_mask: torch.FloatTensor | None = None
- ) -> PerceiverDecoderOutput:
- decoder_outputs = self.decoder(query, z)
- logits = decoder_outputs.logits
- logits = torch.reshape(logits, self.output_shape + [logits.shape[-1]])
- return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions)
- def restructure(modality_sizes: ModalitySizeType, inputs: torch.Tensor) -> Mapping[str, torch.Tensor]:
- """
- Partitions a [B, N, C] tensor into tensors for each modality.
- Args:
- modality_sizes
- dict specifying the size of the modality
- inputs:
- input tensor
- Returns:
- dict mapping name of modality to its associated tensor.
- """
- outputs = {}
- index = 0
- # Apply a predictable ordering to the modalities
- for modality in sorted(modality_sizes.keys()):
- size = modality_sizes[modality]
- inp = inputs[:, index : index + size]
- index += size
- outputs[modality] = inp
- return outputs
- class PerceiverMultimodalDecoder(PerceiverAbstractDecoder):
- """
- Multimodal decoding by composing uni-modal decoders. The *modalities* argument of the constructor is a dictionary
- mapping modality name to the decoder of that modality. That decoder will be used to construct queries for that
- modality. Modality-specific queries are padded with trainable modality-specific parameters, after which they are
- concatenated along the time dimension.
- Next, there is a shared cross attention operation across all modalities.
- Args:
- config ([*PerceiverConfig*]):
- Model configuration.
- modalities (`dict[str, PerceiverAbstractDecoder]`):
- Dictionary mapping modality name to the decoder of that modality.
- num_outputs (`int`):
- The number of outputs of the decoder.
- output_num_channels (`int`):
- The number of channels in the output.
- min_padding_size (`int`, *optional*, defaults to 2):
- The minimum padding size for all modalities. The final output will have num_channels equal to the maximum
- channels across all modalities plus min_padding_size.
- subsampled_index_dims (`dict[str, PerceiverAbstractDecoder]`, *optional*):
- Dictionary mapping modality name to the subsampled index dimensions to use for the decoder query of that
- modality.
- """
- def __init__(
- self,
- config: PerceiverConfig,
- modalities: dict[str, PerceiverAbstractDecoder],
- num_outputs: int,
- output_num_channels: int,
- min_padding_size: int | None = 2,
- subsampled_index_dims: dict[str, PerceiverAbstractDecoder] | None = None,
- **decoder_kwargs,
- ) -> None:
- super().__init__()
- self.modalities = nn.ModuleDict(modalities)
- self.subsampled_index_dims = subsampled_index_dims
- self.min_padding_size = min_padding_size
- self.output_num_channels = output_num_channels
- self.num_outputs = num_outputs
- self.decoder = PerceiverBasicDecoder(
- config,
- output_index_dims=(num_outputs,),
- output_num_channels=output_num_channels,
- position_encoding_type="none",
- num_channels=self.num_query_channels,
- **decoder_kwargs,
- )
- self.padding = nn.ParameterDict(
- {
- modality: nn.Parameter(torch.randn(1, self.num_query_channels - decoder.num_query_channels))
- for modality, decoder in modalities.items()
- }
- )
- @property
- def num_query_channels(self) -> int:
- max_channel_size = max(decoder.num_query_channels for _, decoder in self.modalities.items())
- common_channel_size = max_channel_size + self.min_padding_size
- return common_channel_size
- def decoder_query(self, inputs, modality_sizes, inputs_without_pos=None, subsampled_points=None):
- # Partition the flat inputs among the different modalities
- inputs = restructure(modality_sizes, inputs)
- # Obtain modality-specific decoders' queries
- subsampled_points = subsampled_points or {}
- decoder_queries = {}
- for modality, decoder in self.modalities.items():
- # Get input_without_pos for this modality if it exists.
- input_without_pos = None
- if inputs_without_pos is not None:
- input_without_pos = inputs_without_pos.get(modality, None)
- query = decoder.decoder_query(
- inputs=inputs[modality],
- modality_sizes=None,
- inputs_without_pos=input_without_pos,
- subsampled_points=subsampled_points.get(modality, None),
- )
- decoder_queries[modality] = query
- # Pad all queries with trainable position encodings to make them have the same channels
- def embed(modality, x):
- x = torch.reshape(x, [x.shape[0], np.prod(x.shape[1:-1]), x.shape[-1]])
- pos = self.padding[modality]
- pos = torch.broadcast_to(pos, [x.shape[0], x.shape[1], self.num_query_channels - x.shape[2]])
- return torch.cat([x, pos], dim=2)
- # Apply a predictable ordering to the modalities
- return torch.cat(
- [embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1
- )
- def forward(
- self,
- query: torch.Tensor,
- z: torch.FloatTensor,
- query_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> torch.Tensor:
- # B x 1 x num_classes -> B x num_classes
- decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
- return decoder_outputs
- # Below: IO pre- and post-processor classes for Perceiver.
- def space_to_depth(frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1) -> torch.Tensor:
- """
- Space to depth transform. Rearranges blocks of spatial data, into depth.
- This function assumes the channels to be first, but will place the channels last after transformation.
- """
- if len(frames.shape) == 4:
- batch_size, num_channels, height, width = frames.shape
- # split up dimensions (height by spatial_block_size, width by spatial_block_size)
- frames = frames.view(
- batch_size,
- num_channels,
- height // spatial_block_size,
- spatial_block_size,
- width // spatial_block_size,
- spatial_block_size,
- )
- # move blocks to last dimension: (batch_size, H//bs, W//bs, bs, bs, C)
- frames = frames.permute(0, 2, 4, 3, 5, 1).contiguous()
- # concatenate blocks along channel dimension: (batch_size, H//bs, W//bs, bs*bs*C)
- frames = frames.view(
- batch_size,
- height // spatial_block_size,
- width // spatial_block_size,
- (spatial_block_size**2) * num_channels,
- )
- return frames
- elif len(frames.shape) == 5:
- batch_size, time, num_channels, height, width = frames.shape
- # split up dimensions (time by temporal_block_size, height by spatial_block_size, width by spatial_block_size)
- frames = frames.view(
- batch_size,
- time // temporal_block_size,
- temporal_block_size,
- num_channels,
- height // spatial_block_size,
- spatial_block_size,
- width // spatial_block_size,
- spatial_block_size,
- )
- # move blocks to last dimension: (batch_size, T//ts, H//bs, W//bs, ts, bs, bs, C)
- frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
- # concatenate blocks along channel dimension: (batch_size, T//ts, H//bs, W//bs, ts*bs*bs*C)
- frames = frames.view(
- batch_size,
- time // temporal_block_size,
- height // spatial_block_size,
- width // spatial_block_size,
- temporal_block_size * (spatial_block_size**2) * num_channels,
- )
- return frames
- else:
- raise ValueError(
- "Frames should be of rank 4 (batch, channels, height, width)"
- " or rank 5 (batch, time, channels, height, width)"
- )
- class Conv2dSamePadding(nn.Conv2d):
- """
- Conv2d layer with padding="same" support. Source:
- https://gist.github.com/sumanmichael/4de9dee93f972d47c80c4ade8e149ea6
- """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.zero_pad_2d = nn.ZeroPad2d(
- reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in self.kernel_size[::-1]])
- )
- def forward(self, input):
- return self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias)
- class Conv2DDownsample(nn.Module):
- """Downsamples 4x by applying a 2D convolution and doing max pooling."""
- def __init__(
- self,
- num_layers: int = 1,
- in_channels: int = 3,
- out_channels: int = 64,
- use_batchnorm: bool = True,
- ):
- """
- Constructs a Conv2DDownsample model.
- Args:
- in_channels (`int`, *optional*, defaults to 3):
- The number of input channels.
- out_channels (`int`, *optional*, defaults to 64):
- The number of conv output channels.
- use_batchnorm (`bool`, *optional*, defaults to `True`):
- Whether to use batchnorm.
- """
- super().__init__()
- self.conv = Conv2dSamePadding(
- in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, bias=False
- )
- self.batchnorm = nn.BatchNorm2d(num_features=out_channels) if use_batchnorm else nn.Identity()
- self.relu = nn.ReLU()
- self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
- out = self.conv(inputs)
- out = self.batchnorm(out)
- out = self.relu(out)
- out = self.max_pool(out)
- return out
- def generate_fourier_features(pos, num_bands, max_resolution=(224, 224), concat_pos=True, sine_only=False):
- """
- Generate a Fourier frequency position encoding with linear spacing.
- Args:
- pos (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`):
- The Tensor containing the position of n points in d dimensional space.
- num_bands (`int`):
- The number of frequency bands (K) to use.
- max_resolution (`tuple[int]`, *optional*, defaults to (224, 224)):
- The maximum resolution (i.e. the number of pixels per dim). A tuple representing resolution for each dimension.
- concat_pos (`bool`, *optional*, defaults to `True`):
- Whether to concatenate the input position encoding to the Fourier features.
- sine_only (`bool`, *optional*, defaults to `False`):
- Whether to use a single phase (sin) or two (sin/cos) for each frequency band.
- Returns:
- `torch.FloatTensor` of shape `(batch_size, sequence_length, n_channels)`: The Fourier position embeddings. If
- `concat_pos` is `True` and `sine_only` is `False`, output dimensions are ordered as: [dim_1, dim_2, ..., dim_d,
- sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ..., sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d), cos(pi*f_1*dim_1),
- ..., cos(pi*f_K*dim_1), ..., cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)], where dim_i is pos[:, i] and f_k is the
- kth frequency band.
- """
- batch_size = pos.shape[0]
- min_freq = 1.0
- # Nyquist frequency at the target resolution:
- freq_bands = torch.stack(
- [torch.linspace(start=min_freq, end=res / 2, steps=num_bands) for res in max_resolution], dim=0
- )
- # Get frequency bands for each spatial dimension.
- # Output is size [n, d * num_bands]
- per_pos_features = pos[0, :, :][:, :, None] * freq_bands[None, :, :]
- per_pos_features = torch.reshape(per_pos_features, [-1, np.prod(per_pos_features.shape[1:])])
- if sine_only:
- # Output is size [n, d * num_bands]
- per_pos_features = torch.sin(np.pi * (per_pos_features))
- else:
- # Output is size [n, 2 * d * num_bands]
- per_pos_features = torch.cat(
- [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1
- )
- # Concatenate the raw input positions.
- if concat_pos:
- # Adds d bands to the encoding.
- per_pos_features = torch.cat([pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1)
- return per_pos_features
- def build_linear_positions(index_dims, output_range=(-1.0, 1.0)):
- """
- Generate an array of position indices for an N-D input array.
- Args:
- index_dims (`list[int]`):
- The shape of the index dimensions of the input array.
- output_range (`tuple[float]`, *optional*, defaults to `(-1.0, 1.0)`):
- The min and max values taken by each input index dimension.
- Returns:
- `torch.FloatTensor` of shape `(index_dims[0], index_dims[1], .., index_dims[-1], N)`.
- """
- def _linspace(n_xels_per_dim):
- return torch.linspace(start=output_range[0], end=output_range[1], steps=n_xels_per_dim, dtype=torch.float32)
- dim_ranges = [_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims]
- array_index_grid = torch.meshgrid(*dim_ranges, indexing="ij")
- return torch.stack(array_index_grid, dim=-1)
- class PerceiverAbstractPositionEncoding(nn.Module, metaclass=abc.ABCMeta):
- """Perceiver abstract position encoding."""
- @property
- @abc.abstractmethod
- def num_dimensions(self) -> int:
- raise NotImplementedError
- @abc.abstractmethod
- def output_size(self, *args, **kwargs) -> int:
- raise NotImplementedError
- @abc.abstractmethod
- def forward(self, batch_size, pos):
- raise NotImplementedError
- class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding):
- """Trainable position encoding."""
- def __init__(self, index_dims, num_channels=128):
- super().__init__()
- self._num_channels = num_channels
- self._index_dims = index_dims
- index_dim = np.prod(index_dims)
- self.position_embeddings = nn.Parameter(torch.randn(index_dim, num_channels))
- @property
- def num_dimensions(self) -> int:
- if isinstance(self._index_dims, int):
- return 1
- return len(self._index_dims)
- def output_size(self, *args, **kwargs) -> int:
- return self._num_channels
- def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
- num_positions = position_embeddings.shape[0]
- new_height = new_width = torch_int(num_positions**0.5)
- # always interpolate when tracing to ensure the exported model works for dynamic input shapes
- if not torch.jit.is_tracing() and height == new_height and width == new_width:
- return position_embeddings
- position_embeddings = position_embeddings.reshape(1, new_height, new_width, self._num_channels).permute(
- 0, 3, 1, 2
- )
- position_embeddings = nn.functional.interpolate(
- position_embeddings,
- size=(height, width),
- mode="bicubic",
- align_corners=False,
- )
- position_embeddings = position_embeddings.reshape(1, self._num_channels, -1).permute(0, 2, 1).squeeze(0)
- return position_embeddings
- def forward(
- self, batch_size: int, interpolate_pos_encoding: bool = False, input_size: torch.Size | None = None
- ) -> torch.Tensor:
- position_embeddings = self.position_embeddings
- if interpolate_pos_encoding:
- height, width = input_size
- position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width)
- if batch_size is not None:
- position_embeddings = position_embeddings.expand(batch_size, -1, -1)
- return position_embeddings
- def _check_or_build_spatial_positions(pos, index_dims, batch_size):
- """
- Checks or builds spatial position features (x, y, ...).
- Args:
- pos (`torch.FloatTensor`):
- None, or an array of position features. If None, position features are built. Otherwise, their size is checked.
- index_dims (`list[int]`):
- An iterable giving the spatial/index size of the data to be featurized.
- batch_size (`int`):
- The batch size of the data to be featurized.
- Returns:
- `torch.FloatTensor` of shape `(batch_size, prod(index_dims))` an array of position features.
- """
- if pos is None:
- pos = build_linear_positions(index_dims)
- # equivalent to `torch.broadcast_to(pos[None], (batch_size,) + pos.shape)`
- # but `torch.broadcast_to` cannot be converted to ONNX
- pos = pos[None].expand((batch_size,) + pos.shape)
- pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1])
- else:
- # Just a warning label: you probably don't want your spatial features to
- # have a different spatial layout than your pos coordinate system.
- # But feel free to override if you think it'll work!
- if pos.shape[-1] != len(index_dims):
- raise ValueError("Spatial features have the wrong number of dimensions.")
- return pos
- class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding):
- """Fourier (Sinusoidal) position encoding."""
- def __init__(self, num_bands, max_resolution, concat_pos=True, sine_only=False):
- super().__init__()
- self.num_bands = num_bands
- self.max_resolution = max_resolution
- self.concat_pos = concat_pos
- self.sine_only = sine_only
- @property
- def num_dimensions(self) -> int:
- return len(self.max_resolution)
- def output_size(self):
- """Returns size of positional encodings last dimension."""
- num_dims = len(self.max_resolution)
- encoding_size = self.num_bands * num_dims
- if not self.sine_only:
- encoding_size *= 2
- if self.concat_pos:
- encoding_size += self.num_dimensions
- return encoding_size
- def forward(
- self,
- index_dims: list[int],
- batch_size: int,
- device: torch.device,
- dtype: torch.dtype,
- pos: torch.FloatTensor | None = None,
- ) -> torch.FloatTensor:
- pos = _check_or_build_spatial_positions(pos, index_dims, batch_size)
- fourier_pos_enc = generate_fourier_features(
- pos,
- num_bands=self.num_bands,
- max_resolution=self.max_resolution,
- concat_pos=self.concat_pos,
- sine_only=self.sine_only,
- ).to(device=device, dtype=dtype)
- return fourier_pos_enc
- class AbstractPreprocessor(nn.Module):
- @property
- def num_channels(self) -> int:
- """Returns size of preprocessor output."""
- raise NotImplementedError()
- class PerceiverTextPreprocessor(AbstractPreprocessor):
- """
- Text preprocessing for Perceiver Encoder. Can be used to embed `inputs` and add positional encodings.
- The dimensionality of the embeddings is determined by the `d_model` attribute of the configuration.
- Args:
- config ([`PerceiverConfig`]):
- Model configuration.
- """
- def __init__(self, config: PerceiverConfig) -> None:
- super().__init__()
- self.config = config
- self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
- @property
- def num_channels(self) -> int:
- return self.config.d_model
- def forward(
- self,
- inputs: torch.LongTensor,
- pos: torch.Tensor | None = None,
- network_input_is_1d: bool = True,
- interpolate_pos_encoding: bool = False,
- ):
- embeddings_without_pos = self.embeddings(inputs)
- seq_length = inputs.shape[1]
- position_ids = torch.arange(0, seq_length, device=inputs.device)
- embeddings = embeddings_without_pos + self.position_embeddings(position_ids)
- return embeddings, None, embeddings_without_pos
- class PerceiverEmbeddingDecoder(nn.Module):
- """
- Module to decode embeddings (for masked language modeling).
- Args:
- config ([`PerceiverConfig`]):
- Model configuration.
- """
- def __init__(self, config: PerceiverConfig) -> None:
- super().__init__()
- self.config = config
- self.vocab_size = config.vocab_size
- self.bias = nn.Parameter(torch.zeros(self.vocab_size))
- def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor:
- batch_size, seq_len, d_model = hidden_states.shape
- # Flatten batch dim
- output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.transpose(0, 1))
- output = output + self.bias
- return output.reshape([batch_size, seq_len, self.vocab_size])
- class PerceiverMultimodalPostprocessor(nn.Module):
- """
- Multimodal postprocessing for Perceiver. Can be used to combine modality-specific postprocessors into a single
- postprocessor.
- Args:
- modalities (`Mapping[str, PostprocessorType]`):
- Dictionary mapping modality name to postprocessor class for that modality.
- input_is_dict (`bool`, *optional*, defaults to `False`):
- If True, input is assumed to be dictionary structured, and outputs keep the same dictionary shape. If
- False, input is a tensor which is sliced up during postprocessing by *modality_sizes*.
- """
- def __init__(self, modalities: Mapping[str, PostprocessorType], input_is_dict: bool = False):
- super().__init__()
- self.modalities = nn.ModuleDict(modalities)
- self.input_is_dict = input_is_dict
- def forward(
- self, inputs: torch.Tensor, pos: torch.Tensor | None = None, modality_sizes=None
- ) -> Mapping[str, torch.Tensor]:
- if not self.input_is_dict:
- # Slice up modalities by their sizes.
- if modality_sizes is None:
- raise ValueError("Modality sizes should be specified if input is not a dictionary.")
- inputs = restructure(modality_sizes=modality_sizes, inputs=inputs)
- outputs = {
- modality: postprocessor(inputs[modality], pos=pos, modality_sizes=None)
- for modality, postprocessor in self.modalities.items()
- }
- return outputs
- class PerceiverClassificationPostprocessor(nn.Module):
- """
- Classification postprocessing for Perceiver. Can be used to convert the decoder output to classification logits.
- Args:
- config ([*PerceiverConfig*]):
- Model configuration.
- in_channels (`int`):
- Number of channels in the input.
- """
- def __init__(self, config: PerceiverConfig, in_channels: int) -> None:
- super().__init__()
- self.classifier = nn.Linear(in_channels, config.num_labels)
- def forward(self, inputs, pos: torch.Tensor | None = None, modality_sizes=None) -> torch.Tensor:
- logits = self.classifier(inputs)
- return logits[:, 0, :]
- class PerceiverAudioPostprocessor(nn.Module):
- """
- Audio postprocessing for Perceiver. Can be used to convert the decoder output to audio features.
- Args:
- config ([*PerceiverConfig*]):
- Model configuration.
- in_channels (`int`):
- Number of channels in the input.
- postproc_type (`str`, *optional*, defaults to `"patches"`):
- Postprocessor type to use. Currently, only "patches" is supported.
- """
- def __init__(self, config: PerceiverConfig, in_channels: int, postproc_type: str = "patches") -> None:
- super().__init__()
- if postproc_type != "patches": # to be supported: 'conv', 'patches', 'pixels'
- raise ValueError("Invalid postproc_type!")
- # Architecture parameters:
- self.classifier = nn.Linear(in_channels, config.samples_per_patch)
- def forward(self, inputs: torch.Tensor, pos: torch.Tensor | None = None, modality_sizes=None) -> torch.Tensor:
- logits = self.classifier(inputs)
- return torch.reshape(logits, [inputs.shape[0], -1])
- class PerceiverProjectionPostprocessor(nn.Module):
- """
- Projection postprocessing for Perceiver. Can be used to project the channels of the decoder output to a lower
- dimension.
- Args:
- in_channels (`int`):
- Number of channels in the input.
- out_channels (`int`):
- Number of channels in the output.
- """
- def __init__(self, in_channels: int, out_channels: int) -> None:
- super().__init__()
- self.classifier = nn.Linear(in_channels, out_channels)
- def forward(self, inputs: torch.Tensor, pos: torch.Tensor | None = None, modality_sizes=None) -> torch.Tensor:
- logits = self.classifier(inputs)
- return logits
- class PerceiverImagePreprocessor(AbstractPreprocessor):
- """
- Image preprocessing for Perceiver Encoder.
- Note: the *out_channels* argument refers to the output channels of a convolutional layer, if *prep_type* is set to
- "conv1x1" or "conv". If one adds absolute position embeddings, one must make sure the *num_channels* of the
- position encoding kwargs are set equal to the *out_channels*.
- Args:
- config ([*PerceiverConfig*]):
- Model configuration.
- prep_type (`str`, *optional*, defaults to `"conv"`):
- Preprocessing type. Can be "conv1x1", "conv", "patches", "pixels".
- spatial_downsample (`int`, *optional*, defaults to 4):
- Spatial downsampling factor.
- temporal_downsample (`int`, *optional*, defaults to 1):
- Temporal downsampling factor (only relevant in case a time dimension is present).
- position_encoding_type (`str`, *optional*, defaults to `"fourier"`):
- Position encoding type. Can be "fourier" or "trainable".
- in_channels (`int`, *optional*, defaults to 3):
- Number of channels in the input.
- out_channels (`int`, *optional*, defaults to 64):
- Number of channels in the output.
- conv_after_patching (`bool`, *optional*, defaults to `False`):
- Whether to apply a convolutional layer after patching.
- conv_after_patching_in_channels (`int`, *optional*, defaults to 54):
- Number of channels in the input of the convolutional layer after patching.
- conv2d_use_batchnorm (`bool`, *optional*, defaults to `True`):
- Whether to use batch normalization in the convolutional layer.
- concat_or_add_pos (`str`, *optional*, defaults to `"concat"`):
- How to concatenate the position encoding to the input. Can be "concat" or "add".
- project_pos_dim (`int`, *optional*, defaults to -1):
- Dimension of the position encoding to project to. If -1, no projection is applied.
- **position_encoding_kwargs (`Dict`, *optional*):
- Keyword arguments for the position encoding.
- """
- def __init__(
- self,
- config,
- prep_type="conv",
- spatial_downsample: int = 4,
- temporal_downsample: int = 1,
- position_encoding_type: str = "fourier",
- in_channels: int = 3,
- out_channels: int = 64,
- conv_after_patching: bool = False,
- conv_after_patching_in_channels: int = 54, # only relevant when conv_after_patching = True
- conv2d_use_batchnorm: bool = True,
- concat_or_add_pos: str = "concat",
- project_pos_dim: int = -1,
- **position_encoding_kwargs,
- ):
- super().__init__()
- self.config = config
- if prep_type not in ("conv", "patches", "pixels", "conv1x1"):
- raise ValueError(f"Prep_type {prep_type} is invalid")
- if concat_or_add_pos not in ["concat", "add"]:
- raise ValueError(f"Invalid value {concat_or_add_pos} for concat_or_add_pos.")
- self.in_channels = in_channels
- self.prep_type = prep_type
- self.spatial_downsample = spatial_downsample
- self.temporal_downsample = temporal_downsample
- self.position_encoding_type = position_encoding_type
- self.concat_or_add_pos = concat_or_add_pos
- self.conv_after_patching = conv_after_patching
- self.out_channels = out_channels
- if self.prep_type == "conv":
- # Downsampling with conv is currently restricted
- convnet_num_layers = math.log(spatial_downsample, 4)
- convnet_num_layers_is_int = convnet_num_layers == np.round(convnet_num_layers)
- if not convnet_num_layers_is_int or temporal_downsample != 1:
- raise ValueError(
- "Only powers of 4 expected for spatial and 1 expected for temporal downsampling with conv."
- )
- self.convnet = Conv2DDownsample(
- in_channels=in_channels,
- num_layers=int(convnet_num_layers),
- out_channels=out_channels,
- use_batchnorm=conv2d_use_batchnorm,
- )
- elif self.prep_type == "conv1x1":
- if temporal_downsample != 1:
- raise ValueError("Conv1x1 does not downsample in time.")
- self.convnet_1x1 = nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=(1, 1),
- # spatial_downsample is unconstrained for 1x1 convolutions.
- stride=(spatial_downsample, spatial_downsample),
- )
- # Position embeddings
- self.project_pos_dim = project_pos_dim
- self.position_embeddings, self.positions_projection = build_position_encoding(
- position_encoding_type=position_encoding_type,
- out_channels=out_channels,
- project_pos_dim=project_pos_dim,
- **position_encoding_kwargs,
- )
- # Optional convolutional layer after patches.
- self.conv_after_patches = (
- nn.Linear(conv_after_patching_in_channels, self.out_channels) if conv_after_patching else nn.Identity()
- )
- @property
- def num_channels(self) -> int:
- # Let's assume that the number of resolutions (in the context of image preprocessing)
- # of the input data is 2 or 3 depending on whether we are processing image or video respectively.
- # In this case, for convenience, we will declare is_temporal variable,
- # which will show whether the data has a temporal dimension or not.
- is_temporal = self.position_embeddings.num_dimensions > 2
- # position embedding
- if self.project_pos_dim > 0:
- pos_dim = self.project_pos_dim
- else:
- pos_dim = self.position_embeddings.output_size()
- if self.concat_or_add_pos == "add":
- return pos_dim
- # inputs
- if self.conv_after_patching or self.prep_type in ("conv1x1", "conv"):
- inp_dim = self.out_channels
- elif self.prep_type == "pixels":
- inp_dim = self.in_channels
- if not is_temporal:
- inp_dim = math.ceil(inp_dim / self.spatial_downsample)
- elif self.prep_type == "patches":
- if self.conv_after_patching:
- inp_dim = self.out_channels
- else:
- inp_dim = self.in_channels * self.spatial_downsample**2
- if is_temporal:
- inp_dim *= self.temporal_downsample
- return inp_dim + pos_dim
- def _build_network_inputs(
- self, inputs: torch.Tensor, network_input_is_1d: bool = True, interpolate_pos_encoding: bool = False
- ):
- """
- Construct the final input, including position encoding.
- This method expects the inputs to always have channels as last dimension.
- """
- batch_size = inputs.shape[0]
- input_size = inputs.shape[1:3]
- index_dims = inputs.shape[1:-1]
- indices = np.prod(index_dims)
- # Flatten input features to a 1D index dimension if necessary.
- if len(inputs.shape) > 3 and network_input_is_1d:
- inputs = torch.reshape(inputs, [batch_size, indices, -1])
- # Construct the position encoding.
- if self.position_encoding_type == "trainable":
- pos_enc = self.position_embeddings(batch_size, interpolate_pos_encoding, input_size)
- elif self.position_encoding_type == "fourier":
- pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
- # Optionally project them to a target dimension.
- pos_enc = self.positions_projection(pos_enc)
- if not network_input_is_1d:
- # Reshape pos to match the input feature shape
- # if the network takes non-1D inputs
- sh = inputs.shape
- pos_enc = torch.reshape(pos_enc, list(sh)[:-1] + [-1])
- if self.concat_or_add_pos == "concat":
- inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1)
- elif self.concat_or_add_pos == "add":
- inputs_with_pos = inputs + pos_enc
- return inputs_with_pos, inputs
- def forward(
- self,
- inputs: torch.Tensor,
- pos: torch.Tensor | None = None,
- network_input_is_1d: bool = True,
- interpolate_pos_encoding: bool = False,
- ):
- if self.prep_type == "conv":
- # Convnet image featurization.
- # Downsamples spatially by a factor of 4
- inputs = self.convnet(inputs)
- elif self.prep_type == "conv1x1":
- # map inputs to self.out_channels
- inputs = self.convnet_1x1(inputs)
- elif self.prep_type == "pixels":
- # if requested, downsamples in the crudest way
- if inputs.ndim == 4:
- inputs = inputs[:: self.spatial_downsample, :: self.spatial_downsample]
- elif inputs.ndim == 5:
- inputs = inputs[
- :, :: self.temporal_downsample, :, :: self.spatial_downsample, :: self.spatial_downsample
- ]
- else:
- raise ValueError("Unsupported data format for pixels.")
- elif self.prep_type == "patches":
- # Space2depth featurization.
- # Video: B x T x C x H x W
- inputs = space_to_depth(
- inputs, temporal_block_size=self.temporal_downsample, spatial_block_size=self.spatial_downsample
- )
- if inputs.ndim == 5 and inputs.shape[1] == 1:
- # for flow
- inputs = inputs.squeeze(dim=1)
- # Optionally apply conv layer.
- inputs = self.conv_after_patches(inputs)
- if self.prep_type != "patches":
- # move channels to last dimension, as the _build_network_inputs method below expects this
- if inputs.ndim == 4:
- inputs = inputs.permute(0, 2, 3, 1)
- elif inputs.ndim == 5:
- inputs = inputs.permute(0, 1, 3, 4, 2)
- else:
- raise ValueError("Unsupported data format for conv1x1.")
- inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d, interpolate_pos_encoding)
- modality_sizes = None # Size for each modality, only needed for multimodal
- return inputs, modality_sizes, inputs_without_pos
- class PerceiverOneHotPreprocessor(AbstractPreprocessor):
- """
- One-hot preprocessor for Perceiver Encoder. Can be used to add a dummy index dimension to the input.
- Args:
- config ([`PerceiverConfig`]):
- Model configuration.
- """
- def __init__(self, config: PerceiverConfig) -> None:
- super().__init__()
- self.config: PerceiverConfig = config
- @property
- def num_channels(self) -> int:
- return self.config.num_labels
- def forward(self, inputs: torch.Tensor, pos: torch.Tensor | None = None, network_input_is_1d: bool = True):
- # Add a dummy index dimension.
- inputs = inputs[:, None, :]
- # No position encodings, so the 1st (input) and 3rd (inputs_without_pos)
- # outputs are identical.
- return inputs, None, inputs
- class PerceiverAudioPreprocessor(AbstractPreprocessor):
- """
- Audio preprocessing for Perceiver Encoder.
- Args:
- config ([*PerceiverConfig*]):
- Model configuration.
- prep_type (`str`, *optional*, defaults to `"patches"`):
- Preprocessor type to use. Only "patches" is supported.
- samples_per_patch (`int`, *optional*, defaults to 96):
- Number of samples per patch.
- position_encoding_type (`str`, *optional*, defaults to `"fourier"`):
- Type of position encoding to use. Can be "trainable" or "fourier".
- concat_or_add_pos (`str`, *optional*, defaults to `"concat"`):
- How to concatenate the position encoding to the input. Can be "concat" or "add".
- out_channels (`int`, *optional*, defaults to 64):
- Number of channels in the output.
- project_pos_dim (`int`, *optional*, defaults to -1):
- Dimension of the position encoding to project to. If -1, no projection is applied.
- **position_encoding_kwargs (`Dict`, *optional*):
- Keyword arguments for the position encoding.
- """
- def __init__(
- self,
- config,
- prep_type: str = "patches",
- samples_per_patch: int = 96,
- position_encoding_type: str = "fourier",
- concat_or_add_pos: str = "concat",
- out_channels=64,
- project_pos_dim=-1,
- **position_encoding_kwargs,
- ):
- super().__init__()
- self.config = config
- if prep_type != "patches":
- raise ValueError(f"Prep_type {prep_type} is invalid, can only be 'patches'.")
- if concat_or_add_pos not in ["concat", "add"]:
- raise ValueError(f"Concat_or_pos {concat_or_add_pos} is invalid, can only be 'concat' or 'add'.")
- self.samples_per_patch = samples_per_patch
- self.position_encoding_type = position_encoding_type
- self.concat_or_add_pos = concat_or_add_pos
- self.project_pos_dim = project_pos_dim
- # Position embeddings
- self.position_embeddings, self.positions_projection = build_position_encoding(
- position_encoding_type=position_encoding_type,
- out_channels=out_channels,
- project_pos_dim=project_pos_dim,
- **position_encoding_kwargs,
- )
- @property
- def num_channels(self) -> int:
- # position embedding
- if self.project_pos_dim > 0:
- pos_dim = self.project_pos_dim
- else:
- pos_dim = self.position_embeddings.output_size()
- if self.concat_or_add_pos == "add":
- return pos_dim
- return self.samples_per_patch + pos_dim
- def _build_network_inputs(self, inputs):
- """Construct the final input, including position encoding."""
- batch_size = inputs.shape[0]
- index_dims = inputs.shape[1:-1]
- # Construct the position encoding.
- if self.position_encoding_type == "trainable":
- pos_enc = self.position_embeddings(batch_size)
- elif self.position_encoding_type == "fourier":
- pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
- # Optionally project them to a target dimension.
- pos_enc = self.positions_projection(pos_enc)
- if self.concat_or_add_pos == "concat":
- inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1)
- elif self.concat_or_add_pos == "add":
- inputs_with_pos = inputs + pos_enc
- return inputs_with_pos, inputs
- def forward(
- self,
- inputs: torch.Tensor,
- pos: torch.Tensor | None = None,
- network_input_is_1d: bool = True,
- interpolate_pos_encoding: bool = False,
- ):
- inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])
- inputs, inputs_without_pos = self._build_network_inputs(inputs)
- modality_sizes = None # Size for each modality, only needed for multimodal
- return inputs, modality_sizes, inputs_without_pos
- class PerceiverMultimodalPreprocessor(AbstractPreprocessor):
- """
- Multimodal preprocessing for Perceiver Encoder.
- Inputs for each modality are preprocessed, then padded with trainable position embeddings to have the same number
- of channels.
- Args:
- modalities (`Mapping[str, PreprocessorType]`):
- Dict mapping modality name to preprocessor.
- mask_probs (`dict[str, float]`):
- Dict mapping modality name to masking probability of that modality.
- min_padding_size (`int`, *optional*, defaults to 2):
- The minimum padding size for all modalities. The final output will have num_channels equal to the maximum
- channels across all modalities plus min_padding_size.
- """
- def __init__(
- self,
- modalities: Mapping[str, PreprocessorType],
- mask_probs: Mapping[str, float] | None = None,
- min_padding_size: int = 2,
- ):
- super().__init__()
- self.modalities = nn.ModuleDict(modalities)
- self.min_padding_size = min_padding_size
- self.mask_probs = mask_probs if mask_probs is not None else {}
- self.padding = nn.ParameterDict(
- {
- modality: nn.Parameter(torch.randn(1, self.num_channels - preprocessor.num_channels))
- for modality, preprocessor in modalities.items()
- }
- )
- self.mask = nn.ParameterDict(
- {modality: nn.Parameter(torch.randn(1, self.num_channels)) for modality, _ in self.mask_probs.items()}
- )
- @property
- def num_channels(self) -> int:
- max_channel_size = max(processor.num_channels for _, processor in self.modalities.items())
- common_channel_size = max_channel_size + self.min_padding_size
- return common_channel_size
- def forward(
- self,
- inputs: Mapping[str, torch.Tensor],
- pos: torch.Tensor | None = None,
- network_input_is_1d: bool = True,
- interpolate_pos_encoding: bool = False,
- ) -> PreprocessorOutputType:
- padded = {}
- modality_sizes = {}
- inputs_without_pos = {}
- for modality, preprocessor in self.modalities.items():
- # preprocess each modality using the respective preprocessor.
- output, _, inputs_without_pos[modality] = preprocessor(
- inputs[modality], pos=pos, network_input_is_1d=network_input_is_1d
- )
- # pad to the same common_channel_size.
- batch_size, num_samples, num_channels = output.shape
- pos_enc = self.padding[modality].expand(batch_size, -1, -1)
- padding = torch.broadcast_to(
- pos_enc,
- [batch_size, num_samples, self.num_channels - num_channels],
- )
- output_padded = torch.cat([output, padding], dim=2)
- # mask if required
- if modality in self.mask_probs:
- mask_token = self.mask[modality].expand(batch_size, -1, -1)
- mask_prob = self.mask_probs[modality]
- mask = torch.bernoulli(torch.full([batch_size, num_samples], mask_prob))
- mask = torch.unsqueeze(mask, dim=2).to(mask_token.device)
- output_padded = (1 - mask) * output_padded + mask * mask_token
- padded[modality] = output_padded
- modality_sizes[modality] = output_padded.shape[1]
- # Apply a predictable ordering to the modalities
- padded_ls = [padded[k] for k in sorted(padded.keys())]
- # Finally, concatenate along the time dimension
- final_inputs = torch.cat(padded_ls, dim=1)
- return final_inputs, modality_sizes, inputs_without_pos
- __all__ = [
- "PerceiverForImageClassificationConvProcessing",
- "PerceiverForImageClassificationFourier",
- "PerceiverForImageClassificationLearned",
- "PerceiverForMaskedLM",
- "PerceiverForMultimodalAutoencoding",
- "PerceiverForOpticalFlow",
- "PerceiverForSequenceClassification",
- "PerceiverLayer",
- "PerceiverModel",
- "PerceiverPreTrainedModel",
- ]
|