| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- from dataclasses import field
- from typing import Dict, Optional, cast
- import flax.linen as nn
- import jax
- import jax.numpy as jnp
- from . import RearrangeMixin, ReduceMixin
- from ._einmix import _EinmixMixin
- __author__ = "Alex Rogozhnikov"
- class Reduce(nn.Module):
- pattern: str
- reduction: str
- sizes: dict = field(default_factory=dict)
- def setup(self):
- self.reducer = ReduceMixin(self.pattern, self.reduction, **self.sizes)
- def __call__(self, input):
- return self.reducer._apply_recipe(input)
- class Rearrange(nn.Module):
- pattern: str
- sizes: dict = field(default_factory=dict)
- def setup(self):
- self.rearranger = RearrangeMixin(self.pattern, **self.sizes)
- def __call__(self, input):
- return self.rearranger._apply_recipe(input)
- class EinMix(nn.Module, _EinmixMixin):
- pattern: str
- weight_shape: str
- bias_shape: Optional[str] = None
- sizes: dict = field(default_factory=dict)
- def setup(self):
- self.initialize_einmix(
- pattern=self.pattern,
- weight_shape=self.weight_shape,
- bias_shape=self.bias_shape,
- axes_lengths=self.sizes,
- )
- def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
- self.weight = self.param("weight", jax.nn.initializers.uniform(weight_bound), weight_shape)
- if bias_shape is not None:
- self.bias = self.param("bias", jax.nn.initializers.uniform(bias_bound), bias_shape)
- else:
- self.bias = None
- def _create_rearrange_layers(
- self,
- pre_reshape_pattern: Optional[str],
- pre_reshape_lengths: Optional[Dict],
- post_reshape_pattern: Optional[str],
- post_reshape_lengths: Optional[Dict],
- ):
- self.pre_rearrange = None
- if pre_reshape_pattern is not None:
- self.pre_rearrange = Rearrange(pre_reshape_pattern, sizes=cast(dict, pre_reshape_lengths))
- self.post_rearrange = None
- if post_reshape_pattern is not None:
- self.post_rearrange = Rearrange(post_reshape_pattern, sizes=cast(dict, post_reshape_lengths))
- def __call__(self, input):
- if self.pre_rearrange is not None:
- input = self.pre_rearrange(input)
- result = jnp.einsum(self.einsum_pattern, input, self.weight)
- if self.bias is not None:
- result += self.bias
- if self.post_rearrange is not None:
- result = self.post_rearrange(result)
- return result
|