fer2013.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import csv
  2. import pathlib
  3. from typing import Any, Callable, Optional, Union
  4. import torch
  5. from PIL import Image
  6. from .utils import check_integrity, verify_str_arg
  7. from .vision import VisionDataset
  8. class FER2013(VisionDataset):
  9. """`FER2013
  10. <https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
  11. .. note::
  12. This dataset can return test labels only if ``fer2013.csv`` OR
  13. ``icml_face_data.csv`` are present in ``root/fer2013/``. If only
  14. ``train.csv`` and ``test.csv`` are present, the test labels are set to
  15. ``None``.
  16. Args:
  17. root (str or ``pathlib.Path``): Root directory of dataset where directory
  18. ``root/fer2013`` exists. This directory may contain either
  19. ``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and
  20. ``test.csv``. Precendence is given in that order, i.e. if
  21. ``fer2013.csv`` is present then the rest of the files will be
  22. ignored. All these (combinations of) files contain the same data and
  23. are supported for convenience, but only ``fer2013.csv`` and
  24. ``icml_face_data.csv`` are able to return non-None test labels.
  25. split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
  26. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
  27. version. E.g, ``transforms.RandomCrop``
  28. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
  29. """
  30. _RESOURCES = {
  31. "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
  32. "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
  33. # The fer2013.csv and icml_face_data.csv files contain both train and
  34. # tests instances, and unlike test.csv they contain the labels for the
  35. # test instances. We give these 2 files precedence over train.csv and
  36. # test.csv. And yes, they both contain the same data, but with different
  37. # column names (note the spaces) and ordering:
  38. # $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv
  39. # ==> fer2013.csv <==
  40. # emotion,pixels,Usage
  41. #
  42. # ==> icml_face_data.csv <==
  43. # emotion, Usage, pixels
  44. #
  45. # ==> train.csv <==
  46. # emotion,pixels
  47. #
  48. # ==> test.csv <==
  49. # pixels
  50. "fer": ("fer2013.csv", "f8428a1edbd21e88f42c73edd2a14f95"),
  51. "icml": ("icml_face_data.csv", "b114b9e04e6949e5fe8b6a98b3892b1d"),
  52. }
  53. def __init__(
  54. self,
  55. root: Union[str, pathlib.Path],
  56. split: str = "train",
  57. transform: Optional[Callable] = None,
  58. target_transform: Optional[Callable] = None,
  59. ) -> None:
  60. self._split = verify_str_arg(split, "split", ("train", "test"))
  61. super().__init__(root, transform=transform, target_transform=target_transform)
  62. base_folder = pathlib.Path(self.root) / "fer2013"
  63. use_fer_file = (base_folder / self._RESOURCES["fer"][0]).exists()
  64. use_icml_file = not use_fer_file and (base_folder / self._RESOURCES["icml"][0]).exists()
  65. file_name, md5 = self._RESOURCES["fer" if use_fer_file else "icml" if use_icml_file else self._split]
  66. data_file = base_folder / file_name
  67. if not check_integrity(str(data_file), md5=md5):
  68. raise RuntimeError(
  69. f"{file_name} not found in {base_folder} or corrupted. "
  70. f"You can download it from "
  71. f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
  72. )
  73. pixels_key = " pixels" if use_icml_file else "pixels"
  74. usage_key = " Usage" if use_icml_file else "Usage"
  75. def get_img(row):
  76. return torch.tensor([int(idx) for idx in row[pixels_key].split()], dtype=torch.uint8).reshape(48, 48)
  77. def get_label(row):
  78. if use_fer_file or use_icml_file or self._split == "train":
  79. return int(row["emotion"])
  80. else:
  81. return None
  82. with open(data_file, newline="") as file:
  83. rows = (row for row in csv.DictReader(file))
  84. if use_fer_file or use_icml_file:
  85. valid_keys = ("Training",) if self._split == "train" else ("PublicTest", "PrivateTest")
  86. rows = (row for row in rows if row[usage_key] in valid_keys)
  87. self._samples = [(get_img(row), get_label(row)) for row in rows]
  88. def __len__(self) -> int:
  89. return len(self._samples)
  90. def __getitem__(self, idx: int) -> tuple[Any, Any]:
  91. image_tensor, target = self._samples[idx]
  92. image = Image.fromarray(image_tensor.numpy())
  93. if self.transform is not None:
  94. image = self.transform(image)
  95. if self.target_transform is not None:
  96. target = self.target_transform(target)
  97. return image, target
  98. def extra_repr(self) -> str:
  99. return f"split={self._split}"