tensorflow.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. """
  2. Comment about tensorflow layers:
  3. unfortunately instructions on creation of TF layers change constantly,
  4. and changed way too many times at this point to remember what-compatible-where.
  5. Layers in einops==0.7.0 (and several prior versions)
  6. are compatible with TF 2.13
  7. Layers in einops==0.8.0 were re-implemented
  8. according to official instructions for TF 2.16
  9. """
  10. from typing import Dict, Optional, cast
  11. import tensorflow as tf
  12. from tensorflow.keras.layers import Layer
  13. from . import RearrangeMixin, ReduceMixin
  14. from ._einmix import _EinmixMixin
  15. __author__ = "Alex Rogozhnikov"
  16. class Rearrange(RearrangeMixin, Layer):
  17. def build(self, input_shape):
  18. pass # layer does not have any parameters to be initialized
  19. def call(self, inputs):
  20. return self._apply_recipe(inputs)
  21. def get_config(self):
  22. return {"pattern": self.pattern, **self.axes_lengths}
  23. class Reduce(ReduceMixin, Layer):
  24. def build(self, input_shape):
  25. pass # layer does not have any parameters to be initialized
  26. def call(self, inputs):
  27. return self._apply_recipe(inputs)
  28. def get_config(self):
  29. return {"pattern": self.pattern, "reduction": self.reduction, **self.axes_lengths}
  30. class EinMix(_EinmixMixin, Layer):
  31. def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
  32. # this method is called in __init__,
  33. # but we postpone actual creation to build(), as TF instruction suggests
  34. self._params = [weight_shape, weight_bound, bias_shape, bias_bound]
  35. def _create_rearrange_layers(
  36. self,
  37. pre_reshape_pattern: Optional[str],
  38. pre_reshape_lengths: Optional[Dict],
  39. post_reshape_pattern: Optional[str],
  40. post_reshape_lengths: Optional[Dict],
  41. ):
  42. self.pre_rearrange = None
  43. if pre_reshape_pattern is not None:
  44. self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths))
  45. self.post_rearrange = None
  46. if post_reshape_pattern is not None:
  47. self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths))
  48. def build(self, input_shape):
  49. [weight_shape, weight_bound, bias_shape, bias_bound] = self._params
  50. self.weight = self.add_weight(
  51. shape=weight_shape,
  52. initializer=tf.random_uniform_initializer(-weight_bound, weight_bound),
  53. trainable=True,
  54. )
  55. if bias_shape is not None:
  56. self.bias = self.add_weight(
  57. shape=bias_shape,
  58. initializer=tf.random_uniform_initializer(-bias_bound, bias_bound),
  59. trainable=True,
  60. )
  61. else:
  62. self.bias = None
  63. def call(self, inputs):
  64. if self.pre_rearrange is not None:
  65. inputs = self.pre_rearrange(inputs)
  66. result = tf.einsum(self.einsum_pattern, inputs, self.weight)
  67. if self.bias is not None:
  68. result = result + self.bias
  69. if self.post_rearrange is not None:
  70. result = self.post_rearrange(result)
  71. return result
  72. def get_config(self):
  73. return {
  74. "pattern": self.pattern,
  75. "weight_shape": self.weight_shape,
  76. "bias_shape": self.bias_shape,
  77. **self.axes_lengths,
  78. }