flax.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from dataclasses import field
  2. from typing import Dict, Optional, cast
  3. import flax.linen as nn
  4. import jax
  5. import jax.numpy as jnp
  6. from . import RearrangeMixin, ReduceMixin
  7. from ._einmix import _EinmixMixin
  8. __author__ = "Alex Rogozhnikov"
  9. class Reduce(nn.Module):
  10. pattern: str
  11. reduction: str
  12. sizes: dict = field(default_factory=dict)
  13. def setup(self):
  14. self.reducer = ReduceMixin(self.pattern, self.reduction, **self.sizes)
  15. def __call__(self, input):
  16. return self.reducer._apply_recipe(input)
  17. class Rearrange(nn.Module):
  18. pattern: str
  19. sizes: dict = field(default_factory=dict)
  20. def setup(self):
  21. self.rearranger = RearrangeMixin(self.pattern, **self.sizes)
  22. def __call__(self, input):
  23. return self.rearranger._apply_recipe(input)
  24. class EinMix(nn.Module, _EinmixMixin):
  25. pattern: str
  26. weight_shape: str
  27. bias_shape: Optional[str] = None
  28. sizes: dict = field(default_factory=dict)
  29. def setup(self):
  30. self.initialize_einmix(
  31. pattern=self.pattern,
  32. weight_shape=self.weight_shape,
  33. bias_shape=self.bias_shape,
  34. axes_lengths=self.sizes,
  35. )
  36. def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
  37. self.weight = self.param("weight", jax.nn.initializers.uniform(weight_bound), weight_shape)
  38. if bias_shape is not None:
  39. self.bias = self.param("bias", jax.nn.initializers.uniform(bias_bound), bias_shape)
  40. else:
  41. self.bias = None
  42. def _create_rearrange_layers(
  43. self,
  44. pre_reshape_pattern: Optional[str],
  45. pre_reshape_lengths: Optional[Dict],
  46. post_reshape_pattern: Optional[str],
  47. post_reshape_lengths: Optional[Dict],
  48. ):
  49. self.pre_rearrange = None
  50. if pre_reshape_pattern is not None:
  51. self.pre_rearrange = Rearrange(pre_reshape_pattern, sizes=cast(dict, pre_reshape_lengths))
  52. self.post_rearrange = None
  53. if post_reshape_pattern is not None:
  54. self.post_rearrange = Rearrange(post_reshape_pattern, sizes=cast(dict, post_reshape_lengths))
  55. def __call__(self, input):
  56. if self.pre_rearrange is not None:
  57. input = self.pre_rearrange(input)
  58. result = jnp.einsum(self.einsum_pattern, input, self.weight)
  59. if self.bias is not None:
  60. result += self.bias
  61. if self.post_rearrange is not None:
  62. result = self.post_rearrange(result)
  63. return result