Procházet zdrojové kódy

整体流程完美

yichael před 1 měsícem
rodič
revize
0c79e6fbe6
100 změnil soubory, kde provedl 17224 přidání a 0 odebrání
  1. 9 0
      bring-cmd-window-foreground.ps1
  2. 4 0
      python/LightGlue/.flake8
  3. 1 0
      python/LightGlue/.gitattributes
  4. 24 0
      python/LightGlue/.github/workflows/code-quality.yml
  5. 166 0
      python/LightGlue/.gitignore
  6. 201 0
      python/LightGlue/LICENSE
  7. 183 0
      python/LightGlue/README.md
  8. binární
      python/LightGlue/assets/DSC_0410.JPG
  9. binární
      python/LightGlue/assets/DSC_0411.JPG
  10. 718 0
      python/LightGlue/assets/architecture.svg
  11. binární
      python/LightGlue/assets/benchmark.png
  12. binární
      python/LightGlue/assets/benchmark_cpu.png
  13. binární
      python/LightGlue/assets/easy_hard.jpg
  14. binární
      python/LightGlue/assets/sacre_coeur1.jpg
  15. binární
      python/LightGlue/assets/sacre_coeur2.jpg
  16. 1499 0
      python/LightGlue/assets/teaser.svg
  17. 255 0
      python/LightGlue/benchmark.py
  18. 77 0
      python/LightGlue/demo.ipynb
  19. 7 0
      python/LightGlue/lightglue/__init__.py
  20. 775 0
      python/LightGlue/lightglue/aliked.py
  21. 55 0
      python/LightGlue/lightglue/disk.py
  22. 41 0
      python/LightGlue/lightglue/dog_hardnet.py
  23. 667 0
      python/LightGlue/lightglue/lightglue.py
  24. 216 0
      python/LightGlue/lightglue/sift.py
  25. 227 0
      python/LightGlue/lightglue/superpoint.py
  26. 165 0
      python/LightGlue/lightglue/utils.py
  27. 203 0
      python/LightGlue/lightglue/viz2d.py
  28. 30 0
      python/LightGlue/pyproject.toml
  29. 6 0
      python/LightGlue/requirements.txt
  30. 1 0
      python/py/Lib/site-packages/torchvision-0.26.0.dist-info/INSTALLER
  31. 29 0
      python/py/Lib/site-packages/torchvision-0.26.0.dist-info/LICENSE
  32. 128 0
      python/py/Lib/site-packages/torchvision-0.26.0.dist-info/METADATA
  33. 380 0
      python/py/Lib/site-packages/torchvision-0.26.0.dist-info/RECORD
  34. 0 0
      python/py/Lib/site-packages/torchvision-0.26.0.dist-info/REQUESTED
  35. 5 0
      python/py/Lib/site-packages/torchvision-0.26.0.dist-info/WHEEL
  36. 1 0
      python/py/Lib/site-packages/torchvision-0.26.0.dist-info/top_level.txt
  37. binární
      python/py/Lib/site-packages/torchvision/_C.pyd
  38. 73 0
      python/py/Lib/site-packages/torchvision/__init__.py
  39. 51 0
      python/py/Lib/site-packages/torchvision/_internally_replaced_utils.py
  40. 225 0
      python/py/Lib/site-packages/torchvision/_meta_registrations.py
  41. 33 0
      python/py/Lib/site-packages/torchvision/_utils.py
  42. 147 0
      python/py/Lib/site-packages/torchvision/datasets/__init__.py
  43. 520 0
      python/py/Lib/site-packages/torchvision/datasets/_optical_flow.py
  44. 1223 0
      python/py/Lib/site-packages/torchvision/datasets/_stereo_matching.py
  45. 241 0
      python/py/Lib/site-packages/torchvision/datasets/caltech.py
  46. 210 0
      python/py/Lib/site-packages/torchvision/datasets/celeba.py
  47. 167 0
      python/py/Lib/site-packages/torchvision/datasets/cifar.py
  48. 222 0
      python/py/Lib/site-packages/torchvision/datasets/cityscapes.py
  49. 93 0
      python/py/Lib/site-packages/torchvision/datasets/clevr.py
  50. 111 0
      python/py/Lib/site-packages/torchvision/datasets/coco.py
  51. 67 0
      python/py/Lib/site-packages/torchvision/datasets/country211.py
  52. 105 0
      python/py/Lib/site-packages/torchvision/datasets/dtd.py
  53. 71 0
      python/py/Lib/site-packages/torchvision/datasets/eurosat.py
  54. 67 0
      python/py/Lib/site-packages/torchvision/datasets/fakedata.py
  55. 120 0
      python/py/Lib/site-packages/torchvision/datasets/fer2013.py
  56. 120 0
      python/py/Lib/site-packages/torchvision/datasets/fgvc_aircraft.py
  57. 176 0
      python/py/Lib/site-packages/torchvision/datasets/flickr.py
  58. 225 0
      python/py/Lib/site-packages/torchvision/datasets/flowers102.py
  59. 337 0
      python/py/Lib/site-packages/torchvision/datasets/folder.py
  60. 98 0
      python/py/Lib/site-packages/torchvision/datasets/food101.py
  61. 103 0
      python/py/Lib/site-packages/torchvision/datasets/gtsrb.py
  62. 152 0
      python/py/Lib/site-packages/torchvision/datasets/hmdb51.py
  63. 222 0
      python/py/Lib/site-packages/torchvision/datasets/imagenet.py
  64. 104 0
      python/py/Lib/site-packages/torchvision/datasets/imagenette.py
  65. 245 0
      python/py/Lib/site-packages/torchvision/datasets/inaturalist.py
  66. 237 0
      python/py/Lib/site-packages/torchvision/datasets/kinetics.py
  67. 158 0
      python/py/Lib/site-packages/torchvision/datasets/kitti.py
  68. 268 0
      python/py/Lib/site-packages/torchvision/datasets/lfw.py
  69. 168 0
      python/py/Lib/site-packages/torchvision/datasets/lsun.py
  70. 560 0
      python/py/Lib/site-packages/torchvision/datasets/mnist.py
  71. 94 0
      python/py/Lib/site-packages/torchvision/datasets/moving_mnist.py
  72. 107 0
      python/py/Lib/site-packages/torchvision/datasets/omniglot.py
  73. 135 0
      python/py/Lib/site-packages/torchvision/datasets/oxford_iiit_pet.py
  74. 134 0
      python/py/Lib/site-packages/torchvision/datasets/pcam.py
  75. 230 0
      python/py/Lib/site-packages/torchvision/datasets/phototour.py
  76. 176 0
      python/py/Lib/site-packages/torchvision/datasets/places365.py
  77. 89 0
      python/py/Lib/site-packages/torchvision/datasets/rendered_sst2.py
  78. 3 0
      python/py/Lib/site-packages/torchvision/datasets/samplers/__init__.py
  79. 173 0
      python/py/Lib/site-packages/torchvision/datasets/samplers/clip_sampler.py
  80. 126 0
      python/py/Lib/site-packages/torchvision/datasets/sbd.py
  81. 114 0
      python/py/Lib/site-packages/torchvision/datasets/sbu.py
  82. 92 0
      python/py/Lib/site-packages/torchvision/datasets/semeion.py
  83. 105 0
      python/py/Lib/site-packages/torchvision/datasets/stanford_cars.py
  84. 174 0
      python/py/Lib/site-packages/torchvision/datasets/stl10.py
  85. 81 0
      python/py/Lib/site-packages/torchvision/datasets/sun397.py
  86. 130 0
      python/py/Lib/site-packages/torchvision/datasets/svhn.py
  87. 131 0
      python/py/Lib/site-packages/torchvision/datasets/ucf101.py
  88. 96 0
      python/py/Lib/site-packages/torchvision/datasets/usps.py
  89. 468 0
      python/py/Lib/site-packages/torchvision/datasets/utils.py
  90. 384 0
      python/py/Lib/site-packages/torchvision/datasets/video_utils.py
  91. 111 0
      python/py/Lib/site-packages/torchvision/datasets/vision.py
  92. 224 0
      python/py/Lib/site-packages/torchvision/datasets/voc.py
  93. 196 0
      python/py/Lib/site-packages/torchvision/datasets/widerface.py
  94. 76 0
      python/py/Lib/site-packages/torchvision/extension.py
  95. binární
      python/py/Lib/site-packages/torchvision/image.pyd
  96. 56 0
      python/py/Lib/site-packages/torchvision/io/__init__.py
  97. 527 0
      python/py/Lib/site-packages/torchvision/io/image.py
  98. binární
      python/py/Lib/site-packages/torchvision/jpeg8.dll
  99. binární
      python/py/Lib/site-packages/torchvision/libjpeg.dll
  100. binární
      python/py/Lib/site-packages/torchvision/libpng16.dll

+ 9 - 0
bring-cmd-window-foreground.ps1

@@ -0,0 +1,9 @@
+Add-Type @"
+using System;
+using System.Runtime.InteropServices;
+public static class ConsoleForegroundHelper {
+  [DllImport("user32.dll")] public static extern bool SetForegroundWindow(IntPtr windowHandle);
+  [DllImport("kernel32.dll")] public static extern IntPtr GetConsoleWindow();
+}
+"@
+[void][ConsoleForegroundHelper]::SetForegroundWindow([ConsoleForegroundHelper]::GetConsoleWindow())

+ 4 - 0
python/LightGlue/.flake8

@@ -0,0 +1,4 @@
+[flake8]
+max-line-length = 88
+extend-ignore = E203
+exclude = .git,__pycache__,build,.venv/

+ 1 - 0
python/LightGlue/.gitattributes

@@ -0,0 +1 @@
+*.ipynb linguist-documentation

+ 24 - 0
python/LightGlue/.github/workflows/code-quality.yml

@@ -0,0 +1,24 @@
+name: Format and Lint Checks
+on:
+  push:
+    branches:
+      - main
+    paths:
+      - '*.py'
+  pull_request:
+    types: [ assigned, opened, synchronize, reopened ]
+jobs:
+  check:
+    name: Format and Lint Checks
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v3
+      - uses: actions/setup-python@v4
+        with:
+          python-version: '3.10'
+          cache: 'pip'
+      - run: python -m pip install --upgrade pip
+      - run: python -m pip install .[dev]
+      - run: python -m flake8 .
+      - run: python -m isort . --check-only --diff
+      - run: python -m black . --check --diff

+ 166 - 0
python/LightGlue/.gitignore

@@ -0,0 +1,166 @@
+/data/
+/outputs/
+/lightglue/weights/
+*-checkpoint.ipynb
+*.pth
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+#   For a library or package, you might want to ignore these files since the code is
+#   intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+#   This is especially recommended for binary packages to ensure reproducibility, and is more
+#   commonly ignored for libraries.
+#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+#   in version control.
+#   https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+#  and can be added to the global gitignore or merged into this file.  For a more nuclear
+#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/

+ 201 - 0
python/LightGlue/LICENSE

@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright 2023 ETH Zurich
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.

+ 183 - 0
python/LightGlue/README.md

@@ -0,0 +1,183 @@
+<p align="center">
+  <h1 align="center"><ins>LightGlue</ins> ⚡️<br>Local Feature Matching at Light Speed</h1>
+  <p align="center">
+    <a href="https://www.linkedin.com/in/philipplindenberger/">Philipp Lindenberger</a>
+    ·
+    <a href="https://psarlin.com/">Paul-Edouard&nbsp;Sarlin</a>
+    ·
+    <a href="https://www.microsoft.com/en-us/research/people/mapoll/">Marc&nbsp;Pollefeys</a>
+  </p>
+  <h2 align="center">
+    <p>ICCV 2023</p>
+    <a href="https://arxiv.org/pdf/2306.13643.pdf" align="center">Paper</a> | 
+    <a href="https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb" align="center">Colab</a> | 
+    <a href="https://huggingface.co/spaces/ETH-CVG/LightGlue" align="center">🤗 Demo </a> | 
+    <a href="https://psarlin.com/doc/LightGlue_ICCV2023_poster_compressed.pdf" align="center">Poster</a> | 
+    <a href="https://github.com/cvg/glue-factory" align="center"> ⚙️ Train your own</a>
+  </h2>
+
+</p>
+<p align="center">
+    <a href="https://arxiv.org/abs/2306.13643"><img src="assets/easy_hard.jpg" alt="example" width=80%></a>
+    <br>
+    <em>LightGlue is a deep neural network that matches sparse local features across image pairs.<br>An adaptive mechanism makes it fast for easy pairs (top) and reduces the computational complexity for difficult ones (bottom).</em>
+</p>
+
+##
+
+This repository hosts the inference code of LightGlue, a lightweight feature matcher with high accuracy and blazing fast inference. It takes as input a set of keypoints and descriptors for each image and returns the indices of corresponding points. The architecture is based on adaptive pruning techniques, in both network width and depth - [check out the paper for more details](https://arxiv.org/pdf/2306.13643.pdf).
+
+We release pretrained weights of LightGlue with [SuperPoint](https://arxiv.org/abs/1712.07629), [DISK](https://arxiv.org/abs/2006.13566), [ALIKED](https://arxiv.org/abs/2304.03608) and [SIFT](https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf) local features.
+The training and evaluation code can be found in our library [glue-factory](https://github.com/cvg/glue-factory/).
+
+LightGlue is now part of 🤗 [Hugging Face Transformers](https://huggingface.co/docs/transformers/main/en/model_doc/lightglue) (credit to [@sbucaille](https://huggingface.co/stevenbucaille)!). It enables easy inference in a few lines of Python code, using `pip install transformers` ([model card](https://huggingface.co/ETH-CVG/lightglue_superpoint)).
+
+## Installation and demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb) [![](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ETH-CVG/LightGlue) 
+
+Install this repo using pip:
+
+```bash
+git clone https://github.com/cvg/LightGlue.git && cd LightGlue
+python -m pip install -e .
+```
+
+We provide a [demo notebook](demo.ipynb) which shows how to perform feature extraction and matching on an image pair.
+
+Here is a minimal script to match two images:
+
+```python
+from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet
+from lightglue.utils import load_image, rbd
+
+# SuperPoint+LightGlue
+extractor = SuperPoint(max_num_keypoints=2048).eval().cuda()  # load the extractor
+matcher = LightGlue(features='superpoint').eval().cuda()  # load the matcher
+
+# or DISK+LightGlue, ALIKED+LightGlue or SIFT+LightGlue
+extractor = DISK(max_num_keypoints=2048).eval().cuda()  # load the extractor
+matcher = LightGlue(features='disk').eval().cuda()  # load the matcher
+
+# load each image as a torch.Tensor on GPU with shape (3,H,W), normalized in [0,1]
+image0 = load_image('path/to/image_0.jpg').cuda()
+image1 = load_image('path/to/image_1.jpg').cuda()
+
+# extract local features
+feats0 = extractor.extract(image0)  # auto-resize the image, disable with resize=None
+feats1 = extractor.extract(image1)
+
+# match the features
+matches01 = matcher({'image0': feats0, 'image1': feats1})
+feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]]  # remove batch dimension
+matches = matches01['matches']  # indices with shape (K,2)
+points0 = feats0['keypoints'][matches[..., 0]]  # coordinates in image #0, shape (K,2)
+points1 = feats1['keypoints'][matches[..., 1]]  # coordinates in image #1, shape (K,2)
+```
+
+We also provide a convenience method to match a pair of images:
+
+```python
+from lightglue import match_pair
+feats0, feats1, matches01 = match_pair(extractor, matcher, image0, image1)
+```
+
+##
+
+<p align="center">
+  <a href="https://arxiv.org/abs/2306.13643"><img src="assets/teaser.svg" alt="Logo" width=50%></a>
+  <br>
+  <em>LightGlue can adjust its depth (number of layers) and width (number of keypoints) per image pair, with a marginal impact on accuracy.</em>
+</p>
+
+## Advanced configuration
+
+<details>
+<summary>[Detail of all parameters - click to expand]</summary>
+
+- ```n_layers```: Number of stacked self+cross attention layers. Reduce this value for faster inference at the cost of accuracy (continuous red line in the plot above). Default: 9 (all layers).
+- ```flash```: Enable FlashAttention. Significantly increases the speed and reduces the memory consumption without any impact on accuracy. Default: True (LightGlue automatically detects if FlashAttention is available).
+- ```mp```: Enable mixed precision inference. Default: False (off)
+- ```depth_confidence```: Controls the early stopping. A lower values stops more often at earlier layers. Default: 0.95, disable with -1.
+- ```width_confidence```: Controls the iterative point pruning. A lower value prunes more points earlier. Default: 0.99, disable with -1.
+- ```filter_threshold```: Match confidence. Increase this value to obtain less, but stronger matches. Default: 0.1
+
+</details>
+
+The default values give a good trade-off between speed and accuracy. To maximize the accuracy, use all keypoints and disable the adaptive mechanisms:
+```python
+extractor = SuperPoint(max_num_keypoints=None)
+matcher = LightGlue(features='superpoint', depth_confidence=-1, width_confidence=-1)
+```
+
+To increase the speed with a small drop of accuracy, decrease the number of keypoints and lower the adaptive thresholds:
+```python
+extractor = SuperPoint(max_num_keypoints=1024)
+matcher = LightGlue(features='superpoint', depth_confidence=0.9, width_confidence=0.95)
+```
+
+The maximum speed is obtained with a combination of:
+- [FlashAttention](https://arxiv.org/abs/2205.14135): automatically used when ```torch >= 2.0``` or if [installed from source](https://github.com/HazyResearch/flash-attention#installation-and-features).
+- PyTorch compilation, available when ```torch >= 2.0```:
+```python
+matcher = matcher.eval().cuda()
+matcher.compile(mode='reduce-overhead')
+```
+For inputs with fewer than 1536 keypoints (determined experimentally), this compiles LightGlue but disables point pruning (large overhead). For larger input sizes, it automatically falls backs to eager mode with point pruning. Adaptive depths is supported for any input size.
+
+## Benchmark
+
+
+<p align="center">
+  <a><img src="assets/benchmark.png" alt="Logo" width=80%></a>
+  <br>
+  <em>Benchmark results on GPU (RTX 3080). With compilation and adaptivity, LightGlue runs at 150 FPS @ 1024 keypoints and 50 FPS @ 4096 keypoints per image. This is a 4-10x speedup over SuperGlue. </em>
+</p>
+
+<p align="center">
+  <a><img src="assets/benchmark_cpu.png" alt="Logo" width=80%></a>
+  <br>
+  <em>Benchmark results on CPU (Intel i7 10700K). LightGlue runs at 20 FPS @ 512 keypoints. </em>
+</p>
+
+Obtain the same plots for your setup using our [benchmark script](benchmark.py):
+```
+python benchmark.py [--device cuda] [--add_superglue] [--num_keypoints 512 1024 2048 4096] [--compile]
+```
+
+<details>
+<summary>[Performance tip - click to expand]</summary>
+
+Note: **Point pruning** introduces an overhead that sometimes outweighs its benefits.
+Point pruning is thus enabled only when the there are more than N keypoints in an image, where N is hardware-dependent.
+We provide defaults optimized for current hardware (RTX 30xx GPUs).
+We suggest running the benchmark script and adjusting the thresholds for your hardware by updating `LightGlue.pruning_keypoint_thresholds['cuda']`.
+
+</details>
+
+## Training and evaluation
+
+With [Glue Factory](https://github.com/cvg/glue-factory), you can train LightGlue with your own local features, on your own dataset!
+You can also evaluate it and other baselines on standard benchmarks like HPatches and MegaDepth.
+
+## Other links
+- [hloc - the visual localization toolbox](https://github.com/cvg/Hierarchical-Localization/): run LightGlue for Structure-from-Motion and visual localization.
+- [LightGlue-ONNX](https://github.com/fabio-sim/LightGlue-ONNX): export LightGlue to the Open Neural Network Exchange (ONNX) format with support for TensorRT and OpenVINO.
+- [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui): a web GUI to easily compare different matchers, including LightGlue.
+- [kornia](https://kornia.readthedocs.io) now exposes LightGlue via the interfaces [`LightGlue`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlue) and [`LightGlueMatcher`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlueMatcher).
+
+## BibTeX citation
+If you use any ideas from the paper or code from this repo, please consider citing:
+
+```txt
+@inproceedings{lindenberger2023lightglue,
+  author    = {Philipp Lindenberger and
+               Paul-Edouard Sarlin and
+               Marc Pollefeys},
+  title     = {{LightGlue: Local Feature Matching at Light Speed}},
+  booktitle = {ICCV},
+  year      = {2023}
+}
+```
+
+
+## License
+The pre-trained weights of LightGlue and the code provided in this repository are released under the [Apache-2.0 license](./LICENSE). [DISK](https://github.com/cvlab-epfl/disk) follows this license as well but SuperPoint follows [a different, restrictive license](https://github.com/magicleap/SuperPointPretrainedNetwork/blob/master/LICENSE) (this includes its pre-trained weights and its [inference file](./lightglue/superpoint.py)). [ALIKED](https://github.com/Shiaoming/ALIKED) was published under a BSD-3-Clause license. 

binární
python/LightGlue/assets/DSC_0410.JPG


binární
python/LightGlue/assets/DSC_0411.JPG


Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 718 - 0
python/LightGlue/assets/architecture.svg


binární
python/LightGlue/assets/benchmark.png


binární
python/LightGlue/assets/benchmark_cpu.png


binární
python/LightGlue/assets/easy_hard.jpg


binární
python/LightGlue/assets/sacre_coeur1.jpg


binární
python/LightGlue/assets/sacre_coeur2.jpg


+ 1499 - 0
python/LightGlue/assets/teaser.svg

@@ -0,0 +1,1499 @@
+<?xml version="1.0" encoding="utf-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+  "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="351.50156pt" height="237.315312pt" viewBox="0 0 351.50156 237.315312" xmlns="http://www.w3.org/2000/svg" version="1.1">
+ <metadata>
+  <rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
+   <cc:Work>
+    <dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
+    <dc:date>2023-06-25T11:23:59.261938</dc:date>
+    <dc:format>image/svg+xml</dc:format>
+    <dc:creator>
+     <cc:Agent>
+      <dc:title>Matplotlib v3.7.1, https://matplotlib.org/</dc:title>
+     </cc:Agent>
+    </dc:creator>
+   </cc:Work>
+  </rdf:RDF>
+ </metadata>
+ <defs>
+  <style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
+ </defs>
+ <g id="figure_1">
+  <g id="patch_1">
+   <path d="M 0 237.315312 
+L 351.50156 237.315312 
+L 351.50156 0 
+L 0 0 
+z
+" style="fill: #ffffff"/>
+  </g>
+  <g id="axes_1">
+   <g id="patch_2">
+    <path d="M 38.242188 202.12 
+L 351.50156 202.12 
+L 351.50156 0 
+L 38.242188 0 
+z
+" style="fill: #f2f2f2"/>
+   </g>
+   <g id="PathCollection_1">
+    <defs>
+     <path id="md5bda44a6b" d="M 0 2.738613 
+C 0.726289 2.738613 1.422928 2.450055 1.936492 1.936492 
+C 2.450055 1.422928 2.738613 0.726289 2.738613 0 
+C 2.738613 -0.726289 2.450055 -1.422928 1.936492 -1.936492 
+C 1.422928 -2.450055 0.726289 -2.738613 0 -2.738613 
+C -0.726289 -2.738613 -1.422928 -2.450055 -1.936492 -1.936492 
+C -2.450055 -1.422928 -2.738613 -0.726289 -2.738613 0 
+C -2.738613 0.726289 -2.450055 1.422928 -1.936492 1.936492 
+C -1.422928 2.450055 -0.726289 2.738613 0 2.738613 
+z
+"/>
+    </defs>
+    <g clip-path="url(#pb46ed2897c)">
+     <use xlink:href="#md5bda44a6b" x="117.273002" y="77.281176" style="fill: #0000ff"/>
+    </g>
+   </g>
+   <g id="PathCollection_2">
+    <defs>
+     <path id="m3541600ca9" d="M -0 3.872983 
+L 3.872983 -3.872983 
+L -3.872983 -3.872983 
+z
+"/>
+    </defs>
+    <g clip-path="url(#pb46ed2897c)">
+     <use xlink:href="#m3541600ca9" x="113.203664" y="196.175294" style="fill: #008000"/>
+    </g>
+   </g>
+   <g id="PathCollection_3">
+    <defs>
+     <path id="mee49ddcd29" d="M 0 2.738613 
+C 0.726289 2.738613 1.422928 2.450055 1.936492 1.936492 
+C 2.450055 1.422928 2.738613 0.726289 2.738613 0 
+C 2.738613 -0.726289 2.450055 -1.422928 1.936492 -1.936492 
+C 1.422928 -2.450055 0.726289 -2.738613 0 -2.738613 
+C -0.726289 -2.738613 -1.422928 -2.450055 -1.936492 -1.936492 
+C -2.450055 -1.422928 -2.738613 -0.726289 -2.738613 0 
+C -2.738613 0.726289 -2.450055 1.422928 -1.936492 1.936492 
+C -1.422928 2.450055 -0.726289 2.738613 0 2.738613 
+z
+"/>
+    </defs>
+    <g clip-path="url(#pb46ed2897c)">
+     <use xlink:href="#mee49ddcd29" x="68.806591" y="41.612941"/>
+    </g>
+   </g>
+   <g id="PathCollection_4">
+    <defs>
+     <path id="m3986887d56" d="M 0 2.738613 
+C 0.726289 2.738613 1.422928 2.450055 1.936492 1.936492 
+C 2.450055 1.422928 2.738613 0.726289 2.738613 0 
+C 2.738613 -0.726289 2.450055 -1.422928 1.936492 -1.936492 
+C 1.422928 -2.450055 0.726289 -2.738613 0 -2.738613 
+C -0.726289 -2.738613 -1.422928 -2.450055 -1.936492 -1.936492 
+C -2.450055 -1.422928 -2.738613 -0.726289 -2.738613 0 
+C -2.738613 0.726289 -2.450055 1.422928 -1.936492 1.936492 
+C -1.422928 2.450055 -0.726289 2.738613 0 2.738613 
+z
+"/>
+    </defs>
+    <g clip-path="url(#pb46ed2897c)">
+     <use xlink:href="#m3986887d56" x="52.800495" y="34.479294" style="fill: #800080"/>
+    </g>
+   </g>
+   <g id="PathCollection_5">
+    <defs>
+     <path id="m73cb4f1908" d="M 0 -5.91608 
+L -1.328243 -1.828169 
+L -5.626526 -1.828169 
+L -2.149142 0.698298 
+L -3.477384 4.786209 
+L -0 2.259741 
+L 3.477384 4.786209 
+L 2.149142 0.698298 
+L 5.626526 -1.828169 
+L 1.328243 -1.828169 
+z
+"/>
+    </defs>
+    <g clip-path="url(#pb46ed2897c)">
+     <use xlink:href="#m73cb4f1908" x="289.703869" y="47.557647" style="fill: #ff0000"/>
+    </g>
+   </g>
+   <g id="matplotlib.axis_1">
+    <g id="xtick_1">
+     <g id="line2d_1">
+      <defs>
+       <path id="m69d2a2ec97" d="M 0 0 
+L 0 3.5 
+" style="stroke: #000000; stroke-width: 0.8"/>
+      </defs>
+      <g>
+       <use xlink:href="#m69d2a2ec97" x="38.242188" y="202.12" style="stroke: #000000; stroke-width: 0.8"/>
+      </g>
+     </g>
+     <g id="text_1">
+      <!-- 0 -->
+      <g transform="translate(35.060938 216.718437) scale(0.1 -0.1)">
+       <defs>
+        <path id="DejaVuSans-30" d="M 2034 4250 
+Q 1547 4250 1301 3770 
+Q 1056 3291 1056 2328 
+Q 1056 1369 1301 889 
+Q 1547 409 2034 409 
+Q 2525 409 2770 889 
+Q 3016 1369 3016 2328 
+Q 3016 3291 2770 3770 
+Q 2525 4250 2034 4250 
+z
+M 2034 4750 
+Q 2819 4750 3233 4129 
+Q 3647 3509 3647 2328 
+Q 3647 1150 3233 529 
+Q 2819 -91 2034 -91 
+Q 1250 -91 836 529 
+Q 422 1150 422 2328 
+Q 422 3509 836 4129 
+Q 1250 4750 2034 4750 
+z
+" transform="scale(0.015625)"/>
+       </defs>
+       <use xlink:href="#DejaVuSans-30"/>
+      </g>
+     </g>
+    </g>
+    <g id="xtick_2">
+     <g id="line2d_2">
+      <g>
+       <use xlink:href="#m69d2a2ec97" x="93.563757" y="202.12" style="stroke: #000000; stroke-width: 0.8"/>
+      </g>
+     </g>
+     <g id="text_2">
+      <!-- 10 -->
+      <g transform="translate(87.201257 216.718437) scale(0.1 -0.1)">
+       <defs>
+        <path id="DejaVuSans-31" d="M 794 531 
+L 1825 531 
+L 1825 4091 
+L 703 3866 
+L 703 4441 
+L 1819 4666 
+L 2450 4666 
+L 2450 531 
+L 3481 531 
+L 3481 0 
+L 794 0 
+L 794 531 
+z
+" transform="scale(0.015625)"/>
+       </defs>
+       <use xlink:href="#DejaVuSans-31"/>
+       <use xlink:href="#DejaVuSans-30" x="63.623047"/>
+      </g>
+     </g>
+    </g>
+    <g id="xtick_3">
+     <g id="line2d_3">
+      <g>
+       <use xlink:href="#m69d2a2ec97" x="148.885327" y="202.12" style="stroke: #000000; stroke-width: 0.8"/>
+      </g>
+     </g>
+     <g id="text_3">
+      <!-- 20 -->
+      <g transform="translate(142.522827 216.718437) scale(0.1 -0.1)">
+       <defs>
+        <path id="DejaVuSans-32" d="M 1228 531 
+L 3431 531 
+L 3431 0 
+L 469 0 
+L 469 531 
+Q 828 903 1448 1529 
+Q 2069 2156 2228 2338 
+Q 2531 2678 2651 2914 
+Q 2772 3150 2772 3378 
+Q 2772 3750 2511 3984 
+Q 2250 4219 1831 4219 
+Q 1534 4219 1204 4116 
+Q 875 4013 500 3803 
+L 500 4441 
+Q 881 4594 1212 4672 
+Q 1544 4750 1819 4750 
+Q 2544 4750 2975 4387 
+Q 3406 4025 3406 3419 
+Q 3406 3131 3298 2873 
+Q 3191 2616 2906 2266 
+Q 2828 2175 2409 1742 
+Q 1991 1309 1228 531 
+z
+" transform="scale(0.015625)"/>
+       </defs>
+       <use xlink:href="#DejaVuSans-32"/>
+       <use xlink:href="#DejaVuSans-30" x="63.623047"/>
+      </g>
+     </g>
+    </g>
+    <g id="xtick_4">
+     <g id="line2d_4">
+      <g>
+       <use xlink:href="#m69d2a2ec97" x="204.206897" y="202.12" style="stroke: #000000; stroke-width: 0.8"/>
+      </g>
+     </g>
+     <g id="text_4">
+      <!-- 30 -->
+      <g transform="translate(197.844397 216.718437) scale(0.1 -0.1)">
+       <defs>
+        <path id="DejaVuSans-33" d="M 2597 2516 
+Q 3050 2419 3304 2112 
+Q 3559 1806 3559 1356 
+Q 3559 666 3084 287 
+Q 2609 -91 1734 -91 
+Q 1441 -91 1130 -33 
+Q 819 25 488 141 
+L 488 750 
+Q 750 597 1062 519 
+Q 1375 441 1716 441 
+Q 2309 441 2620 675 
+Q 2931 909 2931 1356 
+Q 2931 1769 2642 2001 
+Q 2353 2234 1838 2234 
+L 1294 2234 
+L 1294 2753 
+L 1863 2753 
+Q 2328 2753 2575 2939 
+Q 2822 3125 2822 3475 
+Q 2822 3834 2567 4026 
+Q 2313 4219 1838 4219 
+Q 1578 4219 1281 4162 
+Q 984 4106 628 3988 
+L 628 4550 
+Q 988 4650 1302 4700 
+Q 1616 4750 1894 4750 
+Q 2613 4750 3031 4423 
+Q 3450 4097 3450 3541 
+Q 3450 3153 3228 2886 
+Q 3006 2619 2597 2516 
+z
+" transform="scale(0.015625)"/>
+       </defs>
+       <use xlink:href="#DejaVuSans-33"/>
+       <use xlink:href="#DejaVuSans-30" x="63.623047"/>
+      </g>
+     </g>
+    </g>
+    <g id="xtick_5">
+     <g id="line2d_5">
+      <g>
+       <use xlink:href="#m69d2a2ec97" x="259.528467" y="202.12" style="stroke: #000000; stroke-width: 0.8"/>
+      </g>
+     </g>
+     <g id="text_5">
+      <!-- 40 -->
+      <g transform="translate(253.165967 216.718437) scale(0.1 -0.1)">
+       <defs>
+        <path id="DejaVuSans-34" d="M 2419 4116 
+L 825 1625 
+L 2419 1625 
+L 2419 4116 
+z
+M 2253 4666 
+L 3047 4666 
+L 3047 1625 
+L 3713 1625 
+L 3713 1100 
+L 3047 1100 
+L 3047 0 
+L 2419 0 
+L 2419 1100 
+L 313 1100 
+L 313 1709 
+L 2253 4666 
+z
+" transform="scale(0.015625)"/>
+       </defs>
+       <use xlink:href="#DejaVuSans-34"/>
+       <use xlink:href="#DejaVuSans-30" x="63.623047"/>
+      </g>
+     </g>
+    </g>
+    <g id="xtick_6">
+     <g id="line2d_6">
+      <g>
+       <use xlink:href="#m69d2a2ec97" x="314.850037" y="202.12" style="stroke: #000000; stroke-width: 0.8"/>
+      </g>
+     </g>
+     <g id="text_6">
+      <!-- 50 -->
+      <g transform="translate(308.487537 216.718437) scale(0.1 -0.1)">
+       <defs>
+        <path id="DejaVuSans-35" d="M 691 4666 
+L 3169 4666 
+L 3169 4134 
+L 1269 4134 
+L 1269 2991 
+Q 1406 3038 1543 3061 
+Q 1681 3084 1819 3084 
+Q 2600 3084 3056 2656 
+Q 3513 2228 3513 1497 
+Q 3513 744 3044 326 
+Q 2575 -91 1722 -91 
+Q 1428 -91 1123 -41 
+Q 819 9 494 109 
+L 494 744 
+Q 775 591 1075 516 
+Q 1375 441 1709 441 
+Q 2250 441 2565 725 
+Q 2881 1009 2881 1497 
+Q 2881 1984 2565 2268 
+Q 2250 2553 1709 2553 
+Q 1456 2553 1204 2497 
+Q 953 2441 691 2322 
+L 691 4666 
+z
+" transform="scale(0.015625)"/>
+       </defs>
+       <use xlink:href="#DejaVuSans-35"/>
+       <use xlink:href="#DejaVuSans-30" x="63.623047"/>
+      </g>
+     </g>
+    </g>
+    <g id="text_7">
+     <!-- Image Pairs Per Second -->
+     <g transform="translate(106.824218 234.195781) scale(0.15 -0.15)">
+      <defs>
+       <path id="DejaVuSans-49" d="M 628 4666 
+L 1259 4666 
+L 1259 0 
+L 628 0 
+L 628 4666 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-6d" d="M 3328 2828 
+Q 3544 3216 3844 3400 
+Q 4144 3584 4550 3584 
+Q 5097 3584 5394 3201 
+Q 5691 2819 5691 2113 
+L 5691 0 
+L 5113 0 
+L 5113 2094 
+Q 5113 2597 4934 2840 
+Q 4756 3084 4391 3084 
+Q 3944 3084 3684 2787 
+Q 3425 2491 3425 1978 
+L 3425 0 
+L 2847 0 
+L 2847 2094 
+Q 2847 2600 2669 2842 
+Q 2491 3084 2119 3084 
+Q 1678 3084 1418 2786 
+Q 1159 2488 1159 1978 
+L 1159 0 
+L 581 0 
+L 581 3500 
+L 1159 3500 
+L 1159 2956 
+Q 1356 3278 1631 3431 
+Q 1906 3584 2284 3584 
+Q 2666 3584 2933 3390 
+Q 3200 3197 3328 2828 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-61" d="M 2194 1759 
+Q 1497 1759 1228 1600 
+Q 959 1441 959 1056 
+Q 959 750 1161 570 
+Q 1363 391 1709 391 
+Q 2188 391 2477 730 
+Q 2766 1069 2766 1631 
+L 2766 1759 
+L 2194 1759 
+z
+M 3341 1997 
+L 3341 0 
+L 2766 0 
+L 2766 531 
+Q 2569 213 2275 61 
+Q 1981 -91 1556 -91 
+Q 1019 -91 701 211 
+Q 384 513 384 1019 
+Q 384 1609 779 1909 
+Q 1175 2209 1959 2209 
+L 2766 2209 
+L 2766 2266 
+Q 2766 2663 2505 2880 
+Q 2244 3097 1772 3097 
+Q 1472 3097 1187 3025 
+Q 903 2953 641 2809 
+L 641 3341 
+Q 956 3463 1253 3523 
+Q 1550 3584 1831 3584 
+Q 2591 3584 2966 3190 
+Q 3341 2797 3341 1997 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-67" d="M 2906 1791 
+Q 2906 2416 2648 2759 
+Q 2391 3103 1925 3103 
+Q 1463 3103 1205 2759 
+Q 947 2416 947 1791 
+Q 947 1169 1205 825 
+Q 1463 481 1925 481 
+Q 2391 481 2648 825 
+Q 2906 1169 2906 1791 
+z
+M 3481 434 
+Q 3481 -459 3084 -895 
+Q 2688 -1331 1869 -1331 
+Q 1566 -1331 1297 -1286 
+Q 1028 -1241 775 -1147 
+L 775 -588 
+Q 1028 -725 1275 -790 
+Q 1522 -856 1778 -856 
+Q 2344 -856 2625 -561 
+Q 2906 -266 2906 331 
+L 2906 616 
+Q 2728 306 2450 153 
+Q 2172 0 1784 0 
+Q 1141 0 747 490 
+Q 353 981 353 1791 
+Q 353 2603 747 3093 
+Q 1141 3584 1784 3584 
+Q 2172 3584 2450 3431 
+Q 2728 3278 2906 2969 
+L 2906 3500 
+L 3481 3500 
+L 3481 434 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-65" d="M 3597 1894 
+L 3597 1613 
+L 953 1613 
+Q 991 1019 1311 708 
+Q 1631 397 2203 397 
+Q 2534 397 2845 478 
+Q 3156 559 3463 722 
+L 3463 178 
+Q 3153 47 2828 -22 
+Q 2503 -91 2169 -91 
+Q 1331 -91 842 396 
+Q 353 884 353 1716 
+Q 353 2575 817 3079 
+Q 1281 3584 2069 3584 
+Q 2775 3584 3186 3129 
+Q 3597 2675 3597 1894 
+z
+M 3022 2063 
+Q 3016 2534 2758 2815 
+Q 2500 3097 2075 3097 
+Q 1594 3097 1305 2825 
+Q 1016 2553 972 2059 
+L 3022 2063 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-20" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-50" d="M 1259 4147 
+L 1259 2394 
+L 2053 2394 
+Q 2494 2394 2734 2622 
+Q 2975 2850 2975 3272 
+Q 2975 3691 2734 3919 
+Q 2494 4147 2053 4147 
+L 1259 4147 
+z
+M 628 4666 
+L 2053 4666 
+Q 2838 4666 3239 4311 
+Q 3641 3956 3641 3272 
+Q 3641 2581 3239 2228 
+Q 2838 1875 2053 1875 
+L 1259 1875 
+L 1259 0 
+L 628 0 
+L 628 4666 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-69" d="M 603 3500 
+L 1178 3500 
+L 1178 0 
+L 603 0 
+L 603 3500 
+z
+M 603 4863 
+L 1178 4863 
+L 1178 4134 
+L 603 4134 
+L 603 4863 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-72" d="M 2631 2963 
+Q 2534 3019 2420 3045 
+Q 2306 3072 2169 3072 
+Q 1681 3072 1420 2755 
+Q 1159 2438 1159 1844 
+L 1159 0 
+L 581 0 
+L 581 3500 
+L 1159 3500 
+L 1159 2956 
+Q 1341 3275 1631 3429 
+Q 1922 3584 2338 3584 
+Q 2397 3584 2469 3576 
+Q 2541 3569 2628 3553 
+L 2631 2963 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-73" d="M 2834 3397 
+L 2834 2853 
+Q 2591 2978 2328 3040 
+Q 2066 3103 1784 3103 
+Q 1356 3103 1142 2972 
+Q 928 2841 928 2578 
+Q 928 2378 1081 2264 
+Q 1234 2150 1697 2047 
+L 1894 2003 
+Q 2506 1872 2764 1633 
+Q 3022 1394 3022 966 
+Q 3022 478 2636 193 
+Q 2250 -91 1575 -91 
+Q 1294 -91 989 -36 
+Q 684 19 347 128 
+L 347 722 
+Q 666 556 975 473 
+Q 1284 391 1588 391 
+Q 1994 391 2212 530 
+Q 2431 669 2431 922 
+Q 2431 1156 2273 1281 
+Q 2116 1406 1581 1522 
+L 1381 1569 
+Q 847 1681 609 1914 
+Q 372 2147 372 2553 
+Q 372 3047 722 3315 
+Q 1072 3584 1716 3584 
+Q 2034 3584 2315 3537 
+Q 2597 3491 2834 3397 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-53" d="M 3425 4513 
+L 3425 3897 
+Q 3066 4069 2747 4153 
+Q 2428 4238 2131 4238 
+Q 1616 4238 1336 4038 
+Q 1056 3838 1056 3469 
+Q 1056 3159 1242 3001 
+Q 1428 2844 1947 2747 
+L 2328 2669 
+Q 3034 2534 3370 2195 
+Q 3706 1856 3706 1288 
+Q 3706 609 3251 259 
+Q 2797 -91 1919 -91 
+Q 1588 -91 1214 -16 
+Q 841 59 441 206 
+L 441 856 
+Q 825 641 1194 531 
+Q 1563 422 1919 422 
+Q 2459 422 2753 634 
+Q 3047 847 3047 1241 
+Q 3047 1584 2836 1778 
+Q 2625 1972 2144 2069 
+L 1759 2144 
+Q 1053 2284 737 2584 
+Q 422 2884 422 3419 
+Q 422 4038 858 4394 
+Q 1294 4750 2059 4750 
+Q 2388 4750 2728 4690 
+Q 3069 4631 3425 4513 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-63" d="M 3122 3366 
+L 3122 2828 
+Q 2878 2963 2633 3030 
+Q 2388 3097 2138 3097 
+Q 1578 3097 1268 2742 
+Q 959 2388 959 1747 
+Q 959 1106 1268 751 
+Q 1578 397 2138 397 
+Q 2388 397 2633 464 
+Q 2878 531 3122 666 
+L 3122 134 
+Q 2881 22 2623 -34 
+Q 2366 -91 2075 -91 
+Q 1284 -91 818 406 
+Q 353 903 353 1747 
+Q 353 2603 823 3093 
+Q 1294 3584 2113 3584 
+Q 2378 3584 2631 3529 
+Q 2884 3475 3122 3366 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-6f" d="M 1959 3097 
+Q 1497 3097 1228 2736 
+Q 959 2375 959 1747 
+Q 959 1119 1226 758 
+Q 1494 397 1959 397 
+Q 2419 397 2687 759 
+Q 2956 1122 2956 1747 
+Q 2956 2369 2687 2733 
+Q 2419 3097 1959 3097 
+z
+M 1959 3584 
+Q 2709 3584 3137 3096 
+Q 3566 2609 3566 1747 
+Q 3566 888 3137 398 
+Q 2709 -91 1959 -91 
+Q 1206 -91 779 398 
+Q 353 888 353 1747 
+Q 353 2609 779 3096 
+Q 1206 3584 1959 3584 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-6e" d="M 3513 2113 
+L 3513 0 
+L 2938 0 
+L 2938 2094 
+Q 2938 2591 2744 2837 
+Q 2550 3084 2163 3084 
+Q 1697 3084 1428 2787 
+Q 1159 2491 1159 1978 
+L 1159 0 
+L 581 0 
+L 581 3500 
+L 1159 3500 
+L 1159 2956 
+Q 1366 3272 1645 3428 
+Q 1925 3584 2291 3584 
+Q 2894 3584 3203 3211 
+Q 3513 2838 3513 2113 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-64" d="M 2906 2969 
+L 2906 4863 
+L 3481 4863 
+L 3481 0 
+L 2906 0 
+L 2906 525 
+Q 2725 213 2448 61 
+Q 2172 -91 1784 -91 
+Q 1150 -91 751 415 
+Q 353 922 353 1747 
+Q 353 2572 751 3078 
+Q 1150 3584 1784 3584 
+Q 2172 3584 2448 3432 
+Q 2725 3281 2906 2969 
+z
+M 947 1747 
+Q 947 1113 1208 752 
+Q 1469 391 1925 391 
+Q 2381 391 2643 752 
+Q 2906 1113 2906 1747 
+Q 2906 2381 2643 2742 
+Q 2381 3103 1925 3103 
+Q 1469 3103 1208 2742 
+Q 947 2381 947 1747 
+z
+" transform="scale(0.015625)"/>
+      </defs>
+      <use xlink:href="#DejaVuSans-49"/>
+      <use xlink:href="#DejaVuSans-6d" x="29.492188"/>
+      <use xlink:href="#DejaVuSans-61" x="126.904297"/>
+      <use xlink:href="#DejaVuSans-67" x="188.183594"/>
+      <use xlink:href="#DejaVuSans-65" x="251.660156"/>
+      <use xlink:href="#DejaVuSans-20" x="313.183594"/>
+      <use xlink:href="#DejaVuSans-50" x="344.970703"/>
+      <use xlink:href="#DejaVuSans-61" x="400.773438"/>
+      <use xlink:href="#DejaVuSans-69" x="462.052734"/>
+      <use xlink:href="#DejaVuSans-72" x="489.835938"/>
+      <use xlink:href="#DejaVuSans-73" x="530.949219"/>
+      <use xlink:href="#DejaVuSans-20" x="583.048828"/>
+      <use xlink:href="#DejaVuSans-50" x="614.835938"/>
+      <use xlink:href="#DejaVuSans-65" x="671.513672"/>
+      <use xlink:href="#DejaVuSans-72" x="733.037109"/>
+      <use xlink:href="#DejaVuSans-20" x="774.150391"/>
+      <use xlink:href="#DejaVuSans-53" x="805.9375"/>
+      <use xlink:href="#DejaVuSans-65" x="869.414062"/>
+      <use xlink:href="#DejaVuSans-63" x="930.9375"/>
+      <use xlink:href="#DejaVuSans-6f" x="985.917969"/>
+      <use xlink:href="#DejaVuSans-6e" x="1047.099609"/>
+      <use xlink:href="#DejaVuSans-64" x="1110.478516"/>
+     </g>
+    </g>
+   </g>
+   <g id="matplotlib.axis_2">
+    <g id="ytick_1">
+     <g id="line2d_7">
+      <defs>
+       <path id="m433b6a5b4b" d="M 0 0 
+L -3.5 0 
+" style="stroke: #000000; stroke-width: 0.8"/>
+      </defs>
+      <g>
+       <use xlink:href="#m433b6a5b4b" x="38.242188" y="184.285882" style="stroke: #000000; stroke-width: 0.8"/>
+      </g>
+     </g>
+     <g id="text_8">
+      <!-- 64 -->
+      <g transform="translate(18.517188 188.085101) scale(0.1 -0.1)">
+       <defs>
+        <path id="DejaVuSans-36" d="M 2113 2584 
+Q 1688 2584 1439 2293 
+Q 1191 2003 1191 1497 
+Q 1191 994 1439 701 
+Q 1688 409 2113 409 
+Q 2538 409 2786 701 
+Q 3034 994 3034 1497 
+Q 3034 2003 2786 2293 
+Q 2538 2584 2113 2584 
+z
+M 3366 4563 
+L 3366 3988 
+Q 3128 4100 2886 4159 
+Q 2644 4219 2406 4219 
+Q 1781 4219 1451 3797 
+Q 1122 3375 1075 2522 
+Q 1259 2794 1537 2939 
+Q 1816 3084 2150 3084 
+Q 2853 3084 3261 2657 
+Q 3669 2231 3669 1497 
+Q 3669 778 3244 343 
+Q 2819 -91 2113 -91 
+Q 1303 -91 875 529 
+Q 447 1150 447 2328 
+Q 447 3434 972 4092 
+Q 1497 4750 2381 4750 
+Q 2619 4750 2861 4703 
+Q 3103 4656 3366 4563 
+z
+" transform="scale(0.015625)"/>
+       </defs>
+       <use xlink:href="#DejaVuSans-36"/>
+       <use xlink:href="#DejaVuSans-34" x="63.623047"/>
+      </g>
+     </g>
+    </g>
+    <g id="ytick_2">
+     <g id="line2d_8">
+      <g>
+       <use xlink:href="#m433b6a5b4b" x="38.242188" y="124.838824" style="stroke: #000000; stroke-width: 0.8"/>
+      </g>
+     </g>
+     <g id="text_9">
+      <!-- 65 -->
+      <g transform="translate(18.517188 128.638042) scale(0.1 -0.1)">
+       <use xlink:href="#DejaVuSans-36"/>
+       <use xlink:href="#DejaVuSans-35" x="63.623047"/>
+      </g>
+     </g>
+    </g>
+    <g id="ytick_3">
+     <g id="line2d_9">
+      <g>
+       <use xlink:href="#m433b6a5b4b" x="38.242188" y="65.391765" style="stroke: #000000; stroke-width: 0.8"/>
+      </g>
+     </g>
+     <g id="text_10">
+      <!-- 66 -->
+      <g transform="translate(18.517188 69.190983) scale(0.1 -0.1)">
+       <use xlink:href="#DejaVuSans-36"/>
+       <use xlink:href="#DejaVuSans-36" x="63.623047"/>
+      </g>
+     </g>
+    </g>
+    <g id="ytick_4">
+     <g id="line2d_10">
+      <g>
+       <use xlink:href="#m433b6a5b4b" x="38.242188" y="5.944706" style="stroke: #000000; stroke-width: 0.8"/>
+      </g>
+     </g>
+     <g id="text_11">
+      <!-- 67 -->
+      <g transform="translate(18.517188 9.743925) scale(0.1 -0.1)">
+       <defs>
+        <path id="DejaVuSans-37" d="M 525 4666 
+L 3525 4666 
+L 3525 4397 
+L 1831 0 
+L 1172 0 
+L 2766 4134 
+L 525 4134 
+L 525 4666 
+z
+" transform="scale(0.015625)"/>
+       </defs>
+       <use xlink:href="#DejaVuSans-36"/>
+       <use xlink:href="#DejaVuSans-37" x="63.623047"/>
+      </g>
+     </g>
+    </g>
+    <g id="text_12">
+     <!-- Relative Pose Accuracy [%] -->
+     <g transform="translate(11.397656 203.038906) rotate(-90) scale(0.15 -0.15)">
+      <defs>
+       <path id="DejaVuSans-52" d="M 2841 2188 
+Q 3044 2119 3236 1894 
+Q 3428 1669 3622 1275 
+L 4263 0 
+L 3584 0 
+L 2988 1197 
+Q 2756 1666 2539 1819 
+Q 2322 1972 1947 1972 
+L 1259 1972 
+L 1259 0 
+L 628 0 
+L 628 4666 
+L 2053 4666 
+Q 2853 4666 3247 4331 
+Q 3641 3997 3641 3322 
+Q 3641 2881 3436 2590 
+Q 3231 2300 2841 2188 
+z
+M 1259 4147 
+L 1259 2491 
+L 2053 2491 
+Q 2509 2491 2742 2702 
+Q 2975 2913 2975 3322 
+Q 2975 3731 2742 3939 
+Q 2509 4147 2053 4147 
+L 1259 4147 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-6c" d="M 603 4863 
+L 1178 4863 
+L 1178 0 
+L 603 0 
+L 603 4863 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-74" d="M 1172 4494 
+L 1172 3500 
+L 2356 3500 
+L 2356 3053 
+L 1172 3053 
+L 1172 1153 
+Q 1172 725 1289 603 
+Q 1406 481 1766 481 
+L 2356 481 
+L 2356 0 
+L 1766 0 
+Q 1100 0 847 248 
+Q 594 497 594 1153 
+L 594 3053 
+L 172 3053 
+L 172 3500 
+L 594 3500 
+L 594 4494 
+L 1172 4494 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-76" d="M 191 3500 
+L 800 3500 
+L 1894 563 
+L 2988 3500 
+L 3597 3500 
+L 2284 0 
+L 1503 0 
+L 191 3500 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-41" d="M 2188 4044 
+L 1331 1722 
+L 3047 1722 
+L 2188 4044 
+z
+M 1831 4666 
+L 2547 4666 
+L 4325 0 
+L 3669 0 
+L 3244 1197 
+L 1141 1197 
+L 716 0 
+L 50 0 
+L 1831 4666 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-75" d="M 544 1381 
+L 544 3500 
+L 1119 3500 
+L 1119 1403 
+Q 1119 906 1312 657 
+Q 1506 409 1894 409 
+Q 2359 409 2629 706 
+Q 2900 1003 2900 1516 
+L 2900 3500 
+L 3475 3500 
+L 3475 0 
+L 2900 0 
+L 2900 538 
+Q 2691 219 2414 64 
+Q 2138 -91 1772 -91 
+Q 1169 -91 856 284 
+Q 544 659 544 1381 
+z
+M 1991 3584 
+L 1991 3584 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-79" d="M 2059 -325 
+Q 1816 -950 1584 -1140 
+Q 1353 -1331 966 -1331 
+L 506 -1331 
+L 506 -850 
+L 844 -850 
+Q 1081 -850 1212 -737 
+Q 1344 -625 1503 -206 
+L 1606 56 
+L 191 3500 
+L 800 3500 
+L 1894 763 
+L 2988 3500 
+L 3597 3500 
+L 2059 -325 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-5b" d="M 550 4863 
+L 1875 4863 
+L 1875 4416 
+L 1125 4416 
+L 1125 -397 
+L 1875 -397 
+L 1875 -844 
+L 550 -844 
+L 550 4863 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-25" d="M 4653 2053 
+Q 4381 2053 4226 1822 
+Q 4072 1591 4072 1178 
+Q 4072 772 4226 539 
+Q 4381 306 4653 306 
+Q 4919 306 5073 539 
+Q 5228 772 5228 1178 
+Q 5228 1588 5073 1820 
+Q 4919 2053 4653 2053 
+z
+M 4653 2450 
+Q 5147 2450 5437 2106 
+Q 5728 1763 5728 1178 
+Q 5728 594 5436 251 
+Q 5144 -91 4653 -91 
+Q 4153 -91 3862 251 
+Q 3572 594 3572 1178 
+Q 3572 1766 3864 2108 
+Q 4156 2450 4653 2450 
+z
+M 1428 4353 
+Q 1159 4353 1004 4120 
+Q 850 3888 850 3481 
+Q 850 3069 1003 2837 
+Q 1156 2606 1428 2606 
+Q 1700 2606 1854 2837 
+Q 2009 3069 2009 3481 
+Q 2009 3884 1853 4118 
+Q 1697 4353 1428 4353 
+z
+M 4250 4750 
+L 4750 4750 
+L 1831 -91 
+L 1331 -91 
+L 4250 4750 
+z
+M 1428 4750 
+Q 1922 4750 2215 4408 
+Q 2509 4066 2509 3481 
+Q 2509 2891 2217 2550 
+Q 1925 2209 1428 2209 
+Q 931 2209 642 2551 
+Q 353 2894 353 3481 
+Q 353 4063 643 4406 
+Q 934 4750 1428 4750 
+z
+" transform="scale(0.015625)"/>
+       <path id="DejaVuSans-5d" d="M 1947 4863 
+L 1947 -844 
+L 622 -844 
+L 622 -397 
+L 1369 -397 
+L 1369 4416 
+L 622 4416 
+L 622 4863 
+L 1947 4863 
+z
+" transform="scale(0.015625)"/>
+      </defs>
+      <use xlink:href="#DejaVuSans-52"/>
+      <use xlink:href="#DejaVuSans-65" x="64.982422"/>
+      <use xlink:href="#DejaVuSans-6c" x="126.505859"/>
+      <use xlink:href="#DejaVuSans-61" x="154.289062"/>
+      <use xlink:href="#DejaVuSans-74" x="215.568359"/>
+      <use xlink:href="#DejaVuSans-69" x="254.777344"/>
+      <use xlink:href="#DejaVuSans-76" x="282.560547"/>
+      <use xlink:href="#DejaVuSans-65" x="341.740234"/>
+      <use xlink:href="#DejaVuSans-20" x="403.263672"/>
+      <use xlink:href="#DejaVuSans-50" x="435.050781"/>
+      <use xlink:href="#DejaVuSans-6f" x="491.728516"/>
+      <use xlink:href="#DejaVuSans-73" x="552.910156"/>
+      <use xlink:href="#DejaVuSans-65" x="605.009766"/>
+      <use xlink:href="#DejaVuSans-20" x="666.533203"/>
+      <use xlink:href="#DejaVuSans-41" x="698.320312"/>
+      <use xlink:href="#DejaVuSans-63" x="764.978516"/>
+      <use xlink:href="#DejaVuSans-63" x="819.958984"/>
+      <use xlink:href="#DejaVuSans-75" x="874.939453"/>
+      <use xlink:href="#DejaVuSans-72" x="938.318359"/>
+      <use xlink:href="#DejaVuSans-61" x="979.431641"/>
+      <use xlink:href="#DejaVuSans-63" x="1040.710938"/>
+      <use xlink:href="#DejaVuSans-79" x="1095.691406"/>
+      <use xlink:href="#DejaVuSans-20" x="1154.871094"/>
+      <use xlink:href="#DejaVuSans-5b" x="1186.658203"/>
+      <use xlink:href="#DejaVuSans-25" x="1225.671875"/>
+      <use xlink:href="#DejaVuSans-5d" x="1320.691406"/>
+     </g>
+    </g>
+   </g>
+   <g id="patch_3">
+    <path d="M 38.242188 202.12 
+L 38.242188 0 
+" style="fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square"/>
+   </g>
+   <g id="patch_4">
+    <path d="M 351.50156 202.12 
+L 351.50156 0 
+" style="fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square"/>
+   </g>
+   <g id="patch_5">
+    <path d="M 38.242188 202.12 
+L 351.50156 202.12 
+" style="fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square"/>
+   </g>
+   <g id="patch_6">
+    <path d="M 38.242188 0 
+L 351.50156 0 
+" style="fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square"/>
+   </g>
+   <g id="text_13">
+    <!-- SuperGlue -->
+    <g style="fill: #0000ff" transform="translate(73.036283 100.678833) scale(0.15 -0.15)">
+     <defs>
+      <path id="DejaVuSans-70" d="M 1159 525 
+L 1159 -1331 
+L 581 -1331 
+L 581 3500 
+L 1159 3500 
+L 1159 2969 
+Q 1341 3281 1617 3432 
+Q 1894 3584 2278 3584 
+Q 2916 3584 3314 3078 
+Q 3713 2572 3713 1747 
+Q 3713 922 3314 415 
+Q 2916 -91 2278 -91 
+Q 1894 -91 1617 61 
+Q 1341 213 1159 525 
+z
+M 3116 1747 
+Q 3116 2381 2855 2742 
+Q 2594 3103 2138 3103 
+Q 1681 3103 1420 2742 
+Q 1159 2381 1159 1747 
+Q 1159 1113 1420 752 
+Q 1681 391 2138 391 
+Q 2594 391 2855 752 
+Q 3116 1113 3116 1747 
+z
+" transform="scale(0.015625)"/>
+      <path id="DejaVuSans-47" d="M 3809 666 
+L 3809 1919 
+L 2778 1919 
+L 2778 2438 
+L 4434 2438 
+L 4434 434 
+Q 4069 175 3628 42 
+Q 3188 -91 2688 -91 
+Q 1594 -91 976 548 
+Q 359 1188 359 2328 
+Q 359 3472 976 4111 
+Q 1594 4750 2688 4750 
+Q 3144 4750 3555 4637 
+Q 3966 4525 4313 4306 
+L 4313 3634 
+Q 3963 3931 3569 4081 
+Q 3175 4231 2741 4231 
+Q 1884 4231 1454 3753 
+Q 1025 3275 1025 2328 
+Q 1025 1384 1454 906 
+Q 1884 428 2741 428 
+Q 3075 428 3337 486 
+Q 3600 544 3809 666 
+z
+" transform="scale(0.015625)"/>
+     </defs>
+     <use xlink:href="#DejaVuSans-53"/>
+     <use xlink:href="#DejaVuSans-75" x="63.476562"/>
+     <use xlink:href="#DejaVuSans-70" x="126.855469"/>
+     <use xlink:href="#DejaVuSans-65" x="190.332031"/>
+     <use xlink:href="#DejaVuSans-72" x="251.855469"/>
+     <use xlink:href="#DejaVuSans-47" x="292.96875"/>
+     <use xlink:href="#DejaVuSans-6c" x="370.458984"/>
+     <use xlink:href="#DejaVuSans-75" x="398.242188"/>
+     <use xlink:href="#DejaVuSans-65" x="461.621094"/>
+    </g>
+   </g>
+   <g id="text_14">
+    <!-- SGMNet -->
+    <g style="fill: #008000" transform="translate(87.993899 188.055763) scale(0.15 -0.15)">
+     <defs>
+      <path id="DejaVuSans-4d" d="M 628 4666 
+L 1569 4666 
+L 2759 1491 
+L 3956 4666 
+L 4897 4666 
+L 4897 0 
+L 4281 0 
+L 4281 4097 
+L 3078 897 
+L 2444 897 
+L 1241 4097 
+L 1241 0 
+L 628 0 
+L 628 4666 
+z
+" transform="scale(0.015625)"/>
+      <path id="DejaVuSans-4e" d="M 628 4666 
+L 1478 4666 
+L 3547 763 
+L 3547 4666 
+L 4159 4666 
+L 4159 0 
+L 3309 0 
+L 1241 3903 
+L 1241 0 
+L 628 0 
+L 628 4666 
+z
+" transform="scale(0.015625)"/>
+     </defs>
+     <use xlink:href="#DejaVuSans-53"/>
+     <use xlink:href="#DejaVuSans-47" x="63.476562"/>
+     <use xlink:href="#DejaVuSans-4d" x="140.966797"/>
+     <use xlink:href="#DejaVuSans-4e" x="227.246094"/>
+     <use xlink:href="#DejaVuSans-65" x="302.050781"/>
+     <use xlink:href="#DejaVuSans-74" x="363.574219"/>
+    </g>
+   </g>
+   <g id="text_15">
+    <!-- LoFTR -->
+    <g transform="translate(46.195263 63.010597) scale(0.15 -0.15)">
+     <defs>
+      <path id="DejaVuSans-4c" d="M 628 4666 
+L 1259 4666 
+L 1259 531 
+L 3531 531 
+L 3531 0 
+L 628 0 
+L 628 4666 
+z
+" transform="scale(0.015625)"/>
+      <path id="DejaVuSans-46" d="M 628 4666 
+L 3309 4666 
+L 3309 4134 
+L 1259 4134 
+L 1259 2759 
+L 3109 2759 
+L 3109 2228 
+L 1259 2228 
+L 1259 0 
+L 628 0 
+L 628 4666 
+z
+" transform="scale(0.015625)"/>
+      <path id="DejaVuSans-54" d="M -19 4666 
+L 3928 4666 
+L 3928 4134 
+L 2272 4134 
+L 2272 0 
+L 1638 0 
+L 1638 4134 
+L -19 4134 
+L -19 4666 
+z
+" transform="scale(0.015625)"/>
+     </defs>
+     <use xlink:href="#DejaVuSans-4c"/>
+     <use xlink:href="#DejaVuSans-6f" x="53.962891"/>
+     <use xlink:href="#DejaVuSans-46" x="115.144531"/>
+     <use xlink:href="#DejaVuSans-54" x="170.914062"/>
+     <use xlink:href="#DejaVuSans-52" x="231.998047"/>
+    </g>
+   </g>
+   <g id="text_16">
+    <!-- MatchFormer -->
+    <g style="fill: #800080" transform="translate(42.800495 23.359763) scale(0.15 -0.15)">
+     <defs>
+      <path id="DejaVuSans-68" d="M 3513 2113 
+L 3513 0 
+L 2938 0 
+L 2938 2094 
+Q 2938 2591 2744 2837 
+Q 2550 3084 2163 3084 
+Q 1697 3084 1428 2787 
+Q 1159 2491 1159 1978 
+L 1159 0 
+L 581 0 
+L 581 4863 
+L 1159 4863 
+L 1159 2956 
+Q 1366 3272 1645 3428 
+Q 1925 3584 2291 3584 
+Q 2894 3584 3203 3211 
+Q 3513 2838 3513 2113 
+z
+" transform="scale(0.015625)"/>
+     </defs>
+     <use xlink:href="#DejaVuSans-4d"/>
+     <use xlink:href="#DejaVuSans-61" x="86.279297"/>
+     <use xlink:href="#DejaVuSans-74" x="147.558594"/>
+     <use xlink:href="#DejaVuSans-63" x="186.767578"/>
+     <use xlink:href="#DejaVuSans-68" x="241.748047"/>
+     <use xlink:href="#DejaVuSans-46" x="305.126953"/>
+     <use xlink:href="#DejaVuSans-6f" x="359.021484"/>
+     <use xlink:href="#DejaVuSans-72" x="420.203125"/>
+     <use xlink:href="#DejaVuSans-6d" x="459.566406"/>
+     <use xlink:href="#DejaVuSans-65" x="556.978516"/>
+     <use xlink:href="#DejaVuSans-72" x="618.501953"/>
+    </g>
+   </g>
+   <g id="text_17">
+    <!-- L=3 -->
+    <g style="fill: #ff0000" transform="translate(318.963638 198.045257) scale(0.1 -0.1)">
+     <defs>
+      <path id="DejaVuSans-3d" d="M 678 2906 
+L 4684 2906 
+L 4684 2381 
+L 678 2381 
+L 678 2906 
+z
+M 678 1631 
+L 4684 1631 
+L 4684 1100 
+L 678 1100 
+L 678 1631 
+z
+" transform="scale(0.015625)"/>
+     </defs>
+     <use xlink:href="#DejaVuSans-4c"/>
+     <use xlink:href="#DejaVuSans-3d" x="55.712891"/>
+     <use xlink:href="#DejaVuSans-33" x="139.501953"/>
+    </g>
+   </g>
+   <g id="text_18">
+    <!-- L=5 -->
+    <g style="fill: #ff0000" transform="translate(228.688766 138.598199) scale(0.1 -0.1)">
+     <use xlink:href="#DejaVuSans-4c"/>
+     <use xlink:href="#DejaVuSans-3d" x="55.712891"/>
+     <use xlink:href="#DejaVuSans-35" x="139.501953"/>
+    </g>
+   </g>
+   <g id="text_19">
+    <!-- L=7 -->
+    <g style="fill: #ff0000" transform="translate(171.91046 67.261728) scale(0.1 -0.1)">
+     <use xlink:href="#DejaVuSans-4c"/>
+     <use xlink:href="#DejaVuSans-3d" x="55.712891"/>
+     <use xlink:href="#DejaVuSans-37" x="139.501953"/>
+    </g>
+   </g>
+   <g id="text_20">
+    <!-- L=9 -->
+    <g style="fill: #ff0000" transform="translate(145.090048 37.538199) scale(0.1 -0.1)">
+     <defs>
+      <path id="DejaVuSans-39" d="M 703 97 
+L 703 672 
+Q 941 559 1184 500 
+Q 1428 441 1663 441 
+Q 2288 441 2617 861 
+Q 2947 1281 2994 2138 
+Q 2813 1869 2534 1725 
+Q 2256 1581 1919 1581 
+Q 1219 1581 811 2004 
+Q 403 2428 403 3163 
+Q 403 3881 828 4315 
+Q 1253 4750 1959 4750 
+Q 2769 4750 3195 4129 
+Q 3622 3509 3622 2328 
+Q 3622 1225 3098 567 
+Q 2575 -91 1691 -91 
+Q 1453 -91 1209 -44 
+Q 966 3 703 97 
+z
+M 1959 2075 
+Q 2384 2075 2632 2365 
+Q 2881 2656 2881 3163 
+Q 2881 3666 2632 3958 
+Q 2384 4250 1959 4250 
+Q 1534 4250 1286 3958 
+Q 1038 3666 1038 3163 
+Q 1038 2656 1286 2365 
+Q 1534 2075 1959 2075 
+z
+" transform="scale(0.015625)"/>
+     </defs>
+     <use xlink:href="#DejaVuSans-4c"/>
+     <use xlink:href="#DejaVuSans-3d" x="55.712891"/>
+     <use xlink:href="#DejaVuSans-39" x="139.501953"/>
+    </g>
+   </g>
+   <g id="text_21">
+    <!-- fixed-depth -->
+    <g style="fill: #ff0000" transform="translate(225.255342 166.790662) scale(0.12 -0.12)">
+     <defs>
+      <path id="DejaVuSans-66" d="M 2375 4863 
+L 2375 4384 
+L 1825 4384 
+Q 1516 4384 1395 4259 
+Q 1275 4134 1275 3809 
+L 1275 3500 
+L 2222 3500 
+L 2222 3053 
+L 1275 3053 
+L 1275 0 
+L 697 0 
+L 697 3053 
+L 147 3053 
+L 147 3500 
+L 697 3500 
+L 697 3744 
+Q 697 4328 969 4595 
+Q 1241 4863 1831 4863 
+L 2375 4863 
+z
+" transform="scale(0.015625)"/>
+      <path id="DejaVuSans-78" d="M 3513 3500 
+L 2247 1797 
+L 3578 0 
+L 2900 0 
+L 1881 1375 
+L 863 0 
+L 184 0 
+L 1544 1831 
+L 300 3500 
+L 978 3500 
+L 1906 2253 
+L 2834 3500 
+L 3513 3500 
+z
+" transform="scale(0.015625)"/>
+      <path id="DejaVuSans-2d" d="M 313 2009 
+L 1997 2009 
+L 1997 1497 
+L 313 1497 
+L 313 2009 
+z
+" transform="scale(0.015625)"/>
+     </defs>
+     <use xlink:href="#DejaVuSans-66"/>
+     <use xlink:href="#DejaVuSans-69" x="35.205078"/>
+     <use xlink:href="#DejaVuSans-78" x="62.988281"/>
+     <use xlink:href="#DejaVuSans-65" x="119.042969"/>
+     <use xlink:href="#DejaVuSans-64" x="180.566406"/>
+     <use xlink:href="#DejaVuSans-2d" x="244.042969"/>
+     <use xlink:href="#DejaVuSans-64" x="280.126953"/>
+     <use xlink:href="#DejaVuSans-65" x="343.603516"/>
+     <use xlink:href="#DejaVuSans-70" x="405.126953"/>
+     <use xlink:href="#DejaVuSans-74" x="468.603516"/>
+     <use xlink:href="#DejaVuSans-68" x="507.8125"/>
+    </g>
+   </g>
+   <g id="text_22">
+    <!-- adaptive -->
+    <g style="fill: #ff0000" transform="translate(283.083817 125.177721) scale(0.12 -0.12)">
+     <use xlink:href="#DejaVuSans-61"/>
+     <use xlink:href="#DejaVuSans-64" x="61.279297"/>
+     <use xlink:href="#DejaVuSans-61" x="124.755859"/>
+     <use xlink:href="#DejaVuSans-70" x="186.035156"/>
+     <use xlink:href="#DejaVuSans-74" x="249.511719"/>
+     <use xlink:href="#DejaVuSans-69" x="288.720703"/>
+     <use xlink:href="#DejaVuSans-76" x="316.503906"/>
+     <use xlink:href="#DejaVuSans-65" x="375.683594"/>
+    </g>
+   </g>
+   <g id="text_23">
+    <!-- optimized -->
+    <g style="fill: #ff0000" transform="translate(260.043244 64.675772) scale(0.12 -0.12)">
+     <defs>
+      <path id="DejaVuSans-7a" d="M 353 3500 
+L 3084 3500 
+L 3084 2975 
+L 922 459 
+L 3084 459 
+L 3084 0 
+L 275 0 
+L 275 525 
+L 2438 3041 
+L 353 3041 
+L 353 3500 
+z
+" transform="scale(0.015625)"/>
+     </defs>
+     <use xlink:href="#DejaVuSans-6f"/>
+     <use xlink:href="#DejaVuSans-70" x="61.181641"/>
+     <use xlink:href="#DejaVuSans-74" x="124.658203"/>
+     <use xlink:href="#DejaVuSans-69" x="163.867188"/>
+     <use xlink:href="#DejaVuSans-6d" x="191.650391"/>
+     <use xlink:href="#DejaVuSans-69" x="289.0625"/>
+     <use xlink:href="#DejaVuSans-7a" x="316.845703"/>
+     <use xlink:href="#DejaVuSans-65" x="369.335938"/>
+     <use xlink:href="#DejaVuSans-64" x="430.859375"/>
+    </g>
+   </g>
+   <g id="text_24">
+    <!-- LightGlue -->
+    <g style="fill: #ff0000" transform="translate(253.72379 21.69671) scale(0.15 -0.15)">
+     <use xlink:href="#DejaVuSans-4c"/>
+     <use xlink:href="#DejaVuSans-69" x="55.712891"/>
+     <use xlink:href="#DejaVuSans-67" x="83.496094"/>
+     <use xlink:href="#DejaVuSans-68" x="146.972656"/>
+     <use xlink:href="#DejaVuSans-74" x="210.351562"/>
+     <use xlink:href="#DejaVuSans-47" x="249.560547"/>
+     <use xlink:href="#DejaVuSans-6c" x="327.050781"/>
+     <use xlink:href="#DejaVuSans-75" x="354.833984"/>
+     <use xlink:href="#DejaVuSans-65" x="418.212891"/>
+    </g>
+   </g>
+   <g id="line2d_11">
+    <path d="M 337.2777 184.285882 
+L 247.002828 124.838824 
+L 190.224522 53.502353 
+L 163.40411 23.778824 
+" clip-path="url(#pb46ed2897c)" style="fill: none; stroke: #ff0000; stroke-width: 2; stroke-linecap: square"/>
+    <defs>
+     <path id="m8759e5a643" d="M 0 3 
+C 0.795609 3 1.55874 2.683901 2.12132 2.12132 
+C 2.683901 1.55874 3 0.795609 3 0 
+C 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132 
+C 1.55874 -2.683901 0.795609 -3 0 -3 
+C -0.795609 -3 -1.55874 -2.683901 -2.12132 -2.12132 
+C -2.683901 -1.55874 -3 -0.795609 -3 0 
+C -3 0.795609 -2.683901 1.55874 -2.12132 2.12132 
+C -1.55874 2.683901 -0.795609 3 0 3 
+z
+" style="stroke: #ff0000"/>
+    </defs>
+    <g clip-path="url(#pb46ed2897c)">
+     <use xlink:href="#m8759e5a643" x="337.2777" y="184.285882" style="fill: #ff0000; stroke: #ff0000"/>
+     <use xlink:href="#m8759e5a643" x="247.002828" y="124.838824" style="fill: #ff0000; stroke: #ff0000"/>
+     <use xlink:href="#m8759e5a643" x="190.224522" y="53.502353" style="fill: #ff0000; stroke: #ff0000"/>
+     <use xlink:href="#m8759e5a643" x="163.40411" y="23.778824" style="fill: #ff0000; stroke: #ff0000"/>
+    </g>
+   </g>
+   <g id="line2d_12">
+    <path d="M 296.754196 112.949412 
+L 241.630312 71.336471 
+L 214.425531 47.557647 
+L 194.077595 29.723529 
+L 163.121578 23.778824 
+" clip-path="url(#pb46ed2897c)" style="fill: none; stroke-dasharray: 7.4,3.2; stroke-dashoffset: 0; stroke: #ff0000; stroke-width: 2"/>
+   </g>
+  </g>
+ </g>
+ <defs>
+  <clipPath id="pb46ed2897c">
+   <rect x="38.242188" y="0" width="313.259373" height="202.12"/>
+  </clipPath>
+ </defs>
+</svg>

+ 255 - 0
python/LightGlue/benchmark.py

@@ -0,0 +1,255 @@
+# Benchmark script for LightGlue on real images
+import argparse
+import time
+from collections import defaultdict
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch._dynamo
+
+from lightglue import LightGlue, SuperPoint
+from lightglue.utils import load_image
+
+torch.set_grad_enabled(False)
+
+
+def measure(matcher, data, device="cuda", r=100):
+    timings = np.zeros((r, 1))
+    if device.type == "cuda":
+        starter = torch.cuda.Event(enable_timing=True)
+        ender = torch.cuda.Event(enable_timing=True)
+    # warmup
+    for _ in range(10):
+        _ = matcher(data)
+    # measurements
+    with torch.no_grad():
+        for rep in range(r):
+            if device.type == "cuda":
+                starter.record()
+                _ = matcher(data)
+                ender.record()
+                # sync gpu
+                torch.cuda.synchronize()
+                curr_time = starter.elapsed_time(ender)
+            else:
+                start = time.perf_counter()
+                _ = matcher(data)
+                curr_time = (time.perf_counter() - start) * 1e3
+            timings[rep] = curr_time
+    mean_syn = np.sum(timings) / r
+    std_syn = np.std(timings)
+    return {"mean": mean_syn, "std": std_syn}
+
+
+def print_as_table(d, title, cnames):
+    print()
+    header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
+    print(header)
+    print("-" * len(header))
+    for k, l in d.items():
+        print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
+    parser.add_argument(
+        "--device",
+        choices=["auto", "cuda", "cpu", "mps"],
+        default="auto",
+        help="device to benchmark on",
+    )
+    parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
+    parser.add_argument(
+        "--no_flash", action="store_true", help="disable FlashAttention"
+    )
+    parser.add_argument(
+        "--no_prune_thresholds",
+        action="store_true",
+        help="disable pruning thresholds (i.e. always do pruning)",
+    )
+    parser.add_argument(
+        "--add_superglue",
+        action="store_true",
+        help="add SuperGlue to the benchmark (requires hloc)",
+    )
+    parser.add_argument(
+        "--measure", default="time", choices=["time", "log-time", "throughput"]
+    )
+    parser.add_argument(
+        "--repeat", "--r", type=int, default=100, help="repetitions of measurements"
+    )
+    parser.add_argument(
+        "--num_keypoints",
+        nargs="+",
+        type=int,
+        default=[256, 512, 1024, 2048, 4096],
+        help="number of keypoints (list separated by spaces)",
+    )
+    parser.add_argument(
+        "--matmul_precision", default="highest", choices=["highest", "high", "medium"]
+    )
+    parser.add_argument(
+        "--save", default=None, type=str, help="path where figure should be saved"
+    )
+    args = parser.parse_intermixed_args()
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    if args.device != "auto":
+        device = torch.device(args.device)
+
+    print("Running benchmark on device:", device)
+
+    images = Path("assets")
+    inputs = {
+        "easy": (
+            load_image(images / "DSC_0411.JPG"),
+            load_image(images / "DSC_0410.JPG"),
+        ),
+        "difficult": (
+            load_image(images / "sacre_coeur1.jpg"),
+            load_image(images / "sacre_coeur2.jpg"),
+        ),
+    }
+
+    configs = {
+        "LightGlue-full": {
+            "depth_confidence": -1,
+            "width_confidence": -1,
+        },
+        # 'LG-prune': {
+        #     'width_confidence': -1,
+        # },
+        # 'LG-depth': {
+        #     'depth_confidence': -1,
+        # },
+        "LightGlue-adaptive": {},
+    }
+
+    if args.compile:
+        configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}
+
+    sg_configs = {
+        # 'SuperGlue': {},
+        "SuperGlue-fast": {"sinkhorn_iterations": 5}
+    }
+
+    torch.set_float32_matmul_precision(args.matmul_precision)
+
+    results = {k: defaultdict(list) for k, v in inputs.items()}
+
+    extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
+    extractor = extractor.eval().to(device)
+    figsize = (len(inputs) * 4.5, 4.5)
+    fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
+    axes = axes if len(inputs) > 1 else [axes]
+    fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")
+
+    for title, ax in zip(inputs.keys(), axes):
+        ax.set_xscale("log", base=2)
+        bases = [2**x for x in range(7, 16)]
+        ax.set_xticks(bases, bases)
+        ax.grid(which="major")
+        if args.measure == "log-time":
+            ax.set_yscale("log")
+            yticks = [10**x for x in range(6)]
+            ax.set_yticks(yticks, yticks)
+            mpos = [10**x * i for x in range(6) for i in range(2, 10)]
+            mlabel = [
+                10**x * i if i in [2, 5] else None
+                for x in range(6)
+                for i in range(2, 10)
+            ]
+            ax.set_yticks(mpos, mlabel, minor=True)
+            ax.grid(which="minor", linewidth=0.2)
+        ax.set_title(title)
+
+        ax.set_xlabel("# keypoints")
+        if args.measure == "throughput":
+            ax.set_ylabel("Throughput [pairs/s]")
+        else:
+            ax.set_ylabel("Latency [ms]")
+
+    for name, conf in configs.items():
+        print("Run benchmark for:", name)
+        torch.cuda.empty_cache()
+        matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
+        if args.no_prune_thresholds:
+            matcher.pruning_keypoint_thresholds = {
+                k: -1 for k in matcher.pruning_keypoint_thresholds
+            }
+        matcher = matcher.eval().to(device)
+        if name.endswith("compile"):
+            import torch._dynamo
+
+            torch._dynamo.reset()  # avoid buffer overflow
+            matcher.compile()
+        for pair_name, ax in zip(inputs.keys(), axes):
+            image0, image1 = [x.to(device) for x in inputs[pair_name]]
+            runtimes = []
+            for num_kpts in args.num_keypoints:
+                extractor.conf.max_num_keypoints = num_kpts
+                feats0 = extractor.extract(image0)
+                feats1 = extractor.extract(image1)
+                runtime = measure(
+                    matcher,
+                    {"image0": feats0, "image1": feats1},
+                    device=device,
+                    r=args.repeat,
+                )["mean"]
+                results[pair_name][name].append(
+                    1000 / runtime if args.measure == "throughput" else runtime
+                )
+            ax.plot(
+                args.num_keypoints, results[pair_name][name], label=name, marker="o"
+            )
+        del matcher, feats0, feats1
+
+    if args.add_superglue:
+        from hloc.matchers.superglue import SuperGlue
+
+        for name, conf in sg_configs.items():
+            print("Run benchmark for:", name)
+            matcher = SuperGlue(conf)
+            matcher = matcher.eval().to(device)
+            for pair_name, ax in zip(inputs.keys(), axes):
+                image0, image1 = [x.to(device) for x in inputs[pair_name]]
+                runtimes = []
+                for num_kpts in args.num_keypoints:
+                    extractor.conf.max_num_keypoints = num_kpts
+                    feats0 = extractor.extract(image0)
+                    feats1 = extractor.extract(image1)
+                    data = {
+                        "image0": image0[None],
+                        "image1": image1[None],
+                        **{k + "0": v for k, v in feats0.items()},
+                        **{k + "1": v for k, v in feats1.items()},
+                    }
+                    data["scores0"] = data["keypoint_scores0"]
+                    data["scores1"] = data["keypoint_scores1"]
+                    data["descriptors0"] = (
+                        data["descriptors0"].transpose(-1, -2).contiguous()
+                    )
+                    data["descriptors1"] = (
+                        data["descriptors1"].transpose(-1, -2).contiguous()
+                    )
+                    runtime = measure(matcher, data, device=device, r=args.repeat)[
+                        "mean"
+                    ]
+                    results[pair_name][name].append(
+                        1000 / runtime if args.measure == "throughput" else runtime
+                    )
+                ax.plot(
+                    args.num_keypoints, results[pair_name][name], label=name, marker="o"
+                )
+            del matcher, data, image0, image1, feats0, feats1
+
+    for name, runtimes in results.items():
+        print_as_table(runtimes, name, args.num_keypoints)
+
+    axes[0].legend()
+    fig.tight_layout()
+    if args.save:
+        plt.savefig(args.save, dpi=fig.dpi)
+    plt.show()

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 77 - 0
python/LightGlue/demo.ipynb


+ 7 - 0
python/LightGlue/lightglue/__init__.py

@@ -0,0 +1,7 @@
+from .aliked import ALIKED  # noqa
+from .disk import DISK  # noqa
+from .dog_hardnet import DoGHardNet  # noqa
+from .lightglue import LightGlue  # noqa
+from .sift import SIFT  # noqa
+from .superpoint import SuperPoint  # noqa
+from .utils import match_pair  # noqa

+ 775 - 0
python/LightGlue/lightglue/aliked.py

@@ -0,0 +1,775 @@
+# BSD 3-Clause License
+
+# Copyright (c) 2022, Zhao Xiaoming
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+#    list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+
+# 3. Neither the name of the copyright holder nor the names of its
+#    contributors may be used to endorse or promote products derived from
+#    this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# Authors:
+# Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li
+# Code from https://github.com/Shiaoming/ALIKED
+
+from typing import Callable, Optional
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from kornia.color import grayscale_to_rgb
+from torch import nn
+from torch.nn.modules.utils import _pair
+from torchvision.models import resnet
+
+from .utils import Extractor, ImagePreprocessor
+
+
+def get_patches(
+    tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
+) -> torch.Tensor:
+    c, h, w = tensor.shape
+    corner = (required_corners - ps / 2 + 1).long()
+    corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
+    corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
+    offset = torch.arange(0, ps)
+
+    kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
+    x, y = torch.meshgrid(offset, offset, **kw)
+    patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
+    patches = patches.to(corner) + corner[None, None]
+    pts = patches.reshape(-1, 2)
+    sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
+    sampled = sampled.reshape(ps, ps, -1, c)
+    assert sampled.shape[:3] == patches.shape[:3]
+    return sampled.permute(2, 3, 0, 1)
+
+
+def simple_nms(scores: torch.Tensor, nms_radius: int):
+    """Fast Non-maximum suppression to remove nearby points"""
+
+    zeros = torch.zeros_like(scores)
+    max_mask = scores == torch.nn.functional.max_pool2d(
+        scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
+    )
+
+    for _ in range(2):
+        supp_mask = (
+            torch.nn.functional.max_pool2d(
+                max_mask.float(),
+                kernel_size=nms_radius * 2 + 1,
+                stride=1,
+                padding=nms_radius,
+            )
+            > 0
+        )
+        supp_scores = torch.where(supp_mask, zeros, scores)
+        new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
+            supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
+        )
+        max_mask = max_mask | (new_max_mask & (~supp_mask))
+    return torch.where(max_mask, scores, zeros)
+
+
+class DKD(nn.Module):
+    def __init__(
+        self,
+        radius: int = 2,
+        top_k: int = 0,
+        scores_th: float = 0.2,
+        n_limit: int = 20000,
+    ):
+        """
+        Args:
+            radius: soft detection radius, kernel size is (2 * radius + 1)
+            top_k: top_k > 0: return top k keypoints
+            scores_th: top_k <= 0 threshold mode:
+                scores_th > 0: return keypoints with scores>scores_th
+                else: return keypoints with scores > scores.mean()
+            n_limit: max number of keypoint in threshold mode
+        """
+        super().__init__()
+        self.radius = radius
+        self.top_k = top_k
+        self.scores_th = scores_th
+        self.n_limit = n_limit
+        self.kernel_size = 2 * self.radius + 1
+        self.temperature = 0.1  # tuned temperature
+        self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
+        # local xy grid
+        x = torch.linspace(-self.radius, self.radius, self.kernel_size)
+        # (kernel_size*kernel_size) x 2 : (w,h)
+        kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
+        self.hw_grid = (
+            torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
+        )
+
+    def forward(
+        self,
+        scores_map: torch.Tensor,
+        sub_pixel: bool = True,
+        image_size: Optional[torch.Tensor] = None,
+    ):
+        """
+        :param scores_map: Bx1xHxW
+        :param descriptor_map: BxCxHxW
+        :param sub_pixel: whether to use sub-pixel keypoint detection
+        :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
+        """
+        b, c, h, w = scores_map.shape
+        scores_nograd = scores_map.detach()
+        nms_scores = simple_nms(scores_nograd, self.radius)
+
+        # remove border
+        nms_scores[:, :, : self.radius, :] = 0
+        nms_scores[:, :, :, : self.radius] = 0
+        if image_size is not None:
+            for i in range(scores_map.shape[0]):
+                w, h = image_size[i].long()
+                nms_scores[i, :, h.item() - self.radius :, :] = 0
+                nms_scores[i, :, :, w.item() - self.radius :] = 0
+        else:
+            nms_scores[:, :, -self.radius :, :] = 0
+            nms_scores[:, :, :, -self.radius :] = 0
+
+        # detect keypoints without grad
+        if self.top_k > 0:
+            topk = torch.topk(nms_scores.view(b, -1), self.top_k)
+            indices_keypoints = [topk.indices[i] for i in range(b)]  # B x top_k
+        else:
+            if self.scores_th > 0:
+                masks = nms_scores > self.scores_th
+                if masks.sum() == 0:
+                    th = scores_nograd.reshape(b, -1).mean(dim=1)  # th = self.scores_th
+                    masks = nms_scores > th.reshape(b, 1, 1, 1)
+            else:
+                th = scores_nograd.reshape(b, -1).mean(dim=1)  # th = self.scores_th
+                masks = nms_scores > th.reshape(b, 1, 1, 1)
+            masks = masks.reshape(b, -1)
+
+            indices_keypoints = []  # list, B x (any size)
+            scores_view = scores_nograd.reshape(b, -1)
+            for mask, scores in zip(masks, scores_view):
+                indices = mask.nonzero()[:, 0]
+                if len(indices) > self.n_limit:
+                    kpts_sc = scores[indices]
+                    sort_idx = kpts_sc.sort(descending=True)[1]
+                    sel_idx = sort_idx[: self.n_limit]
+                    indices = indices[sel_idx]
+                indices_keypoints.append(indices)
+
+        wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)
+
+        keypoints = []
+        scoredispersitys = []
+        kptscores = []
+        if sub_pixel:
+            # detect soft keypoints with grad backpropagation
+            patches = self.unfold(scores_map)  # B x (kernel**2) x (H*W)
+            self.hw_grid = self.hw_grid.to(scores_map)  # to device
+            for b_idx in range(b):
+                patch = patches[b_idx].t()  # (H*W) x (kernel**2)
+                indices_kpt = indices_keypoints[
+                    b_idx
+                ]  # one dimension vector, say its size is M
+                patch_scores = patch[indices_kpt]  # M x (kernel**2)
+                keypoints_xy_nms = torch.stack(
+                    [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
+                    dim=1,
+                )  # Mx2
+
+                # max is detached to prevent undesired backprop loops in the graph
+                max_v = patch_scores.max(dim=1).values.detach()[:, None]
+                x_exp = (
+                    (patch_scores - max_v) / self.temperature
+                ).exp()  # M * (kernel**2), in [0, 1]
+
+                # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
+                xy_residual = (
+                    x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
+                )  # Soft-argmax, Mx2
+
+                hw_grid_dist2 = (
+                    torch.norm(
+                        (self.hw_grid[None, :, :] - xy_residual[:, None, :])
+                        / self.radius,
+                        dim=-1,
+                    )
+                    ** 2
+                )
+                scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
+
+                # compute result keypoints
+                keypoints_xy = keypoints_xy_nms + xy_residual
+                keypoints_xy = keypoints_xy / wh * 2 - 1  # (w,h) -> (-1~1,-1~1)
+
+                kptscore = torch.nn.functional.grid_sample(
+                    scores_map[b_idx].unsqueeze(0),
+                    keypoints_xy.view(1, 1, -1, 2),
+                    mode="bilinear",
+                    align_corners=True,
+                )[
+                    0, 0, 0, :
+                ]  # CxN
+
+                keypoints.append(keypoints_xy)
+                scoredispersitys.append(scoredispersity)
+                kptscores.append(kptscore)
+        else:
+            for b_idx in range(b):
+                indices_kpt = indices_keypoints[
+                    b_idx
+                ]  # one dimension vector, say its size is M
+                # To avoid warning: UserWarning: __floordiv__ is deprecated
+                keypoints_xy_nms = torch.stack(
+                    [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
+                    dim=1,
+                )  # Mx2
+                keypoints_xy = keypoints_xy_nms / wh * 2 - 1  # (w,h) -> (-1~1,-1~1)
+                kptscore = torch.nn.functional.grid_sample(
+                    scores_map[b_idx].unsqueeze(0),
+                    keypoints_xy.view(1, 1, -1, 2),
+                    mode="bilinear",
+                    align_corners=True,
+                )[
+                    0, 0, 0, :
+                ]  # CxN
+                keypoints.append(keypoints_xy)
+                scoredispersitys.append(kptscore)  # for jit.script compatability
+                kptscores.append(kptscore)
+
+        return keypoints, kptscores, scoredispersitys
+
+
+class InputPadder(object):
+    """Pads images such that dimensions are divisible by 8"""
+
+    def __init__(self, h: int, w: int, divis_by: int = 8):
+        self.ht = h
+        self.wd = w
+        pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
+        pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
+        self._pad = [
+            pad_wd // 2,
+            pad_wd - pad_wd // 2,
+            pad_ht // 2,
+            pad_ht - pad_ht // 2,
+        ]
+
+    def pad(self, x: torch.Tensor):
+        assert x.ndim == 4
+        return F.pad(x, self._pad, mode="replicate")
+
+    def unpad(self, x: torch.Tensor):
+        assert x.ndim == 4
+        ht = x.shape[-2]
+        wd = x.shape[-1]
+        c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
+        return x[..., c[0] : c[1], c[2] : c[3]]
+
+
+class DeformableConv2d(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size=3,
+        stride=1,
+        padding=1,
+        bias=False,
+        mask=False,
+    ):
+        super(DeformableConv2d, self).__init__()
+
+        self.padding = padding
+        self.mask = mask
+
+        self.channel_num = (
+            3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
+        )
+        self.offset_conv = nn.Conv2d(
+            in_channels,
+            self.channel_num,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=self.padding,
+            bias=True,
+        )
+
+        self.regular_conv = nn.Conv2d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=self.padding,
+            bias=bias,
+        )
+
+    def forward(self, x):
+        h, w = x.shape[2:]
+        max_offset = max(h, w) / 4.0
+
+        out = self.offset_conv(x)
+        if self.mask:
+            o1, o2, mask = torch.chunk(out, 3, dim=1)
+            offset = torch.cat((o1, o2), dim=1)
+            mask = torch.sigmoid(mask)
+        else:
+            offset = out
+            mask = None
+        offset = offset.clamp(-max_offset, max_offset)
+        x = torchvision.ops.deform_conv2d(
+            input=x,
+            offset=offset,
+            weight=self.regular_conv.weight,
+            bias=self.regular_conv.bias,
+            padding=self.padding,
+            mask=mask,
+        )
+        return x
+
+
+def get_conv(
+    inplanes,
+    planes,
+    kernel_size=3,
+    stride=1,
+    padding=1,
+    bias=False,
+    conv_type="conv",
+    mask=False,
+):
+    if conv_type == "conv":
+        conv = nn.Conv2d(
+            inplanes,
+            planes,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            bias=bias,
+        )
+    elif conv_type == "dcn":
+        conv = DeformableConv2d(
+            inplanes,
+            planes,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=_pair(padding),
+            bias=bias,
+            mask=mask,
+        )
+    else:
+        raise TypeError
+    return conv
+
+
+class ConvBlock(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        gate: Optional[Callable[..., nn.Module]] = None,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+        conv_type: str = "conv",
+        mask: bool = False,
+    ):
+        super().__init__()
+        if gate is None:
+            self.gate = nn.ReLU(inplace=True)
+        else:
+            self.gate = gate
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        self.conv1 = get_conv(
+            in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
+        )
+        self.bn1 = norm_layer(out_channels)
+        self.conv2 = get_conv(
+            out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
+        )
+        self.bn2 = norm_layer(out_channels)
+
+    def forward(self, x):
+        x = self.gate(self.bn1(self.conv1(x)))  # B x in_channels x H x W
+        x = self.gate(self.bn2(self.conv2(x)))  # B x out_channels x H x W
+        return x
+
+
+# modified based on torchvision\models\resnet.py#27->BasicBlock
+class ResBlock(nn.Module):
+    expansion: int = 1
+
+    def __init__(
+        self,
+        inplanes: int,
+        planes: int,
+        stride: int = 1,
+        downsample: Optional[nn.Module] = None,
+        groups: int = 1,
+        base_width: int = 64,
+        dilation: int = 1,
+        gate: Optional[Callable[..., nn.Module]] = None,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+        conv_type: str = "conv",
+        mask: bool = False,
+    ) -> None:
+        super(ResBlock, self).__init__()
+        if gate is None:
+            self.gate = nn.ReLU(inplace=True)
+        else:
+            self.gate = gate
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        if groups != 1 or base_width != 64:
+            raise ValueError("ResBlock only supports groups=1 and base_width=64")
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in ResBlock")
+        # Both self.conv1 and self.downsample layers
+        # downsample the input when stride != 1
+        self.conv1 = get_conv(
+            inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
+        )
+        self.bn1 = norm_layer(planes)
+        self.conv2 = get_conv(
+            planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
+        )
+        self.bn2 = norm_layer(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.gate(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.gate(out)
+
+        return out
+
+
+class SDDH(nn.Module):
+    def __init__(
+        self,
+        dims: int,
+        kernel_size: int = 3,
+        n_pos: int = 8,
+        gate=nn.ReLU(),
+        conv2D=False,
+        mask=False,
+    ):
+        super(SDDH, self).__init__()
+        self.kernel_size = kernel_size
+        self.n_pos = n_pos
+        self.conv2D = conv2D
+        self.mask = mask
+
+        self.get_patches_func = get_patches
+
+        # estimate offsets
+        self.channel_num = 3 * n_pos if mask else 2 * n_pos
+        self.offset_conv = nn.Sequential(
+            nn.Conv2d(
+                dims,
+                self.channel_num,
+                kernel_size=kernel_size,
+                stride=1,
+                padding=0,
+                bias=True,
+            ),
+            gate,
+            nn.Conv2d(
+                self.channel_num,
+                self.channel_num,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=True,
+            ),
+        )
+
+        # sampled feature conv
+        self.sf_conv = nn.Conv2d(
+            dims, dims, kernel_size=1, stride=1, padding=0, bias=False
+        )
+
+        # convM
+        if not conv2D:
+            # deformable desc weights
+            agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
+            self.register_parameter("agg_weights", agg_weights)
+        else:
+            self.convM = nn.Conv2d(
+                dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
+            )
+
+    def forward(self, x, keypoints):
+        # x: [B,C,H,W]
+        # keypoints: list, [[N_kpts,2], ...] (w,h)
+        b, c, h, w = x.shape
+        wh = torch.tensor([[w - 1, h - 1]], device=x.device)
+        max_offset = max(h, w) / 4.0
+
+        offsets = []
+        descriptors = []
+        # get offsets for each keypoint
+        for ib in range(b):
+            xi, kptsi = x[ib], keypoints[ib]
+            kptsi_wh = (kptsi / 2 + 0.5) * wh
+            N_kpts = len(kptsi)
+
+            if self.kernel_size > 1:
+                patch = self.get_patches_func(
+                    xi, kptsi_wh.long(), self.kernel_size
+                )  # [N_kpts, C, K, K]
+            else:
+                kptsi_wh_long = kptsi_wh.long()
+                patch = (
+                    xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
+                    .permute(1, 0)
+                    .reshape(N_kpts, c, 1, 1)
+                )
+
+            offset = self.offset_conv(patch).clamp(
+                -max_offset, max_offset
+            )  # [N_kpts, 2*n_pos, 1, 1]
+            if self.mask:
+                offset = (
+                    offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
+                )  # [N_kpts, n_pos, 3]
+                offset = offset[:, :, :-1]  # [N_kpts, n_pos, 2]
+                mask_weight = torch.sigmoid(offset[:, :, -1])  # [N_kpts, n_pos]
+            else:
+                offset = (
+                    offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
+                )  # [N_kpts, n_pos, 2]
+            offsets.append(offset)  # for visualization
+
+            # get sample positions
+            pos = kptsi_wh.unsqueeze(1) + offset  # [N_kpts, n_pos, 2]
+            pos = 2.0 * pos / wh[None] - 1
+            pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)
+
+            # sample features
+            features = F.grid_sample(
+                xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
+            )  # [1,C,(N_kpts*n_pos),1]
+            features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
+                1, 0, 2, 3
+            )  # [N_kpts, C, n_pos, 1]
+            if self.mask:
+                features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
+
+            features = torch.selu_(self.sf_conv(features)).squeeze(
+                -1
+            )  # [N_kpts, C, n_pos]
+            # convM
+            if not self.conv2D:
+                descs = torch.einsum(
+                    "ncp,pcd->nd", features, self.agg_weights
+                )  # [N_kpts, C]
+            else:
+                features = features.reshape(N_kpts, -1)[
+                    :, :, None, None
+                ]  # [N_kpts, C*n_pos, 1, 1]
+                descs = self.convM(features).squeeze()  # [N_kpts, C]
+
+            # normalize
+            descs = F.normalize(descs, p=2.0, dim=1)
+            descriptors.append(descs)
+
+        return descriptors, offsets
+
+
+class ALIKED(Extractor):
+    default_conf = {
+        "model_name": "aliked-n16",
+        "max_num_keypoints": -1,
+        "detection_threshold": 0.2,
+        "nms_radius": 2,
+    }
+
+    checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
+
+    n_limit_max = 20000
+
+    # c1, c2, c3, c4, dim, K, M
+    cfgs = {
+        "aliked-t16": [8, 16, 32, 64, 64, 3, 16],
+        "aliked-n16": [16, 32, 64, 128, 128, 3, 16],
+        "aliked-n16rot": [16, 32, 64, 128, 128, 3, 16],
+        "aliked-n32": [16, 32, 64, 128, 128, 3, 32],
+    }
+    preprocess_conf = {
+        "resize": 1024,
+    }
+
+    required_data_keys = ["image"]
+
+    def __init__(self, **conf):
+        super().__init__(**conf)  # Update with default configuration.
+        conf = self.conf
+        c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
+        conv_types = ["conv", "conv", "dcn", "dcn"]
+        conv2D = False
+        mask = False
+
+        # build model
+        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
+        self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
+        self.norm = nn.BatchNorm2d
+        self.gate = nn.SELU(inplace=True)
+        self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
+        self.block2 = self.get_resblock(c1, c2, conv_types[1], mask)
+        self.block3 = self.get_resblock(c2, c3, conv_types[2], mask)
+        self.block4 = self.get_resblock(c3, c4, conv_types[3], mask)
+
+        self.conv1 = resnet.conv1x1(c1, dim // 4)
+        self.conv2 = resnet.conv1x1(c2, dim // 4)
+        self.conv3 = resnet.conv1x1(c3, dim // 4)
+        self.conv4 = resnet.conv1x1(dim, dim // 4)
+        self.upsample2 = nn.Upsample(
+            scale_factor=2, mode="bilinear", align_corners=True
+        )
+        self.upsample4 = nn.Upsample(
+            scale_factor=4, mode="bilinear", align_corners=True
+        )
+        self.upsample8 = nn.Upsample(
+            scale_factor=8, mode="bilinear", align_corners=True
+        )
+        self.upsample32 = nn.Upsample(
+            scale_factor=32, mode="bilinear", align_corners=True
+        )
+        self.score_head = nn.Sequential(
+            resnet.conv1x1(dim, 8),
+            self.gate,
+            resnet.conv3x3(8, 4),
+            self.gate,
+            resnet.conv3x3(4, 4),
+            self.gate,
+            resnet.conv3x3(4, 1),
+        )
+        self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
+        self.dkd = DKD(
+            radius=conf.nms_radius,
+            top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
+            scores_th=conf.detection_threshold,
+            n_limit=(
+                conf.max_num_keypoints
+                if conf.max_num_keypoints > 0
+                else self.n_limit_max
+            ),
+        )
+
+        state_dict = torch.hub.load_state_dict_from_url(
+            self.checkpoint_url.format(conf.model_name), map_location="cpu"
+        )
+        self.load_state_dict(state_dict, strict=True)
+
+    def get_resblock(self, c_in, c_out, conv_type, mask):
+        return ResBlock(
+            c_in,
+            c_out,
+            1,
+            nn.Conv2d(c_in, c_out, 1),
+            gate=self.gate,
+            norm_layer=self.norm,
+            conv_type=conv_type,
+            mask=mask,
+        )
+
+    def extract_dense_map(self, image):
+        # Pads images such that dimensions are divisible by
+        div_by = 2**5
+        padder = InputPadder(image.shape[-2], image.shape[-1], div_by)
+        image = padder.pad(image)
+
+        # ================================== feature encoder
+        x1 = self.block1(image)  # B x c1 x H x W
+        x2 = self.pool2(x1)
+        x2 = self.block2(x2)  # B x c2 x H/2 x W/2
+        x3 = self.pool4(x2)
+        x3 = self.block3(x3)  # B x c3 x H/8 x W/8
+        x4 = self.pool4(x3)
+        x4 = self.block4(x4)  # B x dim x H/32 x W/32
+        # ================================== feature aggregation
+        x1 = self.gate(self.conv1(x1))  # B x dim//4 x H x W
+        x2 = self.gate(self.conv2(x2))  # B x dim//4 x H//2 x W//2
+        x3 = self.gate(self.conv3(x3))  # B x dim//4 x H//8 x W//8
+        x4 = self.gate(self.conv4(x4))  # B x dim//4 x H//32 x W//32
+        x2_up = self.upsample2(x2)  # B x dim//4 x H x W
+        x3_up = self.upsample8(x3)  # B x dim//4 x H x W
+        x4_up = self.upsample32(x4)  # B x dim//4 x H x W
+        x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
+        # ================================== score head
+        score_map = torch.sigmoid(self.score_head(x1234))
+        feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1)
+
+        # Unpads images
+        feature_map = padder.unpad(feature_map)
+        score_map = padder.unpad(score_map)
+
+        return feature_map, score_map
+
+    def describe(
+        self, keypoints: torch.Tensor, img: torch.Tensor, **conf
+    ) -> torch.Tensor:
+        """Extract descriptors for a set of keypoints."""
+        if img.dim() == 3:
+            img = img[None]  # add batch dim
+        assert img.dim() == 4 and img.shape[0] == 1
+        w, h = img.shape[-2:][::-1]
+        wh = torch.tensor([w - 1, h - 1], device=img.device)
+        img, _ = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
+        keypoints_n = 2.0 * keypoints / wh[None, None] - 1  # [-1, 1]
+        # Extract dense features on resized img
+        feature_map, _ = self.extract_dense_map(img)
+        return torch.stack(self.desc_head(feature_map, keypoints_n)[0])
+
+    def forward(self, data: dict) -> dict:
+        image = data["image"]
+        if image.shape[1] == 1:
+            image = grayscale_to_rgb(image)
+        feature_map, score_map = self.extract_dense_map(image)
+        keypoints, kptscores, scoredispersitys = self.dkd(
+            score_map, image_size=data.get("image_size")
+        )
+        descriptors, offsets = self.desc_head(feature_map, keypoints)
+
+        _, _, h, w = image.shape
+        wh = torch.tensor([w - 1, h - 1], device=image.device)
+        # no padding required
+        # we can set detection_threshold=-1 and conf.max_num_keypoints > 0
+        return {
+            "keypoints": wh * (torch.stack(keypoints) + 1) / 2.0,  # B x N x 2
+            "descriptors": torch.stack(descriptors),  # B x N x D
+            "keypoint_scores": torch.stack(kptscores),  # B x N
+        }

+ 55 - 0
python/LightGlue/lightglue/disk.py

@@ -0,0 +1,55 @@
+import kornia
+import torch
+
+from .utils import Extractor
+
+
+class DISK(Extractor):
+    default_conf = {
+        "weights": "depth",
+        "max_num_keypoints": None,
+        "desc_dim": 128,
+        "nms_window_size": 5,
+        "detection_threshold": 0.0,
+        "pad_if_not_divisible": True,
+    }
+
+    preprocess_conf = {
+        "resize": 1024,
+        "grayscale": False,
+    }
+
+    required_data_keys = ["image"]
+
+    def __init__(self, **conf) -> None:
+        super().__init__(**conf)  # Update with default configuration.
+        self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
+
+    def forward(self, data: dict) -> dict:
+        """Compute keypoints, scores, descriptors for image"""
+        for key in self.required_data_keys:
+            assert key in data, f"Missing key {key} in data"
+        image = data["image"]
+        if image.shape[1] == 1:
+            image = kornia.color.grayscale_to_rgb(image)
+        features = self.model(
+            image,
+            n=self.conf.max_num_keypoints,
+            window_size=self.conf.nms_window_size,
+            score_threshold=self.conf.detection_threshold,
+            pad_if_not_divisible=self.conf.pad_if_not_divisible,
+        )
+        keypoints = [f.keypoints for f in features]
+        scores = [f.detection_scores for f in features]
+        descriptors = [f.descriptors for f in features]
+        del features
+
+        keypoints = torch.stack(keypoints, 0)
+        scores = torch.stack(scores, 0)
+        descriptors = torch.stack(descriptors, 0)
+
+        return {
+            "keypoints": keypoints.to(image).contiguous(),
+            "keypoint_scores": scores.to(image).contiguous(),
+            "descriptors": descriptors.to(image).contiguous(),
+        }

+ 41 - 0
python/LightGlue/lightglue/dog_hardnet.py

@@ -0,0 +1,41 @@
+import torch
+from kornia.color import rgb_to_grayscale
+from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori
+
+from .sift import SIFT
+
+
+class DoGHardNet(SIFT):
+    required_data_keys = ["image"]
+
+    def __init__(self, **conf):
+        super().__init__(**conf)
+        self.laf_desc = LAFDescriptor(HardNet(True)).eval()
+
+    def forward(self, data: dict) -> dict:
+        image = data["image"]
+        if image.shape[1] == 3:
+            image = rgb_to_grayscale(image)
+        device = image.device
+        self.laf_desc = self.laf_desc.to(device)
+        self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
+        pred = []
+        if "image_size" in data.keys():
+            im_size = data.get("image_size").long()
+        else:
+            im_size = None
+        for k in range(len(image)):
+            img = image[k]
+            if im_size is not None:
+                w, h = data["image_size"][k]
+                img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
+            p = self.extract_single_image(img)
+            lafs = laf_from_center_scale_ori(
+                p["keypoints"].reshape(1, -1, 2),
+                6.0 * p["scales"].reshape(1, -1, 1, 1),
+                torch.rad2deg(p["oris"]).reshape(1, -1, 1),
+            ).to(device)
+            p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
+            pred.append(p)
+        pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
+        return pred

+ 667 - 0
python/LightGlue/lightglue/lightglue.py

@@ -0,0 +1,667 @@
+import warnings
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Callable, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+try:
+    from flash_attn.modules.mha import FlashCrossAttention
+except ModuleNotFoundError:
+    FlashCrossAttention = None
+
+if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
+    FLASH_AVAILABLE = True
+else:
+    FLASH_AVAILABLE = False
+
+torch.backends.cudnn.deterministic = True
+
+
+AMP_CUSTOM_FWD_F32 = (
+    torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
+    if hasattr(torch, "amp") and hasattr(torch.amp, "custom_fwd")
+    else torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+)
+
+
+@AMP_CUSTOM_FWD_F32
+def normalize_keypoints(
+    kpts: torch.Tensor, size: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+    if size is None:
+        size = 1 + kpts.max(-2).values - kpts.min(-2).values
+    elif not isinstance(size, torch.Tensor):
+        size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
+    size = size.to(kpts)
+    shift = size / 2
+    scale = size.max(-1).values / 2
+    kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
+    return kpts
+
+
+def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
+    if length <= x.shape[-2]:
+        return x, torch.ones_like(x[..., :1], dtype=torch.bool)
+    pad = torch.ones(
+        *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
+    )
+    y = torch.cat([x, pad], dim=-2)
+    mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
+    mask[..., : x.shape[-2], :] = True
+    return y, mask
+
+
+def rotate_half(x: torch.Tensor) -> torch.Tensor:
+    x = x.unflatten(-1, (-1, 2))
+    x1, x2 = x.unbind(dim=-1)
+    return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
+
+
+def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+    return (t * freqs[0]) + (rotate_half(t) * freqs[1])
+
+
+class LearnableFourierPositionalEncoding(nn.Module):
+    def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
+        super().__init__()
+        F_dim = F_dim if F_dim is not None else dim
+        self.gamma = gamma
+        self.Wr = nn.Linear(M, F_dim // 2, bias=False)
+        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """encode position vector"""
+        projected = self.Wr(x)
+        cosines, sines = torch.cos(projected), torch.sin(projected)
+        emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
+        return emb.repeat_interleave(2, dim=-1)
+
+
+class TokenConfidence(nn.Module):
+    def __init__(self, dim: int) -> None:
+        super().__init__()
+        self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
+
+    def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
+        """get confidence tokens"""
+        return (
+            self.token(desc0.detach()).squeeze(-1),
+            self.token(desc1.detach()).squeeze(-1),
+        )
+
+
+class Attention(nn.Module):
+    def __init__(self, allow_flash: bool) -> None:
+        super().__init__()
+        if allow_flash and not FLASH_AVAILABLE:
+            warnings.warn(
+                "FlashAttention is not available. For optimal speed, "
+                "consider installing torch >= 2.0 or flash-attn.",
+                stacklevel=2,
+            )
+        self.enable_flash = allow_flash and FLASH_AVAILABLE
+        self.has_sdp = hasattr(F, "scaled_dot_product_attention")
+        if allow_flash and FlashCrossAttention:
+            self.flash_ = FlashCrossAttention()
+        if self.has_sdp:
+            torch.backends.cuda.enable_flash_sdp(allow_flash)
+
+    def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+        if q.shape[-2] == 0 or k.shape[-2] == 0:
+            return q.new_zeros((*q.shape[:-1], v.shape[-1]))
+        if self.enable_flash and q.device.type == "cuda":
+            # use torch 2.0 scaled_dot_product_attention with flash
+            if self.has_sdp:
+                args = [x.half().contiguous() for x in [q, k, v]]
+                v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
+                return v if mask is None else v.nan_to_num()
+            else:
+                assert mask is None
+                q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
+                m = self.flash_(q.half(), torch.stack([k, v], 2).half())
+                return m.transpose(-2, -3).to(q.dtype).clone()
+        elif self.has_sdp:
+            args = [x.contiguous() for x in [q, k, v]]
+            v = F.scaled_dot_product_attention(*args, attn_mask=mask)
+            return v if mask is None else v.nan_to_num()
+        else:
+            s = q.shape[-1] ** -0.5
+            sim = torch.einsum("...id,...jd->...ij", q, k) * s
+            if mask is not None:
+                sim.masked_fill(~mask, -float("inf"))
+            attn = F.softmax(sim, -1)
+            return torch.einsum("...ij,...jd->...id", attn, v)
+
+
+class SelfBlock(nn.Module):
+    def __init__(
+        self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
+    ) -> None:
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        assert self.embed_dim % num_heads == 0
+        self.head_dim = self.embed_dim // num_heads
+        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
+        self.inner_attn = Attention(flash)
+        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.ffn = nn.Sequential(
+            nn.Linear(2 * embed_dim, 2 * embed_dim),
+            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
+            nn.GELU(),
+            nn.Linear(2 * embed_dim, embed_dim),
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        encoding: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        qkv = self.Wqkv(x)
+        qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
+        q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
+        q = apply_cached_rotary_emb(encoding, q)
+        k = apply_cached_rotary_emb(encoding, k)
+        context = self.inner_attn(q, k, v, mask=mask)
+        message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
+        return x + self.ffn(torch.cat([x, message], -1))
+
+
+class CrossBlock(nn.Module):
+    def __init__(
+        self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
+    ) -> None:
+        super().__init__()
+        self.heads = num_heads
+        dim_head = embed_dim // num_heads
+        self.scale = dim_head**-0.5
+        inner_dim = dim_head * num_heads
+        self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
+        self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
+        self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
+        self.ffn = nn.Sequential(
+            nn.Linear(2 * embed_dim, 2 * embed_dim),
+            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
+            nn.GELU(),
+            nn.Linear(2 * embed_dim, embed_dim),
+        )
+        if flash and FLASH_AVAILABLE:
+            self.flash = Attention(True)
+        else:
+            self.flash = None
+
+    def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
+        return func(x0), func(x1)
+
+    def forward(
+        self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
+    ) -> List[torch.Tensor]:
+        qk0, qk1 = self.map_(self.to_qk, x0, x1)
+        v0, v1 = self.map_(self.to_v, x0, x1)
+        qk0, qk1, v0, v1 = map(
+            lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
+            (qk0, qk1, v0, v1),
+        )
+        if self.flash is not None and qk0.device.type == "cuda":
+            m0 = self.flash(qk0, qk1, v1, mask)
+            m1 = self.flash(
+                qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
+            )
+        else:
+            qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
+            sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
+            if mask is not None:
+                sim = sim.masked_fill(~mask, -float("inf"))
+            attn01 = F.softmax(sim, dim=-1)
+            attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
+            m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
+            m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
+            if mask is not None:
+                m0, m1 = m0.nan_to_num(), m1.nan_to_num()
+        m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
+        m0, m1 = self.map_(self.to_out, m0, m1)
+        x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
+        x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
+        return x0, x1
+
+
+class TransformerLayer(nn.Module):
+    def __init__(self, *args, **kwargs):
+        super().__init__()
+        self.self_attn = SelfBlock(*args, **kwargs)
+        self.cross_attn = CrossBlock(*args, **kwargs)
+
+    def forward(
+        self,
+        desc0,
+        desc1,
+        encoding0,
+        encoding1,
+        mask0: Optional[torch.Tensor] = None,
+        mask1: Optional[torch.Tensor] = None,
+    ):
+        if mask0 is not None and mask1 is not None:
+            return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
+        else:
+            desc0 = self.self_attn(desc0, encoding0)
+            desc1 = self.self_attn(desc1, encoding1)
+            return self.cross_attn(desc0, desc1)
+
+    # This part is compiled and allows padding inputs
+    def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
+        mask = mask0 & mask1.transpose(-1, -2)
+        mask0 = mask0 & mask0.transpose(-1, -2)
+        mask1 = mask1 & mask1.transpose(-1, -2)
+        desc0 = self.self_attn(desc0, encoding0, mask0)
+        desc1 = self.self_attn(desc1, encoding1, mask1)
+        return self.cross_attn(desc0, desc1, mask)
+
+
+def sigmoid_log_double_softmax(
+    sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
+) -> torch.Tensor:
+    """create the log assignment matrix from logits and similarity"""
+    b, m, n = sim.shape
+    certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
+    scores0 = F.log_softmax(sim, 2)
+    scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
+    scores = sim.new_full((b, m + 1, n + 1), 0)
+    scores[:, :m, :n] = scores0 + scores1 + certainties
+    scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
+    scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
+    return scores
+
+
+class MatchAssignment(nn.Module):
+    def __init__(self, dim: int) -> None:
+        super().__init__()
+        self.dim = dim
+        self.matchability = nn.Linear(dim, 1, bias=True)
+        self.final_proj = nn.Linear(dim, dim, bias=True)
+
+    def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
+        """build assignment matrix from descriptors"""
+        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
+        _, _, d = mdesc0.shape
+        mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
+        sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
+        z0 = self.matchability(desc0)
+        z1 = self.matchability(desc1)
+        scores = sigmoid_log_double_softmax(sim, z0, z1)
+        return scores, sim
+
+    def get_matchability(self, desc: torch.Tensor):
+        return torch.sigmoid(self.matchability(desc)).squeeze(-1)
+
+
+def filter_matches(scores: torch.Tensor, th: float):
+    """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
+    max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
+    m0, m1 = max0.indices, max1.indices
+    indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
+    indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
+    mutual0 = indices0 == m1.gather(1, m0)
+    mutual1 = indices1 == m0.gather(1, m1)
+    max0_exp = max0.values.exp()
+    zero = max0_exp.new_tensor(0)
+    mscores0 = torch.where(mutual0, max0_exp, zero)
+    mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
+    valid0 = mutual0 & (mscores0 > th)
+    valid1 = mutual1 & valid0.gather(1, m1)
+    m0 = torch.where(valid0, m0, -1)
+    m1 = torch.where(valid1, m1, -1)
+    return m0, m1, mscores0, mscores1
+
+
+class LightGlue(nn.Module):
+    default_conf = {
+        "name": "lightglue",  # just for interfacing
+        "input_dim": 256,  # input descriptor dimension (autoselected from weights)
+        "descriptor_dim": 256,
+        "add_scale_ori": False,
+        "n_layers": 9,
+        "num_heads": 4,
+        "flash": True,  # enable FlashAttention if available.
+        "mp": False,  # enable mixed precision
+        "depth_confidence": 0.95,  # early stopping, disable with -1
+        "width_confidence": 0.99,  # point pruning, disable with -1
+        "filter_threshold": 0.1,  # match threshold
+        "weights": None,
+    }
+
+    # Point pruning involves an overhead (gather).
+    # Therefore, we only activate it if there are enough keypoints.
+    pruning_keypoint_thresholds = {
+        "cpu": -1,
+        "mps": -1,
+        "cuda": 1024,
+        "flash": 1536,
+    }
+
+    required_data_keys = ["image0", "image1"]
+
+    version = "v0.1_arxiv"
+    url = "https://github.com/cvg/LightGlue/releases/download/{}/{}.pth"
+
+    features = {
+        "superpoint": {
+            "weights": "superpoint_lightglue",
+            "input_dim": 256,
+        },
+        "disk": {
+            "weights": "disk_lightglue",
+            "input_dim": 128,
+        },
+        "aliked": {
+            "weights": "aliked_lightglue",
+            "input_dim": 128,
+        },
+        "raco-aliked": {
+            "weights": "raco_aliked_lightglue",
+            "input_dim": 128,
+        },
+        "sift": {
+            "weights": "sift_lightglue",
+            "input_dim": 128,
+            "add_scale_ori": True,
+        },
+        "doghardnet": {
+            "weights": "doghardnet_lightglue",
+            "input_dim": 128,
+            "add_scale_ori": True,
+        },
+    }
+
+    def __init__(self, features="superpoint", **conf) -> None:
+        super().__init__()
+        self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
+        if features is not None:
+            if features not in self.features:
+                raise ValueError(
+                    f"Unsupported features: {features} not in "
+                    f"{{{','.join(self.features)}}}"
+                )
+            for k, v in self.features[features].items():
+                setattr(conf, k, v)
+
+        if conf.input_dim != conf.descriptor_dim:
+            self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
+        else:
+            self.input_proj = nn.Identity()
+
+        head_dim = conf.descriptor_dim // conf.num_heads
+        self.posenc = LearnableFourierPositionalEncoding(
+            2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
+        )
+
+        h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
+
+        self.transformers = nn.ModuleList(
+            [TransformerLayer(d, h, conf.flash) for _ in range(n)]
+        )
+
+        self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
+        self.token_confidence = nn.ModuleList(
+            [TokenConfidence(d) for _ in range(n - 1)]
+        )
+        self.register_buffer(
+            "confidence_thresholds",
+            torch.Tensor(
+                [self.confidence_threshold(i) for i in range(self.conf.n_layers)]
+            ),
+        )
+
+        state_dict = None
+        if features is not None:
+            fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
+            state_dict = torch.hub.load_state_dict_from_url(
+                self.url.format(self.version, self.conf.weights),
+                file_name=fname,
+            )
+            self.load_state_dict(state_dict, strict=False)
+        elif conf.weights is not None:
+            path = Path(__file__).parent
+            path = path / "weights/{}.pth".format(self.conf.weights)
+            state_dict = torch.load(str(path), map_location="cpu")
+
+        if state_dict:
+            # rename old state dict entries
+            for i in range(self.conf.n_layers):
+                pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
+                state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
+                pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
+                state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
+            self.load_state_dict(state_dict, strict=False)
+
+        # static lengths LightGlue is compiled for (only used with torch.compile)
+        self.static_lengths = None
+
+    def compile(
+        self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
+    ):
+        if self.conf.width_confidence != -1:
+            warnings.warn(
+                "Point pruning is partially disabled for compiled forward.",
+                stacklevel=2,
+            )
+
+        torch._inductor.cudagraph_mark_step_begin()
+        for i in range(self.conf.n_layers):
+            self.transformers[i].masked_forward = torch.compile(
+                self.transformers[i].masked_forward, mode=mode, fullgraph=True
+            )
+
+        self.static_lengths = static_lengths
+
+    def forward(self, data: dict) -> dict:
+        """
+        Match keypoints and descriptors between two images
+
+        Input (dict):
+            image0: dict
+                keypoints: [B x M x 2]
+                descriptors: [B x M x D]
+                image: [B x C x H x W] or image_size: [B x 2]
+            image1: dict
+                keypoints: [B x N x 2]
+                descriptors: [B x N x D]
+                image: [B x C x H x W] or image_size: [B x 2]
+        Output (dict):
+            matches0: [B x M]
+            matching_scores0: [B x M]
+            matches1: [B x N]
+            matching_scores1: [B x N]
+            matches: List[[Si x 2]]
+            scores: List[[Si]]
+            stop: int
+            prune0: [B x M]
+            prune1: [B x N]
+        """
+        with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
+            return self._forward(data)
+
+    def _forward(self, data: dict) -> dict:
+        for key in self.required_data_keys:
+            assert key in data, f"Missing key {key} in data"
+        data0, data1 = data["image0"], data["image1"]
+        kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
+        b, m, _ = kpts0.shape
+        b, n, _ = kpts1.shape
+        device = kpts0.device
+        size0, size1 = data0.get("image_size"), data1.get("image_size")
+        kpts0 = normalize_keypoints(kpts0, size0).clone()
+        kpts1 = normalize_keypoints(kpts1, size1).clone()
+
+        if self.conf.add_scale_ori:
+            kpts0 = torch.cat(
+                [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
+            )
+            kpts1 = torch.cat(
+                [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
+            )
+        desc0 = data0["descriptors"].detach().contiguous()
+        desc1 = data1["descriptors"].detach().contiguous()
+
+        assert desc0.shape[-1] == self.conf.input_dim
+        assert desc1.shape[-1] == self.conf.input_dim
+
+        if torch.is_autocast_enabled():
+            desc0 = desc0.half()
+            desc1 = desc1.half()
+
+        mask0, mask1 = None, None
+        c = max(m, n)
+        do_compile = self.static_lengths and c <= max(self.static_lengths)
+        if do_compile:
+            kn = min([k for k in self.static_lengths if k >= c])
+            desc0, mask0 = pad_to_length(desc0, kn)
+            desc1, mask1 = pad_to_length(desc1, kn)
+            kpts0, _ = pad_to_length(kpts0, kn)
+            kpts1, _ = pad_to_length(kpts1, kn)
+        desc0 = self.input_proj(desc0)
+        desc1 = self.input_proj(desc1)
+        # cache positional embeddings
+        encoding0 = self.posenc(kpts0)
+        encoding1 = self.posenc(kpts1)
+
+        # GNN + final_proj + assignment
+        do_early_stop = self.conf.depth_confidence > 0
+        do_point_pruning = self.conf.width_confidence > 0 and not do_compile
+        pruning_th = self.pruning_min_kpts(device)
+        if do_point_pruning:
+            ind0 = torch.arange(0, m, device=device)[None]
+            ind1 = torch.arange(0, n, device=device)[None]
+            # We store the index of the layer at which pruning is detected.
+            prune0 = torch.ones_like(ind0)
+            prune1 = torch.ones_like(ind1)
+        token0, token1 = None, None
+        for i in range(self.conf.n_layers):
+            if desc0.shape[1] == 0 or desc1.shape[1] == 0:  # no keypoints
+                break
+            desc0, desc1 = self.transformers[i](
+                desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
+            )
+            if i == self.conf.n_layers - 1:
+                continue  # no early stopping or adaptive width at last layer
+
+            if do_early_stop:
+                token0, token1 = self.token_confidence[i](desc0, desc1)
+                if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
+                    break
+            if do_point_pruning and desc0.shape[-2] > pruning_th:
+                scores0 = self.log_assignment[i].get_matchability(desc0)
+                prunemask0 = self.get_pruning_mask(token0, scores0, i)
+                keep0 = torch.where(prunemask0)[1]
+                ind0 = ind0.index_select(1, keep0)
+                desc0 = desc0.index_select(1, keep0)
+                encoding0 = encoding0.index_select(-2, keep0)
+                prune0[:, ind0] += 1
+            if do_point_pruning and desc1.shape[-2] > pruning_th:
+                scores1 = self.log_assignment[i].get_matchability(desc1)
+                prunemask1 = self.get_pruning_mask(token1, scores1, i)
+                keep1 = torch.where(prunemask1)[1]
+                ind1 = ind1.index_select(1, keep1)
+                desc1 = desc1.index_select(1, keep1)
+                encoding1 = encoding1.index_select(-2, keep1)
+                prune1[:, ind1] += 1
+
+        if desc0.shape[1] == 0 or desc1.shape[1] == 0:  # no keypoints
+            m0 = desc0.new_full((b, m), -1, dtype=torch.long)
+            m1 = desc1.new_full((b, n), -1, dtype=torch.long)
+            mscores0 = desc0.new_zeros((b, m))
+            mscores1 = desc1.new_zeros((b, n))
+            matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
+            mscores = desc0.new_empty((b, 0))
+            if not do_point_pruning:
+                prune0 = torch.ones_like(mscores0) * self.conf.n_layers
+                prune1 = torch.ones_like(mscores1) * self.conf.n_layers
+            return {
+                "matches0": m0,
+                "matches1": m1,
+                "matching_scores0": mscores0,
+                "matching_scores1": mscores1,
+                "stop": i + 1,
+                "matches": matches,
+                "scores": mscores,
+                "prune0": prune0,
+                "prune1": prune1,
+            }
+
+        desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]  # remove padding
+        scores, _ = self.log_assignment[i](desc0, desc1)
+        m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
+        matches, mscores = [], []
+        for k in range(b):
+            valid = m0[k] > -1
+            m_indices_0 = torch.where(valid)[0]
+            m_indices_1 = m0[k][valid]
+            if do_point_pruning:
+                m_indices_0 = ind0[k, m_indices_0]
+                m_indices_1 = ind1[k, m_indices_1]
+            matches.append(torch.stack([m_indices_0, m_indices_1], -1))
+            mscores.append(mscores0[k][valid])
+
+        # TODO: Remove when hloc switches to the compact format.
+        if do_point_pruning:
+            m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
+            m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
+            m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
+            m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
+            mscores0_ = torch.zeros((b, m), device=mscores0.device)
+            mscores1_ = torch.zeros((b, n), device=mscores1.device)
+            mscores0_[:, ind0] = mscores0
+            mscores1_[:, ind1] = mscores1
+            m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
+        else:
+            prune0 = torch.ones_like(mscores0) * self.conf.n_layers
+            prune1 = torch.ones_like(mscores1) * self.conf.n_layers
+
+        return {
+            "matches0": m0,
+            "matches1": m1,
+            "matching_scores0": mscores0,
+            "matching_scores1": mscores1,
+            "stop": i + 1,
+            "matches": matches,
+            "scores": mscores,
+            "prune0": prune0,
+            "prune1": prune1,
+        }
+
+    def confidence_threshold(self, layer_index: int) -> float:
+        """scaled confidence threshold"""
+        threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
+        return np.clip(threshold, 0, 1)
+
+    def get_pruning_mask(
+        self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
+    ) -> torch.Tensor:
+        """mask points which should be removed"""
+        keep = scores > (1 - self.conf.width_confidence)
+        if confidences is not None:  # Low-confidence points are never pruned.
+            keep |= confidences <= self.confidence_thresholds[layer_index]
+        return keep
+
+    def check_if_stop(
+        self,
+        confidences0: torch.Tensor,
+        confidences1: torch.Tensor,
+        layer_index: int,
+        num_points: int,
+    ) -> torch.Tensor:
+        """evaluate stopping condition"""
+        confidences = torch.cat([confidences0, confidences1], -1)
+        threshold = self.confidence_thresholds[layer_index]
+        ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
+        return ratio_confident > self.conf.depth_confidence
+
+    def pruning_min_kpts(self, device: torch.device):
+        if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
+            return self.pruning_keypoint_thresholds["flash"]
+        else:
+            return self.pruning_keypoint_thresholds[device.type]

+ 216 - 0
python/LightGlue/lightglue/sift.py

@@ -0,0 +1,216 @@
+import warnings
+
+import cv2
+import numpy as np
+import torch
+from kornia.color import rgb_to_grayscale
+from packaging import version
+
+try:
+    import pycolmap
+except ImportError:
+    pycolmap = None
+
+from .utils import Extractor
+
+
+def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
+    h, w = image_shape
+    ij = np.round(points - 0.5).astype(int).T[::-1]
+
+    # Remove duplicate points (identical coordinates).
+    # Pick highest scale or score
+    s = scales if scores is None else scores
+    buffer = np.zeros((h, w))
+    np.maximum.at(buffer, tuple(ij), s)
+    keep = np.where(buffer[tuple(ij)] == s)[0]
+
+    # Pick lowest angle (arbitrary).
+    ij = ij[:, keep]
+    buffer[:] = np.inf
+    o_abs = np.abs(angles[keep])
+    np.minimum.at(buffer, tuple(ij), o_abs)
+    mask = buffer[tuple(ij)] == o_abs
+    ij = ij[:, mask]
+    keep = keep[mask]
+
+    if nms_radius > 0:
+        # Apply NMS on the remaining points
+        buffer[:] = 0
+        buffer[tuple(ij)] = s[keep]  # scores or scale
+
+        local_max = torch.nn.functional.max_pool2d(
+            torch.from_numpy(buffer).unsqueeze(0),
+            kernel_size=nms_radius * 2 + 1,
+            stride=1,
+            padding=nms_radius,
+        ).squeeze(0)
+        is_local_max = buffer == local_max.numpy()
+        keep = keep[is_local_max[tuple(ij)]]
+    return keep
+
+
+def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
+    x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
+    x.clip_(min=eps).sqrt_()
+    return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
+
+
+def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
+    """
+    Detect keypoints using OpenCV Detector.
+    Optionally, perform description.
+    Args:
+        features: OpenCV based keypoints detector and descriptor
+        image: Grayscale image of uint8 data type
+    Returns:
+        keypoints: 1D array of detected cv2.KeyPoint
+        scores: 1D array of responses
+        descriptors: 1D array of descriptors
+    """
+    detections, descriptors = features.detectAndCompute(image, None)
+    points = np.array([k.pt for k in detections], dtype=np.float32)
+    scores = np.array([k.response for k in detections], dtype=np.float32)
+    scales = np.array([k.size for k in detections], dtype=np.float32)
+    angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
+    return points, scores, scales, angles, descriptors
+
+
+class SIFT(Extractor):
+    default_conf = {
+        "rootsift": True,
+        "nms_radius": 0,  # None to disable filtering entirely.
+        "max_num_keypoints": 4096,
+        "backend": "opencv",  # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
+        "detection_threshold": 0.0066667,  # from COLMAP
+        "edge_threshold": 10,
+        "first_octave": -1,  # only used by pycolmap, the default of COLMAP
+        "num_octaves": 4,
+    }
+
+    preprocess_conf = {
+        "resize": 1024,
+    }
+
+    required_data_keys = ["image"]
+
+    def __init__(self, **conf):
+        super().__init__(**conf)  # Update with default configuration.
+        backend = self.conf.backend
+        if backend.startswith("pycolmap"):
+            if pycolmap is None:
+                raise ImportError(
+                    "Cannot find module pycolmap: install it with pip"
+                    "or use backend=opencv."
+                )
+            options = {
+                "peak_threshold": self.conf.detection_threshold,
+                "edge_threshold": self.conf.edge_threshold,
+                "first_octave": self.conf.first_octave,
+                "num_octaves": self.conf.num_octaves,
+                "normalization": pycolmap.Normalization.L2,  # L1_ROOT is buggy.
+            }
+            device = (
+                "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
+            )
+            if (
+                backend == "pycolmap_cpu" or not pycolmap.has_cuda
+            ) and pycolmap.__version__ < "0.5.0":
+                warnings.warn(
+                    "The pycolmap CPU SIFT is buggy in version < 0.5.0, "
+                    "consider upgrading pycolmap or use the CUDA version.",
+                    stacklevel=1,
+                )
+            else:
+                options["max_num_features"] = self.conf.max_num_keypoints
+            self.sift = pycolmap.Sift(options=options, device=device)
+        elif backend == "opencv":
+            self.sift = cv2.SIFT_create(
+                contrastThreshold=self.conf.detection_threshold,
+                nfeatures=self.conf.max_num_keypoints,
+                edgeThreshold=self.conf.edge_threshold,
+                nOctaveLayers=self.conf.num_octaves,
+            )
+        else:
+            backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
+            raise ValueError(
+                f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
+            )
+
+    def extract_single_image(self, image: torch.Tensor):
+        image_np = image.cpu().numpy().squeeze(0)
+
+        if self.conf.backend.startswith("pycolmap"):
+            if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
+                detections, descriptors = self.sift.extract(image_np)
+                scores = None  # Scores are not exposed by COLMAP anymore.
+            else:
+                detections, scores, descriptors = self.sift.extract(image_np)
+            keypoints = detections[:, :2]  # Keep only (x, y).
+            scales, angles = detections[:, -2:].T
+            if scores is not None and (
+                self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
+            ):
+                # Set the scores as a combination of abs. response and scale.
+                scores = np.abs(scores) * scales
+        elif self.conf.backend == "opencv":
+            # TODO: Check if opencv keypoints are already in corner convention
+            keypoints, scores, scales, angles, descriptors = run_opencv_sift(
+                self.sift, (image_np * 255.0).astype(np.uint8)
+            )
+        pred = {
+            "keypoints": keypoints,
+            "scales": scales,
+            "oris": angles,
+            "descriptors": descriptors,
+        }
+        if scores is not None:
+            pred["keypoint_scores"] = scores
+
+        # sometimes pycolmap returns points outside the image. We remove them
+        if self.conf.backend.startswith("pycolmap"):
+            is_inside = (
+                pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
+            ).all(-1)
+            pred = {k: v[is_inside] for k, v in pred.items()}
+
+        if self.conf.nms_radius is not None:
+            keep = filter_dog_point(
+                pred["keypoints"],
+                pred["scales"],
+                pred["oris"],
+                image_np.shape,
+                self.conf.nms_radius,
+                scores=pred.get("keypoint_scores"),
+            )
+            pred = {k: v[keep] for k, v in pred.items()}
+
+        pred = {k: torch.from_numpy(v) for k, v in pred.items()}
+        if scores is not None:
+            # Keep the k keypoints with highest score
+            num_points = self.conf.max_num_keypoints
+            if num_points is not None and len(pred["keypoints"]) > num_points:
+                indices = torch.topk(pred["keypoint_scores"], num_points).indices
+                pred = {k: v[indices] for k, v in pred.items()}
+
+        return pred
+
+    def forward(self, data: dict) -> dict:
+        image = data["image"]
+        if image.shape[1] == 3:
+            image = rgb_to_grayscale(image)
+        device = image.device
+        image = image.cpu()
+        pred = []
+        for k in range(len(image)):
+            img = image[k]
+            if "image_size" in data.keys():
+                # avoid extracting points in padded areas
+                w, h = data["image_size"][k]
+                img = img[:, :h, :w]
+            p = self.extract_single_image(img)
+            pred.append(p)
+        pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
+        if self.conf.rootsift:
+            pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
+        return pred

+ 227 - 0
python/LightGlue/lightglue/superpoint.py

@@ -0,0 +1,227 @@
+# %BANNER_BEGIN%
+# ---------------------------------------------------------------------
+# %COPYRIGHT_BEGIN%
+#
+#  Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
+#
+#  Unpublished Copyright (c) 2020
+#  Magic Leap, Inc., All Rights Reserved.
+#
+# NOTICE:  All information contained herein is, and remains the property
+# of COMPANY. The intellectual and technical concepts contained herein
+# are proprietary to COMPANY and may be covered by U.S. and Foreign
+# Patents, patents in process, and are protected by trade secret or
+# copyright law.  Dissemination of this information or reproduction of
+# this material is strictly forbidden unless prior written permission is
+# obtained from COMPANY.  Access to the source code contained herein is
+# hereby forbidden to anyone except current COMPANY employees, managers
+# or contractors who have executed Confidentiality and Non-disclosure
+# agreements explicitly covering such access.
+#
+# The copyright notice above does not evidence any actual or intended
+# publication or disclosure  of  this source code, which includes
+# information that is confidential and/or proprietary, and is a trade
+# secret, of  COMPANY.   ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
+# PUBLIC  PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE  OF THIS
+# SOURCE CODE  WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
+# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
+# INTERNATIONAL TREATIES.  THE RECEIPT OR POSSESSION OF  THIS SOURCE
+# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
+# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
+# USE, OR SELL ANYTHING THAT IT  MAY DESCRIBE, IN WHOLE OR IN PART.
+#
+# %COPYRIGHT_END%
+# ----------------------------------------------------------------------
+# %AUTHORS_BEGIN%
+#
+#  Originating Authors: Paul-Edouard Sarlin
+#
+# %AUTHORS_END%
+# --------------------------------------------------------------------*/
+# %BANNER_END%
+
+# Adapted by Remi Pautrat, Philipp Lindenberger
+
+import torch
+from kornia.color import rgb_to_grayscale
+from torch import nn
+
+from .utils import Extractor
+
+
+def simple_nms(scores, nms_radius: int):
+    """Fast Non-maximum suppression to remove nearby points"""
+    assert nms_radius >= 0
+
+    def max_pool(x):
+        return torch.nn.functional.max_pool2d(
+            x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
+        )
+
+    zeros = torch.zeros_like(scores)
+    max_mask = scores == max_pool(scores)
+    for _ in range(2):
+        supp_mask = max_pool(max_mask.float()) > 0
+        supp_scores = torch.where(supp_mask, zeros, scores)
+        new_max_mask = supp_scores == max_pool(supp_scores)
+        max_mask = max_mask | (new_max_mask & (~supp_mask))
+    return torch.where(max_mask, scores, zeros)
+
+
+def top_k_keypoints(keypoints, scores, k):
+    if k >= len(keypoints):
+        return keypoints, scores
+    scores, indices = torch.topk(scores, k, dim=0, sorted=True)
+    return keypoints[indices], scores
+
+
+def sample_descriptors(keypoints, descriptors, s: int = 8):
+    """Interpolate descriptors at keypoint locations"""
+    b, c, h, w = descriptors.shape
+    keypoints = keypoints - s / 2 + 0.5
+    keypoints /= torch.tensor(
+        [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
+    ).to(
+        keypoints
+    )[None]
+    keypoints = keypoints * 2 - 1  # normalize to (-1, 1)
+    args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
+    descriptors = torch.nn.functional.grid_sample(
+        descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
+    )
+    descriptors = torch.nn.functional.normalize(
+        descriptors.reshape(b, c, -1), p=2, dim=1
+    )
+    return descriptors
+
+
+class SuperPoint(Extractor):
+    """SuperPoint Convolutional Detector and Descriptor
+
+    SuperPoint: Self-Supervised Interest Point Detection and
+    Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
+    Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
+
+    """
+
+    default_conf = {
+        "descriptor_dim": 256,
+        "nms_radius": 4,
+        "max_num_keypoints": None,
+        "detection_threshold": 0.0005,
+        "remove_borders": 4,
+    }
+
+    preprocess_conf = {
+        "resize": 1024,
+    }
+
+    required_data_keys = ["image"]
+
+    def __init__(self, **conf):
+        super().__init__(**conf)  # Update with default configuration.
+        self.relu = nn.ReLU(inplace=True)
+        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+        c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
+
+        self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
+        self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
+        self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
+        self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
+        self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
+        self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
+        self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
+        self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
+
+        self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+        self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
+
+        self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+        self.convDb = nn.Conv2d(
+            c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0
+        )
+
+        url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth"  # noqa
+        self.load_state_dict(torch.hub.load_state_dict_from_url(url))
+
+        if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0:
+            raise ValueError("max_num_keypoints must be positive or None")
+
+    def forward(self, data: dict) -> dict:
+        """Compute keypoints, scores, descriptors for image"""
+        for key in self.required_data_keys:
+            assert key in data, f"Missing key {key} in data"
+        image = data["image"]
+        if image.shape[1] == 3:
+            image = rgb_to_grayscale(image)
+
+        # Shared Encoder
+        x = self.relu(self.conv1a(image))
+        x = self.relu(self.conv1b(x))
+        x = self.pool(x)
+        x = self.relu(self.conv2a(x))
+        x = self.relu(self.conv2b(x))
+        x = self.pool(x)
+        x = self.relu(self.conv3a(x))
+        x = self.relu(self.conv3b(x))
+        x = self.pool(x)
+        x = self.relu(self.conv4a(x))
+        x = self.relu(self.conv4b(x))
+
+        # Compute the dense keypoint scores
+        cPa = self.relu(self.convPa(x))
+        scores = self.convPb(cPa)
+        scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
+        b, _, h, w = scores.shape
+        scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
+        scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
+        scores = simple_nms(scores, self.conf.nms_radius)
+
+        # Discard keypoints near the image borders
+        if self.conf.remove_borders:
+            pad = self.conf.remove_borders
+            scores[:, :pad] = -1
+            scores[:, :, :pad] = -1
+            scores[:, -pad:] = -1
+            scores[:, :, -pad:] = -1
+
+        # Extract keypoints
+        best_kp = torch.where(scores > self.conf.detection_threshold)
+        scores = scores[best_kp]
+
+        # Separate into batches
+        keypoints = [
+            torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
+        ]
+        scores = [scores[best_kp[0] == i] for i in range(b)]
+
+        # Keep the k keypoints with highest score
+        if self.conf.max_num_keypoints is not None:
+            keypoints, scores = list(
+                zip(
+                    *[
+                        top_k_keypoints(k, s, self.conf.max_num_keypoints)
+                        for k, s in zip(keypoints, scores)
+                    ]
+                )
+            )
+
+        # Convert (h, w) to (x, y)
+        keypoints = [torch.flip(k, [1]).float() for k in keypoints]
+
+        # Compute the dense descriptors
+        cDa = self.relu(self.convDa(x))
+        descriptors = self.convDb(cDa)
+        descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
+
+        # Extract descriptors
+        descriptors = [
+            sample_descriptors(k[None], d[None], 8)[0]
+            for k, d in zip(keypoints, descriptors)
+        ]
+
+        return {
+            "keypoints": torch.stack(keypoints, 0),
+            "keypoint_scores": torch.stack(scores, 0),
+            "descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
+        }

+ 165 - 0
python/LightGlue/lightglue/utils.py

@@ -0,0 +1,165 @@
+import collections.abc as collections
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Callable, List, Optional, Tuple, Union
+
+import cv2
+import kornia
+import numpy as np
+import torch
+
+
+class ImagePreprocessor:
+    default_conf = {
+        "resize": None,  # target edge length, None for no resizing
+        "side": "long",
+        "interpolation": "bilinear",
+        "align_corners": None,
+        "antialias": True,
+    }
+
+    def __init__(self, **conf) -> None:
+        super().__init__()
+        self.conf = {**self.default_conf, **conf}
+        self.conf = SimpleNamespace(**self.conf)
+
+    def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Resize and preprocess an image, return image and resize scale"""
+        h, w = img.shape[-2:]
+        if self.conf.resize is not None:
+            img = kornia.geometry.transform.resize(
+                img,
+                self.conf.resize,
+                side=self.conf.side,
+                antialias=self.conf.antialias,
+                align_corners=self.conf.align_corners,
+            )
+        scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
+        return img, scale
+
+
+def map_tensor(input_, func: Callable):
+    string_classes = (str, bytes)
+    if isinstance(input_, string_classes):
+        return input_
+    elif isinstance(input_, collections.Mapping):
+        return {k: map_tensor(sample, func) for k, sample in input_.items()}
+    elif isinstance(input_, collections.Sequence):
+        return [map_tensor(sample, func) for sample in input_]
+    elif isinstance(input_, torch.Tensor):
+        return func(input_)
+    else:
+        return input_
+
+
+def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True):
+    """Move batch (dict) to device"""
+
+    def _func(tensor):
+        return tensor.to(device=device, non_blocking=non_blocking).detach()
+
+    return map_tensor(batch, _func)
+
+
+def rbd(data: dict) -> dict:
+    """Remove batch dimension from elements in data"""
+    return {
+        k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
+        for k, v in data.items()
+    }
+
+
+def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
+    """Read an image from path as RGB or grayscale"""
+    if not Path(path).exists():
+        raise FileNotFoundError(f"No image at path {path}.")
+    mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
+    image = cv2.imread(str(path), mode)
+    if image is None:
+        raise IOError(f"Could not read image at {path}.")
+    if not grayscale:
+        image = image[..., ::-1]
+    return image
+
+
+def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
+    """Normalize the image tensor and reorder the dimensions."""
+    if image.ndim == 3:
+        image = image.transpose((2, 0, 1))  # HxWxC to CxHxW
+    elif image.ndim == 2:
+        image = image[None]  # add channel axis
+    else:
+        raise ValueError(f"Not an image: {image.shape}")
+    return torch.tensor(image / 255.0, dtype=torch.float)
+
+
+def resize_image(
+    image: np.ndarray,
+    size: Union[List[int], int],
+    fn: str = "max",
+    interp: Optional[str] = "area",
+) -> np.ndarray:
+    """Resize an image to a fixed size, or according to max or min edge."""
+    h, w = image.shape[:2]
+
+    fn = {"max": max, "min": min}[fn]
+    if isinstance(size, int):
+        scale = size / fn(h, w)
+        h_new, w_new = int(round(h * scale)), int(round(w * scale))
+        scale = (w_new / w, h_new / h)
+    elif isinstance(size, (tuple, list)):
+        h_new, w_new = size
+        scale = (w_new / w, h_new / h)
+    else:
+        raise ValueError(f"Incorrect new size: {size}")
+    mode = {
+        "linear": cv2.INTER_LINEAR,
+        "cubic": cv2.INTER_CUBIC,
+        "nearest": cv2.INTER_NEAREST,
+        "area": cv2.INTER_AREA,
+    }[interp]
+    return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
+
+
+def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
+    image = read_image(path)
+    if resize is not None:
+        image, _ = resize_image(image, resize, **kwargs)
+    return numpy_image_to_torch(image)
+
+
+class Extractor(torch.nn.Module):
+    def __init__(self, **conf):
+        super().__init__()
+        self.conf = SimpleNamespace(**{**self.default_conf, **conf})
+
+    @torch.no_grad()
+    def extract(self, img: torch.Tensor, **conf) -> dict:
+        """Perform extraction with online resizing"""
+        if img.dim() == 3:
+            img = img[None]  # add batch dim
+        assert img.dim() == 4 and img.shape[0] == 1
+        shape = img.shape[-2:][::-1]
+        img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
+        feats = self.forward({"image": img})
+        feats["image_size"] = torch.tensor(shape)[None].to(img).float()
+        feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
+        return feats
+
+
+def match_pair(
+    extractor,
+    matcher,
+    image0: torch.Tensor,
+    image1: torch.Tensor,
+    device: str = "cpu",
+    **preprocess,
+):
+    """Match a pair of images (image0, image1) with an extractor and matcher"""
+    feats0 = extractor.extract(image0, **preprocess)
+    feats1 = extractor.extract(image1, **preprocess)
+    matches01 = matcher({"image0": feats0, "image1": feats1})
+    data = [feats0, feats1, matches01]
+    # remove batch dim and move to target device
+    feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data]
+    return feats0, feats1, matches01

+ 203 - 0
python/LightGlue/lightglue/viz2d.py

@@ -0,0 +1,203 @@
+"""
+2D visualization primitives based on Matplotlib.
+1) Plot images with `plot_images`.
+2) Call `plot_keypoints` or `plot_matches` any number of times.
+3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
+"""
+
+import matplotlib
+import matplotlib.patheffects as path_effects
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+
+def cm_RdGn(x):
+    """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
+    x = np.clip(x, 0, 1)[..., None] * 2
+    c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
+    return np.clip(c, 0, 1)
+
+
+def cm_BlRdGn(x_):
+    """Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
+    x = np.clip(x_, 0, 1)[..., None] * 2
+    c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
+
+    xn = -np.clip(x_, -1, 0)[..., None] * 2
+    cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
+    out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
+    return out
+
+
+def cm_prune(x_):
+    """Custom colormap to visualize pruning"""
+    if isinstance(x_, torch.Tensor):
+        x_ = x_.cpu().numpy()
+    max_i = max(x_)
+    norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9)
+    return cm_BlRdGn(norm_x)
+
+
+def cm_grad2d(xy):
+    """2D grad. colormap: yellow (0, 0) -> green (1, 0) -> red (0, 1) -> blue (1, 1)."""
+    tl = np.array([1.0, 0, 0])  # red
+    tr = np.array([0, 0.0, 1])  # blue
+    ll = np.array([1.0, 1.0, 0])  # yellow
+    lr = np.array([0, 1.0, 0])  # green
+
+    xy = np.clip(xy, 0, 1)
+    x = xy[..., :1]
+    y = xy[..., -1:]
+    rgb = (1 - x) * (1 - y) * ll + x * (1 - y) * lr + x * y * tr + (1 - x) * y * tl
+    return rgb.clip(0, 1)
+
+
+def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
+    """Plot a set of images horizontally.
+    Args:
+        imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
+        titles: a list of strings, as titles for each image.
+        cmaps: colormaps for monochrome images.
+        adaptive: whether the figure size should fit the image aspect ratios.
+    """
+    # conversion to (H, W, 3) for torch.Tensor
+    imgs = [
+        (
+            img.permute(1, 2, 0).cpu().numpy()
+            if (isinstance(img, torch.Tensor) and img.dim() == 3)
+            else img
+        )
+        for img in imgs
+    ]
+
+    n = len(imgs)
+    if not isinstance(cmaps, (list, tuple)):
+        cmaps = [cmaps] * n
+
+    if adaptive:
+        ratios = [i.shape[1] / i.shape[0] for i in imgs]  # W / H
+    else:
+        ratios = [4 / 3] * n
+    figsize = [sum(ratios) * 4.5, 4.5]
+    fig, ax = plt.subplots(
+        1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
+    )
+    if n == 1:
+        ax = [ax]
+    for i in range(n):
+        ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
+        ax[i].get_yaxis().set_ticks([])
+        ax[i].get_xaxis().set_ticks([])
+        ax[i].set_axis_off()
+        for spine in ax[i].spines.values():  # remove frame
+            spine.set_visible(False)
+        if titles:
+            ax[i].set_title(titles[i])
+    fig.tight_layout(pad=pad)
+
+
+def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
+    """Plot keypoints for existing images.
+    Args:
+        kpts: list of ndarrays of size (N, 2).
+        colors: string, or list of list of tuples (one for each keypoints).
+        ps: size of the keypoints as float.
+    """
+    if not isinstance(colors, list):
+        colors = [colors] * len(kpts)
+    if not isinstance(a, list):
+        a = [a] * len(kpts)
+    if axes is None:
+        axes = plt.gcf().axes
+    for ax, k, c, alpha in zip(axes, kpts, colors, a):
+        if isinstance(k, torch.Tensor):
+            k = k.cpu().numpy()
+        ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
+
+
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
+    """Plot matches for a pair of existing images.
+    Args:
+        kpts0, kpts1: corresponding keypoints of size (N, 2).
+        color: color of each match, string or RGB tuple. Random if not given.
+        lw: width of the lines.
+        ps: size of the end points (no endpoint if ps=0)
+        indices: indices of the images to draw the matches on.
+        a: alpha opacity of the match lines.
+    """
+    fig = plt.gcf()
+    if axes is None:
+        ax = fig.axes
+        ax0, ax1 = ax[0], ax[1]
+    else:
+        ax0, ax1 = axes
+    if isinstance(kpts0, torch.Tensor):
+        kpts0 = kpts0.cpu().numpy()
+    if isinstance(kpts1, torch.Tensor):
+        kpts1 = kpts1.cpu().numpy()
+    assert len(kpts0) == len(kpts1)
+    if color is None:
+        kpts_norm = (kpts0 - kpts0.min(axis=0, keepdims=True)) / np.ptp(
+            kpts0, axis=0, keepdims=True
+        )
+        color = cm_grad2d(kpts_norm)  # gradient color
+    elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+        color = [color] * len(kpts0)
+
+    if lw > 0:
+        for i in range(len(kpts0)):
+            line = matplotlib.patches.ConnectionPatch(
+                xyA=(kpts0[i, 0], kpts0[i, 1]),
+                xyB=(kpts1[i, 0], kpts1[i, 1]),
+                coordsA=ax0.transData,
+                coordsB=ax1.transData,
+                axesA=ax0,
+                axesB=ax1,
+                zorder=1,
+                color=color[i],
+                linewidth=lw,
+                clip_on=True,
+                alpha=a,
+                label=None if labels is None else labels[i],
+                picker=5.0,
+            )
+            line.set_annotation_clip(True)
+            fig.add_artist(line)
+
+    # freeze the axes to prevent the transform to change
+    ax0.autoscale(enable=False)
+    ax1.autoscale(enable=False)
+
+    if ps > 0:
+        ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
+        ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
+
+
+def add_text(
+    idx,
+    text,
+    pos=(0.01, 0.99),
+    fs=15,
+    color="w",
+    lcolor="k",
+    lwidth=2,
+    ha="left",
+    va="top",
+):
+    ax = plt.gcf().axes[idx]
+    t = ax.text(
+        *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
+    )
+    if lcolor is not None:
+        t.set_path_effects(
+            [
+                path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
+                path_effects.Normal(),
+            ]
+        )
+
+
+def save_plot(path, **kw):
+    """Save the current figure without any white margin."""
+    plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)

+ 30 - 0
python/LightGlue/pyproject.toml

@@ -0,0 +1,30 @@
+[project]
+name = "lightglue"
+description = "LightGlue: Local Feature Matching at Light Speed"
+version = "0.0"
+authors = [
+    {name = "Philipp Lindenberger"},
+    {name = "Paul-Edouard Sarlin"},
+]
+readme = "README.md"
+requires-python = ">=3.6"
+license = {file = "LICENSE"}
+classifiers = [
+    "Programming Language :: Python :: 3",
+    "License :: OSI Approved :: Apache Software License",
+    "Operating System :: OS Independent",
+]
+urls = {Repository = "https://github.com/cvg/LightGlue/"}
+dynamic = ["dependencies"]
+
+[project.optional-dependencies]
+dev = ["black==23.12.1", "flake8", "isort"]
+
+[tool.setuptools]
+packages = ["lightglue"]
+
+[tool.setuptools.dynamic]
+dependencies = {file = ["requirements.txt"]}
+
+[tool.isort]
+profile = "black"

+ 6 - 0
python/LightGlue/requirements.txt

@@ -0,0 +1,6 @@
+torch>=1.9.1
+torchvision>=0.3
+numpy
+opencv-python
+matplotlib
+kornia>=0.6.11

+ 1 - 0
python/py/Lib/site-packages/torchvision-0.26.0.dist-info/INSTALLER

@@ -0,0 +1 @@
+pip

+ 29 - 0
python/py/Lib/site-packages/torchvision-0.26.0.dist-info/LICENSE

@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) Soumith Chintala 2016, 
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+  list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+  this list of conditions and the following disclaimer in the documentation
+  and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+  contributors may be used to endorse or promote products derived from
+  this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

+ 128 - 0
python/py/Lib/site-packages/torchvision-0.26.0.dist-info/METADATA

@@ -0,0 +1,128 @@
+Metadata-Version: 2.1
+Name: torchvision
+Version: 0.26.0
+Summary: image and video datasets and models for torch deep learning
+Home-page: https://github.com/pytorch/vision
+Author: PyTorch Core Team
+Author-email: soumith@pytorch.org
+License: BSD
+Requires-Python: >=3.10,!=3.14.1
+Description-Content-Type: text/markdown
+License-File: LICENSE
+Requires-Dist: numpy
+Requires-Dist: torch (==2.11.0)
+Requires-Dist: pillow (!=8.3.*,>=5.3.0)
+Provides-Extra: gdown
+Requires-Dist: gdown (>=4.7.3) ; extra == 'gdown'
+Provides-Extra: scipy
+Requires-Dist: scipy ; extra == 'scipy'
+
+# torchvision
+
+[![total torchvision downloads](https://pepy.tech/badge/torchvision)](https://pepy.tech/project/torchvision)
+[![documentation](https://img.shields.io/badge/dynamic/json.svg?label=docs&url=https%3A%2F%2Fpypi.org%2Fpypi%2Ftorchvision%2Fjson&query=%24.info.version&colorB=brightgreen&prefix=v)](https://pytorch.org/vision/stable/index.html)
+
+The torchvision package consists of popular datasets, model architectures, and common image transformations for computer
+vision.
+
+## Installation
+
+Please refer to the [official
+instructions](https://pytorch.org/get-started/locally/) to install the stable
+versions of `torch` and `torchvision` on your system.
+
+To build source, refer to our [contributing
+page](https://github.com/pytorch/vision/blob/main/CONTRIBUTING.md#development-installation).
+
+The following is the corresponding `torchvision` versions and supported Python
+versions.
+
+| `torch`            | `torchvision`      | Python              |
+| ------------------ | ------------------ | ------------------- |
+| `main` / `nightly` | `main` / `nightly` | `>=3.10`, `<=3.14`  |
+| `2.10`             | `0.25`             | `>=3.10`, `<=3.14`  |
+| `2.9`              | `0.24`             | `>=3.10`, `<=3.14`  |
+| `2.8`              | `0.23`             | `>=3.9`, `<=3.13`   |
+| `2.7`              | `0.22`             | `>=3.9`, `<=3.13`   |
+| `2.6`              | `0.21`             | `>=3.9`, `<=3.12`   |
+
+<details>
+    <summary>older versions</summary>
+
+| `torch` | `torchvision`     | Python                    |
+|---------|-------------------|---------------------------|
+| `2.5`              | `0.20`             | `>=3.9`, `<=3.12`   |
+| `2.4`              | `0.19`             | `>=3.8`, `<=3.12`   |
+| `2.3`              | `0.18`             | `>=3.8`, `<=3.12`   |
+| `2.2`              | `0.17`             | `>=3.8`, `<=3.11`   |
+| `2.1`              | `0.16`             | `>=3.8`, `<=3.11`   |
+| `2.0`              | `0.15`             | `>=3.8`, `<=3.11`   |
+| `1.13`  | `0.14`            | `>=3.7.2`, `<=3.10`       |
+| `1.12`  | `0.13`            | `>=3.7`, `<=3.10`         |
+| `1.11`  | `0.12`            | `>=3.7`, `<=3.10`         |
+| `1.10`  | `0.11`            | `>=3.6`, `<=3.9`          |
+| `1.9`   | `0.10`            | `>=3.6`, `<=3.9`          |
+| `1.8`   | `0.9`             | `>=3.6`, `<=3.9`          |
+| `1.7`   | `0.8`             | `>=3.6`, `<=3.9`          |
+| `1.6`   | `0.7`             | `>=3.6`, `<=3.8`          |
+| `1.5`   | `0.6`             | `>=3.5`, `<=3.8`          |
+| `1.4`   | `0.5`             | `==2.7`, `>=3.5`, `<=3.8` |
+| `1.3`   | `0.4.2` / `0.4.3` | `==2.7`, `>=3.5`, `<=3.7` |
+| `1.2`   | `0.4.1`           | `==2.7`, `>=3.5`, `<=3.7` |
+| `1.1`   | `0.3`             | `==2.7`, `>=3.5`, `<=3.7` |
+| `<=1.0` | `0.2`             | `==2.7`, `>=3.5`, `<=3.7` |
+
+</details>
+
+## Image Backends
+
+Torchvision currently supports the following image backends:
+
+- torch tensors
+- PIL images:
+    - [Pillow](https://python-pillow.org/)
+    - [Pillow-SIMD](https://github.com/uploadcare/pillow-simd) - a **much faster** drop-in replacement for Pillow with SIMD.
+
+Read more in in our [docs](https://pytorch.org/vision/stable/transforms.html).
+
+## Documentation
+
+You can find the API documentation on the pytorch website: <https://pytorch.org/vision/stable/index.html>
+
+## Contributing
+
+See the [CONTRIBUTING](CONTRIBUTING.md) file for how to help out.
+
+## Disclaimer on Datasets
+
+This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets,
+vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to
+determine whether you have permission to use the dataset under the dataset's license.
+
+If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset
+to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML
+community!
+
+## Pre-trained Model License
+
+The pre-trained models provided in this library may have their own licenses or terms and conditions derived from the
+dataset used for training. It is your responsibility to determine whether you have permission to use the models for your
+use case.
+
+More specifically, SWAG models are released under the CC-BY-NC 4.0 license. See
+[SWAG LICENSE](https://github.com/facebookresearch/SWAG/blob/main/LICENSE) for additional details.
+
+## Citing TorchVision
+
+If you find TorchVision useful in your work, please consider citing the following BibTeX entry:
+
+```bibtex
+@software{torchvision2016,
+    title        = {TorchVision: PyTorch's Computer Vision library},
+    author       = {TorchVision maintainers and contributors},
+    year         = 2016,
+    journal      = {GitHub repository},
+    publisher    = {GitHub},
+    howpublished = {\url{https://github.com/pytorch/vision}}
+}
+```

+ 380 - 0
python/py/Lib/site-packages/torchvision-0.26.0.dist-info/RECORD

@@ -0,0 +1,380 @@
+torchvision-0.26.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+torchvision-0.26.0.dist-info/LICENSE,sha256=wGNj-dM2J9xRc7E1IkRMyF-7Rzn2PhbUWH1cChZbWx4,1546
+torchvision-0.26.0.dist-info/METADATA,sha256=qXPWLwUMoM8pTitzRpq1hsD0_ExL0uCjBKxPKrt_FPw,5476
+torchvision-0.26.0.dist-info/RECORD,,
+torchvision-0.26.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+torchvision-0.26.0.dist-info/WHEEL,sha256=KNRoynpGu-d6mheJI-zfvcGl1iN-y8BewbiCDXsF3cY,101
+torchvision-0.26.0.dist-info/top_level.txt,sha256=ucJZoaluBW9BGYT4TuCE6zoZY_JuSP30wbDh-IRpxUU,12
+torchvision/_C.pyd,sha256=alB-UeGAPNt5QWLzXpaKuQeaxdqEw7KPW7uL64qkh_o,729600
+torchvision/__init__.py,sha256=2Yvhu1HRR0nclxjSfULNQlAgPmJ20OkKtgIarIB6RzU,1980
+torchvision/__pycache__/__init__.cpython-312.pyc,,
+torchvision/__pycache__/_internally_replaced_utils.cpython-312.pyc,,
+torchvision/__pycache__/_meta_registrations.cpython-312.pyc,,
+torchvision/__pycache__/_utils.cpython-312.pyc,,
+torchvision/__pycache__/extension.cpython-312.pyc,,
+torchvision/__pycache__/utils.cpython-312.pyc,,
+torchvision/__pycache__/version.cpython-312.pyc,,
+torchvision/_internally_replaced_utils.py,sha256=FvTnzXWF39nUah3q7DTjDrAVuP2ibtgguFR3Cs0Z1mU,1510
+torchvision/_meta_registrations.py,sha256=_rNIAEDFkHMmgYGm982-WgSmhA9o2KVJpd8DsoUlQmk,7433
+torchvision/_utils.py,sha256=8TvBoMxnfm9Ow4ypvjoWxrF7hKwJ2hX5UUxN209zbV8,988
+torchvision/datasets/__init__.py,sha256=Ajw_RJcPNYz-iKvkvQEkgw2P341llDsDgd8cke-bxas,3753
+torchvision/datasets/__pycache__/__init__.cpython-312.pyc,,
+torchvision/datasets/__pycache__/_optical_flow.cpython-312.pyc,,
+torchvision/datasets/__pycache__/_stereo_matching.cpython-312.pyc,,
+torchvision/datasets/__pycache__/caltech.cpython-312.pyc,,
+torchvision/datasets/__pycache__/celeba.cpython-312.pyc,,
+torchvision/datasets/__pycache__/cifar.cpython-312.pyc,,
+torchvision/datasets/__pycache__/cityscapes.cpython-312.pyc,,
+torchvision/datasets/__pycache__/clevr.cpython-312.pyc,,
+torchvision/datasets/__pycache__/coco.cpython-312.pyc,,
+torchvision/datasets/__pycache__/country211.cpython-312.pyc,,
+torchvision/datasets/__pycache__/dtd.cpython-312.pyc,,
+torchvision/datasets/__pycache__/eurosat.cpython-312.pyc,,
+torchvision/datasets/__pycache__/fakedata.cpython-312.pyc,,
+torchvision/datasets/__pycache__/fer2013.cpython-312.pyc,,
+torchvision/datasets/__pycache__/fgvc_aircraft.cpython-312.pyc,,
+torchvision/datasets/__pycache__/flickr.cpython-312.pyc,,
+torchvision/datasets/__pycache__/flowers102.cpython-312.pyc,,
+torchvision/datasets/__pycache__/folder.cpython-312.pyc,,
+torchvision/datasets/__pycache__/food101.cpython-312.pyc,,
+torchvision/datasets/__pycache__/gtsrb.cpython-312.pyc,,
+torchvision/datasets/__pycache__/hmdb51.cpython-312.pyc,,
+torchvision/datasets/__pycache__/imagenet.cpython-312.pyc,,
+torchvision/datasets/__pycache__/imagenette.cpython-312.pyc,,
+torchvision/datasets/__pycache__/inaturalist.cpython-312.pyc,,
+torchvision/datasets/__pycache__/kinetics.cpython-312.pyc,,
+torchvision/datasets/__pycache__/kitti.cpython-312.pyc,,
+torchvision/datasets/__pycache__/lfw.cpython-312.pyc,,
+torchvision/datasets/__pycache__/lsun.cpython-312.pyc,,
+torchvision/datasets/__pycache__/mnist.cpython-312.pyc,,
+torchvision/datasets/__pycache__/moving_mnist.cpython-312.pyc,,
+torchvision/datasets/__pycache__/omniglot.cpython-312.pyc,,
+torchvision/datasets/__pycache__/oxford_iiit_pet.cpython-312.pyc,,
+torchvision/datasets/__pycache__/pcam.cpython-312.pyc,,
+torchvision/datasets/__pycache__/phototour.cpython-312.pyc,,
+torchvision/datasets/__pycache__/places365.cpython-312.pyc,,
+torchvision/datasets/__pycache__/rendered_sst2.cpython-312.pyc,,
+torchvision/datasets/__pycache__/sbd.cpython-312.pyc,,
+torchvision/datasets/__pycache__/sbu.cpython-312.pyc,,
+torchvision/datasets/__pycache__/semeion.cpython-312.pyc,,
+torchvision/datasets/__pycache__/stanford_cars.cpython-312.pyc,,
+torchvision/datasets/__pycache__/stl10.cpython-312.pyc,,
+torchvision/datasets/__pycache__/sun397.cpython-312.pyc,,
+torchvision/datasets/__pycache__/svhn.cpython-312.pyc,,
+torchvision/datasets/__pycache__/ucf101.cpython-312.pyc,,
+torchvision/datasets/__pycache__/usps.cpython-312.pyc,,
+torchvision/datasets/__pycache__/utils.cpython-312.pyc,,
+torchvision/datasets/__pycache__/video_utils.cpython-312.pyc,,
+torchvision/datasets/__pycache__/vision.cpython-312.pyc,,
+torchvision/datasets/__pycache__/voc.cpython-312.pyc,,
+torchvision/datasets/__pycache__/widerface.cpython-312.pyc,,
+torchvision/datasets/_optical_flow.py,sha256=ub6UKPv1P3MFBl4wQXT8oEqkoNOGn33BgynJEcYVj_0,21724
+torchvision/datasets/_stereo_matching.py,sha256=i1JemwbjQhsiMLB3Hrmdo4vYpX4FcNjGpdDQAYF_Pdc,50263
+torchvision/datasets/caltech.py,sha256=TFgTrxDPfpIxRm7rfcWdpyTUIylkZNMG1lNBeInUUgk,9170
+torchvision/datasets/celeba.py,sha256=LwsUrrd0RxpFjBNxTl30pIlTFKu3PbOOCqcBSPyt3Qc,9257
+torchvision/datasets/cifar.py,sha256=1T71RgltoNVjjWxjCUSRBCQh6Bf_O50Sim8w3cOA40Q,5951
+torchvision/datasets/cityscapes.py,sha256=BPT91qlUiaCz7Zve2bhTTlnKUR11TEOgJdex_CPuwwg,10552
+torchvision/datasets/clevr.py,sha256=VdRijEXJBsR1ED-z0FjJUj7QVF7tEgmPHwflamXe1DU,3948
+torchvision/datasets/coco.py,sha256=KS_Ny3rtnahIZ7vzILgjWCnnupIHnVB6UVYV1pMHnOE,4456
+torchvision/datasets/country211.py,sha256=aw2FzdX-OTQtvHFldL7p5CVA02-7Qt6AXX6j4UQEC8M,2956
+torchvision/datasets/dtd.py,sha256=QkPVyrfzJBwVfSWjwk2m7Gx8fOxDG8KUsvZJ18lTJ4E,4525
+torchvision/datasets/eurosat.py,sha256=vLcnf6XlVIT9MG1ZTPB8kfY10TXqTJcZ6QelpQNmqSE,2832
+torchvision/datasets/fakedata.py,sha256=pdSyCMLTLSW0KWIoYgTbPv0rcYFWj8p6J7vjODl7h0o,2507
+torchvision/datasets/fer2013.py,sha256=Clw7azQ4RpapE2giqWK-eOUiR-SuYwfOM5ikjaOAYLc,5226
+torchvision/datasets/fgvc_aircraft.py,sha256=4RblkmHUiXoxFngtnIL1iDpiHhe29pKGt5f51d7ptPk,5088
+torchvision/datasets/flickr.py,sha256=v5zMcyDQcnzCvp7Y4qN6560q_lFSyBB3Poq0S1eFwyI,6346
+torchvision/datasets/flowers102.py,sha256=luufpQWTl9rdFTPQheQBbOlzV8edFOaAxnZnDMenpmk,7706
+torchvision/datasets/folder.py,sha256=68dCiaX8Pvv8JIiuIhbx0kaVIyp22RyST8zOnxB2YBc,13322
+torchvision/datasets/food101.py,sha256=KGA1JCwbXBEbEQjbuyhsiCEldsfFlLahFies35IMkxQ,4243
+torchvision/datasets/gtsrb.py,sha256=pERh6_VqsEPJKXYMY1UdiyZTjrl8cp2g6og04ifgIp0,3881
+torchvision/datasets/hmdb51.py,sha256=FKxcu2yOGSQVCqZ7kir2VFxzR7h7XjlM7yP-Xiv4oic,6104
+torchvision/datasets/imagenet.py,sha256=758dmmmpAm8HN4PZ2Z9sqj6QmyhO6D-bvkWj7eKsIE4,9144
+torchvision/datasets/imagenette.py,sha256=LmLm2dq98tv5Znw31tH09cO26VyYZK3YcEJ6hrXALXE,4730
+torchvision/datasets/inaturalist.py,sha256=KiWpvP3kBTel1vor198wjZ68X4xv0q83_SNsTb4Eo38,10547
+torchvision/datasets/kinetics.py,sha256=ujFVpL3yrbLWwLQs8g0_TS8JWKWzqNN6P_TpBEnqtZo,10102
+torchvision/datasets/kitti.py,sha256=XsFC8DaCCMiwvs20Qtyxpltr-7yDj4A1_CSVGO6UdE0,5782
+torchvision/datasets/lfw.py,sha256=WT1g-zKyszYDArSTBA-r-nU1yRhLobgevi13GhS0DRA,11671
+torchvision/datasets/lsun.py,sha256=EaPp0k9EcCktMoN12dAGbuu92PuxVtEzUOTyycN_RdU,5898
+torchvision/datasets/mnist.py,sha256=yhBEtZ2CZqCybEo0fDLcXjL9WGpS8ISy-pqzl7xE0vo,22364
+torchvision/datasets/moving_mnist.py,sha256=Cmw-Pj4xEzmfGZab5O0SGHjxYntg7lyL6XCwuzr1FTQ,3738
+torchvision/datasets/omniglot.py,sha256=e51DpqaTFrffgrT3g_HoCCB26xqAa6ILo1FQxC_kwTU,4593
+torchvision/datasets/oxford_iiit_pet.py,sha256=cl4OoiEw5l2QzjUcl1ZsSOeTcChM1gTR_eURvP2HErs,5831
+torchvision/datasets/pcam.py,sha256=QROSVEEcVZ7WP_GNmUBzot49r54jQ2PzhW8cqFuRbrw,5412
+torchvision/datasets/phototour.py,sha256=7ZT9Le--HQU87kzWw1qTDL8UZdLcLJDSsyMzcvh_bKk,8085
+torchvision/datasets/places365.py,sha256=z7XKs8zVImJvvcrh0nt0FnlrP9Xz1rZ6oEDnSnObMmg,7643
+torchvision/datasets/rendered_sst2.py,sha256=FMlBdmva-_oWor7uSgDFH4HlYm8EoE0QyenTPrSALD8,4046
+torchvision/datasets/samplers/__init__.py,sha256=yX8XD3h-VwTCoqF9IdQmcuGblZxvXb2EIqc5lPqXXtY,164
+torchvision/datasets/samplers/__pycache__/__init__.cpython-312.pyc,,
+torchvision/datasets/samplers/__pycache__/clip_sampler.cpython-312.pyc,,
+torchvision/datasets/samplers/clip_sampler.py,sha256=oBhAHsW7BYEKGnUtJ7EqCYy3yTtJYahofAdSPcKiHtA,6438
+torchvision/datasets/sbd.py,sha256=STXgWtZNOJ_P6pj4nYwR9p7HjjPLmLz26b8O5qH_T0Q,5533
+torchvision/datasets/sbu.py,sha256=oEQZm7y5Xu_XMT3WN-PGTpy864IT-Y4CWCeLaeitsKI,4578
+torchvision/datasets/semeion.py,sha256=klYpsKYYuYAZpLylwp3OQX9HBWcKfyJipTDpW3B47H4,3191
+torchvision/datasets/stanford_cars.py,sha256=ezj66uMKHejcdkGY0Im6rhRbRvrGKH_F482Ne-gwLC4,4386
+torchvision/datasets/stl10.py,sha256=lWXtiNAbbRQxVMTnHIY74pfl5vZWsjM3jiGIcWFq0Co,7401
+torchvision/datasets/sun397.py,sha256=fDAmuyll2fw-Wph6cckmpsar9fm3xtAkjq3a7gRiruQ,3257
+torchvision/datasets/svhn.py,sha256=TEzQE4RUgwdH-ezfDib-sdQD1ORpQM6YWn_WJF4zmMM,4951
+torchvision/datasets/ucf101.py,sha256=vTTf8KY82qKdd8HuHY1NcM_ZUt0Bu53BOF0eFZc-820,5645
+torchvision/datasets/usps.py,sha256=74zymHuzkJenGD_5Gbd2aV_Wsi9R7E-1Xsd-tbr-1vU,3605
+torchvision/datasets/utils.py,sha256=F_NX_BoNlomVt-MlFJ0mpud34Nj7IrLN3OjvLZjb1WA,16382
+torchvision/datasets/video_utils.py,sha256=mfAq-bR4aLV6O0Nqfr6ea14VJt27fCqacPzaoqfrKes,15718
+torchvision/datasets/vision.py,sha256=s3_TFYSzplp-TSLDW_QNu4z0Xpua48D5iDZjJDQo__I,4347
+torchvision/datasets/voc.py,sha256=cFi8zSfU8CSig6dvnQIBEnjEMVqJEE5Jgs6e8ehojnA,9040
+torchvision/datasets/widerface.py,sha256=jQpfLZ2by9F8ABigruWY6nDQwMfAv4z63IZQ9XwC43s,8437
+torchvision/extension.py,sha256=H5mQvztA6-hYvXge4Us-CYY9bEO0W4blvCFvw9CZUMk,2809
+torchvision/image.pyd,sha256=j4LepzzmqMGL2cNzQl7epzPnsPE3TSHWZbmLv2tfHQU,270848
+torchvision/io/__init__.py,sha256=8XBTjEsxdhUa5nbRKh8d7OaZ7CG3bWJ5Qd9N4OS4Cwc,1200
+torchvision/io/__pycache__/__init__.cpython-312.pyc,,
+torchvision/io/__pycache__/image.cpython-312.pyc,,
+torchvision/io/image.py,sha256=XK8dVVHIzzY6BV2fQAy1beFfo3A6vrL_DgeIoNgFwuU,22506
+torchvision/jpeg8.dll,sha256=aM-Kj2MkrdHI0gkgpHfh86_icuM26XiMu6gyMGeuKig,552448
+torchvision/libjpeg.dll,sha256=qyn4xdAUJAGvpE8dg79mDDsWt55sj-9g0GGVb9zAj6M,285512
+torchvision/libpng16.dll,sha256=IQX97buEli3tb3C9YiDExfHglGOO_JG3P3psWjU0zvE,209224
+torchvision/libsharpyuv.dll,sha256=W5eBRnuuGzCl3wcHq3RXMMmK2WodsqgluPE1881s-UY,37192
+torchvision/libwebp.dll,sha256=6ev4ynB60ulk-M3ZiSyeU2UXhWaWVHxGcfYqfNTTlPY,387912
+torchvision/models/__init__.py,sha256=6QlTJfvjKcUmMJvwSapWUNFXbf2Vo15dVRcBuNSaYko,888
+torchvision/models/__pycache__/__init__.cpython-312.pyc,,
+torchvision/models/__pycache__/_api.cpython-312.pyc,,
+torchvision/models/__pycache__/_meta.cpython-312.pyc,,
+torchvision/models/__pycache__/_utils.cpython-312.pyc,,
+torchvision/models/__pycache__/alexnet.cpython-312.pyc,,
+torchvision/models/__pycache__/convnext.cpython-312.pyc,,
+torchvision/models/__pycache__/densenet.cpython-312.pyc,,
+torchvision/models/__pycache__/efficientnet.cpython-312.pyc,,
+torchvision/models/__pycache__/feature_extraction.cpython-312.pyc,,
+torchvision/models/__pycache__/googlenet.cpython-312.pyc,,
+torchvision/models/__pycache__/inception.cpython-312.pyc,,
+torchvision/models/__pycache__/maxvit.cpython-312.pyc,,
+torchvision/models/__pycache__/mnasnet.cpython-312.pyc,,
+torchvision/models/__pycache__/mobilenet.cpython-312.pyc,,
+torchvision/models/__pycache__/mobilenetv2.cpython-312.pyc,,
+torchvision/models/__pycache__/mobilenetv3.cpython-312.pyc,,
+torchvision/models/__pycache__/regnet.cpython-312.pyc,,
+torchvision/models/__pycache__/resnet.cpython-312.pyc,,
+torchvision/models/__pycache__/shufflenetv2.cpython-312.pyc,,
+torchvision/models/__pycache__/squeezenet.cpython-312.pyc,,
+torchvision/models/__pycache__/swin_transformer.cpython-312.pyc,,
+torchvision/models/__pycache__/vgg.cpython-312.pyc,,
+torchvision/models/__pycache__/vision_transformer.cpython-312.pyc,,
+torchvision/models/_api.py,sha256=V6lWopggNs_WSppjateyFVwj-x_MZUmAV1_m3khSGN0,10241
+torchvision/models/_meta.py,sha256=2NSIICoq4MDzPZc00DlGJTgHOCwTBSObSTeRTh3E0tQ,30429
+torchvision/models/_utils.py,sha256=gOh6tA08U7NHLGt6zznG7dLJfi18fWy_Wp_dhaH8xko,11136
+torchvision/models/alexnet.py,sha256=XcldP2UuOkdOUfdxitGS8qHzLH78Ny7VCzTzKsaWITU,4607
+torchvision/models/convnext.py,sha256=xM6IWUmP3-TSn79_APA4WDSgzL5JSRXvIleJyVmmRYg,15762
+torchvision/models/densenet.py,sha256=Q1_8tERSWn2uwhjylxkQlMVcYlirp3ARyxZEiG-Zw2Q,17260
+torchvision/models/detection/__init__.py,sha256=D4cs338Z4BQn5TgX2IKuJC9TD2rtw2svUDZlALR-lwI,175
+torchvision/models/detection/__pycache__/__init__.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/_utils.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/anchor_utils.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/backbone_utils.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/faster_rcnn.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/fcos.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/generalized_rcnn.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/image_list.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/keypoint_rcnn.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/mask_rcnn.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/retinanet.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/roi_heads.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/rpn.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/ssd.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/ssdlite.cpython-312.pyc,,
+torchvision/models/detection/__pycache__/transform.cpython-312.pyc,,
+torchvision/models/detection/_utils.py,sha256=dee4UORMRBeHz4YKTOOeyKsEb9Q4AYmL3iG__sfZioM,22646
+torchvision/models/detection/anchor_utils.py,sha256=M6Gm4Qrji8DSfR3IVpfS9VHNEpIKT_AEHJlINioIzns,12121
+torchvision/models/detection/backbone_utils.py,sha256=SC0TRbPxy3P6Lzq0Z9iYRhHsoIGUex43v7ntyFJUwPU,10780
+torchvision/models/detection/faster_rcnn.py,sha256=3ZiR2D3XovDfdIHBJjFfLHwxxfgIDvSfLh-pmMWvNNM,37812
+torchvision/models/detection/fcos.py,sha256=0uOtXF6cOZhaSeTIB1rqreAC3XdcD4WXzk4QTMaF89A,34991
+torchvision/models/detection/generalized_rcnn.py,sha256=7ovthQx8aEi6oxXovSqO7D7Qanp2yzd1HHXAQbQS3Tk,5062
+torchvision/models/detection/image_list.py,sha256=rUmPJI-F1EQ_0nNsNnHrhsMc8fusrDFgUdloMLOJbgo,774
+torchvision/models/detection/keypoint_rcnn.py,sha256=nXGqgQ6uhLLYw785Ta3fFKKL7Mc0KFp8QZG0sqY3T7k,22455
+torchvision/models/detection/mask_rcnn.py,sha256=W_3sXiSIgFz5gE9fmSh-cw9NAt45CX-TGkLTqp-zvsg,27303
+torchvision/models/detection/retinanet.py,sha256=nZNicKRgBZuJh-cM-FqPe9GgaB6GFOIgpopSmz4HE6M,38184
+torchvision/models/detection/roi_heads.py,sha256=D6ifzMGfqlqqncmEqePcKWhKSpRNvVNT3WW4LL73_Is,35182
+torchvision/models/detection/rpn.py,sha256=AP9-p3_pujcDita2wSiRhbLnU8oSTFDPZNjw0w7ezUM,16205
+torchvision/models/detection/ssd.py,sha256=BCU2bgRnPGmxsVFXriz54lsS7Xbdzdfm3N3BmW4OeI8,29642
+torchvision/models/detection/ssdlite.py,sha256=_MKsJEExuIUtqrJ-BWcXCpyTK-jE4jjvaOgn4lUjn64,13538
+torchvision/models/detection/transform.py,sha256=Jqtqf9uLo0uBEraJ1pDaMBSuALY5_vDELnLCQcOwwm0,12489
+torchvision/models/efficientnet.py,sha256=_qqZVYMrNIhUw97v2B_9cgyLprm84hQk8A-X5m6uk4Q,44230
+torchvision/models/feature_extraction.py,sha256=wu9xI_QcRW4XXQYggrveacZI74yAWbr7j-ckW8uHj-M,28521
+torchvision/models/googlenet.py,sha256=aIlM44ZT8hsuRz3t2RLdcZv1vhYQP_3x8Rc-QGE27RU,13138
+torchvision/models/inception.py,sha256=3DNzEmfmschGRrCyA4-mDkRVUBbv7GSnbZa_oaiaVYU,19316
+torchvision/models/maxvit.py,sha256=LHPsHqa-RaWIqbbdMrsm2YEk1332RgJS6wIJ-agqPX4,32944
+torchvision/models/mnasnet.py,sha256=yMR8vdb4nj1kdUqI8r53ERJNRRjYl4tHlTYkMRh57NU,17996
+torchvision/models/mobilenet.py,sha256=alrEJwktmXVcCphU8h2EAJZX0YdKfcz4tJEOdG2BXB8,217
+torchvision/models/mobilenetv2.py,sha256=G0ZvYkujutNjjPWJ1aJh-O0d3F02ExdvajeLyZHVr4Y,9964
+torchvision/models/mobilenetv3.py,sha256=jIF7HMwQ7fqCPtWY1jw7PT17KMIIVYEXRAq17InrQSI,16724
+torchvision/models/optical_flow/__init__.py,sha256=uuRFAdvcDobdAbY2VmxEZ7_CLH_f5-JRkCSuJRkj4RY,21
+torchvision/models/optical_flow/__pycache__/__init__.cpython-312.pyc,,
+torchvision/models/optical_flow/__pycache__/_utils.cpython-312.pyc,,
+torchvision/models/optical_flow/__pycache__/raft.cpython-312.pyc,,
+torchvision/models/optical_flow/_utils.py,sha256=PRcuU-IB6EL3hAOLiyC5q-NBzlvIKfhSF_BMplHbzfY,2125
+torchvision/models/optical_flow/raft.py,sha256=H96CpsOcUd5nyX8kFyOyTLWB6sI7txsWUMgqjH1jduM,40938
+torchvision/models/quantization/__init__.py,sha256=YOJmYqWQTfP5P2ypteZNKQOMW4VEB2WHJlYoSlSaL1Y,130
+torchvision/models/quantization/__pycache__/__init__.cpython-312.pyc,,
+torchvision/models/quantization/__pycache__/googlenet.cpython-312.pyc,,
+torchvision/models/quantization/__pycache__/inception.cpython-312.pyc,,
+torchvision/models/quantization/__pycache__/mobilenet.cpython-312.pyc,,
+torchvision/models/quantization/__pycache__/mobilenetv2.cpython-312.pyc,,
+torchvision/models/quantization/__pycache__/mobilenetv3.cpython-312.pyc,,
+torchvision/models/quantization/__pycache__/resnet.cpython-312.pyc,,
+torchvision/models/quantization/__pycache__/shufflenetv2.cpython-312.pyc,,
+torchvision/models/quantization/__pycache__/utils.cpython-312.pyc,,
+torchvision/models/quantization/googlenet.py,sha256=KfUQoqXi8JluXlJ66t2aUBjFNDN0B5Y6zsu29tWdnYQ,8324
+torchvision/models/quantization/inception.py,sha256=9cb3yiBYnmaC2vC4LEFYbWPvO8iMFNcYQx6LksHmomc,11116
+torchvision/models/quantization/mobilenet.py,sha256=alrEJwktmXVcCphU8h2EAJZX0YdKfcz4tJEOdG2BXB8,217
+torchvision/models/quantization/mobilenetv2.py,sha256=xR7Vpq7xcq2GSdsxq6UT6OA5LEPaWFmxF2NnxRFnzVY,6071
+torchvision/models/quantization/mobilenetv3.py,sha256=ZNb3egv-RdSjIW2nw6mEn6xE13AQZmDryMdDe1B96yg,9495
+torchvision/models/quantization/resnet.py,sha256=0fzTtBt-1iP0CembaGBoRckI2akVSz2MxKzGfBx5xMQ,18547
+torchvision/models/quantization/shufflenetv2.py,sha256=jnrMO3Js8rr3EZzAQi4AQvyHo8wl0zr2EqmRRjwU8o8,17441
+torchvision/models/quantization/utils.py,sha256=mwO6t0K7PMcev2LLndIEmsXNKvYEJBM4f7NzsiM8jk4,2103
+torchvision/models/regnet.py,sha256=vblO9fo8WnU6tbVnNZuHIabNrXyoVHnGVQDru2yRbEU,65105
+torchvision/models/resnet.py,sha256=fDKRmPQU5BHdZNZ3hMH377sehb244WC4xdXFYZ-asRk,39905
+torchvision/models/segmentation/__init__.py,sha256=TLL2SSmqE08HLiv_yyIWyIyrvf2xaOsZi0muDv_Y5Vc,69
+torchvision/models/segmentation/__pycache__/__init__.cpython-312.pyc,,
+torchvision/models/segmentation/__pycache__/_utils.cpython-312.pyc,,
+torchvision/models/segmentation/__pycache__/deeplabv3.cpython-312.pyc,,
+torchvision/models/segmentation/__pycache__/fcn.cpython-312.pyc,,
+torchvision/models/segmentation/__pycache__/lraspp.cpython-312.pyc,,
+torchvision/models/segmentation/_utils.py,sha256=P8658dSbojoau7DLTiE4RZmDZBX7AfgE2V7FjrP_YXs,1228
+torchvision/models/segmentation/deeplabv3.py,sha256=krcN6ustrnFvTmdD2MEc7821lCSUjMWitedz6Zw0IbQ,15433
+torchvision/models/segmentation/fcn.py,sha256=mQ1Wi4S9j5G6OQbNciuxNwVbJ6e9miTzIWj6mUF5JwA,9205
+torchvision/models/segmentation/lraspp.py,sha256=QvcS-sGmSpJg3HQdi4e09jRLd8nPHIE05_GPbMvryRA,7815
+torchvision/models/shufflenetv2.py,sha256=Z-0YBL0T3oLOcjX6CBCj755n3nUbDN0u8-q3xhsTWXY,15846
+torchvision/models/squeezenet.py,sha256=Dha-ci350KU15D0LS9N07kw6MlNuusUHSBnC83Ery_E,8986
+torchvision/models/swin_transformer.py,sha256=vjt095dNfuvMX1-WEG72tqe_xxAtsBLQytts88L2vB8,40364
+torchvision/models/vgg.py,sha256=1SWU2kj-cjL7YDnV0dOh15Om9D_zSt3Ue4d0CjUBIt4,19724
+torchvision/models/video/__init__.py,sha256=xHHR5c6kP0cMDXck7XnXq19iJL_Uemcxg3OC00cqE6A,97
+torchvision/models/video/__pycache__/__init__.cpython-312.pyc,,
+torchvision/models/video/__pycache__/mvit.cpython-312.pyc,,
+torchvision/models/video/__pycache__/resnet.cpython-312.pyc,,
+torchvision/models/video/__pycache__/s3d.cpython-312.pyc,,
+torchvision/models/video/__pycache__/swin_transformer.cpython-312.pyc,,
+torchvision/models/video/mvit.py,sha256=_JaX-FjLbaW3YQGu6WdG6ZTsi0dZO3GVMpNkG-SYu8U,33904
+torchvision/models/video/resnet.py,sha256=7XjOvzBVJQQmIXCTNcaSfrvKMK56BBg3wt2uKhjLsnk,17283
+torchvision/models/video/s3d.py,sha256=Rn-iypP13jrETAap1Qd4NY6kkpYDuSXjGkEKZDOxemI,8034
+torchvision/models/video/swin_transformer.py,sha256=vqE3_gXMbfgaWmrxRB1pp8fh6aqnpz-djmgKhtML3c4,28418
+torchvision/models/vision_transformer.py,sha256=qDIjgxi0xAapBRyriUFTjuIsTeWiQQ02nZaks7ShCoE,32988
+torchvision/ops/__init__.py,sha256=7wibGxcF1JHDviSNs9O9Pwlc8dhMSFfZo0wzVjTFnAY,2001
+torchvision/ops/__pycache__/__init__.cpython-312.pyc,,
+torchvision/ops/__pycache__/_box_convert.cpython-312.pyc,,
+torchvision/ops/__pycache__/_register_onnx_ops.cpython-312.pyc,,
+torchvision/ops/__pycache__/_utils.cpython-312.pyc,,
+torchvision/ops/__pycache__/boxes.cpython-312.pyc,,
+torchvision/ops/__pycache__/ciou_loss.cpython-312.pyc,,
+torchvision/ops/__pycache__/deform_conv.cpython-312.pyc,,
+torchvision/ops/__pycache__/diou_loss.cpython-312.pyc,,
+torchvision/ops/__pycache__/drop_block.cpython-312.pyc,,
+torchvision/ops/__pycache__/feature_pyramid_network.cpython-312.pyc,,
+torchvision/ops/__pycache__/focal_loss.cpython-312.pyc,,
+torchvision/ops/__pycache__/giou_loss.cpython-312.pyc,,
+torchvision/ops/__pycache__/misc.cpython-312.pyc,,
+torchvision/ops/__pycache__/poolers.cpython-312.pyc,,
+torchvision/ops/__pycache__/ps_roi_align.cpython-312.pyc,,
+torchvision/ops/__pycache__/ps_roi_pool.cpython-312.pyc,,
+torchvision/ops/__pycache__/roi_align.cpython-312.pyc,,
+torchvision/ops/__pycache__/roi_pool.cpython-312.pyc,,
+torchvision/ops/__pycache__/stochastic_depth.cpython-312.pyc,,
+torchvision/ops/_box_convert.py,sha256=_Uu8BkweU4hXeaUfUEiNZ1BC-O3FPo2Ra_cu3VPOjG4,7188
+torchvision/ops/_register_onnx_ops.py,sha256=g4M5Fp7n_5ZTzIQcUXvEct3YFlUMPNVSQQfBP-J0eQQ,4288
+torchvision/ops/_utils.py,sha256=MxfEBDZ9tnL-QIYyDj6FEmbE6MPQZtfOljz2Yhy-Oxg,3723
+torchvision/ops/boxes.py,sha256=mY52_JDDrv265eIWc-Dwdg23lxXpo_zNSyas7PN4hao,21125
+torchvision/ops/ciou_loss.py,sha256=EKl6TTnEFJpUlFgS3iaPuslllzxLm19G1h8B2duVx0s,2832
+torchvision/ops/deform_conv.py,sha256=NyILJV9kq_nui7-rMjSCeHGATM_zZzags-6bNfBpdtc,7178
+torchvision/ops/diou_loss.py,sha256=m8ML9PsaYosJHPe-8KBnj69ZoBmAcznf3V3WaEP7IfQ,3426
+torchvision/ops/drop_block.py,sha256=kgQHx7tE9AoDcSBZ5NjWC4RqBKLWrBpIkVKooEWOwhQ,6248
+torchvision/ops/feature_pyramid_network.py,sha256=Ojq68D4xf7I1ii6HQ1UYX6PDS5zzTuJZBNxbVUva3Uw,8933
+torchvision/ops/focal_loss.py,sha256=A-Ec5GG7sbyE8ydGP6QuAPdtkUbDfdg5j4zYvs5PwzA,2480
+torchvision/ops/giou_loss.py,sha256=SQ42KOFbx9pAJ-n1r9MwT3Hu5M890chnUL6oP0pLzRs,2770
+torchvision/ops/misc.py,sha256=VjA5TrbwRXXZxLc0ATAZnMNH8oXZj0Gn0XO8u_2X8E0,13907
+torchvision/ops/poolers.py,sha256=WPb3VxCC6fAhCCmDarUVasTkRbUQnteqZnaT9fESb9k,12228
+torchvision/ops/ps_roi_align.py,sha256=6_kmnE6z_3AZZ1N2hrS_uK3cbuzqZhjdM2rC50mfYUo,3715
+torchvision/ops/ps_roi_pool.py,sha256=2JrjJwzVtEeEg0BebkCnGfq4xEDwMCD-Xh915mvNcyI,2940
+torchvision/ops/roi_align.py,sha256=eL--jezfuGpIjNh9FZNrL-1tjGWXHZCJ2tlN-63Vvkk,11608
+torchvision/ops/roi_pool.py,sha256=kbvY49SbmfuSeKaObttS5Tbz8PztGubNpzRfHsBATM8,3009
+torchvision/ops/stochastic_depth.py,sha256=9T4Zu_BaemKZafSmRwrPCVr5aaGH8tmzlsQAZO-1_-Y,2302
+torchvision/python312.dll,sha256=csG-p44MSfdGiJo6-MSXPgROUOMZwcl23ZgFmkA0BNo,7409480
+torchvision/transforms/__init__.py,sha256=WCNXTJUbJ1h7YaN9UfrBSvt--ST2PAV4sLICbTS-L5A,55
+torchvision/transforms/__pycache__/__init__.cpython-312.pyc,,
+torchvision/transforms/__pycache__/_functional_pil.cpython-312.pyc,,
+torchvision/transforms/__pycache__/_functional_tensor.cpython-312.pyc,,
+torchvision/transforms/__pycache__/_functional_video.cpython-312.pyc,,
+torchvision/transforms/__pycache__/_presets.cpython-312.pyc,,
+torchvision/transforms/__pycache__/_transforms_video.cpython-312.pyc,,
+torchvision/transforms/__pycache__/autoaugment.cpython-312.pyc,,
+torchvision/transforms/__pycache__/functional.cpython-312.pyc,,
+torchvision/transforms/__pycache__/transforms.cpython-312.pyc,,
+torchvision/transforms/_functional_pil.py,sha256=QCm2U14OXmqGsZ87B8U86yxeXMmxm3LqRzk3-_WOwyc,12514
+torchvision/transforms/_functional_tensor.py,sha256=tZ2tPIT8lrnUwcFtw_6rL9ZLCbY1Lk2BXPiGEuNFTxs,34888
+torchvision/transforms/_functional_video.py,sha256=c4BbUi3Y2LvskozFdy619piLBd5acsjxgogYAXmY5P8,3971
+torchvision/transforms/_presets.py,sha256=iX5Lr8qZvHA-UxOAjBH-iTiu5m6l2hQRwJROsxYDs_w,8721
+torchvision/transforms/_transforms_video.py,sha256=ub2gCT5ELiK918Bq-Pp6mzhWrAZxlj7blfpkA8Dhb1o,5124
+torchvision/transforms/autoaugment.py,sha256=WwWSYH-q71z748AVQu9Mtf6Fo6xcmmbwsMCgS22-67M,28839
+torchvision/transforms/functional.py,sha256=LKl3poyvirCF5LmBZJg2aHd-5_NgHbXl5FHEg6br-ak,69447
+torchvision/transforms/transforms.py,sha256=xIeObaDzC6TXFyOgDmcR6T_VRcFQvokjL1cyPSGbBX0,88144
+torchvision/transforms/v2/__init__.py,sha256=yVHPT05BS24xmogo2EJbqFyTu8QLo_pafp0NZ25Pu0s,1677
+torchvision/transforms/v2/__pycache__/__init__.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_augment.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_auto_augment.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_color.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_container.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_deprecated.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_geometry.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_meta.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_misc.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_temporal.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_transform.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_type_conversion.cpython-312.pyc,,
+torchvision/transforms/v2/__pycache__/_utils.cpython-312.pyc,,
+torchvision/transforms/v2/_augment.py,sha256=tnnle6FAHq4-xdemNBygk0PIfJDBaNOC9rUM4ir1jJo,16725
+torchvision/transforms/v2/_auto_augment.py,sha256=EsuyOv_ZwIALwsNMNwJOaFg5TOl-5_oNlELbYNxObho,32868
+torchvision/transforms/v2/_color.py,sha256=IcieYLHLRa4FGXEeC1fGFunMjJ0pjpmsoKo_ulVDhFo,17378
+torchvision/transforms/v2/_container.py,sha256=UnxEDUeR0fwhS561bDII811tKaJ6jyxPHyXy0qdccbs,6516
+torchvision/transforms/v2/_deprecated.py,sha256=EA9nX2T5it6sJiOup2BEtaEynDh6lzzSWZzTH5fcrPw,1990
+torchvision/transforms/v2/_geometry.py,sha256=1Fdv8lV6Gvp2Z9GmbyIekuGFNYil_-tTn-yJjyzGtPY,69255
+torchvision/transforms/v2/_meta.py,sha256=J9s54OCuhC2aOlSfZAA9qO-xaXnjOCJtWk952ORBxYw,3246
+torchvision/transforms/v2/_misc.py,sha256=WW0z0waZi80zpEh2zfU5h_yYPeGieh5OuGvcaaOoDHY,24677
+torchvision/transforms/v2/_temporal.py,sha256=LnmhVdwGsRYK4jk1H_oSOSlQO5_PdbRrRPFFLDx1pO8,925
+torchvision/transforms/v2/_transform.py,sha256=sjB32n5ZOn1KpyeQPJBZZZGEg4hn8AO935FHctqmqnU,9510
+torchvision/transforms/v2/_type_conversion.py,sha256=kHikcSejjaZmAJYLl5folMOvPCSrXrf80qyBLgd1zaQ,3245
+torchvision/transforms/v2/_utils.py,sha256=LetUYsYnyWPbdqod4xZM5mXVvmYI6SkjCo8WDY2BMDc,9541
+torchvision/transforms/v2/functional/__init__.py,sha256=eqz2Y6kX8NfGn8YNhUbpa8eIS09mh3ni5DBdY1SUc9U,4052
+torchvision/transforms/v2/functional/__pycache__/__init__.cpython-312.pyc,,
+torchvision/transforms/v2/functional/__pycache__/_augment.cpython-312.pyc,,
+torchvision/transforms/v2/functional/__pycache__/_color.cpython-312.pyc,,
+torchvision/transforms/v2/functional/__pycache__/_deprecated.cpython-312.pyc,,
+torchvision/transforms/v2/functional/__pycache__/_geometry.cpython-312.pyc,,
+torchvision/transforms/v2/functional/__pycache__/_meta.cpython-312.pyc,,
+torchvision/transforms/v2/functional/__pycache__/_misc.cpython-312.pyc,,
+torchvision/transforms/v2/functional/__pycache__/_temporal.cpython-312.pyc,,
+torchvision/transforms/v2/functional/__pycache__/_type_conversion.cpython-312.pyc,,
+torchvision/transforms/v2/functional/__pycache__/_utils.cpython-312.pyc,,
+torchvision/transforms/v2/functional/_augment.py,sha256=oFPymp04zhQpTQR3O-5XUJHDnQW6-pfwWJKoP-dpWkM,3579
+torchvision/transforms/v2/functional/_color.py,sha256=hLx_h4TjXFz0dZsarb3nIAFWH6h3r3Wn1lY3rlgvh8o,31117
+torchvision/transforms/v2/functional/_deprecated.py,sha256=-Aza1I4OSqpbF7skVMIJA7YA6jYDgK7A54MTsSn95oI,819
+torchvision/transforms/v2/functional/_geometry.py,sha256=9xEtxUK-BFBZnoShK-DMXsNB14Q_42wguTiQMG1TSJQ,114853
+torchvision/transforms/v2/functional/_meta.py,sha256=j-InmeBoJFjUi91xAPfy3OuPmM-P3eVaLrLUmmFxaAo,29318
+torchvision/transforms/v2/functional/_misc.py,sha256=qlykcaCtQk1I5mzLeguTAmK_EMPqVS0C5V1ism0q9MU,22425
+torchvision/transforms/v2/functional/_temporal.py,sha256=tSRkkqOqUQ0QXjENF82F16El1-J0IDoFKIH-ss_cpC4,1163
+torchvision/transforms/v2/functional/_type_conversion.py,sha256=oYf4LMgiClvEZwwc3WbKI7fJ-rRFhDrVSBKiPA5vxio,896
+torchvision/transforms/v2/functional/_utils.py,sha256=8yfZCpdXKX5Py-qZgYEq8Kd2Vp98iZ3zyWLPGmsym2k,5630
+torchvision/tv_tensors/__init__.py,sha256=MKT5n-sVks-UuvAiqKN1VUwz9XJxwEEa36vJyQLUo7Y,1835
+torchvision/tv_tensors/__pycache__/__init__.cpython-312.pyc,,
+torchvision/tv_tensors/__pycache__/_bounding_boxes.cpython-312.pyc,,
+torchvision/tv_tensors/__pycache__/_dataset_wrapper.cpython-312.pyc,,
+torchvision/tv_tensors/__pycache__/_image.cpython-312.pyc,,
+torchvision/tv_tensors/__pycache__/_keypoints.cpython-312.pyc,,
+torchvision/tv_tensors/__pycache__/_mask.cpython-312.pyc,,
+torchvision/tv_tensors/__pycache__/_torch_function_helpers.cpython-312.pyc,,
+torchvision/tv_tensors/__pycache__/_tv_tensor.cpython-312.pyc,,
+torchvision/tv_tensors/__pycache__/_video.cpython-312.pyc,,
+torchvision/tv_tensors/_bounding_boxes.py,sha256=E36IiYCDJcFGiFWaWGbhDWM3j0ypuB8Yq6gdVwqDYRs,7902
+torchvision/tv_tensors/_dataset_wrapper.py,sha256=pgcasnVwGiQb0mmEqPkIgycI-4AV703xbBRwUFbMo9k,25171
+torchvision/tv_tensors/_image.py,sha256=yg9LaAPSwpKWx08bJtTvwhuLwYfWdGk-A0AY8kpX9hw,2016
+torchvision/tv_tensors/_keypoints.py,sha256=A-CzXaDhwsswZh0zQyn9Lkk_BjSllvZLTcOK659mGOM,4686
+torchvision/tv_tensors/_mask.py,sha256=H7wiK_uB9oiVQu22H1PWghyipLMmndb9pnoGIPLzqG4,1486
+torchvision/tv_tensors/_torch_function_helpers.py,sha256=vr1G4egyQfRjUtDedTWRope1gP4OB1hzAjKFZGXfc2Y,2402
+torchvision/tv_tensors/_tv_tensor.py,sha256=RhkzUeyCELJvcm_n4WW0HCDKNCx7zRKQVBocc-Esfhw,6354
+torchvision/tv_tensors/_video.py,sha256=r5_pwyvMM5h10swXEb1NhrKXAmI2GWY4ijNhD6gn24Y,1422
+torchvision/utils.py,sha256=xPx2pk2Jz4DH2E7KeZ7n2Iz-bVyqVIYLwreVRwN0B7k,35361
+torchvision/version.py,sha256=xgaW6qfflyTotCkq60_PfB93-jCjaGuICyGk0ug9KfI,206
+torchvision/zlib.dll,sha256=u15-RCxwvfotPEXJYgg17MXpgZG94rwia015mgmtMLM,101192

+ 0 - 0
python/py/Lib/site-packages/torchvision-0.26.0.dist-info/REQUESTED


+ 5 - 0
python/py/Lib/site-packages/torchvision-0.26.0.dist-info/WHEEL

@@ -0,0 +1,5 @@
+Wheel-Version: 1.0
+Generator: setuptools (72.1.0)
+Root-Is-Purelib: false
+Tag: cp312-cp312-win_amd64
+

+ 1 - 0
python/py/Lib/site-packages/torchvision-0.26.0.dist-info/top_level.txt

@@ -0,0 +1 @@
+torchvision

binární
python/py/Lib/site-packages/torchvision/_C.pyd


+ 73 - 0
python/py/Lib/site-packages/torchvision/__init__.py

@@ -0,0 +1,73 @@
+from modulefinder import Module
+
+import torch
+
+# Don't re-order these, we need to load the _C extension (done when importing
+# .extension) before entering _meta_registrations.
+from . import extension  # usort:skip  # noqa: F401
+from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils  # usort:skip
+
+try:
+    from .version import __version__  # noqa: F401
+except ImportError:
+    pass
+
+
+_image_backend = "PIL"
+
+_video_backend = "pyav"
+
+
+def set_image_backend(backend):
+    """
+    Specifies the package used to load images.
+
+    Args:
+        backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
+            The :mod:`accimage` package uses the Intel IPP library. It is
+            generally faster than PIL, but does not support as many operations.
+    """
+    global _image_backend
+    if backend not in ["PIL", "accimage"]:
+        raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
+    _image_backend = backend
+
+
+def get_image_backend():
+    """
+    Gets the name of the package used to load images
+    """
+    return _image_backend
+
+
+def set_video_backend(backend):
+    """
+    Specifies the package used to decode videos.
+
+    Args:
+        backend (string): Name of the video backend. Only 'pyav' is supported.
+            The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic
+            binding for the FFmpeg libraries.
+    """
+    pass
+
+
+def get_video_backend():
+    """
+    Returns the currently active video backend used to decode videos.
+
+    Returns:
+        str: Name of the video backend. Currently only 'pyav' is supported.
+    """
+
+    return _video_backend
+
+
+def _is_tracing():
+    return torch._C._get_tracing_state()
+
+
+def disable_beta_transforms_warning():
+    # Noop, only exists to avoid breaking existing code.
+    # See https://github.com/pytorch/vision/issues/7896
+    pass

+ 51 - 0
python/py/Lib/site-packages/torchvision/_internally_replaced_utils.py

@@ -0,0 +1,51 @@
+import importlib.machinery
+import os
+
+from torch.hub import _get_torch_home
+
+
+_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
+_USE_SHARDED_DATASETS = False
+IN_FBCODE = False
+
+
+def _download_file_from_remote_location(fpath: str, url: str) -> None:
+    pass
+
+
+def _is_remote_location_available() -> bool:
+    return False
+
+
+try:
+    from torch.hub import load_state_dict_from_url  # noqa: 401
+except ImportError:
+    from torch.utils.model_zoo import load_url as load_state_dict_from_url  # noqa: 401
+
+
+def _get_extension_path(lib_name):
+
+    lib_dir = os.path.dirname(__file__)
+    if os.name == "nt":
+        # Register the main torchvision library location on the default DLL path
+        import ctypes
+
+        kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
+        with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
+        prev_error_mode = kernel32.SetErrorMode(0x0001)
+
+        if with_load_library_flags:
+            kernel32.AddDllDirectory.restype = ctypes.c_void_p
+
+        os.add_dll_directory(lib_dir)
+
+        kernel32.SetErrorMode(prev_error_mode)
+
+    loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
+
+    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
+    ext_specs = extfinder.find_spec(lib_name)
+    if ext_specs is None:
+        raise ImportError(f"Could not find module '{lib_name}' in {lib_dir}")
+
+    return ext_specs.origin

+ 225 - 0
python/py/Lib/site-packages/torchvision/_meta_registrations.py

@@ -0,0 +1,225 @@
+import functools
+
+import torch
+import torch._custom_ops
+import torch.library
+
+# Ensure that torch.ops.torchvision is visible
+import torchvision.extension  # noqa: F401
+
+
+@functools.lru_cache(None)
+def get_meta_lib():
+    return torch.library.Library("torchvision", "IMPL", "Meta")
+
+
+def register_meta(op_name, overload_name="default"):
+    def wrapper(fn):
+        if torchvision.extension._has_ops():
+            get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
+        return fn
+
+    return wrapper
+
+
+@register_meta("roi_align")
+def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    num_rois = rois.size(0)
+    channels = input.size(1)
+    return input.new_empty((num_rois, channels, pooled_height, pooled_width))
+
+
+@register_meta("_roi_align_backward")
+def meta_roi_align_backward(
+    grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@register_meta("ps_roi_align")
+def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    channels = input.size(1)
+    torch._check(
+        channels % (pooled_height * pooled_width) == 0,
+        "input channels must be a multiple of pooling height * pooling width",
+    )
+
+    num_rois = rois.size(0)
+    out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
+    return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta")
+
+
+@register_meta("_ps_roi_align_backward")
+def meta_ps_roi_align_backward(
+    grad,
+    rois,
+    channel_mapping,
+    spatial_scale,
+    pooled_height,
+    pooled_width,
+    sampling_ratio,
+    batch_size,
+    channels,
+    height,
+    width,
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@register_meta("roi_pool")
+def meta_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    num_rois = rois.size(0)
+    channels = input.size(1)
+    out_size = (num_rois, channels, pooled_height, pooled_width)
+    return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
+
+
+@register_meta("_roi_pool_backward")
+def meta_roi_pool_backward(
+    grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@register_meta("ps_roi_pool")
+def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    channels = input.size(1)
+    torch._check(
+        channels % (pooled_height * pooled_width) == 0,
+        "input channels must be a multiple of pooling height * pooling width",
+    )
+    num_rois = rois.size(0)
+    out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
+    return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
+
+
+@register_meta("_ps_roi_pool_backward")
+def meta_ps_roi_pool_backward(
+    grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@torch.library.register_fake("torchvision::nms")
+def meta_nms(dets, scores, iou_threshold):
+    torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
+    torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
+    torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
+    torch._check(
+        dets.size(0) == scores.size(0),
+        lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
+    )
+    ctx = torch._custom_ops.get_ctx()
+    num_to_keep = ctx.create_unbacked_symint()
+    return dets.new_empty(num_to_keep, dtype=torch.long)
+
+
+@register_meta("deform_conv2d")
+def meta_deform_conv2d(
+    input,
+    weight,
+    offset,
+    mask,
+    bias,
+    stride_h,
+    stride_w,
+    pad_h,
+    pad_w,
+    dil_h,
+    dil_w,
+    n_weight_grps,
+    n_offset_grps,
+    use_mask,
+):
+
+    out_height, out_width = offset.shape[-2:]
+    out_channels = weight.shape[0]
+    batch_size = input.shape[0]
+    return input.new_empty((batch_size, out_channels, out_height, out_width))
+
+
+@register_meta("_deform_conv2d_backward")
+def meta_deform_conv2d_backward(
+    grad,
+    input,
+    weight,
+    offset,
+    mask,
+    bias,
+    stride_h,
+    stride_w,
+    pad_h,
+    pad_w,
+    dilation_h,
+    dilation_w,
+    groups,
+    offset_groups,
+    use_mask,
+):
+
+    grad_input = input.new_empty(input.shape)
+    grad_weight = weight.new_empty(weight.shape)
+    grad_offset = offset.new_empty(offset.shape)
+    grad_mask = mask.new_empty(mask.shape)
+    grad_bias = bias.new_empty(bias.shape)
+    return grad_input, grad_weight, grad_offset, grad_mask, grad_bias

+ 33 - 0
python/py/Lib/site-packages/torchvision/_utils.py

@@ -0,0 +1,33 @@
+import enum
+from collections.abc import Sequence
+from typing import TypeVar
+
+T = TypeVar("T", bound=enum.Enum)
+
+
+class StrEnumMeta(enum.EnumMeta):
+    auto = enum.auto
+
+    def from_str(self: type[T], member: str) -> T:  # type: ignore[misc]
+        try:
+            return self[member]
+        except KeyError:
+            # TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
+            #  soon as it is migrated.
+            raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
+
+
+class StrEnum(enum.Enum, metaclass=StrEnumMeta):
+    pass
+
+
+def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
+    if not seq:
+        return ""
+    if len(seq) == 1:
+        return f"'{seq[0]}'"
+
+    head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
+    tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
+
+    return head + tail

+ 147 - 0
python/py/Lib/site-packages/torchvision/datasets/__init__.py

@@ -0,0 +1,147 @@
+from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
+from ._stereo_matching import (
+    CarlaStereo,
+    CREStereo,
+    ETH3DStereo,
+    FallingThingsStereo,
+    InStereo2k,
+    Kitti2012Stereo,
+    Kitti2015Stereo,
+    Middlebury2014Stereo,
+    SceneFlowStereo,
+    SintelStereo,
+)
+from .caltech import Caltech101, Caltech256
+from .celeba import CelebA
+from .cifar import CIFAR10, CIFAR100
+from .cityscapes import Cityscapes
+from .clevr import CLEVRClassification
+from .coco import CocoCaptions, CocoDetection
+from .country211 import Country211
+from .dtd import DTD
+from .eurosat import EuroSAT
+from .fakedata import FakeData
+from .fer2013 import FER2013
+from .fgvc_aircraft import FGVCAircraft
+from .flickr import Flickr30k, Flickr8k
+from .flowers102 import Flowers102
+from .folder import DatasetFolder, ImageFolder
+from .food101 import Food101
+from .gtsrb import GTSRB
+from .hmdb51 import HMDB51
+from .imagenet import ImageNet
+from .imagenette import Imagenette
+from .inaturalist import INaturalist
+from .kinetics import Kinetics
+from .kitti import Kitti
+from .lfw import LFWPairs, LFWPeople
+from .lsun import LSUN, LSUNClass
+from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
+from .moving_mnist import MovingMNIST
+from .omniglot import Omniglot
+from .oxford_iiit_pet import OxfordIIITPet
+from .pcam import PCAM
+from .phototour import PhotoTour
+from .places365 import Places365
+from .rendered_sst2 import RenderedSST2
+from .sbd import SBDataset
+from .sbu import SBU
+from .semeion import SEMEION
+from .stanford_cars import StanfordCars
+from .stl10 import STL10
+from .sun397 import SUN397
+from .svhn import SVHN
+from .ucf101 import UCF101
+from .usps import USPS
+from .vision import VisionDataset
+from .voc import VOCDetection, VOCSegmentation
+from .widerface import WIDERFace
+
+__all__ = (
+    "LSUN",
+    "LSUNClass",
+    "ImageFolder",
+    "DatasetFolder",
+    "FakeData",
+    "CocoCaptions",
+    "CocoDetection",
+    "CIFAR10",
+    "CIFAR100",
+    "EMNIST",
+    "FashionMNIST",
+    "QMNIST",
+    "MNIST",
+    "KMNIST",
+    "MovingMNIST",
+    "StanfordCars",
+    "STL10",
+    "SUN397",
+    "SVHN",
+    "PhotoTour",
+    "SEMEION",
+    "Omniglot",
+    "SBU",
+    "Flickr8k",
+    "Flickr30k",
+    "Flowers102",
+    "VOCSegmentation",
+    "VOCDetection",
+    "Cityscapes",
+    "ImageNet",
+    "Caltech101",
+    "Caltech256",
+    "CelebA",
+    "WIDERFace",
+    "SBDataset",
+    "VisionDataset",
+    "USPS",
+    "Kinetics",
+    "HMDB51",
+    "UCF101",
+    "Places365",
+    "Kitti",
+    "INaturalist",
+    "LFWPeople",
+    "LFWPairs",
+    "KittiFlow",
+    "Sintel",
+    "FlyingChairs",
+    "FlyingThings3D",
+    "HD1K",
+    "Food101",
+    "DTD",
+    "FER2013",
+    "GTSRB",
+    "CLEVRClassification",
+    "OxfordIIITPet",
+    "PCAM",
+    "Country211",
+    "FGVCAircraft",
+    "EuroSAT",
+    "RenderedSST2",
+    "Kitti2012Stereo",
+    "Kitti2015Stereo",
+    "CarlaStereo",
+    "Middlebury2014Stereo",
+    "CREStereo",
+    "FallingThingsStereo",
+    "SceneFlowStereo",
+    "SintelStereo",
+    "InStereo2k",
+    "ETH3DStereo",
+    "wrap_dataset_for_transforms_v2",
+    "Imagenette",
+)
+
+
+# We override current module's attributes to handle the import:
+# from torchvision.datasets import wrap_dataset_for_transforms_v2
+# without a cyclic error.
+# Ref: https://peps.python.org/pep-0562/
+def __getattr__(name):
+    if name in ("wrap_dataset_for_transforms_v2",):
+        from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2
+
+        return wrap_dataset_for_transforms_v2
+
+    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

+ 520 - 0
python/py/Lib/site-packages/torchvision/datasets/_optical_flow.py

@@ -0,0 +1,520 @@
+import itertools
+import os
+from abc import ABC, abstractmethod
+from glob import glob
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from ..io.image import decode_png, read_file
+from .folder import default_loader
+from .utils import _read_pfm, verify_str_arg
+from .vision import VisionDataset
+
+T1 = tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
+T2 = tuple[Image.Image, Image.Image, Optional[np.ndarray]]
+
+
+__all__ = (
+    "KittiFlow",
+    "Sintel",
+    "FlyingThings3D",
+    "FlyingChairs",
+    "HD1K",
+)
+
+
+class FlowDataset(ABC, VisionDataset):
+    # Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
+    # For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
+    # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
+    _has_builtin_flow_mask = False
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        transforms: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+
+        super().__init__(root=root)
+        self.transforms = transforms
+
+        self._flow_list: list[str] = []
+        self._image_list: list[list[str]] = []
+        self._loader = loader
+
+    def _read_img(self, file_name: str) -> Union[Image.Image, torch.Tensor]:
+        return self._loader(file_name)
+
+    @abstractmethod
+    def _read_flow(self, file_name: str):
+        # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
+        pass
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+
+        img1 = self._read_img(self._image_list[index][0])
+        img2 = self._read_img(self._image_list[index][1])
+
+        if self._flow_list:  # it will be empty for some dataset when split="test"
+            flow = self._read_flow(self._flow_list[index])
+            if self._has_builtin_flow_mask:
+                flow, valid_flow_mask = flow
+            else:
+                valid_flow_mask = None
+        else:
+            flow = valid_flow_mask = None
+
+        if self.transforms is not None:
+            img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask)
+
+        if self._has_builtin_flow_mask or valid_flow_mask is not None:
+            # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
+            return img1, img2, flow, valid_flow_mask  # type: ignore[return-value]
+        else:
+            return img1, img2, flow  # type: ignore[return-value]
+
+    def __len__(self) -> int:
+        return len(self._image_list)
+
+    def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset:
+        return torch.utils.data.ConcatDataset([self] * v)
+
+
+class Sintel(FlowDataset):
+    """`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Sintel
+                testing
+                    clean
+                        scene_1
+                        scene_2
+                        ...
+                    final
+                        scene_1
+                        scene_2
+                        ...
+                training
+                    clean
+                        scene_1
+                        scene_2
+                        ...
+                    final
+                        scene_1
+                        scene_2
+                        ...
+                    flow
+                        scene_1
+                        scene_2
+                        ...
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the Sintel Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
+            details on the different passes.
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+            ``valid_flow_mask`` is expected for consistency with other datasets which
+            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        pass_name: str = "clean",
+        transforms: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root=root, transforms=transforms, loader=loader)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+        verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+        passes = ["clean", "final"] if pass_name == "both" else [pass_name]
+
+        root = Path(root) / "Sintel"
+        flow_root = root / "training" / "flow"
+
+        for pass_name in passes:
+            split_dir = "training" if split == "train" else split
+            image_root = root / split_dir / pass_name
+            for scene in os.listdir(image_root):
+                image_list = sorted(glob(str(image_root / scene / "*.png")))
+                for i in range(len(image_list) - 1):
+                    self._image_list += [[image_list[i], image_list[i + 1]]]
+
+                if split == "train":
+                    self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img1, img2, flow)``.
+            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+            ``flow`` is None if ``split="test"``.
+            If a valid flow mask is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> np.ndarray:
+        return _read_flo(file_name)
+
+
+class KittiFlow(FlowDataset):
+    """`KITTI <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow>`__ dataset for optical flow (2015).
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            KittiFlow
+                testing
+                    image_2
+                training
+                    image_2
+                    flow_occ
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    _has_builtin_flow_mask = True
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        transforms: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root=root, transforms=transforms, loader=loader)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "KittiFlow" / (split + "ing")
+        images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
+        images2 = sorted(glob(str(root / "image_2" / "*_11.png")))
+
+        if not images1 or not images2:
+            raise FileNotFoundError(
+                "Could not find the Kitti flow images. Please make sure the directory structure is correct."
+            )
+
+        for img1, img2 in zip(images1, images2):
+            self._image_list += [[img1, img2]]
+
+        if split == "train":
+            self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
+            where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W)
+            indicating which flow values are valid. The flow is a numpy array of
+            shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
+            ``split="test"``.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> tuple[np.ndarray, np.ndarray]:
+        return _read_16bits_png_with_flow_and_valid_mask(file_name)
+
+
+class FlyingChairs(FlowDataset):
+    """`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow.
+
+    You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            FlyingChairs
+                data
+                    00001_flow.flo
+                    00001_img1.ppm
+                    00001_img2.ppm
+                    ...
+                FlyingChairs_train_val.txt
+
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "val"
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+            ``valid_flow_mask`` is expected for consistency with other datasets which
+            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+    """
+
+    def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root=root, transforms=transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "val"))
+
+        root = Path(root) / "FlyingChairs"
+        images = sorted(glob(str(root / "data" / "*.ppm")))
+        flows = sorted(glob(str(root / "data" / "*.flo")))
+
+        split_file_name = "FlyingChairs_train_val.txt"
+
+        if not os.path.exists(root / split_file_name):
+            raise FileNotFoundError(
+                "The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
+            )
+
+        split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32)
+        for i in range(len(flows)):
+            split_id = split_list[i]
+            if (split == "train" and split_id == 1) or (split == "val" and split_id == 2):
+                self._flow_list += [flows[i]]
+                self._image_list += [[images[2 * i], images[2 * i + 1]]]
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img1, img2, flow)``.
+            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+            ``flow`` is None if ``split="val"``.
+            If a valid flow mask is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> np.ndarray:
+        return _read_flo(file_name)
+
+
+class FlyingThings3D(FlowDataset):
+    """`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            FlyingThings3D
+                frames_cleanpass
+                    TEST
+                    TRAIN
+                frames_finalpass
+                    TEST
+                    TRAIN
+                optical_flow
+                    TEST
+                    TRAIN
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
+            details on the different passes.
+        camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+            ``valid_flow_mask`` is expected for consistency with other datasets which
+            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        pass_name: str = "clean",
+        camera: str = "left",
+        transforms: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root=root, transforms=transforms, loader=loader)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+        split = split.upper()
+
+        verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+        passes = {
+            "clean": ["frames_cleanpass"],
+            "final": ["frames_finalpass"],
+            "both": ["frames_cleanpass", "frames_finalpass"],
+        }[pass_name]
+
+        verify_str_arg(camera, "camera", valid_values=("left", "right", "both"))
+        cameras = ["left", "right"] if camera == "both" else [camera]
+
+        root = Path(root) / "FlyingThings3D"
+
+        directions = ("into_future", "into_past")
+        for pass_name, camera, direction in itertools.product(passes, cameras, directions):
+            image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
+            image_dirs = sorted(Path(image_dir) / camera for image_dir in image_dirs)
+
+            flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
+            flow_dirs = sorted(Path(flow_dir) / direction / camera for flow_dir in flow_dirs)
+
+            if not image_dirs or not flow_dirs:
+                raise FileNotFoundError(
+                    "Could not find the FlyingThings3D flow images. "
+                    "Please make sure the directory structure is correct."
+                )
+
+            for image_dir, flow_dir in zip(image_dirs, flow_dirs):
+                images = sorted(glob(str(image_dir / "*.png")))
+                flows = sorted(glob(str(flow_dir / "*.pfm")))
+                for i in range(len(flows) - 1):
+                    if direction == "into_future":
+                        self._image_list += [[images[i], images[i + 1]]]
+                        self._flow_list += [flows[i]]
+                    elif direction == "into_past":
+                        self._image_list += [[images[i + 1], images[i]]]
+                        self._flow_list += [flows[i + 1]]
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img1, img2, flow)``.
+            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+            ``flow`` is None if ``split="test"``.
+            If a valid flow mask is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> np.ndarray:
+        return _read_pfm(file_name)
+
+
+class HD1K(FlowDataset):
+    """`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            hd1k
+                hd1k_challenge
+                    image_2
+                hd1k_flow_gt
+                    flow_occ
+                hd1k_input
+                    image_2
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the HD1K Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    _has_builtin_flow_mask = True
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        transforms: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root=root, transforms=transforms, loader=loader)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "hd1k"
+        if split == "train":
+            # There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
+            for seq_idx in range(36):
+                flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
+                images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
+                for i in range(len(flows) - 1):
+                    self._flow_list += [flows[i]]
+                    self._image_list += [[images[i], images[i + 1]]]
+        else:
+            images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
+            images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
+            for image1, image2 in zip(images1, images2):
+                self._image_list += [[image1, image2]]
+
+        if not self._image_list:
+            raise FileNotFoundError(
+                "Could not find the HD1K images. Please make sure the directory structure is correct."
+            )
+
+    def _read_flow(self, file_name: str) -> tuple[np.ndarray, np.ndarray]:
+        return _read_16bits_png_with_flow_and_valid_mask(file_name)
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
+            is a numpy boolean mask of shape (H, W)
+            indicating which flow values are valid. The flow is a numpy array of
+            shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
+            ``split="test"``.
+        """
+        return super().__getitem__(index)
+
+
+def _read_flo(file_name: str) -> np.ndarray:
+    """Read .flo file in Middlebury format"""
+    # Code adapted from:
+    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
+    # Everything needs to be in little Endian according to
+    # https://vision.middlebury.edu/flow/code/flow-code/README.txt
+    with open(file_name, "rb") as f:
+        magic = np.fromfile(f, "c", count=4).tobytes()
+        if magic != b"PIEH":
+            raise ValueError("Magic number incorrect. Invalid .flo file")
+
+        w = np.fromfile(f, "<i4", count=1).item()
+        h = np.fromfile(f, "<i4", count=1).item()
+        data = np.fromfile(f, "<f4", count=2 * w * h)
+        return data.reshape(h, w, 2).transpose(2, 0, 1)
+
+
+def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> tuple[np.ndarray, np.ndarray]:
+
+    flow_and_valid = decode_png(read_file(file_name)).to(torch.float32)
+    flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
+    flow = (flow - 2**15) / 64  # This conversion is explained somewhere on the kitti archive
+    valid_flow_mask = valid_flow_mask.bool()
+
+    # For consistency with other datasets, we convert to numpy
+    return flow.numpy(), valid_flow_mask.numpy()

+ 1223 - 0
python/py/Lib/site-packages/torchvision/datasets/_stereo_matching.py

@@ -0,0 +1,1223 @@
+import functools
+import json
+import os
+import random
+import shutil
+from abc import ABC, abstractmethod
+from glob import glob
+from pathlib import Path
+from typing import Callable, cast, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+T1 = tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
+T2 = tuple[Image.Image, Image.Image, Optional[np.ndarray]]
+
+__all__ = ()
+
+_read_pfm_file = functools.partial(_read_pfm, slice_channels=1)
+
+
+class StereoMatchingDataset(ABC, VisionDataset):
+    """Base interface for Stereo matching datasets"""
+
+    _has_built_in_disparity_mask = False
+
+    def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
+        """
+        Args:
+            root(str): Root directory of the dataset.
+            transforms(callable, optional): A function/transform that takes in Tuples of
+                (images, disparities, valid_masks) and returns a transformed version of each of them.
+                images is a Tuple of (``PIL.Image``, ``PIL.Image``)
+                disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W)
+                valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W)
+                In some cases, when a dataset does not provide disparities, the ``disparities`` and
+                ``valid_masks`` can be Tuples containing None values.
+                For training splits generally the datasets provide a minimal guarantee of
+                images: (``PIL.Image``, ``PIL.Image``)
+                disparities: (``np.ndarray``, ``None``) with shape (1, H, W)
+                Optionally, based on the dataset, it can return a ``mask`` as well:
+                valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W)
+                For some test splits, the datasets provides outputs that look like:
+                imgaes: (``PIL.Image``, ``PIL.Image``)
+                disparities: (``None``, ``None``)
+                Optionally, based on the dataset, it can return a ``mask`` as well:
+                valid_masks: (``None``, ``None``)
+        """
+        super().__init__(root=root)
+        self.transforms = transforms
+
+        self._images = []  # type: ignore
+        self._disparities = []  # type: ignore
+
+    def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
+        img = Image.open(file_path)
+        if img.mode != "RGB":
+            img = img.convert("RGB")  # type: ignore [assignment]
+        return img
+
+    def _scan_pairs(
+        self,
+        paths_left_pattern: str,
+        paths_right_pattern: Optional[str] = None,
+    ) -> list[tuple[str, Optional[str]]]:
+
+        left_paths = list(sorted(glob(paths_left_pattern)))
+
+        right_paths: list[Union[None, str]]
+        if paths_right_pattern:
+            right_paths = list(sorted(glob(paths_right_pattern)))
+        else:
+            right_paths = list(None for _ in left_paths)
+
+        if not left_paths:
+            raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}")
+
+        if not right_paths:
+            raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}")
+
+        if len(left_paths) != len(right_paths):
+            raise ValueError(
+                f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n "
+                f"left pattern: {paths_left_pattern}\n"
+                f"right pattern: {paths_right_pattern}\n"
+            )
+
+        paths = list((left, right) for left, right in zip(left_paths, right_paths))
+        return paths
+
+    @abstractmethod
+    def _read_disparity(self, file_path: str) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:
+        # function that returns a disparity map and an occlusion map
+        pass
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask``
+                can be a numpy boolean mask of shape (H, W) if the dataset provides a file
+                indicating which disparity pixels are valid. The disparity is a numpy array of
+                shape (1, H, W) and the images are PIL images. ``disparity`` is None for
+                datasets on which for ``split="test"`` the authors did not provide annotations.
+        """
+        img_left = self._read_img(self._images[index][0])
+        img_right = self._read_img(self._images[index][1])
+
+        dsp_map_left, valid_mask_left = self._read_disparity(self._disparities[index][0])
+        dsp_map_right, valid_mask_right = self._read_disparity(self._disparities[index][1])
+
+        imgs = (img_left, img_right)
+        dsp_maps = (dsp_map_left, dsp_map_right)
+        valid_masks = (valid_mask_left, valid_mask_right)
+
+        if self.transforms is not None:
+            (
+                imgs,
+                dsp_maps,
+                valid_masks,
+            ) = self.transforms(imgs, dsp_maps, valid_masks)
+
+        if self._has_built_in_disparity_mask or valid_masks[0] is not None:
+            return imgs[0], imgs[1], dsp_maps[0], cast(np.ndarray, valid_masks[0])
+        else:
+            return imgs[0], imgs[1], dsp_maps[0]
+
+    def __len__(self) -> int:
+        return len(self._images)
+
+
+class CarlaStereo(StereoMatchingDataset):
+    """
+    Carla simulator data linked in the `CREStereo github repo <https://github.com/megvii-research/CREStereo>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            carla-highres
+                trainingF
+                    scene1
+                        img0.png
+                        img1.png
+                        disp0GT.pfm
+                        disp1GT.pfm
+                        calib.txt
+                    scene2
+                        img0.png
+                        img1.png
+                        disp0GT.pfm
+                        disp1GT.pfm
+                        calib.txt
+                    ...
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where `carla-highres` is located.
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "carla-highres"
+
+        left_image_pattern = str(root / "trainingF" / "*" / "im0.png")
+        right_image_pattern = str(root / "trainingF" / "*" / "im1.png")
+        imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
+        self._images = imgs
+
+        left_disparity_pattern = str(root / "trainingF" / "*" / "disp0GT.pfm")
+        right_disparity_pattern = str(root / "trainingF" / "*" / "disp1GT.pfm")
+        disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+        self._disparities = disparities
+
+    def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class Kitti2012Stereo(StereoMatchingDataset):
+    """
+    KITTI dataset from the `2012 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php>`_.
+    Uses the RGB images for consistency with KITTI 2015.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Kitti2012
+                testing
+                    colored_0
+                        1_10.png
+                        2_10.png
+                        ...
+                    colored_1
+                        1_10.png
+                        2_10.png
+                        ...
+                training
+                    colored_0
+                        1_10.png
+                        2_10.png
+                        ...
+                    colored_1
+                        1_10.png
+                        2_10.png
+                        ...
+                    disp_noc
+                        1.png
+                        2.png
+                        ...
+                    calib
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where `Kitti2012` is located.
+        split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "Kitti2012" / (split + "ing")
+
+        left_img_pattern = str(root / "colored_0" / "*_10.png")
+        right_img_pattern = str(root / "colored_1" / "*_10.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        if split == "train":
+            disparity_pattern = str(root / "disp_noc" / "*.png")
+            self._disparities = self._scan_pairs(disparity_pattern, None)
+        else:
+            self._disparities = list((None, None) for _ in self._images)
+
+    def _read_disparity(self, file_path: str) -> tuple[Optional[np.ndarray], None]:
+        # test split has no disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = np.asarray(Image.open(file_path)) / 256.0
+        # unsqueeze the disparity map into (C, H, W) format
+        disparity_map = disparity_map[None, :, :]
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+            Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class Kitti2015Stereo(StereoMatchingDataset):
+    """
+    KITTI dataset from the `2015 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Kitti2015
+                testing
+                    image_2
+                        img1.png
+                        img2.png
+                        ...
+                    image_3
+                        img1.png
+                        img2.png
+                        ...
+                training
+                    image_2
+                        img1.png
+                        img2.png
+                        ...
+                    image_3
+                        img1.png
+                        img2.png
+                        ...
+                    disp_occ_0
+                        img1.png
+                        img2.png
+                        ...
+                    disp_occ_1
+                        img1.png
+                        img2.png
+                        ...
+                    calib
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where `Kitti2015` is located.
+        split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "Kitti2015" / (split + "ing")
+        left_img_pattern = str(root / "image_2" / "*.png")
+        right_img_pattern = str(root / "image_3" / "*.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        if split == "train":
+            left_disparity_pattern = str(root / "disp_occ_0" / "*.png")
+            right_disparity_pattern = str(root / "disp_occ_1" / "*.png")
+            self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+        else:
+            self._disparities = list((None, None) for _ in self._images)
+
+    def _read_disparity(self, file_path: str) -> tuple[Optional[np.ndarray], None]:
+        # test split has no disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = np.asarray(Image.open(file_path)) / 256.0
+        # unsqueeze the disparity map into (C, H, W) format
+        disparity_map = disparity_map[None, :, :]
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+            Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class Middlebury2014Stereo(StereoMatchingDataset):
+    """Publicly available scenes from the Middlebury dataset `2014 version <https://vision.middlebury.edu/stereo/data/scenes2014/>`.
+
+    The dataset mostly follows the original format, without containing the ambient subdirectories.  : ::
+
+        root
+            Middlebury2014
+                train
+                    scene1-{perfect,imperfect}
+                        calib.txt
+                        im{0,1}.png
+                        im1E.png
+                        im1L.png
+                        disp{0,1}.pfm
+                        disp{0,1}-n.png
+                        disp{0,1}-sd.pfm
+                        disp{0,1}y.pfm
+                    scene2-{perfect,imperfect}
+                        calib.txt
+                        im{0,1}.png
+                        im1E.png
+                        im1L.png
+                        disp{0,1}.pfm
+                        disp{0,1}-n.png
+                        disp{0,1}-sd.pfm
+                        disp{0,1}y.pfm
+                    ...
+                additional
+                    scene1-{perfect,imperfect}
+                        calib.txt
+                        im{0,1}.png
+                        im1E.png
+                        im1L.png
+                        disp{0,1}.pfm
+                        disp{0,1}-n.png
+                        disp{0,1}-sd.pfm
+                        disp{0,1}y.pfm
+                    ...
+                test
+                    scene1
+                        calib.txt
+                        im{0,1}.png
+                    scene2
+                        calib.txt
+                        im{0,1}.png
+                    ...
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the Middleburry 2014 Dataset.
+        split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
+        use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
+            The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
+        calibration (string, optional): Whether or not to use the calibrated (default) or uncalibrated scenes.
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+        download (boolean, optional): Whether or not to download the dataset in the ``root`` directory.
+    """
+
+    splits = {
+        "train": [
+            "Adirondack",
+            "Jadeplant",
+            "Motorcycle",
+            "Piano",
+            "Pipes",
+            "Playroom",
+            "Playtable",
+            "Recycle",
+            "Shelves",
+            "Vintage",
+        ],
+        "additional": [
+            "Backpack",
+            "Bicycle1",
+            "Cable",
+            "Classroom1",
+            "Couch",
+            "Flowers",
+            "Mask",
+            "Shopvac",
+            "Sticks",
+            "Storage",
+            "Sword1",
+            "Sword2",
+            "Umbrella",
+        ],
+        "test": [
+            "Plants",
+            "Classroom2E",
+            "Classroom2",
+            "Australia",
+            "DjembeL",
+            "CrusadeP",
+            "Crusade",
+            "Hoops",
+            "Bicycle2",
+            "Staircase",
+            "Newkuba",
+            "AustraliaP",
+            "Djembe",
+            "Livingroom",
+            "Computer",
+        ],
+    }
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        calibration: Optional[str] = "perfect",
+        use_ambient_views: bool = False,
+        transforms: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
+        self.split = split
+
+        if calibration:
+            verify_str_arg(calibration, "calibration", valid_values=("perfect", "imperfect", "both", None))  # type: ignore
+            if split == "test":
+                raise ValueError("Split 'test' has only no calibration settings, please set `calibration=None`.")
+        else:
+            if split != "test":
+                raise ValueError(
+                    f"Split '{split}' has calibration settings, however None was provided as an argument."
+                    f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.",
+                )
+
+        if download:
+            self._download_dataset(root)
+
+        root = Path(root) / "Middlebury2014"
+
+        if not os.path.exists(root / split):
+            raise FileNotFoundError(f"The {split} directory was not found in the provided root directory")
+
+        split_scenes = self.splits[split]
+        # check that the provided root folder contains the scene splits
+        if not any(
+            # using startswith to account for perfect / imperfect calibrartion
+            scene.startswith(s)
+            for scene in os.listdir(root / split)
+            for s in split_scenes
+        ):
+            raise FileNotFoundError(f"Provided root folder does not contain any scenes from the {split} split.")
+
+        calibrartion_suffixes = {
+            None: [""],
+            "perfect": ["-perfect"],
+            "imperfect": ["-imperfect"],
+            "both": ["-perfect", "-imperfect"],
+        }[calibration]
+
+        for calibration_suffix in calibrartion_suffixes:
+            scene_pattern = "*" + calibration_suffix
+            left_img_pattern = str(root / split / scene_pattern / "im0.png")
+            right_img_pattern = str(root / split / scene_pattern / "im1.png")
+            self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+            if split == "test":
+                self._disparities = list((None, None) for _ in self._images)
+            else:
+                left_dispartity_pattern = str(root / split / scene_pattern / "disp0.pfm")
+                right_dispartity_pattern = str(root / split / scene_pattern / "disp1.pfm")
+                self._disparities += self._scan_pairs(left_dispartity_pattern, right_dispartity_pattern)
+
+        self.use_ambient_views = use_ambient_views
+
+    def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
+        """
+        Function that reads either the original right image or an augmented view when ``use_ambient_views`` is True.
+        When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]``
+        as the right image.
+        """
+        ambient_file_paths: list[Union[str, Path]]  # make mypy happy
+
+        if not isinstance(file_path, Path):
+            file_path = Path(file_path)
+
+        if file_path.name == "im1.png" and self.use_ambient_views:
+            base_path = file_path.parent
+            # initialize sampleable container
+            ambient_file_paths = list(base_path / view_name for view_name in ["im1E.png", "im1L.png"])
+            # double check that we're not going to try to read from an invalid file path
+            ambient_file_paths = list(filter(lambda p: os.path.exists(p), ambient_file_paths))
+            # keep the original image as an option as well for uniform sampling between base views
+            ambient_file_paths.append(file_path)
+            file_path = random.choice(ambient_file_paths)  # type: ignore
+        return super()._read_img(file_path)
+
+    def _read_disparity(self, file_path: str) -> Union[tuple[None, None], tuple[np.ndarray, np.ndarray]]:
+        # test split has not disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        disparity_map[disparity_map == np.inf] = 0  # remove infinite disparities
+        valid_mask = (disparity_map > 0).squeeze(0)  # mask out invalid disparities
+        return disparity_map, valid_mask
+
+    def _download_dataset(self, root: Union[str, Path]) -> None:
+        base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
+        # train and additional splits have 2 different calibration settings
+        root = Path(root) / "Middlebury2014"
+        split_name = self.split
+
+        if split_name != "test":
+            for split_scene in self.splits[split_name]:
+                split_root = root / split_name
+                for calibration in ["perfect", "imperfect"]:
+                    scene_name = f"{split_scene}-{calibration}"
+                    scene_url = f"{base_url}/{scene_name}.zip"
+                    # download the scene only if it doesn't exist
+                    if not (split_root / scene_name).exists():
+                        download_and_extract_archive(
+                            url=scene_url,
+                            filename=f"{scene_name}.zip",
+                            download_root=str(split_root),
+                            remove_finished=True,
+                        )
+        else:
+            os.makedirs(root / "test")
+            if any(s not in os.listdir(root / "test") for s in self.splits["test"]):
+                # test split is downloaded from a different location
+                test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip"
+                # the unzip is going to produce a directory MiddEval3 with two subdirectories trainingF and testF
+                # we want to move the contents from testF into the  directory
+                download_and_extract_archive(url=test_set_url, download_root=str(root), remove_finished=True)
+                for scene_dir, scene_names, _ in os.walk(str(root / "MiddEval3/testF")):
+                    for scene in scene_names:
+                        scene_dst_dir = root / "test"
+                        scene_src_dir = Path(scene_dir) / scene
+                        os.makedirs(scene_dst_dir, exist_ok=True)
+                        shutil.move(str(scene_src_dir), str(scene_dst_dir))
+
+                # cleanup MiddEval3 directory
+                shutil.rmtree(str(root / "MiddEval3"))
+
+    def __getitem__(self, index: int) -> T2:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` for `split=test`.
+        """
+        return cast(T2, super().__getitem__(index))
+
+
+class CREStereo(StereoMatchingDataset):
+    """Synthetic dataset used in training the `CREStereo <https://arxiv.org/pdf/2203.11483.pdf>`_ architecture.
+    Dataset details on the official paper `repo <https://github.com/megvii-research/CREStereo>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            CREStereo
+                tree
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    img2_left.jpg
+                    img2_right.jpg
+                    img2_left.disp.jpg
+                    img2_right.disp.jpg
+                    ...
+                shapenet
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    ...
+                reflective
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    ...
+                hole
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    ...
+
+    Args:
+        root (str): Root directory of the dataset.
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "CREStereo"
+
+        dirs = ["shapenet", "reflective", "tree", "hole"]
+
+        for s in dirs:
+            left_image_pattern = str(root / s / "*_left.jpg")
+            right_image_pattern = str(root / s / "*_right.jpg")
+            imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
+            self._images += imgs
+
+            left_disparity_pattern = str(root / s / "*_left.disp.png")
+            right_disparity_pattern = str(root / s / "*_right.disp.png")
+            disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+            self._disparities += disparities
+
+    def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
+        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+        # unsqueeze the disparity map into (C, H, W) format
+        disparity_map = disparity_map[None, :, :] / 32.0
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class FallingThingsStereo(StereoMatchingDataset):
+    """`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            FallingThings
+                single
+                    dir1
+                        scene1
+                            _object_settings.json
+                            _camera_settings.json
+                            image1.left.depth.png
+                            image1.right.depth.png
+                            image1.left.jpg
+                            image1.right.jpg
+                            image2.left.depth.png
+                            image2.right.depth.png
+                            image2.left.jpg
+                            image2.right
+                            ...
+                        scene2
+                    ...
+                mixed
+                    scene1
+                        _object_settings.json
+                        _camera_settings.json
+                        image1.left.depth.png
+                        image1.right.depth.png
+                        image1.left.jpg
+                        image1.right.jpg
+                        image2.left.depth.png
+                        image2.right.depth.png
+                        image2.left.jpg
+                        image2.right
+                        ...
+                    scene2
+                    ...
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where FallingThings is located.
+        variant (string): Which variant to use. Either "single", "mixed", or "both".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "FallingThings"
+
+        verify_str_arg(variant, "variant", valid_values=("single", "mixed", "both"))
+
+        variants = {
+            "single": ["single"],
+            "mixed": ["mixed"],
+            "both": ["single", "mixed"],
+        }[variant]
+
+        split_prefix = {
+            "single": Path("*") / "*",
+            "mixed": Path("*"),
+        }
+
+        for s in variants:
+            left_img_pattern = str(root / s / split_prefix[s] / "*.left.jpg")
+            right_img_pattern = str(root / s / split_prefix[s] / "*.right.jpg")
+            self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+            left_disparity_pattern = str(root / s / split_prefix[s] / "*.left.depth.png")
+            right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
+            self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+    def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
+        # (H, W) image
+        depth = np.asarray(Image.open(file_path))
+        # as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
+        # in order to extract disparity from depth maps
+        camera_settings_path = Path(file_path).parent / "_camera_settings.json"
+        with open(camera_settings_path) as f:
+            # inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constant)
+            intrinsics = json.load(f)
+            focal = intrinsics["camera_settings"][0]["intrinsic_settings"]["fx"]
+            baseline, pixel_constant = 6, 100  # pixel constant is inverted
+            disparity_map = (baseline * focal * pixel_constant) / depth.astype(np.float32)
+            # unsqueeze disparity to (C, H, W)
+            disparity_map = disparity_map[None, :, :]
+            valid_mask = None
+            return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class SceneFlowStereo(StereoMatchingDataset):
+    """Dataset interface for `Scene Flow <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ datasets.
+    This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            SceneFlow
+                Monkaa
+                    frames_cleanpass
+                        scene1
+                            left
+                                img1.png
+                                img2.png
+                            right
+                                img1.png
+                                img2.png
+                        scene2
+                            left
+                                img1.png
+                                img2.png
+                            right
+                                img1.png
+                                img2.png
+                    frames_finalpass
+                        scene1
+                            left
+                                img1.png
+                                img2.png
+                            right
+                                img1.png
+                                img2.png
+                        ...
+                        ...
+                    disparity
+                        scene1
+                            left
+                                img1.pfm
+                                img2.pfm
+                            right
+                                img1.pfm
+                                img2.pfm
+                FlyingThings3D
+                    ...
+                    ...
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where SceneFlow is located.
+        variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
+        pass_name (string): Which pass to use, "clean" (default), "final" or "both".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        variant: str = "FlyingThings3D",
+        pass_name: str = "clean",
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "SceneFlow"
+
+        verify_str_arg(variant, "variant", valid_values=("FlyingThings3D", "Driving", "Monkaa"))
+        verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+
+        passes = {
+            "clean": ["frames_cleanpass"],
+            "final": ["frames_finalpass"],
+            "both": ["frames_cleanpass", "frames_finalpass"],
+        }[pass_name]
+
+        root = root / variant
+
+        prefix_directories = {
+            "Monkaa": Path("*"),
+            "FlyingThings3D": Path("*") / "*" / "*",
+            "Driving": Path("*") / "*" / "*",
+        }
+
+        for p in passes:
+            left_image_pattern = str(root / p / prefix_directories[variant] / "left" / "*.png")
+            right_image_pattern = str(root / p / prefix_directories[variant] / "right" / "*.png")
+            self._images += self._scan_pairs(left_image_pattern, right_image_pattern)
+
+            left_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "left" / "*.pfm")
+            right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
+            self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+    def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class SintelStereo(StereoMatchingDataset):
+    """Sintel `Stereo Dataset <http://sintel.is.tue.mpg.de/stereo>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Sintel
+                training
+                    final_left
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    final_right
+                        scene2
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    disparities
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    occlusions
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    outofframe
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where Sintel Stereo is located.
+        pass_name (string): The name of the pass to use, either "final", "clean" or "both".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
+
+        root = Path(root) / "Sintel"
+        pass_names = {
+            "final": ["final"],
+            "clean": ["clean"],
+            "both": ["final", "clean"],
+        }[pass_name]
+
+        for p in pass_names:
+            left_img_pattern = str(root / "training" / f"{p}_left" / "*" / "*.png")
+            right_img_pattern = str(root / "training" / f"{p}_right" / "*" / "*.png")
+            self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+            disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png")
+            self._disparities += self._scan_pairs(disparity_pattern, None)
+
+    def _get_occlussion_mask_paths(self, file_path: str) -> tuple[str, str]:
+        # helper function to get the occlusion mask paths
+        # a path will look like  .../.../.../training/disparities/scene1/img1.png
+        # we want to get something like .../.../.../training/occlusions/scene1/img1.png
+        fpath = Path(file_path)
+        basename = fpath.name
+        scenedir = fpath.parent
+        # the parent of the scenedir is actually the disparity dir
+        sampledir = scenedir.parent.parent
+
+        occlusion_path = str(sampledir / "occlusions" / scenedir.name / basename)
+        outofframe_path = str(sampledir / "outofframe" / scenedir.name / basename)
+
+        if not os.path.exists(occlusion_path):
+            raise FileNotFoundError(f"Occlusion mask {occlusion_path} does not exist")
+
+        if not os.path.exists(outofframe_path):
+            raise FileNotFoundError(f"Out of frame mask {outofframe_path} does not exist")
+
+        return occlusion_path, outofframe_path
+
+    def _read_disparity(self, file_path: str) -> Union[tuple[None, None], tuple[np.ndarray, np.ndarray]]:
+        if file_path is None:
+            return None, None
+
+        # disparity decoding as per Sintel instructions in the README provided with the dataset
+        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+        r, g, b = np.split(disparity_map, 3, axis=-1)
+        disparity_map = r * 4 + g / (2**6) + b / (2**14)
+        # reshape into (C, H, W) format
+        disparity_map = np.transpose(disparity_map, (2, 0, 1))
+        # find the appropriate file paths
+        occlued_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path)
+        # occlusion masks
+        valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0
+        # out of frame masks
+        off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0
+        # combine the masks together
+        valid_mask = np.logical_and(off_mask, valid_mask)
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T2:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
+            the valid_mask is a numpy array of shape (H, W).
+        """
+        return cast(T2, super().__getitem__(index))
+
+
+class InStereo2k(StereoMatchingDataset):
+    """`InStereo2k <https://github.com/YuhuaXu/StereoDataset>`_ dataset.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            InStereo2k
+                train
+                    scene1
+                        left.png
+                        right.png
+                        left_disp.png
+                        right_disp.png
+                        ...
+                    scene2
+                    ...
+                test
+                    scene1
+                        left.png
+                        right.png
+                        left_disp.png
+                        right_disp.png
+                        ...
+                    scene2
+                    ...
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where InStereo2k is located.
+        split (string): Either "train" or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "InStereo2k" / split
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        left_img_pattern = str(root / "*" / "left.png")
+        right_img_pattern = str(root / "*" / "right.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        left_disparity_pattern = str(root / "*" / "left_disp.png")
+        right_disparity_pattern = str(root / "*" / "right_disp.png")
+        self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+    def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
+        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+        # unsqueeze disparity to (C, H, W)
+        disparity_map = disparity_map[None, :, :] / 1024.0
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class ETH3DStereo(StereoMatchingDataset):
+    """ETH3D `Low-Res Two-View <https://www.eth3d.net/datasets>`_ dataset.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            ETH3D
+                two_view_training
+                    scene1
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    scene2
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    ...
+                two_view_training_gt
+                    scene1
+                        disp0GT.pfm
+                        mask0nocc.png
+                    scene2
+                        disp0GT.pfm
+                        mask0nocc.png
+                    ...
+                two_view_testing
+                    scene1
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    scene2
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    ...
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the ETH3D Dataset.
+        split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "ETH3D"
+
+        img_dir = "two_view_training" if split == "train" else "two_view_test"
+        anot_dir = "two_view_training_gt"
+
+        left_img_pattern = str(root / img_dir / "*" / "im0.png")
+        right_img_pattern = str(root / img_dir / "*" / "im1.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        if split == "test":
+            self._disparities = list((None, None) for _ in self._images)
+        else:
+            disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
+            self._disparities = self._scan_pairs(disparity_pattern, None)
+
+    def _read_disparity(self, file_path: str) -> Union[tuple[None, None], tuple[np.ndarray, np.ndarray]]:
+        # test split has no disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        mask_path = Path(file_path).parent / "mask0nocc.png"
+        valid_mask = Image.open(mask_path)
+        valid_mask = np.asarray(valid_mask).astype(bool)
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T2:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+            Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+        """
+        return cast(T2, super().__getitem__(index))

+ 241 - 0
python/py/Lib/site-packages/torchvision/datasets/caltech.py

@@ -0,0 +1,241 @@
+import os
+import os.path
+import shutil
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Caltech101(VisionDataset):
+    """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where directory
+            ``caltech101`` exists or will be saved to if download is set to True.
+        target_type (string or list, optional): Type of target to use, ``category`` or
+            ``annotation``. Can also be a list to output a tuple with all specified
+            target types.  ``category`` represents the target class, and
+            ``annotation`` is a list of points from a hand-generated outline.
+            Defaults to ``category``.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+            .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        target_type: Union[list[str], str] = "category",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
+        os.makedirs(self.root, exist_ok=True)
+        if isinstance(target_type, str):
+            target_type = [target_type]
+        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
+        self.categories.remove("BACKGROUND_Google")  # this is not a real class
+
+        # For some reason, the category names in "101_ObjectCategories" and
+        # "Annotations" do not always match. This is a manual map between the
+        # two. Defaults to using same name, since most names are fine.
+        name_map = {
+            "Faces": "Faces_2",
+            "Faces_easy": "Faces_3",
+            "Motorbikes": "Motorbikes_16",
+            "airplanes": "Airplanes_Side_2",
+        }
+        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
+
+        self.index: list[int] = []
+        self.y = []
+        for i, c in enumerate(self.categories):
+            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
+            self.index.extend(range(1, n + 1))
+            self.y.extend(n * [i])
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where the type of target specified by target_type.
+        """
+        import scipy.io
+
+        img = Image.open(
+            os.path.join(
+                self.root,
+                "101_ObjectCategories",
+                self.categories[self.y[index]],
+                f"image_{self.index[index]:04d}.jpg",
+            )
+        )
+
+        target: Any = []
+        for t in self.target_type:
+            if t == "category":
+                target.append(self.y[index])
+            elif t == "annotation":
+                data = scipy.io.loadmat(
+                    os.path.join(
+                        self.root,
+                        "Annotations",
+                        self.annotation_categories[self.y[index]],
+                        f"annotation_{self.index[index]:04d}.mat",
+                    )
+                )
+                target.append(data["obj_contour"])
+        target = tuple(target) if len(target) > 1 else target[0]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def _check_integrity(self) -> bool:
+        # can be more robust and check hash of files
+        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
+
+    def __len__(self) -> int:
+        return len(self.index)
+
+    def download(self) -> None:
+        if self._check_integrity():
+            return
+
+        download_and_extract_archive(
+            "https://data.caltech.edu/records/mzrjq-6wc02/files/caltech-101.zip",
+            download_root=self.root,
+            filename="caltech-101.zip",
+            md5="3138e1922a9193bfa496528edbbc45d0",
+        )
+        gzip_folder = os.path.join(self.root, "caltech-101")
+        for gzip_file in os.listdir(gzip_folder):
+            if gzip_file.endswith(".gz"):
+                extract_archive(os.path.join(gzip_folder, gzip_file), self.root)
+        shutil.rmtree(gzip_folder)
+        os.remove(os.path.join(self.root, "caltech-101.zip"))
+
+    def extra_repr(self) -> str:
+        return "Target type: {target_type}".format(**self.__dict__)
+
+
+class Caltech256(VisionDataset):
+    """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where directory
+            ``caltech256`` exists or will be saved to if download is set to True.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
+        os.makedirs(self.root, exist_ok=True)
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
+        self.index: list[int] = []
+        self.y = []
+        for i, c in enumerate(self.categories):
+            n = len(
+                [
+                    item
+                    for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
+                    if item.endswith(".jpg")
+                ]
+            )
+            self.index.extend(range(1, n + 1))
+            self.y.extend(n * [i])
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img = Image.open(
+            os.path.join(
+                self.root,
+                "256_ObjectCategories",
+                self.categories[self.y[index]],
+                f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
+            )
+        )
+
+        target = self.y[index]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def _check_integrity(self) -> bool:
+        # can be more robust and check hash of files
+        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
+
+    def __len__(self) -> int:
+        return len(self.index)
+
+    def download(self) -> None:
+        if self._check_integrity():
+            return
+
+        download_and_extract_archive(
+            "https://data.caltech.edu/records/nyy15-4j048/files/256_ObjectCategories.tar",
+            self.root,
+            filename="256_ObjectCategories.tar",
+            md5="67b4f42ca05d46448c6bb8ecd2220f6d",
+        )

+ 210 - 0
python/py/Lib/site-packages/torchvision/datasets/celeba.py

@@ -0,0 +1,210 @@
+import csv
+import os
+from collections import namedtuple
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import PIL
+import torch
+
+from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+CSV = namedtuple("CSV", ["header", "index", "data"])
+
+
+class CelebA(VisionDataset):
+    """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+        split (string): One of {'train', 'valid', 'test', 'all'}.
+            Accordingly dataset is selected.
+        target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
+            or ``landmarks``. Can also be a list to output a tuple with all specified target types.
+            The targets represent:
+
+                - ``attr`` (Tensor shape=(40,) dtype=int): binary (0, 1) labels for attributes
+                - ``identity`` (int): label for each person (data points with the same identity are the same person)
+                - ``bbox`` (Tensor shape=(4,) dtype=int): bounding box (x, y, width, height).
+
+                  .. warning::
+
+                      These bounding box coordinates correspond to the original uncropped
+                      CelebA images, not the cropped and aligned images returned by this
+                      dataset. As a result, the coordinates will not match and may fall
+                      outside the image boundaries.
+
+                      See `Issue #9008 <https://github.com/pytorch/vision/issues/9008>`_ for
+                      details and potential workarounds.
+
+                - ``landmarks`` (Tensor shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
+                  righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
+
+            Defaults to ``attr``. If empty, ``None`` will be returned as target.
+
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+            .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+    """
+
+    base_folder = "celeba"
+    # There currently does not appear to be an easy way to extract 7z in python (without introducing additional
+    # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
+    # right now.
+    file_list = [
+        # File ID                                      MD5 Hash                            Filename
+        ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
+        # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
+        # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
+        ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
+        ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
+        ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
+        ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
+        # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
+        ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
+    ]
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        target_type: Union[list[str], str] = "attr",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.split = split
+        if isinstance(target_type, list):
+            self.target_type = target_type
+        else:
+            self.target_type = [target_type]
+
+        if not self.target_type and self.target_transform is not None:
+            raise RuntimeError("target_transform is specified but target_type is empty")
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        split_map = {
+            "train": 0,
+            "valid": 1,
+            "test": 2,
+            "all": None,
+        }
+        split_ = split_map[
+            verify_str_arg(
+                split.lower() if isinstance(split, str) else split,
+                "split",
+                ("train", "valid", "test", "all"),
+            )
+        ]
+        splits = self._load_csv("list_eval_partition.txt")
+        identity = self._load_csv("identity_CelebA.txt")
+        bbox = self._load_csv("list_bbox_celeba.txt", header=1)
+        landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
+        attr = self._load_csv("list_attr_celeba.txt", header=1)
+
+        mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
+
+        if mask == slice(None):  # if split == "all"
+            self.filename = splits.index
+        else:
+            self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]  # type: ignore[arg-type]
+        self.identity = identity.data[mask]
+        self.bbox = bbox.data[mask]
+        self.landmarks_align = landmarks_align.data[mask]
+        self.attr = attr.data[mask]
+        # map from {-1, 1} to {0, 1}
+        self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
+        self.attr_names = attr.header
+
+    def _load_csv(
+        self,
+        filename: str,
+        header: Optional[int] = None,
+    ) -> CSV:
+        with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:
+            data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))
+
+        if header is not None:
+            headers = data[header]
+            data = data[header + 1 :]
+        else:
+            headers = []
+
+        indices = [row[0] for row in data]
+        data = [row[1:] for row in data]
+        data_int = [list(map(int, i)) for i in data]
+
+        return CSV(headers, indices, torch.tensor(data_int))
+
+    def _check_integrity(self) -> bool:
+        for _, md5, filename in self.file_list:
+            fpath = os.path.join(self.root, self.base_folder, filename)
+            _, ext = os.path.splitext(filename)
+            # Allow original archive to be deleted (zip and 7z)
+            # Only need the extracted images
+            if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
+                return False
+
+        # Should check a hash of the images
+        return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
+
+    def download(self) -> None:
+        if self._check_integrity():
+            return
+
+        for file_id, md5, filename in self.file_list:
+            download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
+
+        extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
+
+        target: Any = []
+        for t in self.target_type:
+            if t == "attr":
+                target.append(self.attr[index, :])
+            elif t == "identity":
+                target.append(self.identity[index, 0])
+            elif t == "bbox":
+                target.append(self.bbox[index, :])
+            elif t == "landmarks":
+                target.append(self.landmarks_align[index, :])
+            else:
+                # TODO: refactor with utils.verify_str_arg
+                raise ValueError(f'Target type "{t}" is not recognized.')
+
+        if self.transform is not None:
+            X = self.transform(X)
+
+        if target:
+            target = tuple(target) if len(target) > 1 else target[0]
+
+            if self.target_transform is not None:
+                target = self.target_transform(target)
+        else:
+            target = None
+
+        return X, target
+
+    def __len__(self) -> int:
+        return len(self.attr)
+
+    def extra_repr(self) -> str:
+        lines = ["Target type: {target_type}", "Split: {split}"]
+        return "\n".join(lines).format(**self.__dict__)

+ 167 - 0
python/py/Lib/site-packages/torchvision/datasets/cifar.py

@@ -0,0 +1,167 @@
+import os.path
+import pickle
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive
+from .vision import VisionDataset
+
+
+class CIFAR10(VisionDataset):
+    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where directory
+            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
+        train (bool, optional): If True, creates dataset from training set, otherwise
+            creates from test set.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    base_folder = "cifar-10-batches-py"
+    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
+    filename = "cifar-10-python.tar.gz"
+    tgz_md5 = "c58f30108f718f92721af3b95e74349a"
+    train_list = [
+        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
+        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
+        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
+        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
+        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
+    ]
+
+    test_list = [
+        ["test_batch", "40351d587109b95175f43aff81a1287e"],
+    ]
+    meta = {
+        "filename": "batches.meta",
+        "key": "label_names",
+        "md5": "5ff9c542aee3614f3951f8cda6e48888",
+    }
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self.train = train  # training set or test set
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        if self.train:
+            downloaded_list = self.train_list
+        else:
+            downloaded_list = self.test_list
+
+        self.data: Any = []
+        self.targets = []
+
+        # now load the picked numpy arrays
+        for file_name, checksum in downloaded_list:
+            file_path = os.path.join(self.root, self.base_folder, file_name)
+            with open(file_path, "rb") as f:
+                entry = pickle.load(f, encoding="latin1")
+                self.data.append(entry["data"])
+                if "labels" in entry:
+                    self.targets.extend(entry["labels"])
+                else:
+                    self.targets.extend(entry["fine_labels"])
+
+        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
+        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
+
+        self._load_meta()
+
+    def _load_meta(self) -> None:
+        path = os.path.join(self.root, self.base_folder, self.meta["filename"])
+        if not check_integrity(path, self.meta["md5"]):
+            raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
+        with open(path, "rb") as infile:
+            data = pickle.load(infile, encoding="latin1")
+            self.classes = data[self.meta["key"]]
+        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], self.targets[index]
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(img)
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_integrity(self) -> bool:
+        for filename, md5 in self.train_list + self.test_list:
+            fpath = os.path.join(self.root, self.base_folder, filename)
+            if not check_integrity(fpath, md5):
+                return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            return
+        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
+
+    def extra_repr(self) -> str:
+        split = "Train" if self.train is True else "Test"
+        return f"Split: {split}"
+
+
+class CIFAR100(CIFAR10):
+    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
+
+    This is a subclass of the `CIFAR10` Dataset.
+    """
+
+    base_folder = "cifar-100-python"
+    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
+    filename = "cifar-100-python.tar.gz"
+    tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
+    train_list = [
+        ["train", "16019d7e3df5f24257cddd939b257f8d"],
+    ]
+
+    test_list = [
+        ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
+    ]
+    meta = {
+        "filename": "meta",
+        "key": "fine_label_names",
+        "md5": "7973b15100ade9c7d40fb424638fde48",
+    }

+ 222 - 0
python/py/Lib/site-packages/torchvision/datasets/cityscapes.py

@@ -0,0 +1,222 @@
+import json
+import os
+from collections import namedtuple
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import extract_archive, iterable_to_str, verify_str_arg
+from .vision import VisionDataset
+
+
+class Cityscapes(VisionDataset):
+    """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit``
+            and ``gtFine`` or ``gtCoarse`` are located.
+        split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
+            otherwise ``train``, ``train_extra`` or ``val``
+        mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
+        target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
+            or ``color``. Can also be a list to output a tuple with all specified target types.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+
+    Examples:
+
+        Get semantic segmentation target
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+                                 target_type='semantic')
+
+            img, smnt = dataset[0]
+
+        Get multiple targets
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+                                 target_type=['instance', 'color', 'polygon'])
+
+            img, (inst, col, poly) = dataset[0]
+
+        Validate on the "coarse" set
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
+                                 target_type='semantic')
+
+            img, smnt = dataset[0]
+    """
+
+    # Based on https://github.com/mcordts/cityscapesScripts
+    CityscapesClass = namedtuple(
+        "CityscapesClass",
+        ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
+    )
+
+    classes = [
+        CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
+        CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
+        CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
+        CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
+        CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
+        CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
+        CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
+        CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
+        CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
+        CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
+        CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
+        CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
+        CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
+        CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
+        CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
+        CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
+        CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
+        CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
+        CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
+        CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
+        CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
+        CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
+        CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
+        CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
+        CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
+        CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
+        CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
+        CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
+        CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
+        CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
+    ]
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        mode: str = "fine",
+        target_type: Union[list[str], str] = "instance",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms, transform, target_transform)
+        self.mode = "gtFine" if mode == "fine" else "gtCoarse"
+        self.images_dir = os.path.join(self.root, "leftImg8bit", split)
+        self.targets_dir = os.path.join(self.root, self.mode, split)
+        self.target_type = target_type
+        self.split = split
+        self.images = []
+        self.targets = []
+
+        verify_str_arg(mode, "mode", ("fine", "coarse"))
+        if mode == "fine":
+            valid_modes = ("train", "test", "val")
+        else:
+            valid_modes = ("train", "train_extra", "val")
+        msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
+        msg = msg.format(split, mode, iterable_to_str(valid_modes))
+        verify_str_arg(split, "split", valid_modes, msg)
+
+        if not isinstance(target_type, list):
+            self.target_type = [target_type]
+        [
+            verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color"))
+            for value in self.target_type
+        ]
+
+        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
+
+            if split == "train_extra":
+                image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip")
+            else:
+                image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip")
+
+            if self.mode == "gtFine":
+                target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip")
+            elif self.mode == "gtCoarse":
+                target_dir_zip = os.path.join(self.root, f"{self.mode}.zip")
+
+            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
+                extract_archive(from_path=image_dir_zip, to_path=self.root)
+                extract_archive(from_path=target_dir_zip, to_path=self.root)
+            else:
+                raise RuntimeError(
+                    "Dataset not found or incomplete. Please make sure all required folders for the"
+                    ' specified "split" and "mode" are inside the "root" directory'
+                )
+
+        for city in os.listdir(self.images_dir):
+            img_dir = os.path.join(self.images_dir, city)
+            target_dir = os.path.join(self.targets_dir, city)
+            for file_name in os.listdir(img_dir):
+                target_types = []
+                for t in self.target_type:
+                    target_name = "{}_{}".format(
+                        file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t)
+                    )
+                    target_types.append(os.path.join(target_dir, target_name))
+
+                self.images.append(os.path.join(img_dir, file_name))
+                self.targets.append(target_types)
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+        Returns:
+            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
+            than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation.
+        """
+
+        image = Image.open(self.images[index]).convert("RGB")
+
+        targets: Any = []
+        for i, t in enumerate(self.target_type):
+            if t == "polygon":
+                target = self._load_json(self.targets[index][i])
+            else:
+                target = Image.open(self.targets[index][i])  # type: ignore[assignment]
+
+            targets.append(target)
+
+        target = tuple(targets) if len(targets) > 1 else targets[0]  # type: ignore[assignment]
+
+        if self.transforms is not None:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+    def extra_repr(self) -> str:
+        lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
+        return "\n".join(lines).format(**self.__dict__)
+
+    def _load_json(self, path: str) -> dict[str, Any]:
+        with open(path) as file:
+            data = json.load(file)
+        return data
+
+    def _get_target_suffix(self, mode: str, target_type: str) -> str:
+        if target_type == "instance":
+            return f"{mode}_instanceIds.png"
+        elif target_type == "semantic":
+            return f"{mode}_labelIds.png"
+        elif target_type == "color":
+            return f"{mode}_color.png"
+        else:
+            return f"{mode}_polygons.json"

+ 93 - 0
python/py/Lib/site-packages/torchvision/datasets/clevr.py

@@ -0,0 +1,93 @@
+import json
+import pathlib
+from typing import Any, Callable, Optional, Union
+from urllib.parse import urlparse
+
+from .folder import default_loader
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class CLEVRClassification(VisionDataset):
+    """`CLEVR <https://cs.stanford.edu/people/jcjohns/clevr/>`_  classification dataset.
+
+    The number of objects in a scene are used as label.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
+            set to True.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in them target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
+            dataset is already downloaded, it is not downloaded again.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    _URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
+    _MD5 = "b11922020e72d0cd9154779b2d3d07d2"
+
+    def __init__(
+        self,
+        root: Union[str, pathlib.Path],
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.loader = loader
+        self._base_folder = pathlib.Path(self.root) / "clevr"
+        self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*"))
+
+        self._labels: list[Optional[int]]
+        if self._split != "test":
+            with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file:
+                content = json.load(file)
+            num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]}
+            self._labels = [num_objects[image_file.name] for image_file in self._image_files]
+        else:
+            self._labels = [None] * len(self._image_files)
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        image_file = self._image_files[idx]
+        label = self._labels[idx]
+
+        image = self.loader(image_file)
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def _check_exists(self) -> bool:
+        return self._data_folder.exists() and self._data_folder.is_dir()
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+
+        download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5)
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"

+ 111 - 0
python/py/Lib/site-packages/torchvision/datasets/coco.py

@@ -0,0 +1,111 @@
+import os.path
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .vision import VisionDataset
+
+
+class CocoDetection(VisionDataset):
+    """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
+
+    It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
+    which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        annFile: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms, transform, target_transform)
+        from pycocotools.coco import COCO
+
+        self.coco = COCO(annFile)
+        self.ids = list(sorted(self.coco.imgs.keys()))
+
+    def _load_image(self, id: int) -> Image.Image:
+        path = self.coco.loadImgs(id)[0]["file_name"]
+        return Image.open(os.path.join(self.root, path)).convert("RGB")
+
+    def _load_target(self, id: int) -> list[Any]:
+        return self.coco.loadAnns(self.coco.getAnnIds(id))
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+
+        if not isinstance(index, int):
+            raise ValueError(f"Index must be of type integer, got {type(index)} instead.")
+
+        id = self.ids[index]
+        image = self._load_image(id)
+        target = self._load_target(id)
+
+        if self.transforms is not None:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def __len__(self) -> int:
+        return len(self.ids)
+
+
+class CocoCaptions(CocoDetection):
+    """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
+
+    It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
+    which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+
+    Example:
+
+        .. code:: python
+
+            import torchvision.datasets as dset
+            import torchvision.transforms as transforms
+            cap = dset.CocoCaptions(root = 'dir where images are',
+                                    annFile = 'json annotation file',
+                                    transform=transforms.PILToTensor())
+
+            print('Number of samples: ', len(cap))
+            img, target = cap[3] # load 4th sample
+
+            print("Image Size: ", img.size())
+            print(target)
+
+        Output: ::
+
+            Number of samples: 82783
+            Image Size: (3L, 427L, 640L)
+            [u'A plane emitting smoke stream flying over a mountain.',
+            u'A plane darts across a bright blue sky behind a mountain covered in snow',
+            u'A plane leaves a contrail above the snowy mountain top.',
+            u'A mountain that has a plane flying overheard in the distance.',
+            u'A mountain view with a plume of smoke in the background']
+
+    """
+
+    def _load_target(self, id: int) -> list[str]:
+        return [ann["caption"] for ann in super()._load_target(id)]

+ 67 - 0
python/py/Lib/site-packages/torchvision/datasets/country211.py

@@ -0,0 +1,67 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader, ImageFolder
+from .utils import download_and_extract_archive, verify_str_arg
+
+
+class Country211(ImageFolder):
+    """`The Country211 Data Set <https://github.com/openai/CLIP/blob/main/data/country211.md>`_ from OpenAI.
+
+    This dataset was built by filtering the images from the YFCC100m dataset
+    that have GPS coordinate corresponding to a ISO-3166 country code. The
+    dataset is balanced by sampling 150 train images, 50 validation images, and
+    100 test images for each country.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and puts it into
+            ``root/country211/``. If dataset is already downloaded, it is not downloaded again.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
+    _MD5 = "84988d7644798601126c29e9877aab6a"
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
+
+        root = Path(root).expanduser()
+        self.root = str(root)
+        self._base_folder = root / "country211"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        super().__init__(
+            str(self._base_folder / self._split),
+            transform=transform,
+            target_transform=target_transform,
+            loader=loader,
+        )
+        self.root = str(root)
+
+    def _check_exists(self) -> bool:
+        return self._base_folder.exists() and self._base_folder.is_dir()
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

+ 105 - 0
python/py/Lib/site-packages/torchvision/datasets/dtd.py

@@ -0,0 +1,105 @@
+import os
+import pathlib
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class DTD(VisionDataset):
+    """`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+        partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
+
+            .. note::
+
+                The partition only changes which split each image belongs to. Thus, regardless of the selected
+                partition, combining all splits will result in all images.
+
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
+    _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"
+
+    def __init__(
+        self,
+        root: Union[str, pathlib.Path],
+        split: str = "train",
+        partition: int = 1,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        if not isinstance(partition, int) and not (1 <= partition <= 10):
+            raise ValueError(
+                f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
+                f"but got {partition} instead"
+            )
+        self._partition = partition
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
+        self._data_folder = self._base_folder / "dtd"
+        self._meta_folder = self._data_folder / "labels"
+        self._images_folder = self._data_folder / "images"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self._image_files = []
+        classes = []
+        with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
+            for line in file:
+                cls, name = line.strip().split("/")
+                self._image_files.append(self._images_folder.joinpath(cls, name))
+                classes.append(cls)
+
+        self.classes = sorted(set(classes))
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+        self._labels = [self.class_to_idx[cls] for cls in classes]
+        self.loader = loader
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = self.loader(image_file)
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}, partition={self._partition}"
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)

+ 71 - 0
python/py/Lib/site-packages/torchvision/datasets/eurosat.py

@@ -0,0 +1,71 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader, ImageFolder
+from .utils import download_and_extract_archive
+
+
+class EuroSAT(ImageFolder):
+    """RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
+
+    For the MS version of the dataset, see
+    `TorchGeo <https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat>`__.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where ``root/eurosat`` exists.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        self.root = os.path.expanduser(root)
+        self._base_folder = os.path.join(self.root, "eurosat")
+        self._data_folder = os.path.join(self._base_folder, "2750")
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        super().__init__(
+            self._data_folder,
+            transform=transform,
+            target_transform=target_transform,
+            loader=loader,
+        )
+        self.root = os.path.expanduser(root)
+
+    def __len__(self) -> int:
+        return len(self.samples)
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(self._data_folder)
+
+    def download(self) -> None:
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self._base_folder, exist_ok=True)
+        download_and_extract_archive(
+            "https://huggingface.co/datasets/torchgeo/eurosat/resolve/c877bcd43f099cd0196738f714544e355477f3fd/EuroSAT.zip",
+            download_root=self._base_folder,
+            md5="c8fa014336c82ac7804f0398fcb19387",
+        )

+ 67 - 0
python/py/Lib/site-packages/torchvision/datasets/fakedata.py

@@ -0,0 +1,67 @@
+from typing import Any, Callable, Optional
+
+import torch
+
+from .. import transforms
+from .vision import VisionDataset
+
+
+class FakeData(VisionDataset):
+    """A fake dataset that returns randomly generated images and returns them as PIL images
+
+    Args:
+        size (int, optional): Size of the dataset. Default: 1000 images
+        image_size(tuple, optional): Size of the returned images. Default: (3, 224, 224)
+        num_classes(int, optional): Number of classes in the dataset. Default: 10
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        random_offset (int): Offsets the index-based random seed used to
+            generate each image. Default: 0
+
+    """
+
+    def __init__(
+        self,
+        size: int = 1000,
+        image_size: tuple[int, int, int] = (3, 224, 224),
+        num_classes: int = 10,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        random_offset: int = 0,
+    ) -> None:
+        super().__init__(transform=transform, target_transform=target_transform)
+        self.size = size
+        self.num_classes = num_classes
+        self.image_size = image_size
+        self.random_offset = random_offset
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is class_index of the target class.
+        """
+        # create random image that is consistent with the index id
+        if index >= len(self):
+            raise IndexError(f"{self.__class__.__name__} index out of range")
+        rng_state = torch.get_rng_state()
+        torch.manual_seed(index + self.random_offset)
+        img = torch.randn(*self.image_size)
+        target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
+        torch.set_rng_state(rng_state)
+
+        # convert to PIL Image
+        img = transforms.ToPILImage()(img)
+        if self.transform is not None:
+            img = self.transform(img)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target.item()
+
+    def __len__(self) -> int:
+        return self.size

+ 120 - 0
python/py/Lib/site-packages/torchvision/datasets/fer2013.py

@@ -0,0 +1,120 @@
+import csv
+import pathlib
+from typing import Any, Callable, Optional, Union
+
+import torch
+from PIL import Image
+
+from .utils import check_integrity, verify_str_arg
+from .vision import VisionDataset
+
+
+class FER2013(VisionDataset):
+    """`FER2013
+    <https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
+
+    .. note::
+        This dataset can return test labels only if ``fer2013.csv`` OR
+        ``icml_face_data.csv`` are present in ``root/fer2013/``. If only
+        ``train.csv`` and ``test.csv`` are present, the test labels are set to
+        ``None``.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where directory
+            ``root/fer2013`` exists. This directory may contain either
+            ``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and
+            ``test.csv``. Precendence is given in that order, i.e. if
+            ``fer2013.csv`` is present then the rest of the files will be
+            ignored. All these (combinations of) files contain the same data and
+            are supported for convenience, but only ``fer2013.csv`` and
+            ``icml_face_data.csv`` are able to return non-None test labels.
+        split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
+        transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+    """
+
+    _RESOURCES = {
+        "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
+        "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
+        # The fer2013.csv and icml_face_data.csv files contain both train and
+        # tests instances, and unlike test.csv they contain the labels for the
+        # test instances. We give these 2 files precedence over train.csv and
+        # test.csv. And yes, they both contain the same data, but with different
+        # column names (note the spaces) and ordering:
+        # $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv
+        # ==> fer2013.csv <==
+        # emotion,pixels,Usage
+        #
+        # ==> icml_face_data.csv <==
+        # emotion, Usage, pixels
+        #
+        # ==> train.csv <==
+        # emotion,pixels
+        #
+        # ==> test.csv <==
+        # pixels
+        "fer": ("fer2013.csv", "f8428a1edbd21e88f42c73edd2a14f95"),
+        "icml": ("icml_face_data.csv", "b114b9e04e6949e5fe8b6a98b3892b1d"),
+    }
+
+    def __init__(
+        self,
+        root: Union[str, pathlib.Path],
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", ("train", "test"))
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        base_folder = pathlib.Path(self.root) / "fer2013"
+        use_fer_file = (base_folder / self._RESOURCES["fer"][0]).exists()
+        use_icml_file = not use_fer_file and (base_folder / self._RESOURCES["icml"][0]).exists()
+        file_name, md5 = self._RESOURCES["fer" if use_fer_file else "icml" if use_icml_file else self._split]
+        data_file = base_folder / file_name
+        if not check_integrity(str(data_file), md5=md5):
+            raise RuntimeError(
+                f"{file_name} not found in {base_folder} or corrupted. "
+                f"You can download it from "
+                f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
+            )
+
+        pixels_key = " pixels" if use_icml_file else "pixels"
+        usage_key = " Usage" if use_icml_file else "Usage"
+
+        def get_img(row):
+            return torch.tensor([int(idx) for idx in row[pixels_key].split()], dtype=torch.uint8).reshape(48, 48)
+
+        def get_label(row):
+            if use_fer_file or use_icml_file or self._split == "train":
+                return int(row["emotion"])
+            else:
+                return None
+
+        with open(data_file, newline="") as file:
+            rows = (row for row in csv.DictReader(file))
+
+            if use_fer_file or use_icml_file:
+                valid_keys = ("Training",) if self._split == "train" else ("PublicTest", "PrivateTest")
+                rows = (row for row in rows if row[usage_key] in valid_keys)
+
+            self._samples = [(get_img(row), get_label(row)) for row in rows]
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        image_tensor, target = self._samples[idx]
+        image = Image.fromarray(image_tensor.numpy())
+
+        if self.transform is not None:
+            image = self.transform(image)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return image, target
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"

+ 120 - 0
python/py/Lib/site-packages/torchvision/datasets/fgvc_aircraft.py

@@ -0,0 +1,120 @@
+from __future__ import annotations
+
+import os
+from pathlib import Path
+from typing import Any, Callable
+
+from .folder import default_loader
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class FGVCAircraft(VisionDataset):
+    """`FGVC Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
+
+    The dataset contains 10,000 images of aircraft, with 100 images for each of 100
+    different aircraft model variants, most of which are airplanes.
+    Aircraft models are organized in a three-levels hierarchy. The three levels, from
+    finer to coarser, are:
+
+    - ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually
+        indistinguishable into one class. The dataset comprises 100 different variants.
+    - ``family``, e.g. Boeing 737. The dataset comprises 70 different families.
+    - ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the FGVC Aircraft dataset.
+        split (string, optional): The dataset split, supports ``train``, ``val``,
+            ``trainval`` and ``test``.
+        annotation_level (str, optional): The annotation level, supports ``variant``,
+            ``family`` and ``manufacturer``.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
+
+    def __init__(
+        self,
+        root: str | Path,
+        split: str = "trainval",
+        annotation_level: str = "variant",
+        transform: Callable | None = None,
+        target_transform: Callable | None = None,
+        download: bool = False,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
+        self._annotation_level = verify_str_arg(
+            annotation_level, "annotation_level", ("variant", "family", "manufacturer")
+        )
+
+        self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        annotation_file = os.path.join(
+            self._data_path,
+            "data",
+            {
+                "variant": "variants.txt",
+                "family": "families.txt",
+                "manufacturer": "manufacturers.txt",
+            }[self._annotation_level],
+        )
+        with open(annotation_file) as f:
+            self.classes = [line.strip() for line in f]
+
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+        image_data_folder = os.path.join(self._data_path, "data", "images")
+        labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt")
+
+        self._image_files = []
+        self._labels = []
+
+        with open(labels_file) as f:
+            for line in f:
+                image_name, label_name = line.strip().split(" ", 1)
+                self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
+                self._labels.append(self.class_to_idx[label_name])
+        self.loader = loader
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = self.loader(image_file)
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def _download(self) -> None:
+        """
+        Download the FGVC Aircraft dataset archive and extract it under root.
+        """
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, self.root)
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(self._data_path) and os.path.isdir(self._data_path)

+ 176 - 0
python/py/Lib/site-packages/torchvision/datasets/flickr.py

@@ -0,0 +1,176 @@
+import glob
+import os
+from collections import defaultdict
+from html.parser import HTMLParser
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+from .vision import VisionDataset
+
+
+class Flickr8kParser(HTMLParser):
+    """Parser for extracting captions from the Flickr8k dataset web page."""
+
+    def __init__(self, root: Union[str, Path]) -> None:
+        super().__init__()
+
+        self.root = root
+
+        # Data structure to store captions
+        self.annotations: dict[str, list[str]] = {}
+
+        # State variables
+        self.in_table = False
+        self.current_tag: Optional[str] = None
+        self.current_img: Optional[str] = None
+
+    def handle_starttag(self, tag: str, attrs: list[tuple[str, Optional[str]]]) -> None:
+        self.current_tag = tag
+
+        if tag == "table":
+            self.in_table = True
+
+    def handle_endtag(self, tag: str) -> None:
+        self.current_tag = None
+
+        if tag == "table":
+            self.in_table = False
+
+    def handle_data(self, data: str) -> None:
+        if self.in_table:
+            if data == "Image Not Found":
+                self.current_img = None
+            elif self.current_tag == "a":
+                img_id = data.split("/")[-2]
+                img_id = os.path.join(self.root, img_id + "_*.jpg")
+                img_id = glob.glob(img_id)[0]
+                self.current_img = img_id
+                self.annotations[img_id] = []
+            elif self.current_tag == "li" and self.current_img:
+                img_id = self.current_img
+                self.annotations[img_id].append(data.strip())
+
+
+class Flickr8k(VisionDataset):
+    """`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+        ann_file (string): Path to annotation file.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        ann_file: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.ann_file = os.path.expanduser(ann_file)
+
+        # Read annotations and store in a dict
+        parser = Flickr8kParser(self.root)
+        with open(self.ann_file) as fh:
+            parser.feed(fh.read())
+        self.annotations = parser.annotations
+
+        self.ids = list(sorted(self.annotations.keys()))
+        self.loader = loader
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is a list of captions for the image.
+        """
+        img_id = self.ids[index]
+
+        # Image
+        img = self.loader(img_id)
+        if self.transform is not None:
+            img = self.transform(img)
+
+        # Captions
+        target = self.annotations[img_id]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.ids)
+
+
+class Flickr30k(VisionDataset):
+    """`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+        ann_file (string): Path to annotation file.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        ann_file: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.ann_file = os.path.expanduser(ann_file)
+
+        # Read annotations and store in a dict
+        self.annotations = defaultdict(list)
+        with open(self.ann_file) as fh:
+            for line in fh:
+                img_id, caption = line.strip().split("\t")
+                self.annotations[img_id[:-2]].append(caption)
+
+        self.ids = list(sorted(self.annotations.keys()))
+        self.loader = loader
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is a list of captions for the image.
+        """
+        img_id = self.ids[index]
+
+        # Image
+        filename = os.path.join(self.root, img_id)
+        img = self.loader(filename)
+        if self.transform is not None:
+            img = self.transform(img)
+
+        # Captions
+        target = self.annotations[img_id]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.ids)

+ 225 - 0
python/py/Lib/site-packages/torchvision/datasets/flowers102.py

@@ -0,0 +1,225 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class Flowers102(VisionDataset):
+    """`Oxford 102 Flower <https://www.robots.ox.ac.uk/~vgg/data/flowers/102/>`_ Dataset.
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The
+    flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of
+    between 40 and 258 images.
+
+    The images have large scale, pose and light variations. In addition, there are categories that
+    have large variations within the category, and several very similar categories.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
+    _file_dict = {  # filename, md5
+        "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),
+        "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),
+        "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),
+    }
+    _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[Union[str, Path]], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        self._base_folder = Path(self.root) / "flowers-102"
+        self._images_folder = self._base_folder / "jpg"
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        from scipy.io import loadmat
+
+        set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)
+        image_ids = set_ids[self._splits_map[self._split]].tolist()
+
+        labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)
+        image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1))
+
+        self._labels = []
+        self._image_files = []
+        for image_id in image_ids:
+            self._labels.append(image_id_to_label[image_id])
+            self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")
+
+        self.loader = loader
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = self.loader(image_file)
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"
+
+    def _check_integrity(self):
+        if not (self._images_folder.exists() and self._images_folder.is_dir()):
+            return False
+
+        for id in ["label", "setid"]:
+            filename, md5 = self._file_dict[id]
+            if not check_integrity(str(self._base_folder / filename), md5):
+                return False
+        return True
+
+    def download(self):
+        if self._check_integrity():
+            return
+        download_and_extract_archive(
+            f"{self._download_url_prefix}{self._file_dict['image'][0]}",
+            str(self._base_folder),
+            md5=self._file_dict["image"][1],
+        )
+        for id in ["label", "setid"]:
+            filename, md5 = self._file_dict[id]
+            download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)
+
+    classes = [
+        "pink primrose",
+        "hard-leaved pocket orchid",
+        "canterbury bells",
+        "sweet pea",
+        "english marigold",
+        "tiger lily",
+        "moon orchid",
+        "bird of paradise",
+        "monkshood",
+        "globe thistle",
+        "snapdragon",
+        "colt's foot",
+        "king protea",
+        "spear thistle",
+        "yellow iris",
+        "globe-flower",
+        "purple coneflower",
+        "peruvian lily",
+        "balloon flower",
+        "giant white arum lily",
+        "fire lily",
+        "pincushion flower",
+        "fritillary",
+        "red ginger",
+        "grape hyacinth",
+        "corn poppy",
+        "prince of wales feathers",
+        "stemless gentian",
+        "artichoke",
+        "sweet william",
+        "carnation",
+        "garden phlox",
+        "love in the mist",
+        "mexican aster",
+        "alpine sea holly",
+        "ruby-lipped cattleya",
+        "cape flower",
+        "great masterwort",
+        "siam tulip",
+        "lenten rose",
+        "barbeton daisy",
+        "daffodil",
+        "sword lily",
+        "poinsettia",
+        "bolero deep blue",
+        "wallflower",
+        "marigold",
+        "buttercup",
+        "oxeye daisy",
+        "common dandelion",
+        "petunia",
+        "wild pansy",
+        "primula",
+        "sunflower",
+        "pelargonium",
+        "bishop of llandaff",
+        "gaura",
+        "geranium",
+        "orange dahlia",
+        "pink-yellow dahlia?",
+        "cautleya spicata",
+        "japanese anemone",
+        "black-eyed susan",
+        "silverbush",
+        "californian poppy",
+        "osteospermum",
+        "spring crocus",
+        "bearded iris",
+        "windflower",
+        "tree poppy",
+        "gazania",
+        "azalea",
+        "water lily",
+        "rose",
+        "thorn apple",
+        "morning glory",
+        "passion flower",
+        "lotus",
+        "toad lily",
+        "anthurium",
+        "frangipani",
+        "clematis",
+        "hibiscus",
+        "columbine",
+        "desert-rose",
+        "tree mallow",
+        "magnolia",
+        "cyclamen",
+        "watercress",
+        "canna lily",
+        "hippeastrum",
+        "bee balm",
+        "ball moss",
+        "foxglove",
+        "bougainvillea",
+        "camellia",
+        "mallow",
+        "mexican petunia",
+        "bromelia",
+        "blanket flower",
+        "trumpet creeper",
+        "blackberry lily",
+    ]

+ 337 - 0
python/py/Lib/site-packages/torchvision/datasets/folder.py

@@ -0,0 +1,337 @@
+import os
+import os.path
+from pathlib import Path
+from typing import Any, Callable, cast, Optional, Union
+
+from PIL import Image
+
+from .vision import VisionDataset
+
+
+def has_file_allowed_extension(filename: str, extensions: Union[str, tuple[str, ...]]) -> bool:
+    """Checks if a file is an allowed extension.
+
+    Args:
+        filename (string): path to a file
+        extensions (tuple of strings): extensions to consider (lowercase)
+
+    Returns:
+        bool: True if the filename ends with one of given extensions
+    """
+    return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
+
+
+def is_image_file(filename: str) -> bool:
+    """Checks if a file is an allowed image extension.
+
+    Args:
+        filename (string): path to a file
+
+    Returns:
+        bool: True if the filename ends with a known image extension
+    """
+    return has_file_allowed_extension(filename, IMG_EXTENSIONS)
+
+
+def find_classes(directory: Union[str, Path]) -> tuple[list[str], dict[str, int]]:
+    """Finds the class folders in a dataset.
+
+    See :class:`DatasetFolder` for details.
+    """
+    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
+    if not classes:
+        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
+
+    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
+    return classes, class_to_idx
+
+
+def make_dataset(
+    directory: Union[str, Path],
+    class_to_idx: Optional[dict[str, int]] = None,
+    extensions: Optional[Union[str, tuple[str, ...]]] = None,
+    is_valid_file: Optional[Callable[[str], bool]] = None,
+    allow_empty: bool = False,
+) -> list[tuple[str, int]]:
+    """Generates a list of samples of a form (path_to_sample, class).
+
+    See :class:`DatasetFolder` for details.
+
+    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
+    by default.
+    """
+    directory = os.path.expanduser(directory)
+
+    if class_to_idx is None:
+        _, class_to_idx = find_classes(directory)
+    elif not class_to_idx:
+        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
+
+    both_none = extensions is None and is_valid_file is None
+    both_something = extensions is not None and is_valid_file is not None
+    if both_none or both_something:
+        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
+
+    if extensions is not None:
+
+        def is_valid_file(x: str) -> bool:
+            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]
+
+    is_valid_file = cast(Callable[[str], bool], is_valid_file)
+
+    instances = []
+    available_classes = set()
+    for target_class in sorted(class_to_idx.keys()):
+        class_index = class_to_idx[target_class]
+        target_dir = os.path.join(directory, target_class)
+        if not os.path.isdir(target_dir):
+            continue
+        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
+            for fname in sorted(fnames):
+                path = os.path.join(root, fname)
+                if is_valid_file(path):
+                    item = path, class_index
+                    instances.append(item)
+
+                    if target_class not in available_classes:
+                        available_classes.add(target_class)
+
+    empty_classes = set(class_to_idx.keys()) - available_classes
+    if empty_classes and not allow_empty:
+        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
+        if extensions is not None:
+            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
+        raise FileNotFoundError(msg)
+
+    return instances
+
+
+class DatasetFolder(VisionDataset):
+    """A generic data loader.
+
+    This default directory structure can be customized by overriding the
+    :meth:`find_classes` method.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory path.
+        loader (callable): A function to load a sample given its path.
+        extensions (tuple[string]): A list of allowed extensions.
+            both extensions and is_valid_file should not be passed.
+        transform (callable, optional): A function/transform that takes in
+            a sample and returns a transformed version.
+            E.g, ``transforms.RandomCrop`` for images.
+        target_transform (callable, optional): A function/transform that takes
+            in the target and transforms it.
+        is_valid_file (callable, optional): A function that takes path of a file
+            and check if the file is a valid file (used to check of corrupt files)
+            both extensions and is_valid_file should not be passed.
+        allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
+            An error is raised on empty folders if False (default).
+
+     Attributes:
+        classes (list): List of the class names sorted alphabetically.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        samples (list): List of (sample path, class_index) tuples
+        targets (list): The class_index value for each image in the dataset
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        loader: Callable[[str], Any],
+        extensions: Optional[tuple[str, ...]] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        is_valid_file: Optional[Callable[[str], bool]] = None,
+        allow_empty: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        classes, class_to_idx = self.find_classes(self.root)
+        samples = self.make_dataset(
+            self.root,
+            class_to_idx=class_to_idx,
+            extensions=extensions,
+            is_valid_file=is_valid_file,
+            allow_empty=allow_empty,
+        )
+
+        self.loader = loader
+        self.extensions = extensions
+
+        self.classes = classes
+        self.class_to_idx = class_to_idx
+        self.samples = samples
+        self.targets = [s[1] for s in samples]
+
+    @staticmethod
+    def make_dataset(
+        directory: Union[str, Path],
+        class_to_idx: dict[str, int],
+        extensions: Optional[tuple[str, ...]] = None,
+        is_valid_file: Optional[Callable[[str], bool]] = None,
+        allow_empty: bool = False,
+    ) -> list[tuple[str, int]]:
+        """Generates a list of samples of a form (path_to_sample, class).
+
+        This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
+
+        Args:
+            directory (str): root dataset directory, corresponding to ``self.root``.
+            class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
+            extensions (optional): A list of allowed extensions.
+                Either extensions or is_valid_file should be passed. Defaults to None.
+            is_valid_file (optional): A function that takes path of a file
+                and checks if the file is a valid file
+                (used to check of corrupt files) both extensions and
+                is_valid_file should not be passed. Defaults to None.
+            allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
+                An error is raised on empty folders if False (default).
+
+        Raises:
+            ValueError: In case ``class_to_idx`` is empty.
+            ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
+            FileNotFoundError: In case no valid file was found for any class.
+
+        Returns:
+            List[Tuple[str, int]]: samples of a form (path_to_sample, class)
+        """
+        if class_to_idx is None:
+            # prevent potential bug since make_dataset() would use the class_to_idx logic of the
+            # find_classes() function, instead of using that of the find_classes() method, which
+            # is potentially overridden and thus could have a different logic.
+            raise ValueError("The class_to_idx parameter cannot be None.")
+        return make_dataset(
+            directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty
+        )
+
+    def find_classes(self, directory: Union[str, Path]) -> tuple[list[str], dict[str, int]]:
+        """Find the class folders in a dataset structured as follows::
+
+            directory/
+            ├── class_x
+            │   ├── xxx.ext
+            │   ├── xxy.ext
+            │   └── ...
+            │       └── xxz.ext
+            └── class_y
+                ├── 123.ext
+                ├── nsdf3.ext
+                └── ...
+                └── asd932_.ext
+
+        This method can be overridden to only consider
+        a subset of classes, or to adapt to a different dataset directory structure.
+
+        Args:
+            directory(str): Root directory path, corresponding to ``self.root``
+
+        Raises:
+            FileNotFoundError: If ``dir`` has no class folders.
+
+        Returns:
+            (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
+        """
+        return find_classes(directory)
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (sample, target) where target is class_index of the target class.
+        """
+        path, target = self.samples[index]
+        sample = self.loader(path)
+        if self.transform is not None:
+            sample = self.transform(sample)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return sample, target
+
+    def __len__(self) -> int:
+        return len(self.samples)
+
+
+IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
+
+
+def pil_loader(path: Union[str, Path]) -> Image.Image:
+    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+    with open(path, "rb") as f:
+        img = Image.open(f)
+        return img.convert("RGB")
+
+
+# TODO: specify the return type
+def accimage_loader(path: Union[str, Path]) -> Any:
+    import accimage
+
+    try:
+        return accimage.Image(path)
+    except OSError:
+        # Potentially a decoding problem, fall back to PIL.Image
+        return pil_loader(path)
+
+
+def default_loader(path: Union[str, Path]) -> Any:
+    from torchvision import get_image_backend
+
+    if get_image_backend() == "accimage":
+        return accimage_loader(path)
+    else:
+        return pil_loader(path)
+
+
+class ImageFolder(DatasetFolder):
+    """A generic data loader where the images are arranged in this way by default: ::
+
+        root/dog/xxx.png
+        root/dog/xxy.png
+        root/dog/[...]/xxz.png
+
+        root/cat/123.png
+        root/cat/nsdf3.png
+        root/cat/[...]/asd932_.png
+
+    This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
+    the same methods can be overridden to customize the dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory path.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+        is_valid_file (callable, optional): A function that takes path of an Image file
+            and check if the file is a valid file (used to check of corrupt files)
+        allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
+            An error is raised on empty folders if False (default).
+
+     Attributes:
+        classes (list): List of the class names sorted alphabetically.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        imgs (list): List of (image path, class_index) tuples
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+        is_valid_file: Optional[Callable[[str], bool]] = None,
+        allow_empty: bool = False,
+    ):
+        super().__init__(
+            root,
+            loader,
+            IMG_EXTENSIONS if is_valid_file is None else None,
+            transform=transform,
+            target_transform=target_transform,
+            is_valid_file=is_valid_file,
+            allow_empty=allow_empty,
+        )
+        self.imgs = self.samples

+ 98 - 0
python/py/Lib/site-packages/torchvision/datasets/food101.py

@@ -0,0 +1,98 @@
+import json
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Food101(VisionDataset):
+    """`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.
+
+    The Food-101 is a challenging data set of 101 food categories with 101,000 images.
+    For each class, 250 manually reviewed test images are provided as well as 750 training images.
+    On purpose, the training images were not cleaned, and thus still contain some amount of noise.
+    This comes mostly in the form of intense colors and sometimes wrong labels. All images were
+    rescaled to have a maximum side length of 512 pixels.
+
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
+    _MD5 = "85eeb15f3717b99a5da872d97d918f87"
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[Union[str, Path]], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "test"))
+        self._base_folder = Path(self.root) / "food-101"
+        self._meta_folder = self._base_folder / "meta"
+        self._images_folder = self._base_folder / "images"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self._labels = []
+        self._image_files = []
+        with open(self._meta_folder / f"{split}.json") as f:
+            metadata = json.loads(f.read())
+
+        self.classes = sorted(metadata.keys())
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+        for class_label, im_rel_paths in metadata.items():
+            self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
+            self._image_files += [
+                self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
+            ]
+        self.loader = loader
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = self.loader(image_file)
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"
+
+    def _check_exists(self) -> bool:
+        return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

+ 103 - 0
python/py/Lib/site-packages/torchvision/datasets/gtsrb.py

@@ -0,0 +1,103 @@
+import csv
+import pathlib
+from typing import Any, Callable, Optional, Union
+
+import PIL
+
+from .folder import make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class GTSRB(VisionDataset):
+    """`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
+        transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    def __init__(
+        self,
+        root: Union[str, pathlib.Path],
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self._split = verify_str_arg(split, "split", ("train", "test"))
+        self._base_folder = pathlib.Path(root) / "gtsrb"
+        self._target_folder = (
+            self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
+        )
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        if self._split == "train":
+            samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
+        else:
+            with open(self._base_folder / "GT-final_test.csv") as csv_file:
+                samples = [
+                    (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
+                    for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
+                ]
+
+        self._samples = samples
+        self.transform = transform
+        self.target_transform = target_transform
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+
+        path, target = self._samples[index]
+        sample = PIL.Image.open(path).convert("RGB")
+
+        if self.transform is not None:
+            sample = self.transform(sample)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return sample, target
+
+    def _check_exists(self) -> bool:
+        return self._target_folder.is_dir()
+
+    def download(self) -> None:
+        if self._check_exists():
+            return
+
+        base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
+
+        if self._split == "train":
+            download_and_extract_archive(
+                f"{base_url}GTSRB-Training_fixed.zip",
+                download_root=str(self._base_folder),
+                md5="513f3c79a4c5141765e10e952eaa2478",
+            )
+        else:
+            download_and_extract_archive(
+                f"{base_url}GTSRB_Final_Test_Images.zip",
+                download_root=str(self._base_folder),
+                md5="c7e4e6327067d32654124b0fe9e82185",
+            )
+            download_and_extract_archive(
+                f"{base_url}GTSRB_Final_Test_GT.zip",
+                download_root=str(self._base_folder),
+                md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
+            )

+ 152 - 0
python/py/Lib/site-packages/torchvision/datasets/hmdb51.py

@@ -0,0 +1,152 @@
+import glob
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+class HMDB51(VisionDataset):
+    """
+    `HMDB51 <https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/>`_
+    dataset.
+
+    HMDB51 is an action recognition video dataset.
+    This dataset consider every video as a collection of video clips of fixed size, specified
+    by ``frames_per_clip``, where the step in frames between each clip is given by
+    ``step_between_clips``.
+
+    To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+    and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+    elements will come from video 1, and the next three elements from video 2.
+    Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+    frames in a video might be present.
+
+    Internally, it uses a VideoClips object to handle clip creation.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the HMDB51 Dataset.
+        annotation_path (str): Path to the folder containing the split files.
+        frames_per_clip (int): Number of frames in a clip.
+        step_between_clips (int): Number of frames between each clip.
+        fold (int, optional): Which fold to use. Should be between 1 and 3.
+        train (bool, optional): If ``True``, creates a dataset from the train split,
+            otherwise from the ``test`` split.
+        transform (callable, optional): A function/transform that takes in a TxHxWxC video
+            and returns a transformed version.
+        output_format (str, optional): The format of the output video tensors (before transforms).
+            Can be either "THWC" (default) or "TCHW".
+
+    Returns:
+        tuple: A 3-tuple with the following entries:
+
+            - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
+            - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+              and `L` is the number of points
+            - label (int): class of the video clip
+    """
+
+    data_url = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
+    splits = {
+        "url": "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
+        "md5": "15e67781e70dcfbdce2d7dbb9b3344b5",
+    }
+    TRAIN_TAG = 1
+    TEST_TAG = 2
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        annotation_path: str,
+        frames_per_clip: int,
+        step_between_clips: int = 1,
+        frame_rate: Optional[int] = None,
+        fold: int = 1,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        _precomputed_metadata: Optional[dict[str, Any]] = None,
+        num_workers: int = 1,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _audio_samples: int = 0,
+        output_format: str = "THWC",
+    ) -> None:
+        super().__init__(root)
+        if fold not in (1, 2, 3):
+            raise ValueError(f"fold should be between 1 and 3, got {fold}")
+
+        extensions = ("avi",)
+        self.classes, class_to_idx = find_classes(self.root)
+        self.samples = make_dataset(
+            self.root,
+            class_to_idx,
+            extensions,
+        )
+
+        video_paths = [path for (path, _) in self.samples]
+        video_clips = VideoClips(
+            video_paths,
+            frames_per_clip,
+            step_between_clips,
+            frame_rate,
+            _precomputed_metadata,
+            num_workers=num_workers,
+            _video_width=_video_width,
+            _video_height=_video_height,
+            _video_min_dimension=_video_min_dimension,
+            _audio_samples=_audio_samples,
+            output_format=output_format,
+        )
+        # we bookkeep the full version of video clips because we want to be able
+        # to return the metadata of full version rather than the subset version of
+        # video clips
+        self.full_video_clips = video_clips
+        self.fold = fold
+        self.train = train
+        self.indices = self._select_fold(video_paths, annotation_path, fold, train)
+        self.video_clips = video_clips.subset(self.indices)
+        self.transform = transform
+
+    @property
+    def metadata(self) -> dict[str, Any]:
+        return self.full_video_clips.metadata
+
+    def _select_fold(self, video_list: list[str], annotations_dir: str, fold: int, train: bool) -> list[int]:
+        target_tag = self.TRAIN_TAG if train else self.TEST_TAG
+        split_pattern_name = f"*test_split{fold}.txt"
+        split_pattern_path = os.path.join(annotations_dir, split_pattern_name)
+        annotation_paths = glob.glob(split_pattern_path)
+        selected_files = set()
+        for filepath in annotation_paths:
+            with open(filepath) as fid:
+                lines = fid.readlines()
+            for line in lines:
+                video_filename, tag_string = line.split()
+                tag = int(tag_string)
+                if tag == target_tag:
+                    selected_files.add(video_filename)
+
+        indices = []
+        for video_index, video_path in enumerate(video_list):
+            if os.path.basename(video_path) in selected_files:
+                indices.append(video_index)
+
+        return indices
+
+    def __len__(self) -> int:
+        return self.video_clips.num_clips()
+
+    def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, int]:
+        video, audio, _, video_idx = self.video_clips.get_clip(idx)
+        sample_index = self.indices[video_idx]
+        _, class_index = self.samples[sample_index]
+
+        if self.transform is not None:
+            video = self.transform(video)
+
+        return video, audio, class_index

+ 222 - 0
python/py/Lib/site-packages/torchvision/datasets/imagenet.py

@@ -0,0 +1,222 @@
+import os
+import shutil
+import tempfile
+from collections.abc import Iterator
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Any, Optional, Union
+
+import torch
+
+from .folder import ImageFolder
+from .utils import check_integrity, extract_archive, verify_str_arg
+
+ARCHIVE_META = {
+    "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
+    "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
+    "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
+}
+
+META_FILE = "meta.bin"
+
+
+class ImageNet(ImageFolder):
+    """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
+
+    .. note::
+        Before using this class, it is required to download ImageNet 2012 dataset from
+        `here <https://image-net.org/challenges/LSVRC/2012/2012-downloads.php>`_ and
+        place the files ``ILSVRC2012_devkit_t12.tar.gz`` and ``ILSVRC2012_img_train.tar``
+        or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the ImageNet Dataset.
+        split (string, optional): The dataset split, supports ``train``, or ``val``.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+
+     Attributes:
+        classes (list): List of the class name tuples.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        wnids (list): List of the WordNet IDs.
+        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
+        imgs (list): List of (image path, class_index) tuples
+        targets (list): The class_index value for each image in the dataset
+    """
+
+    def __init__(self, root: Union[str, Path], split: str = "train", **kwargs: Any) -> None:
+        root = self.root = os.path.expanduser(root)
+        self.split = verify_str_arg(split, "split", ("train", "val"))
+
+        self.parse_archives()
+        wnid_to_classes = load_meta_file(self.root)[0]
+
+        super().__init__(self.split_folder, **kwargs)
+        self.root = root
+
+        self.wnids = self.classes
+        self.wnid_to_idx = self.class_to_idx
+        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
+        self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss}
+
+    def parse_archives(self) -> None:
+        if not check_integrity(os.path.join(self.root, META_FILE)):
+            parse_devkit_archive(self.root)
+
+        if not os.path.isdir(self.split_folder):
+            if self.split == "train":
+                parse_train_archive(self.root)
+            elif self.split == "val":
+                parse_val_archive(self.root)
+
+    @property
+    def split_folder(self) -> str:
+        return os.path.join(self.root, self.split)
+
+    def extra_repr(self) -> str:
+        return "Split: {split}".format(**self.__dict__)
+
+
+def load_meta_file(root: Union[str, Path], file: Optional[str] = None) -> tuple[dict[str, str], list[str]]:
+    if file is None:
+        file = META_FILE
+    file = os.path.join(root, file)
+
+    if check_integrity(file):
+        return torch.load(file, weights_only=True)
+    else:
+        msg = (
+            "The meta file {} is not present in the root directory or is corrupted. "
+            "This file is automatically created by the ImageNet dataset."
+        )
+        raise RuntimeError(msg.format(file, root))
+
+
+def _verify_archive(root: Union[str, Path], file: str, md5: str) -> None:
+    if not check_integrity(os.path.join(root, file), md5):
+        msg = (
+            "The archive {} is not present in the root directory or is corrupted. "
+            "You need to download it externally and place it in {}."
+        )
+        raise RuntimeError(msg.format(file, root))
+
+
+def parse_devkit_archive(root: Union[str, Path], file: Optional[str] = None) -> None:
+    """Parse the devkit archive of the ImageNet2012 classification dataset and save
+    the meta information in a binary file.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory containing the devkit archive
+        file (str, optional): Name of devkit archive. Defaults to
+            'ILSVRC2012_devkit_t12.tar.gz'
+    """
+    import scipy.io as sio
+
+    def parse_meta_mat(devkit_root: str) -> tuple[dict[int, str], dict[str, tuple[str, ...]]]:
+        metafile = os.path.join(devkit_root, "data", "meta.mat")
+        meta = sio.loadmat(metafile, squeeze_me=True)["synsets"]
+        nums_children = list(zip(*meta))[4]
+        meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
+        idcs, wnids, classes = list(zip(*meta))[:3]
+        classes = [tuple(clss.split(", ")) for clss in classes]
+        idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
+        wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
+        return idx_to_wnid, wnid_to_classes
+
+    def parse_val_groundtruth_txt(devkit_root: str) -> list[int]:
+        file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
+        with open(file) as txtfh:
+            val_idcs = txtfh.readlines()
+        return [int(val_idx) for val_idx in val_idcs]
+
+    @contextmanager
+    def get_tmp_dir() -> Iterator[str]:
+        tmp_dir = tempfile.mkdtemp()
+        try:
+            yield tmp_dir
+        finally:
+            shutil.rmtree(tmp_dir)
+
+    archive_meta = ARCHIVE_META["devkit"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
+
+    _verify_archive(root, file, md5)
+
+    with get_tmp_dir() as tmp_dir:
+        extract_archive(os.path.join(root, file), tmp_dir)
+
+        devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
+        idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
+        val_idcs = parse_val_groundtruth_txt(devkit_root)
+        val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
+
+        torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
+
+
+def parse_train_archive(root: Union[str, Path], file: Optional[str] = None, folder: str = "train") -> None:
+    """Parse the train images archive of the ImageNet2012 classification dataset and
+    prepare it for usage with the ImageNet dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory containing the train images archive
+        file (str, optional): Name of train images archive. Defaults to
+            'ILSVRC2012_img_train.tar'
+        folder (str, optional): Optional name for train images folder. Defaults to
+            'train'
+    """
+    archive_meta = ARCHIVE_META["train"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
+
+    _verify_archive(root, file, md5)
+
+    train_root = os.path.join(root, folder)
+    extract_archive(os.path.join(root, file), train_root)
+
+    archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
+    for archive in archives:
+        extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
+
+
+def parse_val_archive(
+    root: Union[str, Path], file: Optional[str] = None, wnids: Optional[list[str]] = None, folder: str = "val"
+) -> None:
+    """Parse the validation images archive of the ImageNet2012 classification dataset
+    and prepare it for usage with the ImageNet dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory containing the validation images archive
+        file (str, optional): Name of validation images archive. Defaults to
+            'ILSVRC2012_img_val.tar'
+        wnids (list, optional): List of WordNet IDs of the validation images. If None
+            is given, the IDs are loaded from the meta file in the root directory
+        folder (str, optional): Optional name for validation images folder. Defaults to
+            'val'
+    """
+    archive_meta = ARCHIVE_META["val"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
+    if wnids is None:
+        wnids = load_meta_file(root)[1]
+
+    _verify_archive(root, file, md5)
+
+    val_root = os.path.join(root, folder)
+    extract_archive(os.path.join(root, file), val_root)
+
+    images = sorted(os.path.join(val_root, image) for image in os.listdir(val_root))
+
+    for wnid in set(wnids):
+        os.mkdir(os.path.join(val_root, wnid))
+
+    for wnid, img_file in zip(wnids, images):
+        shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))

+ 104 - 0
python/py/Lib/site-packages/torchvision/datasets/imagenette.py

@@ -0,0 +1,104 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader, find_classes, make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Imagenette(VisionDataset):
+    """`Imagenette <https://github.com/fastai/imagenette#imagenette-1>`_ image classification dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the Imagenette dataset.
+        split (string, optional): The dataset split. Supports ``"train"`` (default), and ``"val"``.
+        size (string, optional): The image size. Supports ``"full"`` (default), ``"320px"``, and ``"160px"``.
+        download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
+            downloaded archives are not downloaded again.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+
+     Attributes:
+        classes (list): List of the class name tuples.
+        class_to_idx (dict): Dict with items (class name, class index).
+        wnids (list): List of the WordNet IDs.
+        wnid_to_idx (dict): Dict with items (WordNet ID, class index).
+    """
+
+    _ARCHIVES = {
+        "full": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz", "fe2fc210e6bb7c5664d602c3cd71e612"),
+        "320px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz", "3df6f0d01a2c9592104656642f5e78a3"),
+        "160px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz", "e793b78cc4c9e9a4ccc0c1155377a412"),
+    }
+    _WNID_TO_CLASS = {
+        "n01440764": ("tench", "Tinca tinca"),
+        "n02102040": ("English springer", "English springer spaniel"),
+        "n02979186": ("cassette player",),
+        "n03000684": ("chain saw", "chainsaw"),
+        "n03028079": ("church", "church building"),
+        "n03394916": ("French horn", "horn"),
+        "n03417042": ("garbage truck", "dustcart"),
+        "n03425413": ("gas pump", "gasoline pump", "petrol pump", "island dispenser"),
+        "n03445777": ("golf ball",),
+        "n03888257": ("parachute", "chute"),
+    }
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        size: str = "full",
+        download=False,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self._split = verify_str_arg(split, "split", ["train", "val"])
+        self._size = verify_str_arg(size, "size", ["full", "320px", "160px"])
+
+        self._url, self._md5 = self._ARCHIVES[self._size]
+        self._size_root = Path(self.root) / Path(self._url).stem
+        self._image_root = str(self._size_root / self._split)
+
+        if download:
+            self._download()
+        elif not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it.")
+
+        self.wnids, self.wnid_to_idx = find_classes(self._image_root)
+        self.classes = [self._WNID_TO_CLASS[wnid] for wnid in self.wnids]
+        self.class_to_idx = {
+            class_name: idx for wnid, idx in self.wnid_to_idx.items() for class_name in self._WNID_TO_CLASS[wnid]
+        }
+        self._samples = make_dataset(self._image_root, self.wnid_to_idx, extensions=".jpeg")
+        self.loader = loader
+
+    def _check_exists(self) -> bool:
+        return self._size_root.exists()
+
+    def _download(self):
+        if self._check_exists():
+            return
+
+        download_and_extract_archive(self._url, self.root, md5=self._md5)
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        path, label = self._samples[idx]
+        image = self.loader(path)
+
+        if self.transform is not None:
+            image = self.transform(image)
+
+        if self.target_transform is not None:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def __len__(self) -> int:
+        return len(self._samples)

+ 245 - 0
python/py/Lib/site-packages/torchvision/datasets/inaturalist.py

@@ -0,0 +1,245 @@
+import os
+import os.path
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
+
+DATASET_URLS = {
+    "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
+    "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
+    "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
+    "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
+    "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
+    "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
+}
+
+DATASET_MD5 = {
+    "2017": "7c784ea5e424efaec655bd392f87301f",
+    "2018": "b1c6952ce38f31868cc50ea72d066cc3",
+    "2019": "c60a6e2962c9b8ccbd458d12c8582644",
+    "2021_train": "e0526d53c7f7b2e3167b2b43bb2690ed",
+    "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
+    "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
+}
+
+
+class INaturalist(VisionDataset):
+    """`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where the image files are stored.
+            This class does not require/use annotation files.
+        version (string, optional): Which version of the dataset to download/use. One of
+            '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
+            Default: `2021_train`.
+        target_type (string or list, optional): Type of target to use, for 2021 versions, one of:
+
+            - ``full``: the full category (species)
+            - ``kingdom``: e.g. "Animalia"
+            - ``phylum``: e.g. "Arthropoda"
+            - ``class``: e.g. "Insecta"
+            - ``order``: e.g. "Coleoptera"
+            - ``family``: e.g. "Cleridae"
+            - ``genus``: e.g. "Trichodes"
+
+            for 2017-2019 versions, one of:
+
+            - ``full``: the full (numeric) category
+            - ``super``: the super category, e.g. "Amphibians"
+
+            Can also be a list to output a tuple with all specified target types.
+            Defaults to ``full``.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        version: str = "2021_train",
+        target_type: Union[list[str], str] = "full",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Optional[Callable[[Union[str, Path]], Any]] = None,
+    ) -> None:
+        self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
+
+        super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
+
+        os.makedirs(root, exist_ok=True)
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.all_categories: list[str] = []
+
+        # map: category type -> name of category -> index
+        self.categories_index: dict[str, dict[str, int]] = {}
+
+        # list indexed by category id, containing mapping from category type -> index
+        self.categories_map: list[dict[str, int]] = []
+
+        if not isinstance(target_type, list):
+            target_type = [target_type]
+        if self.version[:4] == "2021":
+            self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
+            self._init_2021()
+        else:
+            self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
+            self._init_pre2021()
+
+        # index of all files: (full category id, filename)
+        self.index: list[tuple[int, str]] = []
+
+        for dir_index, dir_name in enumerate(self.all_categories):
+            files = os.listdir(os.path.join(self.root, dir_name))
+            for fname in files:
+                self.index.append((dir_index, fname))
+
+        self.loader = loader
+
+    def _init_2021(self) -> None:
+        """Initialize based on 2021 layout"""
+
+        self.all_categories = sorted(os.listdir(self.root))
+
+        # map: category type -> name of category -> index
+        self.categories_index = {k: {} for k in CATEGORIES_2021}
+
+        for dir_index, dir_name in enumerate(self.all_categories):
+            pieces = dir_name.split("_")
+            if len(pieces) != 8:
+                raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
+            if pieces[0] != f"{dir_index:05d}":
+                raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
+            cat_map = {}
+            for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
+                if name in self.categories_index[cat]:
+                    cat_id = self.categories_index[cat][name]
+                else:
+                    cat_id = len(self.categories_index[cat])
+                    self.categories_index[cat][name] = cat_id
+                cat_map[cat] = cat_id
+            self.categories_map.append(cat_map)
+
+    def _init_pre2021(self) -> None:
+        """Initialize based on 2017-2019 layout"""
+
+        # map: category type -> name of category -> index
+        self.categories_index = {"super": {}}
+
+        cat_index = 0
+        super_categories = sorted(os.listdir(self.root))
+        for sindex, scat in enumerate(super_categories):
+            self.categories_index["super"][scat] = sindex
+            subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
+            for subcat in subcategories:
+                if self.version == "2017":
+                    # this version does not use ids as directory names
+                    subcat_i = cat_index
+                    cat_index += 1
+                else:
+                    try:
+                        subcat_i = int(subcat)
+                    except ValueError:
+                        raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
+                if subcat_i >= len(self.categories_map):
+                    old_len = len(self.categories_map)
+                    self.categories_map.extend([{}] * (subcat_i - old_len + 1))
+                    self.all_categories.extend([""] * (subcat_i - old_len + 1))
+                if self.categories_map[subcat_i]:
+                    raise RuntimeError(f"Duplicate category {subcat}")
+                self.categories_map[subcat_i] = {"super": sindex}
+                self.all_categories[subcat_i] = os.path.join(scat, subcat)
+
+        # validate the dictionary
+        for cindex, c in enumerate(self.categories_map):
+            if not c:
+                raise RuntimeError(f"Missing category {cindex}")
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where the type of target specified by target_type.
+        """
+
+        cat_id, fname = self.index[index]
+        image_path = os.path.join(self.root, self.all_categories[cat_id], fname)
+        img = self.loader(image_path) if self.loader is not None else Image.open(image_path)
+
+        target: Any = []
+        for t in self.target_type:
+            if t == "full":
+                target.append(cat_id)
+            else:
+                target.append(self.categories_map[cat_id][t])
+        target = tuple(target) if len(target) > 1 else target[0]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.index)
+
+    def category_name(self, category_type: str, category_id: int) -> str:
+        """
+        Args:
+            category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
+            category_id(int): an index (class id) from this category
+
+        Returns:
+            the name of the category
+        """
+        if category_type == "full":
+            return self.all_categories[category_id]
+        else:
+            if category_type not in self.categories_index:
+                raise ValueError(f"Invalid category type '{category_type}'")
+            else:
+                for name, id in self.categories_index[category_type].items():
+                    if id == category_id:
+                        return name
+                raise ValueError(f"Invalid category id {category_id} for {category_type}")
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(self.root) and len(os.listdir(self.root)) > 0
+
+    def download(self) -> None:
+        if self._check_exists():
+            return
+
+        base_root = os.path.dirname(self.root)
+
+        download_and_extract_archive(
+            DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
+        )
+
+        orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
+        if not os.path.exists(orig_dir_name):
+            raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
+        os.rename(orig_dir_name, self.root)

+ 237 - 0
python/py/Lib/site-packages/torchvision/datasets/kinetics.py

@@ -0,0 +1,237 @@
+import csv
+import os
+import urllib
+from functools import partial
+from multiprocessing import Pool
+from os import path
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+def _dl_wrap(tarpath: Union[str, Path], videopath: Union[str, Path], line: str) -> None:
+    download_and_extract_archive(line, tarpath, videopath)
+
+
+class Kinetics(VisionDataset):
+    """`Generic Kinetics <https://www.deepmind.com/open-source/kinetics>`_
+    dataset.
+
+    Kinetics-400/600/700 are action recognition video datasets.
+    This dataset consider every video as a collection of video clips of fixed size, specified
+    by ``frames_per_clip``, where the step in frames between each clip is given by
+    ``step_between_clips``.
+
+    To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+    and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+    elements will come from video 1, and the next three elements from video 2.
+    Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+    frames in a video might be present.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the Kinetics Dataset.
+            Directory should be structured as follows:
+            .. code::
+
+                root/
+                ├── split
+                │   ├──  class1
+                │   │   ├──  vid1.mp4
+                │   │   ├──  vid2.mp4
+                │   │   ├──  vid3.mp4
+                │   │   ├──  ...
+                │   ├──  class2
+                │   │   ├──   vidx.mp4
+                │   │    └── ...
+
+            Note: split is appended automatically using the split argument.
+        frames_per_clip (int): number of frames in a clip
+        num_classes (int): select between Kinetics-400 (default), Kinetics-600, and Kinetics-700
+        split (str): split of the dataset to consider; supports ``"train"`` (default) ``"val"`` ``"test"``
+        frame_rate (float): If omitted, interpolate different frame rate for each clip.
+        step_between_clips (int): number of frames between each clip
+        transform (callable, optional): A function/transform that takes in a TxHxWxC video
+            and returns a transformed version.
+        download (bool): Download the official version of the dataset to root folder.
+        num_workers (int): Use multiple workers for VideoClips creation
+        num_download_workers (int): Use multiprocessing in order to speed up download.
+        output_format (str, optional): The format of the output video tensors (before transforms).
+            Can be either "THWC" or "TCHW" (default).
+            Note that in most other utils and datasets, the default is actually "THWC".
+
+    Returns:
+        tuple: A 3-tuple with the following entries:
+
+            - video (Tensor[T, C, H, W] or Tensor[T, H, W, C]): the `T` video frames in torch.uint8 tensor
+            - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+              and `L` is the number of points in torch.float tensor
+            - label (int): class of the video clip
+
+    Raises:
+        RuntimeError: If ``download is True`` and the video archives are already extracted.
+    """
+
+    _TAR_URLS = {
+        "400": "https://s3.amazonaws.com/kinetics/400/{split}/k400_{split}_path.txt",
+        "600": "https://s3.amazonaws.com/kinetics/600/{split}/k600_{split}_path.txt",
+        "700": "https://s3.amazonaws.com/kinetics/700_2020/{split}/k700_2020_{split}_path.txt",
+    }
+    _ANNOTATION_URLS = {
+        "400": "https://s3.amazonaws.com/kinetics/400/annotations/{split}.csv",
+        "600": "https://s3.amazonaws.com/kinetics/600/annotations/{split}.csv",
+        "700": "https://s3.amazonaws.com/kinetics/700_2020/annotations/{split}.csv",
+    }
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        frames_per_clip: int,
+        num_classes: str = "400",
+        split: str = "train",
+        frame_rate: Optional[int] = None,
+        step_between_clips: int = 1,
+        transform: Optional[Callable] = None,
+        extensions: tuple[str, ...] = ("avi", "mp4"),
+        download: bool = False,
+        num_download_workers: int = 1,
+        num_workers: int = 1,
+        _precomputed_metadata: Optional[dict[str, Any]] = None,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _audio_samples: int = 0,
+        _audio_channels: int = 0,
+        _legacy: bool = False,
+        output_format: str = "TCHW",
+    ) -> None:
+
+        # TODO: support test
+        self.num_classes = verify_str_arg(num_classes, arg="num_classes", valid_values=["400", "600", "700"])
+        self.extensions = extensions
+        self.num_download_workers = num_download_workers
+
+        self.root = root
+        self._legacy = _legacy
+
+        if _legacy:
+            self.split_folder = root
+            self.split = "unknown"
+            output_format = "THWC"
+            if download:
+                raise ValueError("Cannot download the videos using legacy_structure.")
+        else:
+            self.split_folder = path.join(root, split)
+            self.split = verify_str_arg(split, arg="split", valid_values=["train", "val", "test"])
+
+        if download:
+            self.download_and_process_videos()
+
+        super().__init__(self.root)
+
+        self.classes, class_to_idx = find_classes(self.split_folder)
+        self.samples = make_dataset(self.split_folder, class_to_idx, extensions, is_valid_file=None)
+        video_list = [x[0] for x in self.samples]
+        self.video_clips = VideoClips(
+            video_list,
+            frames_per_clip,
+            step_between_clips,
+            frame_rate,
+            _precomputed_metadata,
+            num_workers=num_workers,
+            _video_width=_video_width,
+            _video_height=_video_height,
+            _video_min_dimension=_video_min_dimension,
+            _audio_samples=_audio_samples,
+            _audio_channels=_audio_channels,
+            output_format=output_format,
+        )
+        self.transform = transform
+
+    def download_and_process_videos(self) -> None:
+        """Downloads all the videos to the _root_ folder in the expected format."""
+        self._download_videos()
+        self._make_ds_structure()
+
+    def _download_videos(self) -> None:
+        """download tarballs containing the video to "tars" folder and extract them into the _split_ folder where
+        split is one of the official dataset splits.
+
+        Raises:
+            RuntimeError: if download folder exists, break to prevent downloading entire dataset again.
+        """
+        if path.exists(self.split_folder):
+            return
+        tar_path = path.join(self.root, "tars")
+        file_list_path = path.join(self.root, "files")
+
+        split_url = self._TAR_URLS[self.num_classes].format(split=self.split)
+        split_url_filepath = path.join(file_list_path, path.basename(split_url))
+        if not check_integrity(split_url_filepath):
+            download_url(split_url, file_list_path)
+        with open(split_url_filepath) as file:
+            list_video_urls = [urllib.parse.quote(line, safe="/,:") for line in file.read().splitlines()]
+
+        if self.num_download_workers == 1:
+            for line in list_video_urls:
+                download_and_extract_archive(line, tar_path, self.split_folder)
+        else:
+            part = partial(_dl_wrap, tar_path, self.split_folder)
+            poolproc = Pool(self.num_download_workers)
+            poolproc.map(part, list_video_urls)
+
+    def _make_ds_structure(self) -> None:
+        """move videos from
+        split_folder/
+            ├── clip1.avi
+            ├── clip2.avi
+
+        to the correct format as described below:
+        split_folder/
+            ├── class1
+            │   ├── clip1.avi
+
+        """
+        annotation_path = path.join(self.root, "annotations")
+        if not check_integrity(path.join(annotation_path, f"{self.split}.csv")):
+            download_url(self._ANNOTATION_URLS[self.num_classes].format(split=self.split), annotation_path)
+        annotations = path.join(annotation_path, f"{self.split}.csv")
+
+        file_fmtstr = "{ytid}_{start:06}_{end:06}.mp4"
+        with open(annotations) as csvfile:
+            reader = csv.DictReader(csvfile)
+            for row in reader:
+                f = file_fmtstr.format(
+                    ytid=row["youtube_id"],
+                    start=int(row["time_start"]),
+                    end=int(row["time_end"]),
+                )
+                label = row["label"].replace(" ", "_").replace("'", "").replace("(", "").replace(")", "")
+                os.makedirs(path.join(self.split_folder, label), exist_ok=True)
+                downloaded_file = path.join(self.split_folder, f)
+                if path.isfile(downloaded_file):
+                    os.replace(
+                        downloaded_file,
+                        path.join(self.split_folder, label, f),
+                    )
+
+    @property
+    def metadata(self) -> dict[str, Any]:
+        return self.video_clips.metadata
+
+    def __len__(self) -> int:
+        return self.video_clips.num_clips()
+
+    def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, int]:
+        video, audio, info, video_idx = self.video_clips.get_clip(idx)
+        label = self.samples[video_idx][1]
+
+        if self.transform is not None:
+            video = self.transform(video)
+
+        return video, audio, label

+ 158 - 0
python/py/Lib/site-packages/torchvision/datasets/kitti.py

@@ -0,0 +1,158 @@
+import csv
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive
+from .vision import VisionDataset
+
+
+class Kitti(VisionDataset):
+    """`KITTI <http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark>`_ Dataset.
+
+    It corresponds to the "left color images of object" dataset, for object detection.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+            Expects the following folder structure if download=False:
+
+            .. code::
+
+                <root>
+                    └── Kitti
+                        └─ raw
+                            ├── training
+                            |   ├── image_2
+                            |   └── label_2
+                            └── testing
+                                └── image_2
+        train (bool, optional): Use ``train`` split if true, else ``test`` split.
+            Defaults to ``train``.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample
+            and its target as entry and returns a transformed version.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/"
+    resources = [
+        "data_object_image_2.zip",
+        "data_object_label_2.zip",
+    ]
+    image_dir_name = "image_2"
+    labels_dir_name = "label_2"
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+        download: bool = False,
+    ):
+        super().__init__(
+            root,
+            transform=transform,
+            target_transform=target_transform,
+            transforms=transforms,
+        )
+        self.images = []
+        self.targets = []
+        self.train = train
+        self._location = "training" if self.train else "testing"
+
+        if download:
+            self.download()
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You may use download=True to download it.")
+
+        image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
+        if self.train:
+            labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name)
+        for img_file in os.listdir(image_dir):
+            self.images.append(os.path.join(image_dir, img_file))
+            if self.train:
+                self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt"))
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """Get item at a given index.
+
+        Args:
+            index (int): Index
+        Returns:
+            tuple: (image, target), where
+            target is a list of dictionaries with the following keys:
+
+            - type: str
+            - truncated: float
+            - occluded: int
+            - alpha: float
+            - bbox: float[4]
+            - dimensions: float[3]
+            - locations: float[3]
+            - rotation_y: float
+
+        """
+        image = Image.open(self.images[index])
+        target = self._parse_target(index) if self.train else None
+        if self.transforms:
+            image, target = self.transforms(image, target)
+        return image, target
+
+    def _parse_target(self, index: int) -> list:
+        target = []
+        with open(self.targets[index]) as inp:
+            content = csv.reader(inp, delimiter=" ")
+            for line in content:
+                target.append(
+                    {
+                        "type": line[0],
+                        "truncated": float(line[1]),
+                        "occluded": int(line[2]),
+                        "alpha": float(line[3]),
+                        "bbox": [float(x) for x in line[4:8]],
+                        "dimensions": [float(x) for x in line[8:11]],
+                        "location": [float(x) for x in line[11:14]],
+                        "rotation_y": float(line[14]),
+                    }
+                )
+        return target
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+    @property
+    def _raw_folder(self) -> str:
+        return os.path.join(self.root, self.__class__.__name__, "raw")
+
+    def _check_exists(self) -> bool:
+        """Check if the data directory exists."""
+        folders = [self.image_dir_name]
+        if self.train:
+            folders.append(self.labels_dir_name)
+        return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders)
+
+    def download(self) -> None:
+        """Download the KITTI data if it doesn't exist already."""
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self._raw_folder, exist_ok=True)
+
+        # download files
+        for fname in self.resources:
+            download_and_extract_archive(
+                url=f"{self.data_url}{fname}",
+                download_root=self._raw_folder,
+                filename=fname,
+            )

+ 268 - 0
python/py/Lib/site-packages/torchvision/datasets/lfw.py

@@ -0,0 +1,268 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class _LFW(VisionDataset):
+
+    base_folder = "lfw-py"
+    download_url_prefix = "http://vis-www.cs.umass.edu/lfw/"
+
+    file_dict = {
+        "original": ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"),
+        "funneled": ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"),
+        "deepfunneled": ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201"),
+    }
+    checksums = {
+        "pairs.txt": "9f1ba174e4e1c508ff7cdf10ac338a7d",
+        "pairsDevTest.txt": "5132f7440eb68cf58910c8a45a2ac10b",
+        "pairsDevTrain.txt": "4f27cbf15b2da4a85c1907eb4181ad21",
+        "people.txt": "450f0863dd89e85e73936a6d71a3474b",
+        "peopleDevTest.txt": "e4bf5be0a43b5dcd9dc5ccfcb8fb19c5",
+        "peopleDevTrain.txt": "54eaac34beb6d042ed3a7d883e247a21",
+        "lfw-names.txt": "a6d0a479bd074669f656265a6e693f6d",
+    }
+    annot_file = {"10fold": "", "train": "DevTrain", "test": "DevTest"}
+    names = "lfw-names.txt"
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str,
+        image_set: str,
+        view: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform)
+
+        self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
+        images_dir, self.filename, self.md5 = self.file_dict[self.image_set]
+
+        self.view = verify_str_arg(view.lower(), "view", ["people", "pairs"])
+        self.split = verify_str_arg(split.lower(), "split", ["10fold", "train", "test"])
+        self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt"
+        self.data: list[Any] = []
+
+        if download:
+            raise ValueError(
+                "LFW dataset is no longer available for download."
+                "Please download the dataset manually and place it in the specified directory"
+            )
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.images_dir = os.path.join(self.root, images_dir)
+        self._loader = loader
+
+    def _check_integrity(self) -> bool:
+        st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
+        st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
+        if not st1 or not st2:
+            return False
+        if self.view == "people":
+            return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            return
+        url = f"{self.download_url_prefix}{self.filename}"
+        download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5)
+        download_url(f"{self.download_url_prefix}{self.labels_file}", self.root)
+        if self.view == "people":
+            download_url(f"{self.download_url_prefix}{self.names}", self.root)
+
+    def _get_path(self, identity: str, no: Union[int, str]) -> str:
+        return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")
+
+    def extra_repr(self) -> str:
+        return f"Alignment: {self.image_set}\nSplit: {self.split}"
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+
+class LFWPeople(_LFW):
+    """`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
+
+    .. warning:
+
+        The LFW dataset is no longer available for automatic download. Please
+        download it manually and place it in the specified directory.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where directory
+            ``lfw-py`` exists or will be saved to if download is set to True.
+        split (string, optional): The image split to use. Can be one of ``train``, ``test``,
+            ``10fold`` (default).
+        image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
+            ``deepfunneled``. Defaults to ``funneled``.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): NOT SUPPORTED ANYMORE, leave to False.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "10fold",
+        image_set: str = "funneled",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, split, image_set, "people", transform, target_transform, download, loader=loader)
+
+        self.class_to_idx = self._get_classes()
+        self.data, self.targets = self._get_people()
+
+    def _get_people(self) -> tuple[list[str], list[int]]:
+        data, targets = [], []
+        with open(os.path.join(self.root, self.labels_file)) as f:
+            lines = f.readlines()
+            n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)
+
+            for fold in range(n_folds):
+                n_lines = int(lines[s])
+                people = [line.strip().split("\t") for line in lines[s + 1 : s + n_lines + 1]]
+                s += n_lines + 1
+                for i, (identity, num_imgs) in enumerate(people):
+                    for num in range(1, int(num_imgs) + 1):
+                        img = self._get_path(identity, num)
+                        data.append(img)
+                        targets.append(self.class_to_idx[identity])
+
+        return data, targets
+
+    def _get_classes(self) -> dict[str, int]:
+        with open(os.path.join(self.root, self.names)) as f:
+            lines = f.readlines()
+            names = [line.strip().split()[0] for line in lines]
+        class_to_idx = {name: i for i, name in enumerate(names)}
+        return class_to_idx
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target) where target is the identity of the person.
+        """
+        img = self._loader(self.data[index])
+        target = self.targets[index]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def extra_repr(self) -> str:
+        return super().extra_repr() + f"\nClasses (identities): {len(self.class_to_idx)}"
+
+
+class LFWPairs(_LFW):
+    """`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
+
+    .. warning:
+
+        The LFW dataset is no longer available for automatic download. Please
+        download it manually and place it in the specified directory.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where directory
+            ``lfw-py`` exists or will be saved to if download is set to True.
+        split (string, optional): The image split to use. Can be one of ``train``, ``test``,
+            ``10fold``. Defaults to ``10fold``.
+        image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
+            ``deepfunneled``. Defaults to ``funneled``.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomRotation``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): NOT SUPPORTED ANYMORE, leave to False.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "10fold",
+        image_set: str = "funneled",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, split, image_set, "pairs", transform, target_transform, download, loader=loader)
+
+        self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
+
+    def _get_pairs(self, images_dir: str) -> tuple[list[tuple[str, str]], list[tuple[str, str]], list[int]]:
+        pair_names, data, targets = [], [], []
+        with open(os.path.join(self.root, self.labels_file)) as f:
+            lines = f.readlines()
+            if self.split == "10fold":
+                n_folds, n_pairs = lines[0].split("\t")
+                n_folds, n_pairs = int(n_folds), int(n_pairs)
+            else:
+                n_folds, n_pairs = 1, int(lines[0])
+            s = 1
+
+            for fold in range(n_folds):
+                matched_pairs = [line.strip().split("\t") for line in lines[s : s + n_pairs]]
+                unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs : s + (2 * n_pairs)]]
+                s += 2 * n_pairs
+                for pair in matched_pairs:
+                    img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1
+                    pair_names.append((pair[0], pair[0]))
+                    data.append((img1, img2))
+                    targets.append(same)
+                for pair in unmatched_pairs:
+                    img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[2], pair[3]), 0
+                    pair_names.append((pair[0], pair[2]))
+                    data.append((img1, img2))
+                    targets.append(same)
+
+        return pair_names, data, targets
+
+    def __getitem__(self, index: int) -> tuple[Any, Any, int]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image1, image2, target) where target is `0` for different indentities and `1` for same identities.
+        """
+        img1, img2 = self.data[index]
+        img1, img2 = self._loader(img1), self._loader(img2)
+        target = self.targets[index]
+
+        if self.transform is not None:
+            img1, img2 = self.transform(img1), self.transform(img2)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img1, img2, target

+ 168 - 0
python/py/Lib/site-packages/torchvision/datasets/lsun.py

@@ -0,0 +1,168 @@
+import io
+import os.path
+import pickle
+import string
+from collections.abc import Iterable
+from pathlib import Path
+from typing import Any, Callable, cast, Optional, Union
+
+from PIL import Image
+
+from .utils import iterable_to_str, verify_str_arg
+from .vision import VisionDataset
+
+
+class LSUNClass(VisionDataset):
+    def __init__(
+        self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
+    ) -> None:
+        import lmdb
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
+        with self.env.begin(write=False) as txn:
+            self.length = txn.stat()["entries"]
+        cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters)
+        if os.path.isfile(cache_file):
+            self.keys = pickle.load(open(cache_file, "rb"))
+        else:
+            with self.env.begin(write=False) as txn:
+                self.keys = [key for key in txn.cursor().iternext(keys=True, values=False)]
+            pickle.dump(self.keys, open(cache_file, "wb"))
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        img, target = None, None
+        env = self.env
+        with env.begin(write=False) as txn:
+            imgbuf = txn.get(self.keys[index])
+
+        buf = io.BytesIO()
+        buf.write(imgbuf)
+        buf.seek(0)
+        img = Image.open(buf).convert("RGB")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return self.length
+
+
+class LSUN(VisionDataset):
+    """`LSUN <https://paperswithcode.com/dataset/lsun>`_ dataset.
+
+    You will need to install the ``lmdb`` package to use this dataset: run
+    ``pip install lmdb``
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory for the database files.
+        classes (string or list): One of {'train', 'val', 'test'} or a list of
+            categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        classes: Union[str, list[str]] = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.classes = self._verify_classes(classes)
+
+        # for each class, create an LSUNClassDataset
+        self.dbs = []
+        for c in self.classes:
+            self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform))
+
+        self.indices = []
+        count = 0
+        for db in self.dbs:
+            count += len(db)
+            self.indices.append(count)
+
+        self.length = count
+
+    def _verify_classes(self, classes: Union[str, list[str]]) -> list[str]:
+        categories = [
+            "bedroom",
+            "bridge",
+            "church_outdoor",
+            "classroom",
+            "conference_room",
+            "dining_room",
+            "kitchen",
+            "living_room",
+            "restaurant",
+            "tower",
+        ]
+        dset_opts = ["train", "val", "test"]
+
+        try:
+            classes = cast(str, classes)
+            verify_str_arg(classes, "classes", dset_opts)
+            if classes == "test":
+                classes = [classes]
+            else:
+                classes = [c + "_" + classes for c in categories]
+        except ValueError:
+            if not isinstance(classes, Iterable):
+                msg = "Expected type str or Iterable for argument classes, but got type {}."
+                raise ValueError(msg.format(type(classes)))
+
+            classes = list(classes)
+            msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."
+            for c in classes:
+                verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
+                c_short = c.split("_")
+                category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
+
+                msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
+                msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
+                verify_str_arg(category, valid_values=categories, custom_msg=msg)
+
+                msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
+                verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
+
+        return classes
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target) where target is the index of the target category.
+        """
+        target = 0
+        sub = 0
+        for ind in self.indices:
+            if index < ind:
+                break
+            target += 1
+            sub = ind
+
+        db = self.dbs[target]
+        index = index - sub
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        img, _ = db[index]
+        return img, target
+
+    def __len__(self) -> int:
+        return self.length
+
+    def extra_repr(self) -> str:
+        return "Classes: {classes}".format(**self.__dict__)

+ 560 - 0
python/py/Lib/site-packages/torchvision/datasets/mnist.py

@@ -0,0 +1,560 @@
+import codecs
+import os
+import os.path
+import shutil
+import string
+import sys
+import warnings
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+from urllib.error import URLError
+
+import numpy as np
+import torch
+
+from ..utils import _Image_fromarray
+from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class MNIST(VisionDataset):
+    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
+            and  ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
+        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+            otherwise from ``t10k-images-idx3-ubyte``.
+        transform (callable, optional): A function/transform that  takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    mirrors = [
+        "https://ossci-datasets.s3.amazonaws.com/mnist/",
+        "http://yann.lecun.com/exdb/mnist/",
+    ]
+
+    resources = [
+        ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
+        ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
+        ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
+        ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
+    ]
+
+    training_file = "training.pt"
+    test_file = "test.pt"
+    classes = [
+        "0 - zero",
+        "1 - one",
+        "2 - two",
+        "3 - three",
+        "4 - four",
+        "5 - five",
+        "6 - six",
+        "7 - seven",
+        "8 - eight",
+        "9 - nine",
+    ]
+
+    @property
+    def train_labels(self):
+        warnings.warn("train_labels has been renamed targets")
+        return self.targets
+
+    @property
+    def test_labels(self):
+        warnings.warn("test_labels has been renamed targets")
+        return self.targets
+
+    @property
+    def train_data(self):
+        warnings.warn("train_data has been renamed data")
+        return self.data
+
+    @property
+    def test_data(self):
+        warnings.warn("test_data has been renamed data")
+        return self.data
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.train = train  # training set or test set
+
+        if self._check_legacy_exist():
+            self.data, self.targets = self._load_legacy_data()
+            return
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self.data, self.targets = self._load_data()
+
+    def _check_legacy_exist(self):
+        processed_folder_exists = os.path.exists(self.processed_folder)
+        if not processed_folder_exists:
+            return False
+
+        return all(
+            check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
+        )
+
+    def _load_legacy_data(self):
+        # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
+        # directly.
+        data_file = self.training_file if self.train else self.test_file
+        return torch.load(os.path.join(self.processed_folder, data_file), weights_only=True)
+
+    def _load_data(self):
+        image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
+        data = read_image_file(os.path.join(self.raw_folder, image_file))
+
+        label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
+        targets = read_label_file(os.path.join(self.raw_folder, label_file))
+
+        return data, targets
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.targets[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = _Image_fromarray(img.numpy(), mode="L")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    @property
+    def raw_folder(self) -> str:
+        return os.path.join(self.root, self.__class__.__name__, "raw")
+
+    @property
+    def processed_folder(self) -> str:
+        return os.path.join(self.root, self.__class__.__name__, "processed")
+
+    @property
+    def class_to_idx(self) -> dict[str, int]:
+        return {_class: i for i, _class in enumerate(self.classes)}
+
+    def _check_exists(self) -> bool:
+        return all(
+            check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
+            for url, _ in self.resources
+        )
+
+    def download(self) -> None:
+        """Download the MNIST data if it doesn't exist already."""
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self.raw_folder, exist_ok=True)
+
+        # download files
+        for filename, md5 in self.resources:
+            errors = []
+            for mirror in self.mirrors:
+                url = f"{mirror}{filename}"
+                try:
+                    download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
+                except URLError as e:
+                    errors.append(e)
+                    continue
+                break
+            else:
+                s = f"Error downloading {filename}:\n"
+                for mirror, err in zip(self.mirrors, errors):
+                    s += f"Tried {mirror}, got:\n{str(err)}\n"
+                raise RuntimeError(s)
+
+    def extra_repr(self) -> str:
+        split = "Train" if self.train is True else "Test"
+        return f"Split: {split}"
+
+
+class FashionMNIST(MNIST):
+    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
+            and  ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
+        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+            otherwise from ``t10k-images-idx3-ubyte``.
+        transform (callable, optional): A function/transform that  takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
+
+    resources = [
+        ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
+        ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
+        ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
+        ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
+    ]
+    classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
+
+
+class KMNIST(MNIST):
+    """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
+            and  ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
+        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+            otherwise from ``t10k-images-idx3-ubyte``.
+        transform (callable, optional): A function/transform that  takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
+
+    resources = [
+        ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
+        ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
+        ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
+        ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
+    ]
+    classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
+
+
+class EMNIST(MNIST):
+    """`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
+            and  ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
+        split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
+            ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
+            which one to use.
+        train (bool, optional): If True, creates dataset from ``training.pt``,
+            otherwise from ``test.pt``.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    url = "https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip"
+    md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
+    splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
+    # Merged Classes assumes Same structure for both uppercase and lowercase version
+    _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
+    _all_classes = set(string.digits + string.ascii_letters)
+    classes_split_dict = {
+        "byclass": sorted(list(_all_classes)),
+        "bymerge": sorted(list(_all_classes - _merged_classes)),
+        "balanced": sorted(list(_all_classes - _merged_classes)),
+        "letters": ["N/A"] + list(string.ascii_lowercase),
+        "digits": list(string.digits),
+        "mnist": list(string.digits),
+    }
+
+    def __init__(self, root: Union[str, Path], split: str, **kwargs: Any) -> None:
+        self.split = verify_str_arg(split, "split", self.splits)
+        self.training_file = self._training_file(split)
+        self.test_file = self._test_file(split)
+        super().__init__(root, **kwargs)
+        self.classes = self.classes_split_dict[self.split]
+
+    @staticmethod
+    def _training_file(split) -> str:
+        return f"training_{split}.pt"
+
+    @staticmethod
+    def _test_file(split) -> str:
+        return f"test_{split}.pt"
+
+    @property
+    def _file_prefix(self) -> str:
+        return f"emnist-{self.split}-{'train' if self.train else 'test'}"
+
+    @property
+    def images_file(self) -> str:
+        return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
+
+    @property
+    def labels_file(self) -> str:
+        return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
+
+    def _load_data(self):
+        return read_image_file(self.images_file), read_label_file(self.labels_file)
+
+    def _check_exists(self) -> bool:
+        return all(check_integrity(file) for file in (self.images_file, self.labels_file))
+
+    def download(self) -> None:
+        """Download the EMNIST data if it doesn't exist already."""
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self.raw_folder, exist_ok=True)
+
+        download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
+        gzip_folder = os.path.join(self.raw_folder, "gzip")
+        for gzip_file in os.listdir(gzip_folder):
+            if gzip_file.endswith(".gz"):
+                extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
+        shutil.rmtree(gzip_folder)
+
+
+class QMNIST(MNIST):
+    """`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset whose ``raw``
+            subdir contains binary files of the datasets.
+        what (string,optional): Can be 'train', 'test', 'test10k',
+            'test50k', or 'nist' for respectively the mnist compatible
+            training set, the 60k qmnist testing set, the 10k qmnist
+            examples that match the mnist testing set, the 50k
+            remaining qmnist testing examples, or all the nist
+            digits. The default is to select 'train' or 'test'
+            according to the compatibility argument 'train'.
+        compat (bool,optional): A boolean that says whether the target
+            for each example is class number (for compatibility with
+            the MNIST dataloader) or a torch vector containing the
+            full qmnist information. Default=True.
+        train (bool,optional,compatibility): When argument 'what' is
+            not specified, this boolean decides whether to load the
+            training set or the testing set.  Default: True.
+        download (bool, optional): If True, downloads the dataset from
+            the internet and puts it in root directory. If dataset is
+            already downloaded, it is not downloaded again.
+        transform (callable, optional): A function/transform that
+            takes in a PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform
+            that takes in the target and transforms it.
+    """
+
+    subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
+    resources: dict[str, list[tuple[str, str]]] = {  # type: ignore[assignment]
+        "train": [
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
+                "ed72d4157d28c017586c42bc6afe6370",
+            ),
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
+                "0058f8dd561b90ffdd0f734c6a30e5e4",
+            ),
+        ],
+        "test": [
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
+                "1394631089c404de565df7b7aeaf9412",
+            ),
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
+                "5b5b05890a5e13444e108efe57b788aa",
+            ),
+        ],
+        "nist": [
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
+                "7f124b3b8ab81486c9d8c2749c17f834",
+            ),
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
+                "5ed0e788978e45d4a8bd4b7caec3d79d",
+            ),
+        ],
+    }
+    classes = [
+        "0 - zero",
+        "1 - one",
+        "2 - two",
+        "3 - three",
+        "4 - four",
+        "5 - five",
+        "6 - six",
+        "7 - seven",
+        "8 - eight",
+        "9 - nine",
+    ]
+
+    def __init__(
+        self, root: Union[str, Path], what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
+    ) -> None:
+        if what is None:
+            what = "train" if train else "test"
+        self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
+        self.compat = compat
+        self.data_file = what + ".pt"
+        self.training_file = self.data_file
+        self.test_file = self.data_file
+        super().__init__(root, train, **kwargs)
+
+    @property
+    def images_file(self) -> str:
+        (url, _), _ = self.resources[self.subsets[self.what]]
+        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
+
+    @property
+    def labels_file(self) -> str:
+        _, (url, _) = self.resources[self.subsets[self.what]]
+        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
+
+    def _check_exists(self) -> bool:
+        return all(check_integrity(file) for file in (self.images_file, self.labels_file))
+
+    def _load_data(self):
+        data = read_sn3_pascalvincent_tensor(self.images_file)
+        if data.dtype != torch.uint8:
+            raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
+        if data.ndimension() != 3:
+            raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
+
+        targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
+        if targets.ndimension() != 2:
+            raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
+
+        if self.what == "test10k":
+            data = data[0:10000, :, :].clone()
+            targets = targets[0:10000, :].clone()
+        elif self.what == "test50k":
+            data = data[10000:, :, :].clone()
+            targets = targets[10000:, :].clone()
+
+        return data, targets
+
+    def download(self) -> None:
+        """Download the QMNIST data if it doesn't exist already.
+        Note that we only download what has been asked for (argument 'what').
+        """
+        if self._check_exists():
+            return
+
+        os.makedirs(self.raw_folder, exist_ok=True)
+        split = self.resources[self.subsets[self.what]]
+
+        for url, md5 in split:
+            download_and_extract_archive(url, self.raw_folder, md5=md5)
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        # redefined to handle the compat flag
+        img, target = self.data[index], self.targets[index]
+        img = _Image_fromarray(img.numpy(), mode="L")
+        if self.transform is not None:
+            img = self.transform(img)
+        if self.compat:
+            target = int(target[0])
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+        return img, target
+
+    def extra_repr(self) -> str:
+        return f"Split: {self.what}"
+
+
+def get_int(b: bytes) -> int:
+    return int(codecs.encode(b, "hex"), 16)
+
+
+SN3_PASCALVINCENT_TYPEMAP = {
+    8: torch.uint8,
+    9: torch.int8,
+    11: torch.int16,
+    12: torch.int32,
+    13: torch.float32,
+    14: torch.float64,
+}
+
+
+def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
+    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
+    Argument may be a filename, compressed filename, or file object.
+    """
+    # read
+    with open(path, "rb") as f:
+        data = f.read()
+
+    # parse
+    if sys.byteorder == "little" or sys.platform == "aix":
+        magic = get_int(data[0:4])
+        nd = magic % 256
+        ty = magic // 256
+    else:
+        nd = get_int(data[0:1])
+        ty = get_int(data[1:2]) + get_int(data[2:3]) * 256 + get_int(data[3:4]) * 256 * 256
+
+    assert 1 <= nd <= 3
+    assert 8 <= ty <= 14
+    torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
+    s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
+
+    if sys.byteorder == "big" and not sys.platform == "aix":
+        for i in range(len(s)):
+            s[i] = int.from_bytes(s[i].to_bytes(4, byteorder="little"), byteorder="big", signed=False)
+
+    parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
+
+    # The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
+    # that is little endian and the dtype has more than one byte, we need to flip them.
+    if sys.byteorder == "little" and parsed.element_size() > 1:
+        parsed = _flip_byte_order(parsed)
+
+    assert parsed.shape[0] == np.prod(s) or not strict
+    return parsed.view(*s)
+
+
+def read_label_file(path: str) -> torch.Tensor:
+    x = read_sn3_pascalvincent_tensor(path, strict=False)
+    if x.dtype != torch.uint8:
+        raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
+    if x.ndimension() != 1:
+        raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
+    return x.long()
+
+
+def read_image_file(path: str) -> torch.Tensor:
+    x = read_sn3_pascalvincent_tensor(path, strict=False)
+    if x.dtype != torch.uint8:
+        raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
+    if x.ndimension() != 3:
+        raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
+    return x

+ 94 - 0
python/py/Lib/site-packages/torchvision/datasets/moving_mnist.py

@@ -0,0 +1,94 @@
+import os.path
+from pathlib import Path
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+from torchvision.datasets.utils import download_url, verify_str_arg
+from torchvision.datasets.vision import VisionDataset
+
+
+class MovingMNIST(VisionDataset):
+    """`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
+        split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
+            If ``split=None``, the full data is returned.
+        split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
+            frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
+            is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that takes in a torch Tensor
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+    """
+
+    _URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: Optional[str] = None,
+        split_ratio: int = 10,
+        download: bool = False,
+        transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform)
+
+        self._base_folder = os.path.join(self.root, self.__class__.__name__)
+        self._filename = self._URL.split("/")[-1]
+
+        if split is not None:
+            verify_str_arg(split, "split", ("train", "test"))
+        self.split = split
+
+        if not isinstance(split_ratio, int):
+            raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")
+        elif not (1 <= split_ratio <= 19):
+            raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")
+        self.split_ratio = split_ratio
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it.")
+
+        data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
+        if self.split == "train":
+            data = data[: self.split_ratio]
+        elif self.split == "test":
+            data = data[self.split_ratio :]
+        self.data = data.transpose(0, 1).unsqueeze(2).contiguous()
+
+    def __getitem__(self, idx: int) -> torch.Tensor:
+        """
+        Args:
+            idx (int): Index
+        Returns:
+            torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
+        """
+        data = self.data[idx]
+        if self.transform is not None:
+            data = self.transform(data)
+
+        return data
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(os.path.join(self._base_folder, self._filename))
+
+    def download(self) -> None:
+        if self._check_exists():
+            return
+
+        download_url(
+            url=self._URL,
+            root=self._base_folder,
+            filename=self._filename,
+            md5="be083ec986bfe91a449d63653c411eb2",
+        )

+ 107 - 0
python/py/Lib/site-packages/torchvision/datasets/omniglot.py

@@ -0,0 +1,107 @@
+from os.path import join
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, list_dir, list_files
+from .vision import VisionDataset
+
+
+class Omniglot(VisionDataset):
+    """`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where directory
+            ``omniglot-py`` exists.
+        background (bool, optional): If True, creates dataset from the "background" set, otherwise
+            creates from the "evaluation" set. This terminology is defined by the authors.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset zip files from the internet and
+            puts it in root directory. If the zip files are already downloaded, they are not
+            downloaded again.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    folder = "omniglot-py"
+    download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
+    zips_md5 = {
+        "images_background": "68d2efa1b9178cc56df9314c21c6e718",
+        "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
+    }
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        background: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Optional[Callable[[Union[str, Path]], Any]] = None,
+    ) -> None:
+        super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
+        self.background = background
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.target_folder = join(self.root, self._get_target_folder())
+        self._alphabets = list_dir(self.target_folder)
+        self._characters: list[str] = sum(
+            ([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
+        )
+        self._character_images = [
+            [(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
+            for idx, character in enumerate(self._characters)
+        ]
+        self._flat_character_images: list[tuple[str, int]] = sum(self._character_images, [])
+        self.loader = loader
+
+    def __len__(self) -> int:
+        return len(self._flat_character_images)
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target character class.
+        """
+        image_name, character_class = self._flat_character_images[index]
+        image_path = join(self.target_folder, self._characters[character_class], image_name)
+        image = Image.open(image_path, mode="r").convert("L") if self.loader is None else self.loader(image_path)
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            character_class = self.target_transform(character_class)
+
+        return image, character_class
+
+    def _check_integrity(self) -> bool:
+        zip_filename = self._get_target_folder()
+        if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
+            return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            return
+
+        filename = self._get_target_folder()
+        zip_filename = filename + ".zip"
+        url = self.download_url_prefix + "/" + zip_filename
+        download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
+
+    def _get_target_folder(self) -> str:
+        return "images_background" if self.background else "images_evaluation"

+ 135 - 0
python/py/Lib/site-packages/torchvision/datasets/oxford_iiit_pet.py

@@ -0,0 +1,135 @@
+import os
+import os.path
+import pathlib
+from collections.abc import Sequence
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class OxfordIIITPet(VisionDataset):
+    """`Oxford-IIIT Pet Dataset   <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``.
+        target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or
+            ``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent:
+
+                - ``category`` (int): Label for one of the 37 pet categories.
+                - ``binary-category`` (int): Binary label for cat or dog.
+                - ``segmentation`` (PIL image): Segmentation trimap of the image.
+
+            If empty, ``None`` will be returned as target.
+
+        transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample
+            and its target as entry and returns a transformed version.
+        download (bool, optional): If True, downloads the dataset from the internet and puts it into
+            ``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
+    """
+
+    _RESOURCES = (
+        ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
+        ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
+    )
+    _VALID_TARGET_TYPES = ("category", "binary-category", "segmentation")
+
+    def __init__(
+        self,
+        root: Union[str, pathlib.Path],
+        split: str = "trainval",
+        target_types: Union[Sequence[str], str] = "category",
+        transforms: Optional[Callable] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ):
+        self._split = verify_str_arg(split, "split", ("trainval", "test"))
+        if isinstance(target_types, str):
+            target_types = [target_types]
+        self._target_types = [
+            verify_str_arg(target_type, "target_types", self._VALID_TARGET_TYPES) for target_type in target_types
+        ]
+
+        super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform)
+        self._base_folder = pathlib.Path(self.root) / "oxford-iiit-pet"
+        self._images_folder = self._base_folder / "images"
+        self._anns_folder = self._base_folder / "annotations"
+        self._segs_folder = self._anns_folder / "trimaps"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        image_ids = []
+        self._labels = []
+        self._bin_labels = []
+        with open(self._anns_folder / f"{self._split}.txt") as file:
+            for line in file:
+                image_id, label, bin_label, _ = line.strip().split()
+                image_ids.append(image_id)
+                self._labels.append(int(label) - 1)
+                self._bin_labels.append(int(bin_label) - 1)
+
+        self.bin_classes = ["Cat", "Dog"]
+        self.classes = [
+            " ".join(part.title() for part in raw_cls.split("_"))
+            for raw_cls, _ in sorted(
+                {(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)},
+                key=lambda image_id_and_label: image_id_and_label[1],
+            )
+        ]
+        self.bin_class_to_idx = dict(zip(self.bin_classes, range(len(self.bin_classes))))
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+        self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids]
+        self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids]
+
+    def __len__(self) -> int:
+        return len(self._images)
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        image = Image.open(self._images[idx]).convert("RGB")
+
+        target: Any = []
+        for target_type in self._target_types:
+            if target_type == "category":
+                target.append(self._labels[idx])
+            elif target_type == "binary-category":
+                target.append(self._bin_labels[idx])
+            else:  # target_type == "segmentation"
+                target.append(Image.open(self._segs[idx]))
+
+        if not target:
+            target = None
+        elif len(target) == 1:
+            target = target[0]
+        else:
+            target = tuple(target)
+
+        if self.transforms:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def _check_exists(self) -> bool:
+        for folder in (self._images_folder, self._anns_folder):
+            if not (os.path.exists(folder) and os.path.isdir(folder)):
+                return False
+        else:
+            return True
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+
+        for url, md5 in self._RESOURCES:
+            download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5)

+ 134 - 0
python/py/Lib/site-packages/torchvision/datasets/pcam.py

@@ -0,0 +1,134 @@
+import pathlib
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import _decompress, download_file_from_google_drive, verify_str_arg
+from .vision import VisionDataset
+
+
+class PCAM(VisionDataset):
+    """`PCAM Dataset   <https://github.com/basveeling/pcam>`_.
+
+    The PatchCamelyon dataset is a binary classification dataset with 327,680
+    color images (96px x 96px), extracted from histopathologic scans of lymph node
+    sections. Each image is annotated with a binary label indicating presence of
+    metastatic tissue.
+
+    This dataset requires the ``h5py`` package which you can install with ``pip install h5py``.
+
+    Args:
+         root (str or ``pathlib.Path``): Root directory of the dataset.
+         split (string, optional): The dataset split, supports ``"train"`` (default), ``"test"`` or ``"val"``.
+         transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
+             version. E.g, ``transforms.RandomCrop``.
+         target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+         download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
+             dataset is already downloaded, it is not downloaded again.
+
+             .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+    """
+
+    _FILES = {
+        "train": {
+            "images": (
+                "camelyonpatch_level_2_split_train_x.h5",  # Data file name
+                "1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2",  # Google Drive ID
+                "1571f514728f59376b705fc836ff4b63",  # md5 hash
+            ),
+            "targets": (
+                "camelyonpatch_level_2_split_train_y.h5",
+                "1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
+                "35c2d7259d906cfc8143347bb8e05be7",
+            ),
+        },
+        "test": {
+            "images": (
+                "camelyonpatch_level_2_split_test_x.h5",
+                "1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
+                "d8c2d60d490dbd479f8199bdfa0cf6ec",
+            ),
+            "targets": (
+                "camelyonpatch_level_2_split_test_y.h5",
+                "17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
+                "60a7035772fbdb7f34eb86d4420cf66a",
+            ),
+        },
+        "val": {
+            "images": (
+                "camelyonpatch_level_2_split_valid_x.h5",
+                "1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
+                "d5b63470df7cfa627aeec8b9dc0c066e",
+            ),
+            "targets": (
+                "camelyonpatch_level_2_split_valid_y.h5",
+                "1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
+                "2b85f58b927af9964a4c15b8f7e8f179",
+            ),
+        },
+    }
+
+    def __init__(
+        self,
+        root: Union[str, pathlib.Path],
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ):
+        try:
+            import h5py
+
+            self.h5py = h5py
+        except ImportError:
+            raise RuntimeError(
+                "h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
+            )
+
+        self._split = verify_str_arg(split, "split", ("train", "test", "val"))
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._base_folder = pathlib.Path(self.root) / "pcam"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+    def __len__(self) -> int:
+        images_file = self._FILES[self._split]["images"][0]
+        with self.h5py.File(self._base_folder / images_file) as images_data:
+            return images_data["x"].shape[0]
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        images_file = self._FILES[self._split]["images"][0]
+        with self.h5py.File(self._base_folder / images_file) as images_data:
+            image = Image.fromarray(images_data["x"][idx]).convert("RGB")
+
+        targets_file = self._FILES[self._split]["targets"][0]
+        with self.h5py.File(self._base_folder / targets_file) as targets_data:
+            target = int(targets_data["y"][idx, 0, 0, 0])  # shape is [num_images, 1, 1, 1]
+
+        if self.transform:
+            image = self.transform(image)
+        if self.target_transform:
+            target = self.target_transform(target)
+
+        return image, target
+
+    def _check_exists(self) -> bool:
+        images_file = self._FILES[self._split]["images"][0]
+        targets_file = self._FILES[self._split]["targets"][0]
+        return all(self._base_folder.joinpath(h5_file).exists() for h5_file in (images_file, targets_file))
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+
+        for file_name, file_id, md5 in self._FILES[self._split].values():
+            archive_name = file_name + ".gz"
+            download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5)
+            _decompress(str(self._base_folder / archive_name))

+ 230 - 0
python/py/Lib/site-packages/torchvision/datasets/phototour.py

@@ -0,0 +1,230 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from .utils import download_url
+from .vision import VisionDataset
+
+
+class PhotoTour(VisionDataset):
+    """`Multi-view Stereo Correspondence <http://matthewalunbrown.com/patchdata/patchdata.html>`_ Dataset.
+
+    .. note::
+
+        We only provide the newer version of the dataset, since the authors state that it
+
+            is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the
+            patches are centred on real interest point detections, rather than being projections of 3D points as is the
+            case in the old dataset.
+
+        The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm.
+
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where images are.
+        name (string): Name of the dataset to load.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    urls = {
+        "notredame_harris": [
+            "http://matthewalunbrown.com/patchdata/notredame_harris.zip",
+            "notredame_harris.zip",
+            "69f8c90f78e171349abdf0307afefe4d",
+        ],
+        "yosemite_harris": [
+            "http://matthewalunbrown.com/patchdata/yosemite_harris.zip",
+            "yosemite_harris.zip",
+            "a73253d1c6fbd3ba2613c45065c00d46",
+        ],
+        "liberty_harris": [
+            "http://matthewalunbrown.com/patchdata/liberty_harris.zip",
+            "liberty_harris.zip",
+            "c731fcfb3abb4091110d0ae8c7ba182c",
+        ],
+        "notredame": [
+            "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip",
+            "notredame.zip",
+            "509eda8535847b8c0a90bbb210c83484",
+        ],
+        "yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"],
+        "liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"],
+    }
+    means = {
+        "notredame": 0.4854,
+        "yosemite": 0.4844,
+        "liberty": 0.4437,
+        "notredame_harris": 0.4854,
+        "yosemite_harris": 0.4844,
+        "liberty_harris": 0.4437,
+    }
+    stds = {
+        "notredame": 0.1864,
+        "yosemite": 0.1818,
+        "liberty": 0.2019,
+        "notredame_harris": 0.1864,
+        "yosemite_harris": 0.1818,
+        "liberty_harris": 0.2019,
+    }
+    lens = {
+        "notredame": 468159,
+        "yosemite": 633587,
+        "liberty": 450092,
+        "liberty_harris": 379587,
+        "yosemite_harris": 450912,
+        "notredame_harris": 325295,
+    }
+    image_ext = "bmp"
+    info_file = "info.txt"
+    matches_files = "m50_100000_100000_0.txt"
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        name: str,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform)
+        self.name = name
+        self.data_dir = os.path.join(self.root, name)
+        self.data_down = os.path.join(self.root, f"{name}.zip")
+        self.data_file = os.path.join(self.root, f"{name}.pt")
+
+        self.train = train
+        self.mean = self.means[name]
+        self.std = self.stds[name]
+
+        if download:
+            self.download()
+
+        if not self._check_datafile_exists():
+            self.cache()
+
+        # load the serialized data
+        self.data, self.labels, self.matches = torch.load(self.data_file, weights_only=True)
+
+    def __getitem__(self, index: int) -> Union[torch.Tensor, tuple[Any, Any, torch.Tensor]]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (data1, data2, matches)
+        """
+        if self.train:
+            data = self.data[index]
+            if self.transform is not None:
+                data = self.transform(data)
+            return data
+        m = self.matches[index]
+        data1, data2 = self.data[m[0]], self.data[m[1]]
+        if self.transform is not None:
+            data1 = self.transform(data1)
+            data2 = self.transform(data2)
+        return data1, data2, m[2]
+
+    def __len__(self) -> int:
+        return len(self.data if self.train else self.matches)
+
+    def _check_datafile_exists(self) -> bool:
+        return os.path.exists(self.data_file)
+
+    def _check_downloaded(self) -> bool:
+        return os.path.exists(self.data_dir)
+
+    def download(self) -> None:
+        if self._check_datafile_exists():
+            return
+
+        if not self._check_downloaded():
+            # download files
+            url = self.urls[self.name][0]
+            filename = self.urls[self.name][1]
+            md5 = self.urls[self.name][2]
+            fpath = os.path.join(self.root, filename)
+
+            download_url(url, self.root, filename, md5)
+
+            import zipfile
+
+            with zipfile.ZipFile(fpath, "r") as z:
+                z.extractall(self.data_dir)
+
+            os.unlink(fpath)
+
+    def cache(self) -> None:
+        # process and save as torch files
+
+        dataset = (
+            read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
+            read_info_file(self.data_dir, self.info_file),
+            read_matches_files(self.data_dir, self.matches_files),
+        )
+
+        with open(self.data_file, "wb") as f:
+            torch.save(dataset, f)
+
+    def extra_repr(self) -> str:
+        split = "Train" if self.train is True else "Test"
+        return f"Split: {split}"
+
+
+def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
+    """Return a Tensor containing the patches"""
+
+    def PIL2array(_img: Image.Image) -> np.ndarray:
+        """Convert PIL image type to numpy 2D array"""
+        return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)
+
+    def find_files(_data_dir: str, _image_ext: str) -> list[str]:
+        """Return a list with the file names of the images containing the patches"""
+        files = []
+        # find those files with the specified extension
+        for file_dir in os.listdir(_data_dir):
+            if file_dir.endswith(_image_ext):
+                files.append(os.path.join(_data_dir, file_dir))
+        return sorted(files)  # sort files in ascend order to keep relations
+
+    patches = []
+    list_files = find_files(data_dir, image_ext)
+
+    for fpath in list_files:
+        img = Image.open(fpath)
+        for y in range(0, img.height, 64):
+            for x in range(0, img.width, 64):
+                patch = img.crop((x, y, x + 64, y + 64))
+                patches.append(PIL2array(patch))
+    return torch.ByteTensor(np.array(patches[:n]))
+
+
+def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
+    """Return a Tensor containing the list of labels
+    Read the file and keep only the ID of the 3D point.
+    """
+    with open(os.path.join(data_dir, info_file)) as f:
+        labels = [int(line.split()[0]) for line in f]
+    return torch.LongTensor(labels)
+
+
+def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
+    """Return a Tensor containing the ground truth matches
+    Read the file and keep only 3D point ID.
+    Matches are represented with a 1, non matches with a 0.
+    """
+    matches = []
+    with open(os.path.join(data_dir, matches_file)) as f:
+        for line in f:
+            line_split = line.split()
+            matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
+    return torch.LongTensor(matches)

+ 176 - 0
python/py/Lib/site-packages/torchvision/datasets/places365.py

@@ -0,0 +1,176 @@
+import os
+from os import path
+from pathlib import Path
+from typing import Any, Callable, cast, Optional, Union
+from urllib.parse import urljoin
+
+from .folder import default_loader
+from .utils import check_integrity, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Places365(VisionDataset):
+    r"""`Places365 <http://places2.csail.mit.edu/index.html>`_ classification dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the Places365 dataset.
+        split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
+            ``val``, ``test``.
+        small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
+            high resolution ones.
+        download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
+            downloaded archives are not downloaded again.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+
+     Attributes:
+        classes (list): List of the class names.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        imgs (list): List of (image path, class_index) tuples
+        targets (list): The class_index value for each image in the dataset
+
+    Raises:
+        RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
+        RuntimeError: If ``download is True`` and the image archive is already extracted.
+    """
+
+    _SPLITS = ("train-standard", "train-challenge", "val", "test")
+    _BASE_URL = "http://data.csail.mit.edu/places/places365/"
+    # {variant: (archive, md5)}
+    _DEVKIT_META = {
+        "standard": ("filelist_places365-standard.tar", "35a0585fee1fa656440f3ab298f8479c"),
+        "challenge": ("filelist_places365-challenge.tar", "70a8307e459c3de41690a7c76c931734"),
+    }
+    # (file, md5)
+    _CATEGORIES_META = ("categories_places365.txt", "06c963b85866bd0649f97cb43dd16673")
+    # {split: (file, md5)}
+    _FILE_LIST_META = {
+        "train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
+        "train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
+        "val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
+        "test": ("places365_test.txt", "2fce8233fe493576d724142e45d93653"),
+    }
+    # {(split, small): (file, md5)}
+    _IMAGES_META = {
+        ("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
+        ("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
+        ("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
+        ("test", False): ("test_large.tar", "41a4b6b724b1d2cd862fb3871ed59913"),
+        ("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
+        ("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
+        ("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
+        ("test", True): ("test_256.tar", "f532f6ad7b582262a2ec8009075e186b"),
+    }
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train-standard",
+        small: bool = False,
+        download: bool = False,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self.split = self._verify_split(split)
+        self.small = small
+        self.loader = loader
+
+        self.classes, self.class_to_idx = self.load_categories(download)
+        self.imgs, self.targets = self.load_file_list(download)
+
+        if download:
+            self.download_images()
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        file, target = self.imgs[index]
+        image = self.loader(file)
+
+        if self.transforms is not None:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def __len__(self) -> int:
+        return len(self.imgs)
+
+    @property
+    def variant(self) -> str:
+        return "challenge" if "challenge" in self.split else "standard"
+
+    @property
+    def images_dir(self) -> str:
+        size = "256" if self.small else "large"
+        if self.split.startswith("train"):
+            dir = f"data_{size}_{self.variant}"
+        else:
+            dir = f"{self.split}_{size}"
+        return path.join(self.root, dir)
+
+    def load_categories(self, download: bool = True) -> tuple[list[str], dict[str, int]]:
+        def process(line: str) -> tuple[str, int]:
+            cls, idx = line.split()
+            return cls, int(idx)
+
+        file, md5 = self._CATEGORIES_META
+        file = path.join(self.root, file)
+        if not self._check_integrity(file, md5, download):
+            self.download_devkit()
+
+        with open(file) as fh:
+            class_to_idx = dict(process(line) for line in fh)
+
+        return sorted(class_to_idx.keys()), class_to_idx
+
+    def load_file_list(
+        self, download: bool = True
+    ) -> tuple[list[tuple[str, Union[int, None]]], list[Union[int, None]]]:
+        def process(line: str, sep="/") -> tuple[str, Union[int, None]]:
+            image, idx = (line.split() + [None])[:2]
+            image = cast(str, image)
+            idx = int(idx) if idx is not None else None
+            return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), idx
+
+        file, md5 = self._FILE_LIST_META[self.split]
+        file = path.join(self.root, file)
+        if not self._check_integrity(file, md5, download):
+            self.download_devkit()
+
+        with open(file) as fh:
+            images = [process(line) for line in fh]
+
+        _, targets = zip(*images)
+        return images, list(targets)
+
+    def download_devkit(self) -> None:
+        file, md5 = self._DEVKIT_META[self.variant]
+        download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
+
+    def download_images(self) -> None:
+        if path.exists(self.images_dir):
+            return
+
+        file, md5 = self._IMAGES_META[(self.split, self.small)]
+        download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
+
+        if self.split.startswith("train"):
+            os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir)
+
+    def extra_repr(self) -> str:
+        return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__)
+
+    def _verify_split(self, split: str) -> str:
+        return verify_str_arg(split, "split", self._SPLITS)
+
+    def _check_integrity(self, file: str, md5: str, download: bool) -> bool:
+        integrity = check_integrity(file, md5=md5)
+        if not integrity and not download:
+            raise RuntimeError(
+                f"The file {file} does not exist or is corrupted. You can set download=True to download it."
+            )
+        return integrity

+ 89 - 0
python/py/Lib/site-packages/torchvision/datasets/rendered_sst2.py

@@ -0,0 +1,89 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader, make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class RenderedSST2(VisionDataset):
+    """`The Rendered SST2 Dataset <https://github.com/openai/CLIP/blob/main/data/rendered-sst2.md>`_.
+
+    Rendered SST2 is an image classification dataset used to evaluate the models capability on optical
+    character recognition. This dataset was generated by rendering sentences in the Standford Sentiment
+    Treebank v2 dataset.
+
+    This dataset contains two classes (positive and negative) and is divided in three splits: a  train
+    split containing 6920 images (3610 positive and 3310 negative), a validation split containing 872 images
+    (444 positive and 428 negative), and a test split containing 1821 images (909 positive and 912 negative).
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    _URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz"
+    _MD5 = "2384d08e9dcfa4bd55b324e610496ee5"
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        self._split_to_folder = {"train": "train", "val": "valid", "test": "test"}
+        self._base_folder = Path(self.root) / "rendered-sst2"
+        self.classes = ["negative", "positive"]
+        self.class_to_idx = {"negative": 0, "positive": 1}
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions=("png",))
+        self.loader = loader
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        image_file, label = self._samples[idx]
+        image = self.loader(image_file)
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"
+
+    def _check_exists(self) -> bool:
+        for class_label in set(self.classes):
+            if not (self._base_folder / self._split_to_folder[self._split] / class_label).is_dir():
+                return False
+        return True
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

+ 3 - 0
python/py/Lib/site-packages/torchvision/datasets/samplers/__init__.py

@@ -0,0 +1,3 @@
+from .clip_sampler import DistributedSampler, RandomClipSampler, UniformClipSampler
+
+__all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler")

+ 173 - 0
python/py/Lib/site-packages/torchvision/datasets/samplers/clip_sampler.py

@@ -0,0 +1,173 @@
+import math
+from collections.abc import Iterator, Sized
+from typing import cast, Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch.utils.data import Sampler
+from torchvision.datasets.video_utils import VideoClips
+
+
+class DistributedSampler(Sampler):
+    """
+    Extension of DistributedSampler, as discussed in
+    https://github.com/pytorch/pytorch/issues/23430
+
+    Example:
+        dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
+        num_replicas: 4
+        shuffle: False
+
+    when group_size = 1
+            RANK    |  shard_dataset
+            =========================
+            rank_0  |  [0, 4, 8, 12]
+            rank_1  |  [1, 5, 9, 13]
+            rank_2  |  [2, 6, 10, 0]
+            rank_3  |  [3, 7, 11, 1]
+
+    when group_size = 2
+
+            RANK    |  shard_dataset
+            =========================
+            rank_0  |  [0, 1, 8, 9]
+            rank_1  |  [2, 3, 10, 11]
+            rank_2  |  [4, 5, 12, 13]
+            rank_3  |  [6, 7, 0, 1]
+
+    """
+
+    def __init__(
+        self,
+        dataset: Sized,
+        num_replicas: Optional[int] = None,
+        rank: Optional[int] = None,
+        shuffle: bool = False,
+        group_size: int = 1,
+    ) -> None:
+        if num_replicas is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = dist.get_world_size()
+        if rank is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = dist.get_rank()
+        if len(dataset) % group_size != 0:
+            raise ValueError(
+                f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}"
+            )
+        self.dataset = dataset
+        self.group_size = group_size
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.epoch = 0
+        dataset_group_length = len(dataset) // group_size
+        self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas))
+        self.num_samples = self.num_group_samples * group_size
+        self.total_size = self.num_samples * self.num_replicas
+        self.shuffle = shuffle
+
+    def __iter__(self) -> Iterator[int]:
+        # deterministically shuffle based on epoch
+        g = torch.Generator()
+        g.manual_seed(self.epoch)
+        indices: Union[torch.Tensor, list[int]]
+        if self.shuffle:
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = list(range(len(self.dataset)))
+
+        # add extra samples to make it evenly divisible
+        indices += indices[: (self.total_size - len(indices))]
+        assert len(indices) == self.total_size
+
+        total_group_size = self.total_size // self.group_size
+        indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size))
+
+        # subsample
+        indices = indices[self.rank : total_group_size : self.num_replicas, :]
+        indices = torch.reshape(indices, (-1,)).tolist()
+        assert len(indices) == self.num_samples
+
+        if isinstance(self.dataset, Sampler):
+            orig_indices = list(iter(self.dataset))
+            indices = [orig_indices[i] for i in indices]
+
+        return iter(indices)
+
+    def __len__(self) -> int:
+        return self.num_samples
+
+    def set_epoch(self, epoch: int) -> None:
+        self.epoch = epoch
+
+
+class UniformClipSampler(Sampler):
+    """
+    Sample `num_video_clips_per_video` clips for each video, equally spaced.
+    When number of unique clips in the video is fewer than num_video_clips_per_video,
+    repeat the clips until `num_video_clips_per_video` clips are collected
+
+    Args:
+        video_clips (VideoClips): video clips to sample from
+        num_clips_per_video (int): number of clips to be sampled per video
+    """
+
+    def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
+        if not isinstance(video_clips, VideoClips):
+            raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
+        self.video_clips = video_clips
+        self.num_clips_per_video = num_clips_per_video
+
+    def __iter__(self) -> Iterator[int]:
+        idxs = []
+        s = 0
+        # select num_clips_per_video for each video, uniformly spaced
+        for c in self.video_clips.clips:
+            length = len(c)
+            if length == 0:
+                # corner case where video decoding fails
+                continue
+
+            sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64)
+            s += length
+            idxs.append(sampled)
+        return iter(cast(list[int], torch.cat(idxs).tolist()))
+
+    def __len__(self) -> int:
+        return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0)
+
+
+class RandomClipSampler(Sampler):
+    """
+    Samples at most `max_video_clips_per_video` clips for each video randomly
+
+    Args:
+        video_clips (VideoClips): video clips to sample from
+        max_clips_per_video (int): maximum number of clips to be sampled per video
+    """
+
+    def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
+        if not isinstance(video_clips, VideoClips):
+            raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
+        self.video_clips = video_clips
+        self.max_clips_per_video = max_clips_per_video
+
+    def __iter__(self) -> Iterator[int]:
+        idxs = []
+        s = 0
+        # select at most max_clips_per_video for each video, randomly
+        for c in self.video_clips.clips:
+            length = len(c)
+            size = min(length, self.max_clips_per_video)
+            sampled = torch.randperm(length)[:size] + s
+            s += length
+            idxs.append(sampled)
+        idxs_ = torch.cat(idxs)
+        # shuffle all clips randomly
+        perm = torch.randperm(len(idxs_))
+        return iter(idxs_[perm].tolist())
+
+    def __len__(self) -> int:
+        return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)

+ 126 - 0
python/py/Lib/site-packages/torchvision/datasets/sbd.py

@@ -0,0 +1,126 @@
+import os
+import shutil
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class SBDataset(VisionDataset):
+    """`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_
+
+    The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
+
+    .. note ::
+
+        Please note that the train and val splits included with this dataset are different from
+        the splits in the PASCAL VOC dataset. In particular some "train" images might be part of
+        VOC2012 val.
+        If you are interested in testing on VOC 2012 val, then use `image_set='train_noval'`,
+        which excludes all val images.
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the Semantic Boundaries Dataset
+        image_set (string, optional): Select the image_set to use, ``train``, ``val`` or ``train_noval``.
+            Image set ``train_noval`` excludes VOC 2012 val images.
+        mode (string, optional): Select target type. Possible values 'boundaries' or 'segmentation'.
+            In case of 'boundaries', the target is an array of shape `[num_classes, H, W]`,
+            where `num_classes=20`.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version. Input sample is PIL image and target is a numpy array
+            if `mode='boundaries'` or PIL image if `mode='segmentation'`.
+    """
+
+    url = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
+    md5 = "82b4d87ceb2ed10f6038a1cba92111cb"
+    filename = "benchmark.tgz"
+
+    voc_train_url = "https://www.cs.cornell.edu/~bharathh/train_noval.txt"
+    voc_split_filename = "train_noval.txt"
+    voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        image_set: str = "train",
+        mode: str = "boundaries",
+        download: bool = False,
+        transforms: Optional[Callable] = None,
+    ) -> None:
+
+        try:
+            from scipy.io import loadmat
+
+            self._loadmat = loadmat
+        except ImportError:
+            raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
+
+        super().__init__(root, transforms)
+        self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval"))
+        self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
+        self.num_classes = 20
+
+        sbd_root = self.root
+        image_dir = os.path.join(sbd_root, "img")
+        mask_dir = os.path.join(sbd_root, "cls")
+
+        if download:
+            download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
+            extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset")
+            for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
+                old_path = os.path.join(extracted_ds_root, f)
+                shutil.move(old_path, sbd_root)
+            if self.image_set == "train_noval":
+                # Note: this is failing as of June 2024 https://github.com/pytorch/vision/issues/8471
+                download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5)
+
+        if not os.path.isdir(sbd_root):
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt")
+
+        with open(os.path.join(split_f)) as fh:
+            file_names = [x.strip() for x in fh.readlines()]
+
+        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+        self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
+
+        self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target
+
+    def _get_segmentation_target(self, filepath: str) -> Image.Image:
+        mat = self._loadmat(filepath)
+        return Image.fromarray(mat["GTcls"][0]["Segmentation"][0])
+
+    def _get_boundaries_target(self, filepath: str) -> np.ndarray:
+        mat = self._loadmat(filepath)
+        return np.concatenate(
+            [np.expand_dims(mat["GTcls"][0]["Boundaries"][0][i][0].toarray(), axis=0) for i in range(self.num_classes)],
+            axis=0,
+        )
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        img = Image.open(self.images[index]).convert("RGB")
+        target = self._get_target(self.masks[index])
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+    def extra_repr(self) -> str:
+        lines = ["Image set: {image_set}", "Mode: {mode}"]
+        return "\n".join(lines).format(**self.__dict__)

+ 114 - 0
python/py/Lib/site-packages/torchvision/datasets/sbu.py

@@ -0,0 +1,114 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import check_integrity, download_and_extract_archive, download_url
+from .vision import VisionDataset
+
+
+class SBU(VisionDataset):
+    """`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where tarball
+            ``SBUCaptionedPhotoDataset.tar.gz`` exists.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    url = "https://www.cs.rice.edu/~vo9/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
+    filename = "SBUCaptionedPhotoDataset.tar.gz"
+    md5_checksum = "9aec147b3488753cf758b4d493422285"
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = True,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.loader = loader
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        # Read the caption for each photo
+        self.photos = []
+        self.captions = []
+
+        file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
+        file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
+
+        for line1, line2 in zip(open(file1), open(file2)):
+            url = line1.rstrip()
+            photo = os.path.basename(url)
+            filename = os.path.join(self.root, "dataset", photo)
+            if os.path.exists(filename):
+                caption = line2.rstrip()
+                self.photos.append(photo)
+                self.captions.append(caption)
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is a caption for the photo.
+        """
+        filename = os.path.join(self.root, "dataset", self.photos[index])
+        img = self.loader(filename)
+        if self.transform is not None:
+            img = self.transform(img)
+
+        target = self.captions[index]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        """The number of photos in the dataset."""
+        return len(self.photos)
+
+    def _check_integrity(self) -> bool:
+        """Check the md5 checksum of the downloaded tarball."""
+        root = self.root
+        fpath = os.path.join(root, self.filename)
+        if not check_integrity(fpath, self.md5_checksum):
+            return False
+        return True
+
+    def download(self) -> None:
+        """Download and extract the tarball, and download each individual photo."""
+
+        if self._check_integrity():
+            return
+
+        download_and_extract_archive(self.url, self.root, self.root, self.filename, self.md5_checksum)
+
+        # Download individual photos
+        with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
+            for line in fh:
+                url = line.rstrip()
+                try:
+                    download_url(url, os.path.join(self.root, "dataset"))
+                except OSError:
+                    # The images point to public images on Flickr.
+                    # Note: Images might be removed by users at anytime.
+                    pass

+ 92 - 0
python/py/Lib/site-packages/torchvision/datasets/semeion.py

@@ -0,0 +1,92 @@
+import os.path
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+
+from ..utils import _Image_fromarray
+from .utils import check_integrity, download_url
+from .vision import VisionDataset
+
+
+class SEMEION(VisionDataset):
+    r"""`SEMEION <http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where directory
+            ``semeion.py`` exists.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data"
+    filename = "semeion.data"
+    md5_checksum = "cb545d371d2ce14ec121470795a77432"
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = True,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        fp = os.path.join(self.root, self.filename)
+        data = np.loadtxt(fp)
+        # convert value to 8 bit unsigned integer
+        # color (white #255) the pixels
+        self.data = (data[:, :256] * 255).astype("uint8")
+        self.data = np.reshape(self.data, (-1, 16, 16))
+        self.labels = np.nonzero(data[:, 256:])[1]
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.labels[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = _Image_fromarray(img, mode="L")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_integrity(self) -> bool:
+        root = self.root
+        fpath = os.path.join(root, self.filename)
+        if not check_integrity(fpath, self.md5_checksum):
+            return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            return
+
+        root = self.root
+        download_url(self.url, root, self.filename, self.md5_checksum)

+ 105 - 0
python/py/Lib/site-packages/torchvision/datasets/stanford_cars.py

@@ -0,0 +1,105 @@
+import pathlib
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import verify_str_arg
+from .vision import VisionDataset
+
+
+class StanfordCars(VisionDataset):
+    """Stanford Cars  Dataset
+
+    The Cars dataset contains 16,185 images of 196 classes of cars. The data is
+    split into 8,144 training images and 8,041 testing images, where each class
+    has been split roughly in a 50-50 split
+
+    The original URL is https://ai.stanford.edu/~jkrause/cars/car_dataset.html,
+    the dataset isn't available online anymore.
+
+    .. note::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset
+        split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): This parameter exists for backward compatibility but it does not
+            download the dataset, since the original URL is not available anymore.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    def __init__(
+        self,
+        root: Union[str, pathlib.Path],
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+
+        try:
+            import scipy.io as sio
+        except ImportError:
+            raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self._split = verify_str_arg(split, "split", ("train", "test"))
+        self._base_folder = pathlib.Path(root) / "stanford_cars"
+        devkit = self._base_folder / "devkit"
+
+        if self._split == "train":
+            self._annotations_mat_path = devkit / "cars_train_annos.mat"
+            self._images_base_path = self._base_folder / "cars_train"
+        else:
+            self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
+            self._images_base_path = self._base_folder / "cars_test"
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found.")
+
+        self._samples = [
+            (
+                str(self._images_base_path / annotation["fname"]),
+                annotation["class"] - 1,  # Original target mapping  starts from 1, hence -1
+            )
+            for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
+        ]
+
+        self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
+        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
+        self.loader = loader
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        """Returns pil_image and class_id for given index"""
+        image_path, target = self._samples[idx]
+        image = self.loader(image_path)
+
+        if self.transform is not None:
+            image = self.transform(image)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+        return image, target
+
+    def _check_exists(self) -> bool:
+        if not (self._base_folder / "devkit").is_dir():
+            return False
+
+        return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
+
+    def download(self):
+        raise ValueError("The original URL is broken so the StanfordCars dataset cannot be downloaded anymore.")

+ 174 - 0
python/py/Lib/site-packages/torchvision/datasets/stl10.py

@@ -0,0 +1,174 @@
+import os.path
+from pathlib import Path
+from typing import Any, Callable, cast, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class STL10(VisionDataset):
+    """`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset where directory
+            ``stl10_binary`` exists.
+        split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
+            Accordingly, dataset is selected.
+        folds (int, optional): One of {0-9} or None.
+            For training, loads one of the 10 pre-defined folds of 1k samples for the
+            standard evaluation procedure. If no value is passed, loads the 5k samples.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    base_folder = "stl10_binary"
+    url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
+    filename = "stl10_binary.tar.gz"
+    tgz_md5 = "91f7769df0f17e558f3565bffb0c7dfb"
+    class_names_file = "class_names.txt"
+    folds_list_file = "fold_indices.txt"
+    train_list = [
+        ["train_X.bin", "918c2871b30a85fa023e0c44e0bee87f"],
+        ["train_y.bin", "5a34089d4802c674881badbb80307741"],
+        ["unlabeled_X.bin", "5242ba1fed5e4be9e1e742405eb56ca4"],
+    ]
+
+    test_list = [["test_X.bin", "7f263ba9f9e0b06b93213547f721ac82"], ["test_y.bin", "36f9794fa4beb8a2c72628de14fa638e"]]
+    splits = ("train", "train+unlabeled", "unlabeled", "test")
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        folds: Optional[int] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.split = verify_str_arg(split, "split", self.splits)
+        self.folds = self._verify_folds(folds)
+
+        if download:
+            self.download()
+        elif not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        # now load the picked numpy arrays
+        self.labels: Optional[np.ndarray]
+        if self.split == "train":
+            self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
+            self.labels = cast(np.ndarray, self.labels)
+            self.__load_folds(folds)
+
+        elif self.split == "train+unlabeled":
+            self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
+            self.labels = cast(np.ndarray, self.labels)
+            self.__load_folds(folds)
+            unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
+            self.data = np.concatenate((self.data, unlabeled_data))
+            self.labels = np.concatenate((self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
+
+        elif self.split == "unlabeled":
+            self.data, _ = self.__loadfile(self.train_list[2][0])
+            self.labels = np.asarray([-1] * self.data.shape[0])
+        else:  # self.split == 'test':
+            self.data, self.labels = self.__loadfile(self.test_list[0][0], self.test_list[1][0])
+
+        class_file = os.path.join(self.root, self.base_folder, self.class_names_file)
+        if os.path.isfile(class_file):
+            with open(class_file) as f:
+                self.classes = f.read().splitlines()
+
+    def _verify_folds(self, folds: Optional[int]) -> Optional[int]:
+        if folds is None:
+            return folds
+        elif isinstance(folds, int):
+            if folds in range(10):
+                return folds
+            msg = "Value for argument folds should be in the range [0, 10), but got {}."
+            raise ValueError(msg.format(folds))
+        else:
+            msg = "Expected type None or int for argument folds, but got type {}."
+            raise ValueError(msg.format(type(folds)))
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        target: Optional[int]
+        if self.labels is not None:
+            img, target = self.data[index], int(self.labels[index])
+        else:
+            img, target = self.data[index], None
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(np.transpose(img, (1, 2, 0)))
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return self.data.shape[0]
+
+    def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> tuple[np.ndarray, Optional[np.ndarray]]:
+        labels = None
+        if labels_file:
+            path_to_labels = os.path.join(self.root, self.base_folder, labels_file)
+            with open(path_to_labels, "rb") as f:
+                labels = np.fromfile(f, dtype=np.uint8) - 1  # 0-based
+
+        path_to_data = os.path.join(self.root, self.base_folder, data_file)
+        with open(path_to_data, "rb") as f:
+            # read whole file in uint8 chunks
+            everything = np.fromfile(f, dtype=np.uint8)
+            images = np.reshape(everything, (-1, 3, 96, 96))
+            images = np.transpose(images, (0, 1, 3, 2))
+
+        return images, labels
+
+    def _check_integrity(self) -> bool:
+        for filename, md5 in self.train_list + self.test_list:
+            fpath = os.path.join(self.root, self.base_folder, filename)
+            if not check_integrity(fpath, md5):
+                return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            return
+        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
+        self._check_integrity()
+
+    def extra_repr(self) -> str:
+        return "Split: {split}".format(**self.__dict__)
+
+    def __load_folds(self, folds: Optional[int]) -> None:
+        # loads one of the folds if specified
+        if folds is None:
+            return
+        path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file)
+        with open(path_to_folds) as f:
+            str_idx = f.read().splitlines()[folds]
+            list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ")
+            self.data = self.data[list_idx, :, :, :]
+            if self.labels is not None:
+                self.labels = self.labels[list_idx]

+ 81 - 0
python/py/Lib/site-packages/torchvision/datasets/sun397.py

@@ -0,0 +1,81 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import download_and_extract_archive
+from .vision import VisionDataset
+
+
+class SUN397(VisionDataset):
+    """`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.
+
+    The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
+    397 categories with 108'754 images.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the dataset.
+        transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        loader (callable, optional): A function to load an image given its path.
+            By default, it uses PIL as its image loader, but users could also pass in
+            ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+    """
+
+    _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
+    _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+        loader: Callable[[Union[str, Path]], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._data_dir = Path(self.root) / "SUN397"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        with open(self._data_dir / "ClassName.txt") as f:
+            self.classes = [c[3:].strip() for c in f]
+
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+        self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
+
+        self._labels = [
+            self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
+        ]
+        self.loader = loader
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = self.loader(image_file)
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def _check_exists(self) -> bool:
+        return self._data_dir.is_dir()
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)

+ 130 - 0
python/py/Lib/site-packages/torchvision/datasets/svhn.py

@@ -0,0 +1,130 @@
+import os.path
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class SVHN(VisionDataset):
+    """`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
+    Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
+    we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
+    expect the class labels to be in the range `[0, C-1]`
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load data from `.mat` format.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the dataset where the data is stored.
+        split (string): One of {'train', 'test', 'extra'}.
+            Accordingly dataset is selected. 'extra' is Extra training set.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    split_list = {
+        "train": [
+            "http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
+            "train_32x32.mat",
+            "e26dedcc434d2e4c54c9b2d4a06d8373",
+        ],
+        "test": [
+            "http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
+            "test_32x32.mat",
+            "eb5a983be6a315427106f1b164d9cef3",
+        ],
+        "extra": [
+            "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
+            "extra_32x32.mat",
+            "a93ce644f1a588dc4d68dda5feec44a7",
+        ],
+    }
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
+        self.url = self.split_list[split][0]
+        self.filename = self.split_list[split][1]
+        self.file_md5 = self.split_list[split][2]
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        # import here rather than at top of file because this is
+        # an optional dependency for torchvision
+        import scipy.io as sio
+
+        # reading(loading) mat file as array
+        loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
+
+        self.data = loaded_mat["X"]
+        # loading from the .mat file gives an np.ndarray of type np.uint8
+        # converting to np.int64, so that we have a LongTensor after
+        # the conversion from the numpy array
+        # the squeeze is needed to obtain a 1D tensor
+        self.labels = loaded_mat["y"].astype(np.int64).squeeze()
+
+        # the svhn dataset assigns the class label "10" to the digit 0
+        # this makes it inconsistent with several loss functions
+        # which expect the class labels to be in the range [0, C-1]
+        np.place(self.labels, self.labels == 10, 0)
+        self.data = np.transpose(self.data, (3, 2, 0, 1))
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.labels[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(np.transpose(img, (1, 2, 0)))
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_integrity(self) -> bool:
+        root = self.root
+        md5 = self.split_list[self.split][2]
+        fpath = os.path.join(root, self.filename)
+        return check_integrity(fpath, md5)
+
+    def download(self) -> None:
+        md5 = self.split_list[self.split][2]
+        download_url(self.url, self.root, self.filename, md5)
+
+    def extra_repr(self) -> str:
+        return "Split: {split}".format(**self.__dict__)

+ 131 - 0
python/py/Lib/site-packages/torchvision/datasets/ucf101.py

@@ -0,0 +1,131 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+class UCF101(VisionDataset):
+    """
+    `UCF101 <https://www.crcv.ucf.edu/data/UCF101.php>`_ dataset.
+
+    UCF101 is an action recognition video dataset.
+    This dataset consider every video as a collection of video clips of fixed size, specified
+    by ``frames_per_clip``, where the step in frames between each clip is given by
+    ``step_between_clips``. The dataset itself can be downloaded from the dataset website;
+    annotations that ``annotation_path`` should be pointing to can be downloaded from `here
+    <https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip>`_.
+
+    To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+    and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+    elements will come from video 1, and the next three elements from video 2.
+    Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+    frames in a video might be present.
+
+    Internally, it uses a VideoClips object to handle clip creation.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the UCF101 Dataset.
+        annotation_path (str): path to the folder containing the split files;
+            see docstring above for download instructions of these files
+        frames_per_clip (int): number of frames in a clip.
+        step_between_clips (int, optional): number of frames between each clip.
+        fold (int, optional): which fold to use. Should be between 1 and 3.
+        train (bool, optional): if ``True``, creates a dataset from the train split,
+            otherwise from the ``test`` split.
+        transform (callable, optional): A function/transform that takes in a TxHxWxC video
+            and returns a transformed version.
+        output_format (str, optional): The format of the output video tensors (before transforms).
+            Can be either "THWC" (default) or "TCHW".
+
+    Returns:
+        tuple: A 3-tuple with the following entries:
+
+            - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
+            -  audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+               and `L` is the number of points
+            - label (int): class of the video clip
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        annotation_path: str,
+        frames_per_clip: int,
+        step_between_clips: int = 1,
+        frame_rate: Optional[int] = None,
+        fold: int = 1,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        _precomputed_metadata: Optional[dict[str, Any]] = None,
+        num_workers: int = 1,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _audio_samples: int = 0,
+        output_format: str = "THWC",
+    ) -> None:
+        super().__init__(root)
+        if not 1 <= fold <= 3:
+            raise ValueError(f"fold should be between 1 and 3, got {fold}")
+
+        extensions = ("avi",)
+        self.fold = fold
+        self.train = train
+
+        self.classes, class_to_idx = find_classes(self.root)
+        self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
+        video_list = [x[0] for x in self.samples]
+        video_clips = VideoClips(
+            video_list,
+            frames_per_clip,
+            step_between_clips,
+            frame_rate,
+            _precomputed_metadata,
+            num_workers=num_workers,
+            _video_width=_video_width,
+            _video_height=_video_height,
+            _video_min_dimension=_video_min_dimension,
+            _audio_samples=_audio_samples,
+            output_format=output_format,
+        )
+        # we bookkeep the full version of video clips because we want to be able
+        # to return the metadata of full version rather than the subset version of
+        # video clips
+        self.full_video_clips = video_clips
+        self.indices = self._select_fold(video_list, annotation_path, fold, train)
+        self.video_clips = video_clips.subset(self.indices)
+        self.transform = transform
+
+    @property
+    def metadata(self) -> dict[str, Any]:
+        return self.full_video_clips.metadata
+
+    def _select_fold(self, video_list: list[str], annotation_path: str, fold: int, train: bool) -> list[int]:
+        name = "train" if train else "test"
+        name = f"{name}list{fold:02d}.txt"
+        f = os.path.join(annotation_path, name)
+        selected_files = set()
+        with open(f) as fid:
+            data = fid.readlines()
+            data = [x.strip().split(" ")[0] for x in data]
+            data = [os.path.join(self.root, *x.split("/")) for x in data]
+            selected_files.update(data)
+        indices = [i for i in range(len(video_list)) if video_list[i] in selected_files]
+        return indices
+
+    def __len__(self) -> int:
+        return self.video_clips.num_clips()
+
+    def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, int]:
+        video, audio, info, video_idx = self.video_clips.get_clip(idx)
+        label = self.samples[self.indices[video_idx]][1]
+
+        if self.transform is not None:
+            video = self.transform(video)
+
+        return video, audio, label

+ 96 - 0
python/py/Lib/site-packages/torchvision/datasets/usps.py

@@ -0,0 +1,96 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+
+from ..utils import _Image_fromarray
+from .utils import download_url
+from .vision import VisionDataset
+
+
+class USPS(VisionDataset):
+    """`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.
+    The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``.
+    The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
+    and make pixel values in ``[0, 255]``.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of dataset to store``USPS`` data files.
+        train (bool, optional): If True, creates dataset from ``usps.bz2``,
+            otherwise from ``usps.t.bz2``.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    split_list = {
+        "train": [
+            "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
+            "usps.bz2",
+            "ec16c51db3855ca6c91edd34d0e9b197",
+        ],
+        "test": [
+            "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
+            "usps.t.bz2",
+            "8ea070ee2aca1ac39742fdd1ef5ed118",
+        ],
+    }
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        split = "train" if train else "test"
+        url, filename, checksum = self.split_list[split]
+        full_path = os.path.join(self.root, filename)
+
+        if download and not os.path.exists(full_path):
+            download_url(url, self.root, filename, md5=checksum)
+
+        import bz2
+
+        with bz2.open(full_path) as fp:
+            raw_data = [line.decode().split() for line in fp.readlines()]
+            tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
+            imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
+            imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
+            targets = [int(d[0]) - 1 for d in raw_data]
+
+        self.data = imgs
+        self.targets = targets
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.targets[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = _Image_fromarray(img, mode="L")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)

+ 468 - 0
python/py/Lib/site-packages/torchvision/datasets/utils.py

@@ -0,0 +1,468 @@
+import bz2
+import gzip
+import hashlib
+import lzma
+import os
+import os.path
+import pathlib
+import re
+import tarfile
+import urllib
+import urllib.error
+import urllib.request
+import zipfile
+from collections.abc import Iterable
+from typing import Any, Callable, IO, Optional, TypeVar, Union
+from urllib.parse import urlparse
+
+import numpy as np
+import torch
+from torch.utils.model_zoo import tqdm
+
+from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
+
+USER_AGENT = "pytorch/vision"
+
+
+def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None:
+    with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
+        with open(filename, "wb") as fh, tqdm(total=response.length, unit="B", unit_scale=True) as pbar:
+            while chunk := response.read(chunk_size):
+                fh.write(chunk)
+                pbar.update(len(chunk))
+
+
+def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str:
+    # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
+    # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
+    # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
+    md5 = hashlib.md5(usedforsecurity=False)
+    with open(fpath, "rb") as f:
+        while chunk := f.read(chunk_size):
+            md5.update(chunk)
+    return md5.hexdigest()
+
+
+def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:
+    return md5 == calculate_md5(fpath, **kwargs)
+
+
+def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:
+    if not os.path.isfile(fpath):
+        return False
+    if md5 is None:
+        return True
+    return check_md5(fpath, md5)
+
+
+def _get_redirect_url(url: str, max_hops: int = 3) -> str:
+    initial_url = url
+    headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
+
+    for _ in range(max_hops + 1):
+        with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
+            if response.url == url or response.url is None:
+                return url
+
+            url = response.url
+    else:
+        raise RecursionError(
+            f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
+        )
+
+
+def _get_google_drive_file_id(url: str) -> Optional[str]:
+    parts = urlparse(url)
+
+    if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
+        return None
+
+    match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
+    if match is None:
+        return None
+
+    return match.group("id")
+
+
+def download_url(
+    url: str,
+    root: Union[str, pathlib.Path],
+    filename: Optional[Union[str, pathlib.Path]] = None,
+    md5: Optional[str] = None,
+    max_redirect_hops: int = 3,
+) -> None:
+    """Download a file from a url and place it in root.
+
+    Args:
+        url (str): URL to download file from
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under. If None, use the basename of the URL
+        md5 (str, optional): MD5 checksum of the download. If None, do not check
+        max_redirect_hops (int, optional): Maximum number of redirect hops allowed
+    """
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = os.path.basename(url)
+    fpath = os.fspath(os.path.join(root, filename))
+
+    os.makedirs(root, exist_ok=True)
+
+    # check if file is already present locally
+    if check_integrity(fpath, md5):
+        return
+
+    if _is_remote_location_available():
+        _download_file_from_remote_location(fpath, url)
+    else:
+        # expand redirect chain if needed
+        url = _get_redirect_url(url, max_hops=max_redirect_hops)
+
+        # check if file is located on Google Drive
+        file_id = _get_google_drive_file_id(url)
+        if file_id is not None:
+            return download_file_from_google_drive(file_id, root, filename, md5)
+
+        # download the file
+        try:
+            _urlretrieve(url, fpath)
+        except (urllib.error.URLError, OSError) as e:  # type: ignore[attr-defined]
+            if url[:5] == "https":
+                url = url.replace("https:", "http:")
+                _urlretrieve(url, fpath)
+            else:
+                raise e
+
+    # check integrity of downloaded file
+    if not check_integrity(fpath, md5):
+        raise RuntimeError("File not found or corrupted.")
+
+
+def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> list[str]:
+    """List all directories at a given root
+
+    Args:
+        root (str): Path to directory whose folders need to be listed
+        prefix (bool, optional): If true, prepends the path to each result, otherwise
+            only returns the name of the directories found
+    """
+    root = os.path.expanduser(root)
+    directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
+    if prefix is True:
+        directories = [os.path.join(root, d) for d in directories]
+    return directories
+
+
+def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> list[str]:
+    """List all files ending with a suffix at a given root
+
+    Args:
+        root (str): Path to directory whose folders need to be listed
+        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
+            It uses the Python "str.endswith" method and is passed directly
+        prefix (bool, optional): If true, prepends the path to each result, otherwise
+            only returns the name of the files found
+    """
+    root = os.path.expanduser(root)
+    files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
+    if prefix is True:
+        files = [os.path.join(root, d) for d in files]
+    return files
+
+
+def download_file_from_google_drive(
+    file_id: str,
+    root: Union[str, pathlib.Path],
+    filename: Optional[Union[str, pathlib.Path]] = None,
+    md5: Optional[str] = None,
+):
+    """Download a Google Drive file from  and place it in root.
+
+    Args:
+        file_id (str): id of file to be downloaded
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under. If None, use the id of the file.
+        md5 (str, optional): MD5 checksum of the download. If None, do not check
+    """
+    try:
+        import gdown
+    except ModuleNotFoundError:
+        raise RuntimeError(
+            "To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
+        )
+
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = file_id
+    fpath = os.fspath(os.path.join(root, filename))
+
+    os.makedirs(root, exist_ok=True)
+
+    if check_integrity(fpath, md5):
+        return
+
+    gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT)
+
+    if not check_integrity(fpath, md5):
+        raise RuntimeError("File not found or corrupted.")
+
+
+def _extract_tar(
+    from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
+) -> None:
+    with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
+        tar.extractall(to_path)
+
+
+_ZIP_COMPRESSION_MAP: dict[str, int] = {
+    ".bz2": zipfile.ZIP_BZIP2,
+    ".xz": zipfile.ZIP_LZMA,
+}
+
+
+def _extract_zip(
+    from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
+) -> None:
+    with zipfile.ZipFile(
+        from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
+    ) as zip:
+        zip.extractall(to_path)
+
+
+_ARCHIVE_EXTRACTORS: dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = {
+    ".tar": _extract_tar,
+    ".zip": _extract_zip,
+}
+_COMPRESSED_FILE_OPENERS: dict[str, Callable[..., IO]] = {
+    ".bz2": bz2.open,
+    ".gz": gzip.open,
+    ".xz": lzma.open,
+}
+_FILE_TYPE_ALIASES: dict[str, tuple[Optional[str], Optional[str]]] = {
+    ".tbz": (".tar", ".bz2"),
+    ".tbz2": (".tar", ".bz2"),
+    ".tgz": (".tar", ".gz"),
+}
+
+
+def _detect_file_type(file: Union[str, pathlib.Path]) -> tuple[str, Optional[str], Optional[str]]:
+    """Detect the archive type and/or compression of a file.
+
+    Args:
+        file (str): the filename
+
+    Returns:
+        (tuple): tuple of suffix, archive type, and compression
+
+    Raises:
+        RuntimeError: if file has no suffix or suffix is not supported
+    """
+    suffixes = pathlib.Path(file).suffixes
+    if not suffixes:
+        raise RuntimeError(
+            f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
+        )
+    suffix = suffixes[-1]
+
+    # check if the suffix is a known alias
+    if suffix in _FILE_TYPE_ALIASES:
+        return (suffix, *_FILE_TYPE_ALIASES[suffix])
+
+    # check if the suffix is an archive type
+    if suffix in _ARCHIVE_EXTRACTORS:
+        return suffix, suffix, None
+
+    # check if the suffix is a compression
+    if suffix in _COMPRESSED_FILE_OPENERS:
+        # check for suffix hierarchy
+        if len(suffixes) > 1:
+            suffix2 = suffixes[-2]
+
+            # check if the suffix2 is an archive type
+            if suffix2 in _ARCHIVE_EXTRACTORS:
+                return suffix2 + suffix, suffix2, suffix
+
+        return suffix, None, suffix
+
+    valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
+    raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
+
+
+def _decompress(
+    from_path: Union[str, pathlib.Path],
+    to_path: Optional[Union[str, pathlib.Path]] = None,
+    remove_finished: bool = False,
+) -> pathlib.Path:
+    r"""Decompress a file.
+
+    The compression is automatically detected from the file name.
+
+    Args:
+        from_path (str): Path to the file to be decompressed.
+        to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
+        remove_finished (bool): If ``True``, remove the file after the extraction.
+
+    Returns:
+        (str): Path to the decompressed file.
+    """
+    suffix, archive_type, compression = _detect_file_type(from_path)
+    if not compression:
+        raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
+
+    if to_path is None:
+        to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else ""))
+
+    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
+    compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
+
+    with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
+        wfh.write(rfh.read())
+
+    if remove_finished:
+        os.remove(from_path)
+
+    return pathlib.Path(to_path)
+
+
+def extract_archive(
+    from_path: Union[str, pathlib.Path],
+    to_path: Optional[Union[str, pathlib.Path]] = None,
+    remove_finished: bool = False,
+) -> Union[str, pathlib.Path]:
+    """Extract an archive.
+
+    The archive type and a possible compression is automatically detected from the file name. If the file is compressed
+    but not an archive the call is dispatched to :func:`decompress`.
+
+    Args:
+        from_path (str): Path to the file to be extracted.
+        to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
+            used.
+        remove_finished (bool): If ``True``, remove the file after the extraction.
+
+    Returns:
+        (str): Path to the directory the file was extracted to.
+    """
+
+    def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]:
+        if isinstance(from_path, str):
+            return os.fspath(ret_path)
+        else:
+            return ret_path
+
+    if to_path is None:
+        to_path = os.path.dirname(from_path)
+
+    suffix, archive_type, compression = _detect_file_type(from_path)
+    if not archive_type:
+        ret_path = _decompress(
+            from_path,
+            os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
+            remove_finished=remove_finished,
+        )
+        return path_or_str(ret_path)
+
+    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
+    extractor = _ARCHIVE_EXTRACTORS[archive_type]
+
+    extractor(from_path, to_path, compression)
+    if remove_finished:
+        os.remove(from_path)
+
+    return path_or_str(pathlib.Path(to_path))
+
+
+def download_and_extract_archive(
+    url: str,
+    download_root: Union[str, pathlib.Path],
+    extract_root: Optional[Union[str, pathlib.Path]] = None,
+    filename: Optional[Union[str, pathlib.Path]] = None,
+    md5: Optional[str] = None,
+    remove_finished: bool = False,
+) -> None:
+    download_root = os.path.expanduser(download_root)
+    if extract_root is None:
+        extract_root = download_root
+    if not filename:
+        filename = os.path.basename(url)
+
+    download_url(url, download_root, filename, md5)
+
+    archive = os.path.join(download_root, filename)
+    extract_archive(archive, extract_root, remove_finished)
+
+
+def iterable_to_str(iterable: Iterable) -> str:
+    return "'" + "', '".join([str(item) for item in iterable]) + "'"
+
+
+T = TypeVar("T", str, bytes)
+
+
+def verify_str_arg(
+    value: T,
+    arg: Optional[str] = None,
+    valid_values: Optional[Iterable[T]] = None,
+    custom_msg: Optional[str] = None,
+) -> T:
+    if not isinstance(value, str):
+        if arg is None:
+            msg = "Expected type str, but got type {type}."
+        else:
+            msg = "Expected type str for argument {arg}, but got type {type}."
+        msg = msg.format(type=type(value), arg=arg)
+        raise ValueError(msg)
+
+    if valid_values is None:
+        return value
+
+    if value not in valid_values:
+        if custom_msg is not None:
+            msg = custom_msg
+        else:
+            msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
+            msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
+        raise ValueError(msg)
+
+    return value
+
+
+def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray:
+    """Read file in .pfm format. Might contain either 1 or 3 channels of data.
+
+    Args:
+        file_name (str): Path to the file.
+        slice_channels (int): Number of channels to slice out of the file.
+            Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
+    """
+
+    with open(file_name, "rb") as f:
+        header = f.readline().rstrip()
+        if header not in [b"PF", b"Pf"]:
+            raise ValueError("Invalid PFM file")
+
+        dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
+        if not dim_match:
+            raise Exception("Malformed PFM header.")
+        w, h = (int(dim) for dim in dim_match.groups())
+
+        scale = float(f.readline().rstrip())
+        if scale < 0:  # little-endian
+            endian = "<"
+            scale = -scale
+        else:
+            endian = ">"  # big-endian
+
+        data = np.fromfile(f, dtype=endian + "f")
+
+    pfm_channels = 3 if header == b"PF" else 1
+
+    data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
+    data = np.flip(data, axis=1)  # flip on h dimension
+    data = data[:slice_channels, :, :]
+    return data.astype(np.float32)
+
+
+def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
+    return (
+        t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
+    )

+ 384 - 0
python/py/Lib/site-packages/torchvision/datasets/video_utils.py

@@ -0,0 +1,384 @@
+import bisect
+import math
+import warnings
+from typing import Any, Optional, TypeVar, Union
+
+import torch
+
+from .utils import tqdm
+
+T = TypeVar("T")
+
+
+def _get_torchcodec():
+    try:
+        import torchcodec  # type: ignore[import-not-found]
+    except ImportError:
+        raise ImportError(
+            "Video decoding capabilities were removed from torchvision and migrated "
+            "to TorchCodec. Please install TorchCodec following instructions at "
+            "https://github.com/pytorch/torchcodec#installing-torchcodec"
+        )
+    return torchcodec
+
+
+def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor:
+    """
+    similar to tensor.unfold, but with the dilation
+    and specialized for 1d tensors
+
+    Returns all consecutive windows of `size` elements, with
+    `step` between windows. The distance between each element
+    in a window is given by `dilation`.
+    """
+    if tensor.dim() != 1:
+        raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}")
+    o_stride = tensor.stride(0)
+    numel = tensor.numel()
+    new_stride = (step * o_stride, dilation * o_stride)
+    new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
+    if new_size[0] < 1:
+        new_size = (0, size)
+    return torch.as_strided(tensor, new_size, new_stride)
+
+
+class _VideoTimestampsDataset:
+    """
+    Dataset used to parallelize the reading of the timestamps
+    of a list of videos, given their paths in the filesystem.
+
+    Used in VideoClips and defined at top level, so it can be
+    pickled when forking.
+    """
+
+    def __init__(self, video_paths: list[str]) -> None:
+        self.video_paths = video_paths
+
+    def __len__(self) -> int:
+        return len(self.video_paths)
+
+    def __getitem__(self, idx: int) -> tuple[list[int], Optional[float]]:
+        torchcodec = _get_torchcodec()
+        decoder = torchcodec.decoders.VideoDecoder(self.video_paths[idx])
+        num_frames = decoder.metadata.num_frames
+        fps = decoder.metadata.average_fps
+        return list(range(num_frames)), fps
+
+
+def _collate_fn(x: T) -> T:
+    """
+    Dummy collate function to be used with _VideoTimestampsDataset
+    """
+    return x
+
+
+class VideoClips:
+    """
+    Given a list of video files, computes all consecutive subvideos of size
+    `clip_length_in_frames`, where the distance between each subvideo in the
+    same video is defined by `frames_between_clips`.
+    If `frame_rate` is specified, it will also resample all the videos to have
+    the same frame rate, and the clips will refer to this frame rate.
+
+    Creating this instance the first time is time-consuming, as it needs to
+    decode all the videos in `video_paths`. It is recommended that you
+    cache the results after instantiation of the class.
+
+    Recreating the clips for different clip lengths is fast, and can be done
+    with the `compute_clips` method.
+
+    Args:
+        video_paths (List[str]): paths to the video files
+        clip_length_in_frames (int): size of a clip in number of frames
+        frames_between_clips (int): step (in frames) between each clip
+        frame_rate (float, optional): if specified, it will resample the video
+            so that it has `frame_rate`, and then the clips will be defined
+            on the resampled video
+        num_workers (int): how many subprocesses to use for data loading.
+            0 means that the data will be loaded in the main process. (default: 0)
+        output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
+    """
+
+    def __init__(
+        self,
+        video_paths: list[str],
+        clip_length_in_frames: int = 16,
+        frames_between_clips: int = 1,
+        frame_rate: Optional[float] = None,
+        _precomputed_metadata: Optional[dict[str, Any]] = None,
+        num_workers: int = 0,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _video_max_dimension: int = 0,
+        _audio_samples: int = 0,
+        _audio_channels: int = 0,
+        output_format: str = "THWC",
+    ) -> None:
+
+        self.video_paths = video_paths
+        self.num_workers = num_workers
+
+        # these options are not valid for pyav backend
+        self._video_width = _video_width
+        self._video_height = _video_height
+        self._video_min_dimension = _video_min_dimension
+        self._video_max_dimension = _video_max_dimension
+        self._audio_samples = _audio_samples
+        self._audio_channels = _audio_channels
+        self.output_format = output_format.upper()
+        if self.output_format not in ("THWC", "TCHW"):
+            raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
+
+        if _precomputed_metadata is None:
+            self._compute_frame_pts()
+        else:
+            self._init_from_metadata(_precomputed_metadata)
+        self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
+
+    def _compute_frame_pts(self) -> None:
+        self.video_pts = []  # len = num_videos. Each entry is a tensor of shape (num_frames_in_video,)
+        self.video_fps: list[float] = []  # len = num_videos
+
+        # strategy: use a DataLoader to parallelize read_video_timestamps
+        # so need to create a dummy dataset first
+        import torch.utils.data
+
+        dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
+            _VideoTimestampsDataset(self.video_paths),  # type: ignore[arg-type]
+            batch_size=16,
+            num_workers=self.num_workers,
+            collate_fn=_collate_fn,
+        )
+
+        with tqdm(total=len(dl)) as pbar:
+            for batch in dl:
+                pbar.update(1)
+                batch_pts, batch_fps = list(zip(*batch))
+                # we need to specify dtype=torch.long because for empty list,
+                # torch.as_tensor will use torch.float as default dtype. This
+                # happens when decoding fails and no pts is returned in the list.
+                batch_pts = [torch.as_tensor(pts, dtype=torch.long) for pts in batch_pts]
+                self.video_pts.extend(batch_pts)
+                self.video_fps.extend(batch_fps)
+
+    def _init_from_metadata(self, metadata: dict[str, Any]) -> None:
+        self.video_paths = metadata["video_paths"]
+        assert len(self.video_paths) == len(metadata["video_pts"])
+        self.video_pts = metadata["video_pts"]
+        assert len(self.video_paths) == len(metadata["video_fps"])
+        self.video_fps = metadata["video_fps"]
+
+    @property
+    def metadata(self) -> dict[str, Any]:
+        _metadata = {
+            "video_paths": self.video_paths,
+            "video_pts": self.video_pts,
+            "video_fps": self.video_fps,
+        }
+        return _metadata
+
+    def subset(self, indices: list[int]) -> "VideoClips":
+        video_paths = [self.video_paths[i] for i in indices]
+        video_pts = [self.video_pts[i] for i in indices]
+        video_fps = [self.video_fps[i] for i in indices]
+        metadata = {
+            "video_paths": video_paths,
+            "video_pts": video_pts,
+            "video_fps": video_fps,
+        }
+        return type(self)(
+            video_paths,
+            clip_length_in_frames=self.num_frames,
+            frames_between_clips=self.step,
+            frame_rate=self.frame_rate,
+            _precomputed_metadata=metadata,
+            num_workers=self.num_workers,
+            _video_width=self._video_width,
+            _video_height=self._video_height,
+            _video_min_dimension=self._video_min_dimension,
+            _video_max_dimension=self._video_max_dimension,
+            _audio_samples=self._audio_samples,
+            _audio_channels=self._audio_channels,
+            output_format=self.output_format,
+        )
+
+    @staticmethod
+    def compute_clips_for_video(
+        video_pts: torch.Tensor, num_frames: int, step: int, fps: Optional[float], frame_rate: Optional[float] = None
+    ) -> tuple[torch.Tensor, Union[list[slice], torch.Tensor]]:
+        if fps is None:
+            # if for some reason the video doesn't have fps (because doesn't have a video stream)
+            # set the fps to 1. The value doesn't matter, because video_pts is empty anyway
+            fps = 1
+        if frame_rate is None:
+            frame_rate = fps
+        total_frames = len(video_pts) * frame_rate / fps
+        _idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
+        video_pts = video_pts[_idxs]
+        clips = unfold(video_pts, num_frames, step)
+        if not clips.numel():
+            warnings.warn(
+                "There aren't enough frames in the current video to get a clip for the given clip length and "
+                "frames between clips. The video (and potentially others) will be skipped."
+            )
+        idxs: Union[list[slice], torch.Tensor]
+        if isinstance(_idxs, slice):
+            idxs = [_idxs] * len(clips)
+        else:
+            idxs = unfold(_idxs, num_frames, step)
+        return clips, idxs
+
+    def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[float] = None) -> None:
+        """
+        Compute all consecutive sequences of clips from video_pts.
+        Always returns clips of size `num_frames`, meaning that the
+        last few frames in a video can potentially be dropped.
+
+        Args:
+            num_frames (int): number of frames for the clip
+            step (int): distance between two clips
+            frame_rate (int, optional): The frame rate
+        """
+        self.num_frames = num_frames
+        self.step = step
+        self.frame_rate = frame_rate
+        self.clips = []
+        self.resampling_idxs = []
+        for video_pts, fps in zip(self.video_pts, self.video_fps):
+            clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
+            self.clips.append(clips)
+            self.resampling_idxs.append(idxs)
+        clip_lengths = torch.as_tensor([len(v) for v in self.clips])
+        self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
+
+    def __len__(self) -> int:
+        return self.num_clips()
+
+    def num_videos(self) -> int:
+        return len(self.video_paths)
+
+    def num_clips(self) -> int:
+        """
+        Number of subclips that are available in the video list.
+        """
+        return self.cumulative_sizes[-1]
+
+    def get_clip_location(self, idx: int) -> tuple[int, int]:
+        """
+        Converts a flattened representation of the indices into a video_idx, clip_idx
+        representation.
+        """
+        video_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+        if video_idx == 0:
+            clip_idx = idx
+        else:
+            clip_idx = idx - self.cumulative_sizes[video_idx - 1]
+        return video_idx, clip_idx
+
+    @staticmethod
+    def _resample_video_idx(num_frames: int, original_fps: float, new_fps: float) -> Union[slice, torch.Tensor]:
+        step = original_fps / new_fps
+        if step.is_integer():
+            # optimization: if step is integer, don't need to perform
+            # advanced indexing
+            step = int(step)
+            return slice(None, None, step)
+        idxs = torch.arange(num_frames, dtype=torch.float32) * step
+        idxs = idxs.floor().to(torch.int64)
+        return idxs
+
+    def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any], int]:
+        """
+        Gets a subclip from a list of videos.
+
+        Args:
+            idx (int): index of the subclip. Must be between 0 and num_clips().
+
+        Returns:
+            video (Tensor)
+            audio (Tensor)
+            info (Dict)
+            video_idx (int): index of the video in `video_paths`
+        """
+        if idx >= self.num_clips():
+            raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)")
+        video_idx, clip_idx = self.get_clip_location(idx)
+        video_path = self.video_paths[video_idx]
+        clip_pts = self.clips[video_idx][clip_idx]
+
+        start_idx = int(clip_pts[0].item())
+        end_idx = int(clip_pts[-1].item())
+
+        torchcodec = _get_torchcodec()
+
+        dimension_order = "NHWC" if self.output_format == "THWC" else "NCHW"
+        decoder = torchcodec.decoders.VideoDecoder(video_path, dimension_order=dimension_order)
+        video = decoder.get_frames_at(indices=list(range(start_idx, end_idx + 1))).data
+
+        # Audio via TorchCodec
+        fps = decoder.metadata.average_fps
+        start_sec = start_idx / fps
+        end_sec = (end_idx + 1) / fps
+        try:
+            audio_decoder = torchcodec.decoders.AudioDecoder(video_path)
+            audio_samples = audio_decoder.get_samples_played_in_range(start_seconds=start_sec, stop_seconds=end_sec)
+            audio = audio_samples.data
+        except Exception:
+            audio = torch.empty((1, 0), dtype=torch.float32)
+
+        info = {"video_fps": fps}
+
+        if self.frame_rate is not None:
+            resampling_idx = self.resampling_idxs[video_idx][clip_idx]
+            if isinstance(resampling_idx, torch.Tensor):
+                resampling_idx = resampling_idx - resampling_idx[0]
+            video = video[resampling_idx]
+            info["video_fps"] = self.frame_rate
+        assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
+
+        return video, audio, info, video_idx
+
+    def __getstate__(self) -> dict[str, Any]:
+        video_pts_sizes = [len(v) for v in self.video_pts]
+        # To be back-compatible, we convert data to dtype torch.long as needed
+        # because for empty list, in legacy implementation, torch.as_tensor will
+        # use torch.float as default dtype. This happens when decoding fails and
+        # no pts is returned in the list.
+        video_pts = [x.to(torch.int64) for x in self.video_pts]
+        # video_pts can be an empty list if no frames have been decoded
+        if video_pts:
+            video_pts = torch.cat(video_pts)  # type: ignore[assignment]
+            # avoid bug in https://github.com/pytorch/pytorch/issues/32351
+            # TODO: Revert it once the bug is fixed.
+            video_pts = video_pts.numpy()  # type: ignore[attr-defined]
+
+        # make a copy of the fields of self
+        d = self.__dict__.copy()
+        d["video_pts_sizes"] = video_pts_sizes
+        d["video_pts"] = video_pts
+        # delete the following attributes to reduce the size of dictionary. They
+        # will be re-computed in "__setstate__()"
+        del d["clips"]
+        del d["resampling_idxs"]
+        del d["cumulative_sizes"]
+
+        # for backwards-compatibility
+        d["_version"] = 2
+        return d
+
+    def __setstate__(self, d: dict[str, Any]) -> None:
+        # for backwards-compatibility
+        if "_version" not in d:
+            self.__dict__ = d
+            return
+
+        video_pts = torch.as_tensor(d["video_pts"], dtype=torch.int64)
+        video_pts = torch.split(video_pts, d["video_pts_sizes"], dim=0)
+        # don't need this info anymore
+        del d["video_pts_sizes"]
+
+        d["video_pts"] = video_pts
+        self.__dict__ = d
+        # recompute attributes "clips", "resampling_idxs" and other derivative ones
+        self.compute_clips(self.num_frames, self.step, self.frame_rate)

+ 111 - 0
python/py/Lib/site-packages/torchvision/datasets/vision.py

@@ -0,0 +1,111 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import torch.utils.data as data
+
+from ..utils import _log_api_usage_once
+
+
+class VisionDataset(data.Dataset):
+    """
+    Base Class For making datasets which are compatible with torchvision.
+    It is necessary to override the ``__getitem__`` and ``__len__`` method.
+
+    Args:
+        root (string, optional): Root directory of dataset. Only used for `__repr__`.
+        transforms (callable, optional): A function/transforms that takes in
+            an image and a label and returns the transformed versions of both.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+
+    .. note::
+
+        :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
+    """
+
+    _repr_indent = 4
+
+    def __init__(
+        self,
+        root: Union[str, Path] = None,  # type: ignore[assignment]
+        transforms: Optional[Callable] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        _log_api_usage_once(self)
+        if isinstance(root, str):
+            root = os.path.expanduser(root)
+        self.root = root
+
+        has_transforms = transforms is not None
+        has_separate_transform = transform is not None or target_transform is not None
+        if has_transforms and has_separate_transform:
+            raise ValueError("Only transforms or transform/target_transform can be passed as argument")
+
+        # for backwards-compatibility
+        self.transform = transform
+        self.target_transform = target_transform
+
+        if has_separate_transform:
+            transforms = StandardTransform(transform, target_transform)
+        self.transforms = transforms
+
+    def __getitem__(self, index: int) -> Any:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            (Any): Sample and meta data, optionally transformed by the respective transforms.
+        """
+        raise NotImplementedError
+
+    def __len__(self) -> int:
+        raise NotImplementedError
+
+    def __repr__(self) -> str:
+        head = "Dataset " + self.__class__.__name__
+        body = [f"Number of datapoints: {self.__len__()}"]
+        if self.root is not None:
+            body.append(f"Root location: {self.root}")
+        body += self.extra_repr().splitlines()
+        if hasattr(self, "transforms") and self.transforms is not None:
+            body += [repr(self.transforms)]
+        lines = [head] + [" " * self._repr_indent + line for line in body]
+        return "\n".join(lines)
+
+    def _format_transform_repr(self, transform: Callable, head: str) -> list[str]:
+        lines = transform.__repr__().splitlines()
+        return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
+
+    def extra_repr(self) -> str:
+        return ""
+
+
+class StandardTransform:
+    def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
+        self.transform = transform
+        self.target_transform = target_transform
+
+    def __call__(self, input: Any, target: Any) -> tuple[Any, Any]:
+        if self.transform is not None:
+            input = self.transform(input)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+        return input, target
+
+    def _format_transform_repr(self, transform: Callable, head: str) -> list[str]:
+        lines = transform.__repr__().splitlines()
+        return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
+
+    def __repr__(self) -> str:
+        body = [self.__class__.__name__]
+        if self.transform is not None:
+            body += self._format_transform_repr(self.transform, "Transform: ")
+        if self.target_transform is not None:
+            body += self._format_transform_repr(self.target_transform, "Target transform: ")
+
+        return "\n".join(body)

+ 224 - 0
python/py/Lib/site-packages/torchvision/datasets/voc.py

@@ -0,0 +1,224 @@
+import collections
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+from xml.etree.ElementTree import Element as ET_Element
+
+try:
+    from defusedxml.ElementTree import parse as ET_parse
+except ImportError:
+    from xml.etree.ElementTree import parse as ET_parse
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+DATASET_YEAR_DICT = {
+    "2012": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
+        "filename": "VOCtrainval_11-May-2012.tar",
+        "md5": "6cd6e144f989b92b3379bac3b3de84fd",
+        "base_dir": os.path.join("VOCdevkit", "VOC2012"),
+    },
+    "2011": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
+        "filename": "VOCtrainval_25-May-2011.tar",
+        "md5": "6c3384ef61512963050cb5d687e5bf1e",
+        "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
+    },
+    "2010": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
+        "filename": "VOCtrainval_03-May-2010.tar",
+        "md5": "da459979d0c395079b5c75ee67908abb",
+        "base_dir": os.path.join("VOCdevkit", "VOC2010"),
+    },
+    "2009": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
+        "filename": "VOCtrainval_11-May-2009.tar",
+        "md5": "a3e00b113cfcfebf17e343f59da3caa1",
+        "base_dir": os.path.join("VOCdevkit", "VOC2009"),
+    },
+    "2008": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
+        "filename": "VOCtrainval_11-May-2012.tar",
+        "md5": "2629fa636546599198acfcfbfcf1904a",
+        "base_dir": os.path.join("VOCdevkit", "VOC2008"),
+    },
+    "2007": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
+        "filename": "VOCtrainval_06-Nov-2007.tar",
+        "md5": "c52e279531787c972589f7e41ab4ae64",
+        "base_dir": os.path.join("VOCdevkit", "VOC2007"),
+    },
+    "2007-test": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
+        "filename": "VOCtest_06-Nov-2007.tar",
+        "md5": "b6e924de25625d8de591ea690078ad9f",
+        "base_dir": os.path.join("VOCdevkit", "VOC2007"),
+    },
+}
+
+
+class _VOCBase(VisionDataset):
+    _SPLITS_DIR: str
+    _TARGET_DIR: str
+    _TARGET_FILE_EXT: str
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        year: str = "2012",
+        image_set: str = "train",
+        download: bool = False,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+    ):
+        super().__init__(root, transforms, transform, target_transform)
+
+        self.year = verify_str_arg(year, "year", valid_values=[str(yr) for yr in range(2007, 2013)])
+
+        valid_image_sets = ["train", "trainval", "val"]
+        if year == "2007":
+            valid_image_sets.append("test")
+        self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
+
+        key = "2007-test" if year == "2007" and image_set == "test" else year
+        dataset_year_dict = DATASET_YEAR_DICT[key]
+
+        self.url = dataset_year_dict["url"]
+        self.filename = dataset_year_dict["filename"]
+        self.md5 = dataset_year_dict["md5"]
+
+        base_dir = dataset_year_dict["base_dir"]
+        voc_root = os.path.join(self.root, base_dir)
+
+        if download:
+            download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
+
+        if not os.path.isdir(voc_root):
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
+        split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
+        with open(os.path.join(split_f)) as f:
+            file_names = [x.strip() for x in f.readlines()]
+
+        image_dir = os.path.join(voc_root, "JPEGImages")
+        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+
+        target_dir = os.path.join(voc_root, self._TARGET_DIR)
+        self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
+
+        assert len(self.images) == len(self.targets)
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+
+class VOCSegmentation(_VOCBase):
+    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the VOC Dataset.
+        year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
+        image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
+            ``year=="2007"``, can also be ``"test"``.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+    """
+
+    _SPLITS_DIR = "Segmentation"
+    _TARGET_DIR = "SegmentationClass"
+    _TARGET_FILE_EXT = ".png"
+
+    @property
+    def masks(self) -> list[str]:
+        return self.targets
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is the image segmentation.
+        """
+        img = Image.open(self.images[index]).convert("RGB")
+        target = Image.open(self.masks[index])
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+
+class VOCDetection(_VOCBase):
+    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory of the VOC Dataset.
+        year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
+        image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
+            ``year=="2007"``, can also be ``"test"``.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+            (default: alphabetic indexing of VOC's 20 classes).
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, required): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+    """
+
+    _SPLITS_DIR = "Main"
+    _TARGET_DIR = "Annotations"
+    _TARGET_FILE_EXT = ".xml"
+
+    @property
+    def annotations(self) -> list[str]:
+        return self.targets
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is a dictionary of the XML tree.
+        """
+        img = Image.open(self.images[index]).convert("RGB")
+        target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+    @staticmethod
+    def parse_voc_xml(node: ET_Element) -> dict[str, Any]:
+        voc_dict: dict[str, Any] = {}
+        children = list(node)
+        if children:
+            def_dic: dict[str, Any] = collections.defaultdict(list)
+            for dc in map(VOCDetection.parse_voc_xml, children):
+                for ind, v in dc.items():
+                    def_dic[ind].append(v)
+            if node.tag == "annotation":
+                def_dic["object"] = [def_dic["object"]]
+            voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
+        if node.text:
+            text = node.text.strip()
+            if not children:
+                voc_dict[node.tag] = text
+        return voc_dict

+ 196 - 0
python/py/Lib/site-packages/torchvision/datasets/widerface.py

@@ -0,0 +1,196 @@
+import os
+from os.path import abspath, expanduser
+from pathlib import Path
+
+from typing import Any, Callable, Optional, Union
+
+import torch
+from PIL import Image
+
+from .utils import download_and_extract_archive, download_file_from_google_drive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class WIDERFace(VisionDataset):
+    """`WIDERFace <http://shuoyang1213.me/WIDERFACE/>`_ Dataset.
+
+    Args:
+        root (str or ``pathlib.Path``): Root directory where images and annotations are downloaded to.
+            Expects the following folder structure if download=False:
+
+            .. code::
+
+                <root>
+                    └── widerface
+                        ├── wider_face_split ('wider_face_split.zip' if compressed)
+                        ├── WIDER_train ('WIDER_train.zip' if compressed)
+                        ├── WIDER_val ('WIDER_val.zip' if compressed)
+                        └── WIDER_test ('WIDER_test.zip' if compressed)
+        split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
+            Defaults to ``train``.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+            .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+
+    """
+
+    BASE_FOLDER = "widerface"
+    FILE_LIST = [
+        # File ID                             MD5 Hash                            Filename
+        ("15hGDLhsx8bLgLcIRD5DhYt5iBxnjNF1M", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
+        ("1GUCogbp16PMGa39thoMMeWxp7Rp5oM8Q", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
+        ("1HIfDbVEWKmsYKJZm4lchTBDLW5N7dY5T", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
+    ]
+    ANNOTATIONS_FILE = (
+        "http://shuoyang1213.me/WIDERFACE/support/bbx_annotation/wider_face_split.zip",
+        "0e3767bcf0e326556d407bf5bff5d27c",
+        "wider_face_split.zip",
+    )
+
+    def __init__(
+        self,
+        root: Union[str, Path],
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(
+            root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
+        )
+        # check arguments
+        self.split = verify_str_arg(split, "split", ("train", "val", "test"))
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download and prepare it")
+
+        self.img_info: list[dict[str, Union[str, dict[str, torch.Tensor]]]] = []
+        if self.split in ("train", "val"):
+            self.parse_train_val_annotations_file()
+        else:
+            self.parse_test_annotations_file()
+
+    def __getitem__(self, index: int) -> tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is a dict of annotations for all faces in the image.
+            target=None for the test split.
+        """
+
+        # stay consistent with other datasets and return a PIL Image
+        img = Image.open(self.img_info[index]["img_path"])  # type: ignore[arg-type]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        target = None if self.split == "test" else self.img_info[index]["annotations"]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.img_info)
+
+    def extra_repr(self) -> str:
+        lines = ["Split: {split}"]
+        return "\n".join(lines).format(**self.__dict__)
+
+    def parse_train_val_annotations_file(self) -> None:
+        filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
+        filepath = os.path.join(self.root, "wider_face_split", filename)
+
+        with open(filepath) as f:
+            lines = f.readlines()
+            file_name_line, num_boxes_line, box_annotation_line = True, False, False
+            num_boxes, box_counter = 0, 0
+            labels = []
+            for line in lines:
+                line = line.rstrip()
+                if file_name_line:
+                    img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
+                    img_path = abspath(expanduser(img_path))
+                    file_name_line = False
+                    num_boxes_line = True
+                elif num_boxes_line:
+                    num_boxes = int(line)
+                    num_boxes_line = False
+                    box_annotation_line = True
+                elif box_annotation_line:
+                    box_counter += 1
+                    line_split = line.split(" ")
+                    line_values = [int(x) for x in line_split]
+                    labels.append(line_values)
+                    if box_counter >= num_boxes:
+                        box_annotation_line = False
+                        file_name_line = True
+                        labels_tensor = torch.tensor(labels)
+                        self.img_info.append(
+                            {
+                                "img_path": img_path,
+                                "annotations": {
+                                    "bbox": labels_tensor[:, 0:4].clone(),  # x, y, width, height
+                                    "blur": labels_tensor[:, 4].clone(),
+                                    "expression": labels_tensor[:, 5].clone(),
+                                    "illumination": labels_tensor[:, 6].clone(),
+                                    "occlusion": labels_tensor[:, 7].clone(),
+                                    "pose": labels_tensor[:, 8].clone(),
+                                    "invalid": labels_tensor[:, 9].clone(),
+                                },
+                            }
+                        )
+                        box_counter = 0
+                        labels.clear()
+                else:
+                    raise RuntimeError(f"Error parsing annotation file {filepath}")
+
+    def parse_test_annotations_file(self) -> None:
+        filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
+        filepath = abspath(expanduser(filepath))
+        with open(filepath) as f:
+            lines = f.readlines()
+            for line in lines:
+                line = line.rstrip()
+                img_path = os.path.join(self.root, "WIDER_test", "images", line)
+                img_path = abspath(expanduser(img_path))
+                self.img_info.append({"img_path": img_path})
+
+    def _check_integrity(self) -> bool:
+        # Allow original archive to be deleted (zip). Only need the extracted images
+        all_files = self.FILE_LIST.copy()
+        all_files.append(self.ANNOTATIONS_FILE)
+        for _, md5, filename in all_files:
+            file, ext = os.path.splitext(filename)
+            extracted_dir = os.path.join(self.root, file)
+            if not os.path.exists(extracted_dir):
+                return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            return
+
+        # download and extract image data
+        for file_id, md5, filename in self.FILE_LIST:
+            download_file_from_google_drive(file_id, self.root, filename, md5)
+            filepath = os.path.join(self.root, filename)
+            extract_archive(filepath)
+
+        # download and extract annotation files
+        download_and_extract_archive(
+            url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
+        )

+ 76 - 0
python/py/Lib/site-packages/torchvision/extension.py

@@ -0,0 +1,76 @@
+import os
+
+import torch
+
+from ._internally_replaced_utils import _get_extension_path
+
+
+def _load_library(lib_name):
+    """Load a library, optionally warning on failure based on env variable.
+
+    Returns True if the library was loaded successfully, False otherwise.
+    """
+    try:
+        lib_path = _get_extension_path(lib_name)
+        torch.ops.load_library(lib_path)
+        return True
+    except (ImportError, OSError) as e:
+        if os.environ.get("TORCHVISION_WARN_WHEN_EXTENSION_LOADING_FAILS"):
+            import warnings
+
+            warnings.warn(f"Failed to load '{lib_name}' extension: {type(e).__name__}: {e}")
+        return False
+
+
+def _has_ops():
+    return False
+
+
+if _load_library("_C"):
+
+    def _has_ops():  # noqa: F811
+        return True
+
+
+def _assert_has_ops():
+    if not _has_ops():
+        raise RuntimeError(
+            "Couldn't load custom C++ ops. This can happen if your PyTorch and "
+            "torchvision versions are incompatible, or if you had errors while compiling "
+            "torchvision from source. For further information on the compatible versions, check "
+            "https://github.com/pytorch/vision#installation for the compatibility matrix. "
+            "Please check your PyTorch version with torch.__version__ and your torchvision "
+            "version with torchvision.__version__ and verify if they are compatible, and if not "
+            "please reinstall torchvision so that it matches your PyTorch install. "
+            "Set TORCHVISION_WARN_WHEN_EXTENSION_LOADING_FAILS=1 and retry to get more details."
+        )
+
+
+def _check_cuda_version():
+    """
+    Make sure that CUDA versions match between the pytorch install and torchvision install
+    """
+    if not _has_ops():
+        return -1
+    from torch.version import cuda as torch_version_cuda
+
+    _version = torch.ops.torchvision._cuda_version()
+    if _version != -1 and torch_version_cuda is not None:
+        tv_version = str(_version)
+        assert int(tv_version) >= 12000, f"Unexpected CUDA version {_version}, please file a bug report."
+        tv_major = int(tv_version[0:2])
+        tv_minor = int(tv_version[3])
+        t_version = torch_version_cuda.split(".")
+        t_major = int(t_version[0])
+        t_minor = int(t_version[1])
+        if t_major != tv_major:
+            raise RuntimeError(
+                "Detected that PyTorch and torchvision were compiled with different CUDA major versions. "
+                f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
+                f"CUDA Version={tv_major}.{tv_minor}. "
+                "Please reinstall the torchvision that matches your PyTorch install."
+            )
+    return _version
+
+
+_check_cuda_version()

binární
python/py/Lib/site-packages/torchvision/image.pyd


+ 56 - 0
python/py/Lib/site-packages/torchvision/io/__init__.py

@@ -0,0 +1,56 @@
+# In fbcode, import from the fb-only location
+# For OSS, these imports would fail (video_reader not available)
+try:
+    from pytorch.vision.fb.io import (  # type: ignore[import-not-found]
+        _HAS_CPU_VIDEO_DECODER,
+        _HAS_VIDEO_OPT,
+        _probe_video_from_file,
+        _probe_video_from_memory,
+        _read_video_from_file,
+        _read_video_from_memory,
+        _read_video_timestamps_from_file,
+        _read_video_timestamps_from_memory,
+        _video_opt,
+        Timebase,
+        VideoMetaData,
+        VideoReader,
+    )
+except ImportError:
+    pass
+
+from .image import (
+    decode_avif,
+    decode_gif,
+    decode_heic,
+    decode_image,
+    decode_jpeg,
+    decode_png,
+    decode_webp,
+    encode_jpeg,
+    encode_png,
+    ImageReadMode,
+    read_file,
+    read_image,
+    write_file,
+    write_jpeg,
+    write_png,
+)
+
+
+__all__ = [
+    "ImageReadMode",
+    "decode_image",
+    "decode_jpeg",
+    "decode_png",
+    "decode_avif",
+    "decode_heic",
+    "decode_webp",
+    "decode_gif",
+    "encode_jpeg",
+    "encode_png",
+    "read_file",
+    "read_image",
+    "write_file",
+    "write_jpeg",
+    "write_png",
+]

+ 527 - 0
python/py/Lib/site-packages/torchvision/io/image.py

@@ -0,0 +1,527 @@
+from enum import Enum
+from typing import Union
+
+import torch
+
+from ..extension import _load_library
+from ..utils import _log_api_usage_once
+
+
+def _has_image_ops():
+    return False
+
+
+if _load_library("image"):
+
+    def _has_image_ops():  # noqa: F811
+        return True
+
+
+def _assert_has_image_ops():
+    if not _has_image_ops():
+        raise RuntimeError(
+            "Couldn't load the image extension. "
+            "If you built torchvision from source, make sure libjpeg and libpng were found. "
+            "Set TORCHVISION_WARN_WHEN_EXTENSION_LOADING_FAILS=1 and retry to get more details."
+        )
+
+
+class ImageReadMode(Enum):
+    """Allow automatic conversion to RGB, RGBA, etc while decoding.
+
+    .. note::
+
+        You don't need to use this struct, you can just pass strings to all
+        ``mode`` parameters, e.g. ``mode="RGB"``.
+
+    The different available modes are the following.
+
+    - UNCHANGED: loads the image as-is
+    - RGB: converts to RGB
+    - RGBA: converts to RGB with transparency (also aliased as RGB_ALPHA)
+    - GRAY: converts to grayscale
+    - GRAY_ALPHA: converts to grayscale with transparency
+
+    .. note::
+
+        Some decoders won't support all possible values, e.g. GRAY and
+        GRAY_ALPHA are only supported for PNG and JPEG images.
+    """
+
+    UNCHANGED = 0
+    GRAY = 1
+    GRAY_ALPHA = 2
+    RGB = 3
+    RGB_ALPHA = 4
+    RGBA = RGB_ALPHA  # Alias for convenience
+
+
+def read_file(path: str) -> torch.Tensor:
+    """
+    Return the bytes contents of a file as a uint8 1D Tensor.
+
+    Args:
+        path (str or ``pathlib.Path``): the path to the file to be read
+
+    Returns:
+        data (Tensor)
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(read_file)
+    _assert_has_image_ops()
+    data = torch.ops.image.read_file(str(path))
+    return data
+
+
+def write_file(filename: str, data: torch.Tensor) -> None:
+    """
+    Write the content of an uint8 1D tensor to a file.
+
+    Args:
+        filename (str or ``pathlib.Path``): the path to the file to be written
+        data (Tensor): the contents to be written to the output file
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(write_file)
+    _assert_has_image_ops()
+    torch.ops.image.write_file(str(filename), data)
+
+
+def decode_png(
+    input: torch.Tensor,
+    mode: ImageReadMode = ImageReadMode.UNCHANGED,
+    apply_exif_orientation: bool = False,
+) -> torch.Tensor:
+    """
+    Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
+
+    The values of the output tensor are in uint8 in [0, 255] for most cases. If
+    the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
+    (supported from torchvision ``0.21``). Since uint16 support is limited in
+    pytorch, we recommend calling
+    :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
+    after this function to convert the decoded image into a uint8 or float
+    tensor.
+
+    Args:
+        input (Tensor[1]): a one dimensional uint8 tensor containing
+            the raw bytes of the PNG image.
+        mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+            Default is "UNCHANGED".  See :class:`~torchvision.io.ImageReadMode`
+            for available modes.
+        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
+            Default: False.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(decode_png)
+    _assert_has_image_ops()
+    if isinstance(mode, str):
+        mode = ImageReadMode[mode.upper()]
+    output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation)
+    return output
+
+
+def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
+    """
+    Takes an input tensor in CHW layout and returns a buffer with the contents
+    of its corresponding PNG file.
+
+    Args:
+        input (Tensor[channels, image_height, image_width]): int8 image tensor of
+            ``c`` channels, where ``c`` must 3 or 1.
+        compression_level (int): Compression factor for the resulting file, it must be a number
+            between 0 and 9. Default: 6
+
+    Returns:
+        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
+            PNG file.
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(encode_png)
+    _assert_has_image_ops()
+    output = torch.ops.image.encode_png(input, compression_level)
+    return output
+
+
+def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
+    """
+    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
+    and saves it in a PNG file.
+
+    Args:
+        input (Tensor[channels, image_height, image_width]): int8 image tensor of
+            ``c`` channels, where ``c`` must be 1 or 3.
+        filename (str or ``pathlib.Path``): Path to save the image.
+        compression_level (int): Compression factor for the resulting file, it must be a number
+            between 0 and 9. Default: 6
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(write_png)
+    output = encode_png(input, compression_level)
+    write_file(filename, output)
+
+
+def decode_jpeg(
+    input: Union[torch.Tensor, list[torch.Tensor]],
+    mode: ImageReadMode = ImageReadMode.UNCHANGED,
+    device: Union[str, torch.device] = "cpu",
+    apply_exif_orientation: bool = False,
+) -> Union[torch.Tensor, list[torch.Tensor]]:
+    """Decode JPEG image(s) into 3D RGB or grayscale Tensor(s), on CPU or CUDA.
+
+    The values of the output tensor are uint8 between 0 and 255.
+
+    .. note::
+        When using a CUDA device, passing a list of tensors is more efficient than repeated individual calls to ``decode_jpeg``.
+        When using CPU the performance is equivalent.
+        The CUDA version of this function has explicitly been designed with thread-safety in mind.
+        This function does not return partial results in case of an error.
+
+    Args:
+        input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing
+            the raw bytes of the JPEG image. The tensor(s) must be on CPU,
+            regardless of the ``device`` parameter.
+        mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+            Default is "UNCHANGED".  See :class:`~torchvision.io.ImageReadMode`
+            for available modes.
+        device (str or torch.device): The device on which the decoded image will
+            be stored. If a cuda device is specified, the image will be decoded
+            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
+            supported for CUDA version >= 10.1
+
+            .. betastatus:: device parameter
+
+            .. warning::
+                There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
+                Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
+        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
+            Default: False. Only implemented for JPEG format on CPU.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width] or list[Tensor[image_channels, image_height, image_width]]):
+            The values of the output tensor(s) are uint8 between 0 and 255.
+            ``output.device`` will be set to the specified ``device``
+
+
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(decode_jpeg)
+    _assert_has_image_ops()
+    if isinstance(device, str):
+        device = torch.device(device)
+    if isinstance(mode, str):
+        mode = ImageReadMode[mode.upper()]
+
+    if isinstance(input, list):
+        if len(input) == 0:
+            raise ValueError("Input list must contain at least one element")
+        if not all(isinstance(t, torch.Tensor) for t in input):
+            raise ValueError("All elements of the input list must be tensors.")
+        if not all(t.device.type == "cpu" for t in input):
+            raise ValueError("Input list must contain tensors on CPU.")
+        if device.type == "cuda":
+            return torch.ops.image.decode_jpegs_cuda(input, mode.value, device)
+        else:
+            return [torch.ops.image.decode_jpeg(img, mode.value, apply_exif_orientation) for img in input]
+
+    else:  # input is tensor
+        if input.device.type != "cpu":
+            raise ValueError("Input tensor must be a CPU tensor")
+        if device.type == "cuda":
+            return torch.ops.image.decode_jpegs_cuda([input], mode.value, device)[0]
+        else:
+            return torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation)
+
+
+def encode_jpeg(
+    input: Union[torch.Tensor, list[torch.Tensor]], quality: int = 75
+) -> Union[torch.Tensor, list[torch.Tensor]]:
+    """Encode RGB tensor(s) into raw encoded jpeg bytes, on CPU or CUDA.
+
+    .. note::
+        Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``.
+        For CPU tensors the performance is equivalent.
+
+    Args:
+        input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]):
+            (list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3
+        quality (int): Quality of the resulting JPEG file(s). Must be a number between
+            1 and 100. Default: 75
+
+    Returns:
+        output (Tensor[1] or list[Tensor[1]]): A (list of) one dimensional uint8 tensor(s) that contain the raw bytes of the JPEG file.
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(encode_jpeg)
+    _assert_has_image_ops()
+    if quality < 1 or quality > 100:
+        raise ValueError("Image quality should be a positive number between 1 and 100")
+    if isinstance(input, list):
+        if not input:
+            raise ValueError("encode_jpeg requires at least one input tensor when a list is passed")
+        if input[0].device.type == "cuda":
+            return torch.ops.image.encode_jpegs_cuda(input, quality)
+        else:
+            return [torch.ops.image.encode_jpeg(image, quality) for image in input]
+    else:  # single input tensor
+        if input.device.type == "cuda":
+            return torch.ops.image.encode_jpegs_cuda([input], quality)[0]
+        else:
+            return torch.ops.image.encode_jpeg(input, quality)
+
+
+def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
+    """
+    Takes an input tensor in CHW layout and saves it in a JPEG file.
+
+    Args:
+        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
+            channels, where ``c`` must be 1 or 3.
+        filename (str or ``pathlib.Path``): Path to save the image.
+        quality (int): Quality of the resulting JPEG file, it must be a number
+            between 1 and 100. Default: 75
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(write_jpeg)
+    output = encode_jpeg(input, quality)
+    assert isinstance(output, torch.Tensor)  # Needed for torchscript
+    write_file(filename, output)
+
+
+def decode_image(
+    input: Union[torch.Tensor, str],
+    mode: ImageReadMode = ImageReadMode.UNCHANGED,
+    apply_exif_orientation: bool = False,
+) -> torch.Tensor:
+    """Decode an image into a uint8 tensor, from a path or from raw encoded bytes.
+
+    Currently supported image formats are jpeg, png, gif and webp.
+
+    The values of the output tensor are in uint8 in [0, 255] for most cases.
+
+    If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
+    (supported from torchvision ``0.21``). Since uint16 support is limited in
+    pytorch, we recommend calling
+    :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
+    after this function to convert the decoded image into a uint8 or float
+    tensor.
+
+    .. note::
+
+        ``decode_image()`` doesn't work yet on AVIF or HEIC images. For these
+        formats, directly call  :func:`~torchvision.io.decode_avif` or
+        :func:`~torchvision.io.decode_heic`.
+
+    Args:
+        input (Tensor or str or ``pathlib.Path``): The image to decode. If a
+            tensor is passed, it must be one dimensional uint8 tensor containing
+            the raw bytes of the image. Otherwise, this must be a path to the image file.
+        mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+            Default is "UNCHANGED".  See :class:`~torchvision.io.ImageReadMode`
+            for available modes.
+        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
+           Only applies to JPEG and PNG images. Default: False.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(decode_image)
+    _assert_has_image_ops()
+    if not isinstance(input, torch.Tensor):
+        input = read_file(str(input))
+    if isinstance(mode, str):
+        mode = ImageReadMode[mode.upper()]
+    output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
+    return output
+
+
+def read_image(
+    path: str,
+    mode: ImageReadMode = ImageReadMode.UNCHANGED,
+    apply_exif_orientation: bool = False,
+) -> torch.Tensor:
+    """[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead."""
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(read_image)
+    data = read_file(path)
+    return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)
+
+
+def decode_gif(input: torch.Tensor) -> torch.Tensor:
+    """
+    Decode a GIF image into a 3 or 4 dimensional RGB Tensor.
+
+    The values of the output tensor are uint8 between 0 and 255.
+    The output tensor has shape ``(C, H, W)`` if there is only one image in the
+    GIF, and ``(N, C, H, W)`` if there are ``N`` images.
+
+    Args:
+        input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
+            the raw bytes of the GIF image.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width] or Tensor[num_images, image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(decode_gif)
+    _assert_has_image_ops()
+    return torch.ops.image.decode_gif(input)
+
+
+def decode_webp(
+    input: torch.Tensor,
+    mode: ImageReadMode = ImageReadMode.UNCHANGED,
+) -> torch.Tensor:
+    """
+    Decode a WEBP image into a 3 dimensional RGB[A] Tensor.
+
+    The values of the output tensor are uint8 between 0 and 255.
+
+    Args:
+        input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
+            the raw bytes of the WEBP image.
+        mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+            Default is "UNCHANGED".  See :class:`~torchvision.io.ImageReadMode`
+            for available modes.
+
+    Returns:
+        Decoded image (Tensor[image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(decode_webp)
+    _assert_has_image_ops()
+    if isinstance(mode, str):
+        mode = ImageReadMode[mode.upper()]
+    return torch.ops.image.decode_webp(input, mode.value)
+
+
+# TODO_AVIF_HEIC: Better support for torchscript. Scripting decode_avif of
+# decode_heic currently fails, mainly because of the logic
+# _load_extra_decoders_once() (using global variables, try/except statements,
+# etc.).
+# The ops (torch.ops.extra_decoders_ns.decode_*) are otherwise torchscript-able,
+# and users who need torchscript can always just wrap those.
+
+# TODO_AVIF_HEIC: decode_image() should work for those. The key technical issue
+# we have here is that the format detection logic of decode_image() is
+# implemented in torchvision, and torchvision has zero knowledge of
+# torchvision-extra-decoders, so we cannot call the AVIF/HEIC C++ decoders
+# (those in torchvision-extra-decoders) from there.
+# A trivial check that could be done within torchvision would be to check the
+# file extension, if a path was passed. We could also just implement the
+# AVIF/HEIC detection logic in Python as a fallback, if the file detection
+# didn't find any format. In any case: properly determining whether a file is
+# HEIC is far from trivial, and relying on libmagic would probably be best
+
+
+_EXTRA_DECODERS_ALREADY_LOADED = False
+
+
+def _load_extra_decoders_once():
+    global _EXTRA_DECODERS_ALREADY_LOADED
+    if _EXTRA_DECODERS_ALREADY_LOADED:
+        return
+
+    try:
+        import torchvision_extra_decoders
+
+        # torchvision-extra-decoders only supports linux for now. BUT, users on
+        # e.g. MacOS can still install it: they will get the pure-python
+        # 0.0.0.dev version:
+        # https://pypi.org/project/torchvision-extra-decoders/0.0.0.dev0, which
+        # is a dummy version that was created to reserve the namespace on PyPI.
+        # We have to check that expose_extra_decoders() exists for those users,
+        # so we can properly error on non-Linux archs.
+        assert hasattr(torchvision_extra_decoders, "expose_extra_decoders")
+    except (AssertionError, ImportError) as e:
+        raise RuntimeError(
+            "In order to enable the AVIF and HEIC decoding capabilities of "
+            "torchvision, you need to `pip install torchvision-extra-decoders`. "
+            "Just install the package, you don't need to update your code. "
+            "This is only supported on Linux, and this feature is still in BETA stage. "
+            "Please let us know of any issue: https://github.com/pytorch/vision/issues/new/choose. "
+            "Note that `torchvision-extra-decoders` is released under the LGPL license. "
+        ) from e
+
+    # This will expose torch.ops.extra_decoders_ns.decode_avif and torch.ops.extra_decoders_ns.decode_heic
+    torchvision_extra_decoders.expose_extra_decoders()
+
+    _EXTRA_DECODERS_ALREADY_LOADED = True
+
+
+def decode_avif(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
+    """Decode an AVIF image into a 3 dimensional RGB[A] Tensor.
+
+    .. warning::
+        In order to enable the AVIF decoding capabilities of torchvision, you
+        first need to run ``pip install torchvision-extra-decoders``. Just
+        install the package, you don't need to update your code. This is only
+        supported on Linux, and this feature is still in BETA stage. Please let
+        us know of any issue:
+        https://github.com/pytorch/vision/issues/new/choose. Note that
+        `torchvision-extra-decoders
+        <https://github.com/meta-pytorch/torchvision-extra-decoders/>`_ is
+        released under the LGPL license.
+
+    The values of the output tensor are in uint8 in [0, 255] for most images. If
+    the image has a bit-depth of more than 8, then the output tensor is uint16
+    in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
+    calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
+    ``scale=True`` after this function to convert the decoded image into a uint8
+    or float tensor.
+
+    Args:
+        input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
+            the raw bytes of the AVIF image.
+        mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+            Default is "UNCHANGED".  See :class:`~torchvision.io.ImageReadMode`
+            for available modes.
+
+    Returns:
+        Decoded image (Tensor[image_channels, image_height, image_width])
+    """
+    _load_extra_decoders_once()
+    if input.dtype != torch.uint8:
+        raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}")
+    return torch.ops.extra_decoders_ns.decode_avif(input, mode.value)
+
+
+def decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
+    """Decode an HEIC image into a 3 dimensional RGB[A] Tensor.
+
+    .. warning::
+        In order to enable the HEIC decoding capabilities of torchvision, you
+        first need to run ``pip install torchvision-extra-decoders``. Just
+        install the package, you don't need to update your code. This is only
+        supported on Linux, and this feature is still in BETA stage. Please let
+        us know of any issue:
+        https://github.com/pytorch/vision/issues/new/choose. Note that
+        `torchvision-extra-decoders
+        <https://github.com/meta-pytorch/torchvision-extra-decoders/>`_ is
+        released under the LGPL license.
+
+    The values of the output tensor are in uint8 in [0, 255] for most images. If
+    the image has a bit-depth of more than 8, then the output tensor is uint16
+    in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
+    calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
+    ``scale=True`` after this function to convert the decoded image into a uint8
+    or float tensor.
+
+    Args:
+        input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
+            the raw bytes of the HEIC image.
+        mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+            Default is "UNCHANGED".  See :class:`~torchvision.io.ImageReadMode`
+            for available modes.
+
+    Returns:
+        Decoded image (Tensor[image_channels, image_height, image_width])
+    """
+    _load_extra_decoders_once()
+    if input.dtype != torch.uint8:
+        raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}")
+    return torch.ops.extra_decoders_ns.decode_heic(input, mode.value)

binární
python/py/Lib/site-packages/torchvision/jpeg8.dll


binární
python/py/Lib/site-packages/torchvision/libjpeg.dll


binární
python/py/Lib/site-packages/torchvision/libpng16.dll


Některé soubory nejsou zobrazeny, neboť je v těchto rozdílových datech změněno mnoho souborů