test_streams.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. """ Testing
  2. """
  3. import platform
  4. import os
  5. import random
  6. import sys
  7. import zlib
  8. from io import BytesIO
  9. from tempfile import mkstemp
  10. from contextlib import contextmanager
  11. import numpy as np
  12. from numpy.testing import assert_, assert_equal
  13. from pytest import raises as assert_raises
  14. import pytest
  15. from scipy.io.matlab._streams import (make_stream,
  16. GenericStream, ZlibInputStream,
  17. _read_into, _read_string, BLOCK_SIZE)
  18. @contextmanager
  19. def setup_test_file():
  20. val = b'a\x00string'
  21. fd, fname = mkstemp()
  22. with os.fdopen(fd, 'wb') as fs:
  23. fs.write(val)
  24. with open(fname, 'rb') as fs:
  25. gs = BytesIO(val)
  26. cs = BytesIO(val)
  27. yield fs, gs, cs
  28. os.unlink(fname)
  29. def test_make_stream():
  30. with setup_test_file() as (fs, gs, cs):
  31. # test stream initialization
  32. assert_(isinstance(make_stream(gs), GenericStream))
  33. def test_tell_seek():
  34. with setup_test_file() as (fs, gs, cs):
  35. for s in (fs, gs, cs):
  36. st = make_stream(s)
  37. res = st.seek(0)
  38. assert_equal(res, 0)
  39. assert_equal(st.tell(), 0)
  40. res = st.seek(5)
  41. assert_equal(res, 0)
  42. assert_equal(st.tell(), 5)
  43. res = st.seek(2, 1)
  44. assert_equal(res, 0)
  45. assert_equal(st.tell(), 7)
  46. res = st.seek(-2, 2)
  47. assert_equal(res, 0)
  48. assert_equal(st.tell(), 6)
  49. def test_read():
  50. with setup_test_file() as (fs, gs, cs):
  51. for s in (fs, gs, cs):
  52. st = make_stream(s)
  53. st.seek(0)
  54. res = st.read(-1)
  55. assert_equal(res, b'a\x00string')
  56. st.seek(0)
  57. res = st.read(4)
  58. assert_equal(res, b'a\x00st')
  59. # read into
  60. st.seek(0)
  61. res = _read_into(st, 4)
  62. assert_equal(res, b'a\x00st')
  63. res = _read_into(st, 4)
  64. assert_equal(res, b'ring')
  65. assert_raises(OSError, _read_into, st, 2)
  66. # read alloc
  67. st.seek(0)
  68. res = _read_string(st, 4)
  69. assert_equal(res, b'a\x00st')
  70. res = _read_string(st, 4)
  71. assert_equal(res, b'ring')
  72. assert_raises(OSError, _read_string, st, 2)
  73. class TestZlibInputStream:
  74. def _get_data(self, size):
  75. data = random.randbytes(size)
  76. compressed_data = zlib.compress(data)
  77. stream = BytesIO(compressed_data)
  78. return stream, len(compressed_data), data
  79. def test_read(self):
  80. SIZES = [0, 1, 10, BLOCK_SIZE//2, BLOCK_SIZE-1,
  81. BLOCK_SIZE, BLOCK_SIZE+1, 2*BLOCK_SIZE-1]
  82. READ_SIZES = [BLOCK_SIZE//2, BLOCK_SIZE-1,
  83. BLOCK_SIZE, BLOCK_SIZE+1]
  84. def check(size, read_size):
  85. compressed_stream, compressed_data_len, data = self._get_data(size)
  86. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  87. data2 = b''
  88. so_far = 0
  89. while True:
  90. block = stream.read(min(read_size,
  91. size - so_far))
  92. if not block:
  93. break
  94. so_far += len(block)
  95. data2 += block
  96. assert_equal(data, data2)
  97. for size in SIZES:
  98. for read_size in READ_SIZES:
  99. check(size, read_size)
  100. def test_read_max_length(self):
  101. data = random.randbytes(1234)
  102. compressed_data = zlib.compress(data)
  103. compressed_stream = BytesIO(compressed_data + b"abbacaca")
  104. stream = ZlibInputStream(compressed_stream, len(compressed_data))
  105. stream.read(len(data))
  106. assert_equal(compressed_stream.tell(), len(compressed_data))
  107. assert_raises(OSError, stream.read, 1)
  108. def test_read_bad_checksum(self):
  109. data = random.randbytes(10)
  110. compressed_data = zlib.compress(data)
  111. # break checksum
  112. compressed_data = (compressed_data[:-1]
  113. + bytes([(compressed_data[-1] + 1) & 255]))
  114. compressed_stream = BytesIO(compressed_data)
  115. stream = ZlibInputStream(compressed_stream, len(compressed_data))
  116. assert_raises(zlib.error, stream.read, len(data))
  117. def test_seek(self):
  118. compressed_stream, compressed_data_len, data = self._get_data(1024)
  119. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  120. stream.seek(123)
  121. p = 123
  122. assert_equal(stream.tell(), p)
  123. d1 = stream.read(11)
  124. assert_equal(d1, data[p:p+11])
  125. stream.seek(321, 1)
  126. p = 123+11+321
  127. assert_equal(stream.tell(), p)
  128. d2 = stream.read(21)
  129. assert_equal(d2, data[p:p+21])
  130. stream.seek(641, 0)
  131. p = 641
  132. assert_equal(stream.tell(), p)
  133. d3 = stream.read(11)
  134. assert_equal(d3, data[p:p+11])
  135. assert_raises(OSError, stream.seek, 10, 2)
  136. assert_raises(OSError, stream.seek, -1, 1)
  137. assert_raises(ValueError, stream.seek, 1, 123)
  138. stream.seek(10000, 1)
  139. assert_raises(OSError, stream.read, 12)
  140. def test_seek_bad_checksum(self):
  141. data = random.randbytes(10)
  142. compressed_data = zlib.compress(data)
  143. # break checksum
  144. compressed_data = (compressed_data[:-1]
  145. + bytes([(compressed_data[-1] + 1) & 255]))
  146. compressed_stream = BytesIO(compressed_data)
  147. stream = ZlibInputStream(compressed_stream, len(compressed_data))
  148. assert_raises(zlib.error, stream.seek, len(data))
  149. def test_all_data_read(self):
  150. compressed_stream, compressed_data_len, data = self._get_data(1024)
  151. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  152. assert_(not stream.all_data_read())
  153. stream.seek(512)
  154. assert_(not stream.all_data_read())
  155. stream.seek(1024)
  156. assert_(stream.all_data_read())
  157. @pytest.mark.skipif(
  158. (platform.system() == 'Windows' and sys.version_info >= (3, 14)),
  159. reason='gh-23185')
  160. def test_all_data_read_overlap(self):
  161. COMPRESSION_LEVEL = 6
  162. data = np.arange(33707000, dtype=np.uint8)
  163. compressed_data = zlib.compress(data, COMPRESSION_LEVEL)
  164. compressed_data_len = len(compressed_data)
  165. # check that part of the checksum overlaps
  166. assert_(compressed_data_len == BLOCK_SIZE + 2)
  167. compressed_stream = BytesIO(compressed_data)
  168. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  169. assert_(not stream.all_data_read())
  170. stream.seek(len(data))
  171. assert_(stream.all_data_read())
  172. @pytest.mark.skipif(
  173. (platform.system() == 'Windows' and sys.version_info >= (3, 14)),
  174. reason='gh-23185')
  175. def test_all_data_read_bad_checksum(self):
  176. COMPRESSION_LEVEL = 6
  177. data = np.arange(33707000, dtype=np.uint8)
  178. compressed_data = zlib.compress(data, COMPRESSION_LEVEL)
  179. compressed_data_len = len(compressed_data)
  180. # check that part of the checksum overlaps
  181. assert_(compressed_data_len == BLOCK_SIZE + 2)
  182. # break checksum
  183. compressed_data = (compressed_data[:-1]
  184. + bytes([(compressed_data[-1] + 1) & 255]))
  185. compressed_stream = BytesIO(compressed_data)
  186. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  187. assert_(not stream.all_data_read())
  188. stream.seek(len(data))
  189. assert_raises(zlib.error, stream.all_data_read)