compression.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import base64
  2. import logging
  3. import time
  4. import numpy as np
  5. from ray import cloudpickle as pickle
  6. from ray.rllib.utils.annotations import DeveloperAPI
  7. logger = logging.getLogger(__name__)
  8. try:
  9. import lz4.frame
  10. LZ4_ENABLED = True
  11. except ImportError:
  12. logger.warning(
  13. "lz4 not available, disabling sample compression. "
  14. "This will significantly impact RLlib performance. "
  15. "To install lz4, run `pip install lz4`."
  16. )
  17. LZ4_ENABLED = False
  18. @DeveloperAPI
  19. def compression_supported():
  20. return LZ4_ENABLED
  21. @DeveloperAPI
  22. def pack(data):
  23. if LZ4_ENABLED:
  24. data = pickle.dumps(data)
  25. data = lz4.frame.compress(data)
  26. # TODO(ekl) we shouldn't need to base64 encode this data, but this
  27. # seems to not survive a transfer through the object store if we don't.
  28. data = base64.b64encode(data).decode("ascii")
  29. return data
  30. @DeveloperAPI
  31. def pack_if_needed(data):
  32. if isinstance(data, np.ndarray):
  33. data = pack(data)
  34. return data
  35. @DeveloperAPI
  36. def unpack(data):
  37. if LZ4_ENABLED:
  38. data = base64.b64decode(data)
  39. data = lz4.frame.decompress(data)
  40. data = pickle.loads(data)
  41. return data
  42. @DeveloperAPI
  43. def unpack_if_needed(data):
  44. if is_compressed(data):
  45. data = unpack(data)
  46. return data
  47. @DeveloperAPI
  48. def is_compressed(data):
  49. return isinstance(data, bytes) or isinstance(data, str)
  50. # Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz
  51. # Compression speed: 753.664 MB/s
  52. # Compression ratio: 87.4839812046
  53. # Decompression speed: 910.9504 MB/s
  54. if __name__ == "__main__":
  55. size = 32 * 80 * 80 * 4
  56. data = np.ones(size).reshape((32, 80, 80, 4))
  57. count = 0
  58. start = time.time()
  59. while time.time() - start < 1:
  60. pack(data)
  61. count += 1
  62. compressed = pack(data)
  63. print("Compression speed: {} MB/s".format(count * size * 4 / 1e6))
  64. print("Compression ratio: {}".format(round(size * 4 / len(compressed), 2)))
  65. count = 0
  66. start = time.time()
  67. while time.time() - start < 1:
  68. unpack(compressed)
  69. count += 1
  70. print("Decompression speed: {} MB/s".format(count * size * 4 / 1e6))