transpiler.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """Module for transpiling Kornia to other frameworks."""
  18. from types import ModuleType
  19. import kornia
  20. from kornia.core.external import ivy
  21. def to_jax() -> ModuleType:
  22. """Convert Kornia to JAX.
  23. Transpiles the Kornia library to JAX using [ivy](https://github.com/ivy-llc/ivy). The transpilation process
  24. occurs lazily, so the transpilation on a given kornia function/class will only occur when it's called or
  25. instantiated for the first time. This will make any functions/classes slow when being used for the first time,
  26. but any subsequent uses should be as fast as expected.
  27. Return:
  28. The Kornia library transpiled to JAX
  29. Example:
  30. .. highlight:: python
  31. .. code-block:: python
  32. import kornia
  33. jax_kornia = kornia.to_jax()
  34. import jax
  35. input = jax.random.normal(jax.random.key(42), shape=(2, 3, 4, 5))
  36. gray = jax_kornia.color.gray.rgb_to_grayscale(input)
  37. """
  38. return ivy.transpile(
  39. kornia,
  40. source="torch",
  41. target="jax",
  42. ) # type: ignore
  43. def to_numpy() -> ModuleType:
  44. """Convert Kornia to NumPy.
  45. Transpiles the Kornia library to NumPy using [ivy](https://github.com/ivy-llc/ivy). The transpilation process
  46. occurs lazily, so the transpilation on a given kornia function/class will only occur when it's called or
  47. instantiated for the first time. This will make any functions/classes slow when being used for the first time,
  48. but any subsequent uses should be as fast as expected.
  49. Return:
  50. The Kornia library transpiled to NumPy
  51. Example:
  52. .. highlight:: python
  53. .. code-block:: python
  54. import kornia
  55. np_kornia = kornia.to_numpy()
  56. import numpy as np
  57. input = np.random.normal(size=(2, 3, 4, 5))
  58. gray = np_kornia.color.gray.rgb_to_grayscale(input)
  59. Note:
  60. Ivy does not currently support transpiling trainable modules to NumPy.
  61. """
  62. return ivy.transpile(
  63. kornia,
  64. source="torch",
  65. target="numpy",
  66. ) # type: ignore
  67. def to_tensorflow() -> ModuleType:
  68. """Convert Kornia to TensorFlow.
  69. Transpiles the Kornia library to TensorFlow using [ivy](https://github.com/ivy-llc/ivy). The transpilation process
  70. occurs lazily, so the transpilation on a given kornia function/class will only occur when it's called or
  71. instantiated for the first time. This will make any functions/classes slow when being used for the first time,
  72. but any subsequent uses should be as fast as expected.
  73. Return:
  74. The Kornia library transpiled to TensorFlow
  75. Example:
  76. .. highlight:: python
  77. .. code-block:: python
  78. import kornia
  79. tf_kornia = kornia.to_tensorflow()
  80. import tensorflow as tf
  81. input = tf.random.normal((2, 3, 4, 5))
  82. gray = tf_kornia.color.gray.rgb_to_grayscale(input)
  83. """
  84. return ivy.transpile(
  85. kornia,
  86. source="torch",
  87. target="tensorflow",
  88. ) # type: ignore