test__datasource.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import os
  2. import urllib.request as urllib_request
  3. from shutil import rmtree
  4. from tempfile import NamedTemporaryFile, mkdtemp, mkstemp
  5. from urllib.error import URLError
  6. from urllib.parse import urlparse
  7. import pytest
  8. import numpy.lib._datasource as datasource
  9. from numpy.testing import assert_, assert_equal, assert_raises
  10. def urlopen_stub(url, data=None):
  11. '''Stub to replace urlopen for testing.'''
  12. if url == valid_httpurl():
  13. tmpfile = NamedTemporaryFile(prefix='urltmp_')
  14. return tmpfile
  15. else:
  16. raise URLError('Name or service not known')
  17. # setup and teardown
  18. old_urlopen = None
  19. def setup_module():
  20. global old_urlopen
  21. old_urlopen = urllib_request.urlopen
  22. urllib_request.urlopen = urlopen_stub
  23. def teardown_module():
  24. urllib_request.urlopen = old_urlopen
  25. # A valid website for more robust testing
  26. http_path = 'http://www.google.com/'
  27. http_file = 'index.html'
  28. http_fakepath = 'http://fake.abc.web/site/'
  29. http_fakefile = 'fake.txt'
  30. malicious_files = ['/etc/shadow', '../../shadow',
  31. '..\\system.dat', 'c:\\windows\\system.dat']
  32. magic_line = b'three is the magic number'
  33. # Utility functions used by many tests
  34. def valid_textfile(filedir):
  35. # Generate and return a valid temporary file.
  36. fd, path = mkstemp(suffix='.txt', prefix='dstmp_', dir=filedir, text=True)
  37. os.close(fd)
  38. return path
  39. def invalid_textfile(filedir):
  40. # Generate and return an invalid filename.
  41. fd, path = mkstemp(suffix='.txt', prefix='dstmp_', dir=filedir)
  42. os.close(fd)
  43. os.remove(path)
  44. return path
  45. def valid_httpurl():
  46. return http_path + http_file
  47. def invalid_httpurl():
  48. return http_fakepath + http_fakefile
  49. def valid_baseurl():
  50. return http_path
  51. def invalid_baseurl():
  52. return http_fakepath
  53. def valid_httpfile():
  54. return http_file
  55. def invalid_httpfile():
  56. return http_fakefile
  57. class TestDataSourceOpen:
  58. def test_ValidHTTP(self, tmp_path):
  59. ds = datasource.DataSource(tmp_path)
  60. fh = ds.open(valid_httpurl())
  61. assert_(fh)
  62. fh.close()
  63. def test_InvalidHTTP(self, tmp_path):
  64. ds = datasource.DataSource(tmp_path)
  65. url = invalid_httpurl()
  66. assert_raises(OSError, ds.open, url)
  67. try:
  68. ds.open(url)
  69. except OSError as e:
  70. # Regression test for bug fixed in r4342.
  71. assert_(e.errno is None)
  72. def test_InvalidHTTPCacheURLError(self, tmp_path):
  73. ds = datasource.DataSource(tmp_path)
  74. assert_raises(URLError, ds._cache, invalid_httpurl())
  75. def test_ValidFile(self, tmp_path):
  76. ds = datasource.DataSource(tmp_path)
  77. local_file = valid_textfile(tmp_path)
  78. fh = ds.open(local_file)
  79. assert_(fh)
  80. fh.close()
  81. def test_InvalidFile(self, tmp_path):
  82. ds = datasource.DataSource(tmp_path)
  83. invalid_file = invalid_textfile(tmp_path)
  84. assert_raises(OSError, ds.open, invalid_file)
  85. def test_ValidGzipFile(self, tmp_path):
  86. try:
  87. import gzip
  88. except ImportError:
  89. # We don't have the gzip capabilities to test.
  90. pytest.skip()
  91. # Test datasource's internal file_opener for Gzip files.
  92. ds = datasource.DataSource(tmp_path)
  93. filepath = os.path.join(tmp_path, 'foobar.txt.gz')
  94. fp = gzip.open(filepath, 'w')
  95. fp.write(magic_line)
  96. fp.close()
  97. fp = ds.open(filepath)
  98. result = fp.readline()
  99. fp.close()
  100. assert_equal(magic_line, result)
  101. def test_ValidBz2File(self, tmp_path):
  102. try:
  103. import bz2
  104. except ImportError:
  105. # We don't have the bz2 capabilities to test.
  106. pytest.skip()
  107. # Test datasource's internal file_opener for BZip2 files.
  108. ds = datasource.DataSource(tmp_path)
  109. filepath = os.path.join(tmp_path, 'foobar.txt.bz2')
  110. fp = bz2.BZ2File(filepath, 'w')
  111. fp.write(magic_line)
  112. fp.close()
  113. fp = ds.open(filepath)
  114. result = fp.readline()
  115. fp.close()
  116. assert_equal(magic_line, result)
  117. class TestDataSourceExists:
  118. def test_ValidHTTP(self, tmp_path):
  119. ds = datasource.DataSource(tmp_path)
  120. assert_(ds.exists(valid_httpurl()))
  121. def test_InvalidHTTP(self, tmp_path):
  122. ds = datasource.DataSource(tmp_path)
  123. assert_equal(ds.exists(invalid_httpurl()), False)
  124. def test_ValidFile(self, tmp_path):
  125. # Test valid file in destpath
  126. ds = datasource.DataSource(tmp_path)
  127. tmpfile = valid_textfile(tmp_path)
  128. assert_(ds.exists(tmpfile))
  129. # Test valid local file not in destpath
  130. localdir = mkdtemp()
  131. tmpfile = valid_textfile(localdir)
  132. assert_(ds.exists(tmpfile))
  133. rmtree(localdir)
  134. def test_InvalidFile(self, tmp_path):
  135. ds = datasource.DataSource(tmp_path)
  136. tmpfile = invalid_textfile(tmp_path)
  137. assert_equal(ds.exists(tmpfile), False)
  138. class TestDataSourceAbspath:
  139. def test_ValidHTTP(self, tmp_path):
  140. ds = datasource.DataSource(tmp_path)
  141. _, netloc, upath, _, _, _ = urlparse(valid_httpurl())
  142. local_path = os.path.join(tmp_path, netloc,
  143. upath.strip(os.sep).strip('/'))
  144. assert_equal(local_path, ds.abspath(valid_httpurl()))
  145. def test_ValidFile(self, tmp_path):
  146. ds = datasource.DataSource(tmp_path)
  147. tmpfile = valid_textfile(tmp_path)
  148. tmpfilename = os.path.split(tmpfile)[-1]
  149. # Test with filename only
  150. assert_equal(tmpfile, ds.abspath(tmpfilename))
  151. # Test filename with complete path
  152. assert_equal(tmpfile, ds.abspath(tmpfile))
  153. def test_InvalidHTTP(self, tmp_path):
  154. ds = datasource.DataSource(tmp_path)
  155. _, netloc, upath, _, _, _ = urlparse(invalid_httpurl())
  156. invalidhttp = os.path.join(tmp_path, netloc,
  157. upath.strip(os.sep).strip('/'))
  158. assert_(invalidhttp != ds.abspath(valid_httpurl()))
  159. def test_InvalidFile(self, tmp_path):
  160. ds = datasource.DataSource(tmp_path)
  161. invalidfile = valid_textfile(tmp_path)
  162. tmpfile = valid_textfile(tmp_path)
  163. tmpfilename = os.path.split(tmpfile)[-1]
  164. # Test with filename only
  165. assert_(invalidfile != ds.abspath(tmpfilename))
  166. # Test filename with complete path
  167. assert_(invalidfile != ds.abspath(tmpfile))
  168. def test_sandboxing(self, tmp_path):
  169. ds = datasource.DataSource(tmp_path)
  170. tmpfile = valid_textfile(tmp_path)
  171. tmpfilename = os.path.split(tmpfile)[-1]
  172. path = lambda x: os.path.abspath(ds.abspath(x))
  173. assert_(path(valid_httpurl()).startswith(str(tmp_path)))
  174. assert_(path(invalid_httpurl()).startswith(str(tmp_path)))
  175. assert_(path(tmpfile).startswith(str(tmp_path)))
  176. assert_(path(tmpfilename).startswith(str(tmp_path)))
  177. for fn in malicious_files:
  178. assert_(path(http_path + fn).startswith(str(tmp_path)))
  179. assert_(path(fn).startswith(str(tmp_path)))
  180. def test_windows_os_sep(self, tmp_path):
  181. orig_os_sep = os.sep
  182. try:
  183. os.sep = '\\'
  184. self.test_ValidHTTP(tmp_path)
  185. self.test_ValidFile(tmp_path)
  186. self.test_InvalidHTTP(tmp_path)
  187. self.test_InvalidFile(tmp_path)
  188. self.test_sandboxing(tmp_path)
  189. finally:
  190. os.sep = orig_os_sep
  191. class TestRepositoryAbspath:
  192. def test_ValidHTTP(self, tmp_path):
  193. repos = datasource.Repository(valid_baseurl(), tmp_path)
  194. _, netloc, upath, _, _, _ = urlparse(valid_httpurl())
  195. local_path = os.path.join(repos._destpath, netloc,
  196. upath.strip(os.sep).strip('/'))
  197. filepath = repos.abspath(valid_httpfile())
  198. assert_equal(local_path, filepath)
  199. def test_sandboxing(self, tmp_path):
  200. repos = datasource.Repository(valid_baseurl(), tmp_path)
  201. path = lambda x: os.path.abspath(repos.abspath(x))
  202. assert_(path(valid_httpfile()).startswith(str(tmp_path)))
  203. for fn in malicious_files:
  204. assert_(path(http_path + fn).startswith(str(tmp_path)))
  205. assert_(path(fn).startswith(str(tmp_path)))
  206. def test_windows_os_sep(self, tmp_path):
  207. orig_os_sep = os.sep
  208. try:
  209. os.sep = '\\'
  210. self.test_ValidHTTP(tmp_path)
  211. self.test_sandboxing(tmp_path)
  212. finally:
  213. os.sep = orig_os_sep
  214. class TestRepositoryExists:
  215. def test_ValidFile(self, tmp_path):
  216. # Create local temp file
  217. repos = datasource.Repository(valid_baseurl(), tmp_path)
  218. tmpfile = valid_textfile(tmp_path)
  219. assert_(repos.exists(tmpfile))
  220. def test_InvalidFile(self, tmp_path):
  221. repos = datasource.Repository(valid_baseurl(), tmp_path)
  222. tmpfile = invalid_textfile(tmp_path)
  223. assert_equal(repos.exists(tmpfile), False)
  224. def test_RemoveHTTPFile(self, tmp_path):
  225. repos = datasource.Repository(valid_baseurl(), tmp_path)
  226. assert_(repos.exists(valid_httpurl()))
  227. def test_CachedHTTPFile(self, tmp_path):
  228. localfile = valid_httpurl()
  229. # Create a locally cached temp file with an URL based
  230. # directory structure. This is similar to what Repository.open
  231. # would do.
  232. repos = datasource.Repository(valid_baseurl(), tmp_path)
  233. _, netloc, _, _, _, _ = urlparse(localfile)
  234. local_path = os.path.join(repos._destpath, netloc)
  235. os.mkdir(local_path, 0o0700)
  236. tmpfile = valid_textfile(local_path)
  237. assert_(repos.exists(tmpfile))
  238. class TestOpenFunc:
  239. def test_DataSourceOpen(self, tmp_path):
  240. local_file = valid_textfile(tmp_path)
  241. # Test case where destpath is passed in
  242. fp = datasource.open(local_file, destpath=tmp_path)
  243. assert_(fp)
  244. fp.close()
  245. # Test case where default destpath is used
  246. fp = datasource.open(local_file)
  247. assert_(fp)
  248. fp.close()
  249. def test_del_attr_handling():
  250. # DataSource __del__ can be called
  251. # even if __init__ fails when the
  252. # Exception object is caught by the
  253. # caller as happens in refguide_check
  254. # is_deprecated() function
  255. ds = datasource.DataSource()
  256. # simulate failed __init__ by removing key attribute
  257. # produced within __init__ and expected by __del__
  258. del ds._istmpdest
  259. # should not raise an AttributeError if __del__
  260. # gracefully handles failed __init__:
  261. ds.__del__()