usps.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import os
  2. from pathlib import Path
  3. from typing import Any, Callable, Optional, Union
  4. import numpy as np
  5. from ..utils import _Image_fromarray
  6. from .utils import download_url
  7. from .vision import VisionDataset
  8. class USPS(VisionDataset):
  9. """`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.
  10. The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``.
  11. The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
  12. and make pixel values in ``[0, 255]``.
  13. Args:
  14. root (str or ``pathlib.Path``): Root directory of dataset to store``USPS`` data files.
  15. train (bool, optional): If True, creates dataset from ``usps.bz2``,
  16. otherwise from ``usps.t.bz2``.
  17. transform (callable, optional): A function/transform that takes in a PIL image
  18. and returns a transformed version. E.g, ``transforms.RandomCrop``
  19. target_transform (callable, optional): A function/transform that takes in the
  20. target and transforms it.
  21. download (bool, optional): If true, downloads the dataset from the internet and
  22. puts it in root directory. If dataset is already downloaded, it is not
  23. downloaded again.
  24. """
  25. split_list = {
  26. "train": [
  27. "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
  28. "usps.bz2",
  29. "ec16c51db3855ca6c91edd34d0e9b197",
  30. ],
  31. "test": [
  32. "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
  33. "usps.t.bz2",
  34. "8ea070ee2aca1ac39742fdd1ef5ed118",
  35. ],
  36. }
  37. def __init__(
  38. self,
  39. root: Union[str, Path],
  40. train: bool = True,
  41. transform: Optional[Callable] = None,
  42. target_transform: Optional[Callable] = None,
  43. download: bool = False,
  44. ) -> None:
  45. super().__init__(root, transform=transform, target_transform=target_transform)
  46. split = "train" if train else "test"
  47. url, filename, checksum = self.split_list[split]
  48. full_path = os.path.join(self.root, filename)
  49. if download and not os.path.exists(full_path):
  50. download_url(url, self.root, filename, md5=checksum)
  51. import bz2
  52. with bz2.open(full_path) as fp:
  53. raw_data = [line.decode().split() for line in fp.readlines()]
  54. tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
  55. imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
  56. imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
  57. targets = [int(d[0]) - 1 for d in raw_data]
  58. self.data = imgs
  59. self.targets = targets
  60. def __getitem__(self, index: int) -> tuple[Any, Any]:
  61. """
  62. Args:
  63. index (int): Index
  64. Returns:
  65. tuple: (image, target) where target is index of the target class.
  66. """
  67. img, target = self.data[index], int(self.targets[index])
  68. # doing this so that it is consistent with all other datasets
  69. # to return a PIL Image
  70. img = _Image_fromarray(img, mode="L")
  71. if self.transform is not None:
  72. img = self.transform(img)
  73. if self.target_transform is not None:
  74. target = self.target_transform(target)
  75. return img, target
  76. def __len__(self) -> int:
  77. return len(self.data)