oneflow.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from typing import Dict, Optional, cast
  2. import oneflow as flow
  3. from . import RearrangeMixin, ReduceMixin
  4. from ._einmix import _EinmixMixin
  5. __author__ = "Tianhe Ren & Depeng Liang"
  6. class Rearrange(RearrangeMixin, flow.nn.Module):
  7. def forward(self, input):
  8. return self._apply_recipe(input)
  9. class Reduce(ReduceMixin, flow.nn.Module):
  10. def forward(self, input):
  11. return self._apply_recipe(input)
  12. class EinMix(_EinmixMixin, flow.nn.Module):
  13. def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
  14. self.weight = flow.nn.Parameter(
  15. flow.zeros(weight_shape).uniform_(-weight_bound, weight_bound), requires_grad=True
  16. )
  17. if bias_shape is not None:
  18. self.bias = flow.nn.Parameter(flow.zeros(bias_shape).uniform_(-bias_bound, bias_bound), requires_grad=True)
  19. else:
  20. self.bias = None
  21. def _create_rearrange_layers(
  22. self,
  23. pre_reshape_pattern: Optional[str],
  24. pre_reshape_lengths: Optional[Dict],
  25. post_reshape_pattern: Optional[str],
  26. post_reshape_lengths: Optional[Dict],
  27. ):
  28. self.pre_rearrange = None
  29. if pre_reshape_pattern is not None:
  30. self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths))
  31. self.post_rearrange = None
  32. if post_reshape_pattern is not None:
  33. self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths))
  34. def forward(self, input):
  35. if self.pre_rearrange is not None:
  36. input = self.pre_rearrange(input)
  37. result = flow.einsum(self.einsum_pattern, input, self.weight)
  38. if self.bias is not None:
  39. result += self.bias
  40. if self.post_rearrange is not None:
  41. result = self.post_rearrange(result)
  42. return result