_dataset_viewer.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright 2026 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import shutil
  17. import subprocess
  18. from dataclasses import dataclass
  19. from typing import TYPE_CHECKING, Any, Union
  20. from . import constants
  21. from .utils import get_token
  22. if TYPE_CHECKING:
  23. import duckdb
  24. @dataclass(frozen=True)
  25. class DatasetParquetEntry:
  26. """Represents a single parquet file available for a dataset on the Hub."""
  27. config: str
  28. split: str
  29. url: str
  30. size: int
  31. def execute_raw_sql_query(sql_query: str, *, token: str | bool | None = None) -> list[dict[str, Any]]:
  32. normalized_query = sql_query.strip().rstrip(";").strip()
  33. _raise_on_forbidden_query(normalized_query)
  34. connection = None
  35. try:
  36. connection = _get_duckdb_connection(token=token)
  37. relation = connection.sql(normalized_query)
  38. if relation is None:
  39. raise ValueError("SQL query must return rows.")
  40. if isinstance(relation, _DuckDBCliRelation):
  41. # DuckDB binary => run CLI => parse JSON
  42. return relation.execute()
  43. else:
  44. # DuckDB Python API => fetch columns + rows => convert to dicts
  45. columns = tuple(column[0] for column in relation.description)
  46. rows = tuple(tuple(row) for row in relation.fetchall())
  47. return [dict(zip(columns, row)) for row in rows]
  48. finally:
  49. if connection is not None:
  50. connection.close()
  51. def _raise_on_forbidden_query(query: str) -> None:
  52. if len(query) == 0:
  53. raise ValueError("SQL query cannot be empty.")
  54. # DuckDB CLI meta-commands are dot-prefixed words (e.g. `.shell`, `.output`).
  55. # Let's forbid them for now but allow SQL expressions like `.5` that can legitimately start a line.
  56. for line in query.splitlines():
  57. stripped = line.lstrip()
  58. if stripped.startswith(".") and stripped[1:2].isalpha():
  59. raise ValueError("DuckDB CLI meta-commands are not allowed in SQL queries.")
  60. def _get_duckdb_connection(
  61. token: str | bool | None,
  62. ) -> Union["duckdb.DuckDBPyConnection", "_DuckDBCliConnection"]:
  63. try:
  64. # If DuckDB is installed as a Python package, use it!
  65. import duckdb
  66. except ImportError as error:
  67. # Otherwise, use the DuckDB CLI binary.
  68. duckdb_binary = shutil.which("duckdb")
  69. if duckdb_binary is None:
  70. raise ImportError(
  71. "DuckDB is required for `hf datasets sql`. Install the Python package with `pip install duckdb` or "
  72. "install the DuckDB CLI binary (for example `brew install duckdb`)."
  73. ) from error
  74. return _DuckDBCliConnection(binary_path=duckdb_binary, token=token)
  75. # Create a new connection (Python API).
  76. connection = duckdb.connect()
  77. try:
  78. for statement in _build_duckdb_secret_statements(token):
  79. connection.execute(statement)
  80. return connection
  81. except Exception:
  82. connection.close()
  83. raise
  84. @dataclass
  85. class _DuckDBCliConnection:
  86. """DuckDB connection.
  87. Mimics the DuckDB Python API, but runs the queries via the DuckDB CLI binary.
  88. """
  89. binary_path: str
  90. token: str | bool | None
  91. def __post_init__(self) -> None:
  92. self._setup_statements = _build_duckdb_secret_statements(self.token)
  93. def sql(self, query: str) -> "_DuckDBCliRelation":
  94. return _DuckDBCliRelation(binary_path=self.binary_path, setup_statements=self._setup_statements, query=query)
  95. def close(self) -> None:
  96. pass
  97. @dataclass
  98. class _DuckDBCliRelation:
  99. """DuckDB relation.
  100. Mimics the DuckDB Python API, but runs the queries via the DuckDB CLI binary.
  101. """
  102. binary_path: str
  103. setup_statements: list[str]
  104. query: str
  105. def execute(self) -> list[dict[str, Any]]:
  106. # Build the DuckDB CLI input.
  107. setup = []
  108. if self.setup_statements:
  109. setup = [
  110. f".output {os.devnull}",
  111. *(f"{stmt};" for stmt in self.setup_statements),
  112. ".output",
  113. ]
  114. full_query = "\n".join(setup + [self.query + ";"])
  115. # Run DuckDB binary
  116. result = subprocess.run(
  117. [self.binary_path, "-json"],
  118. input=full_query,
  119. capture_output=True,
  120. text=True,
  121. check=False,
  122. )
  123. if result.returncode != 0:
  124. error_message = result.stderr.strip() or result.stdout.strip() or "DuckDB CLI command failed."
  125. raise RuntimeError(error_message)
  126. # Parse JSON output and return
  127. return json.loads(result.stdout.strip())
  128. def _build_duckdb_secret_statements(token: str | bool | None) -> list[str]:
  129. if token is None or token is True:
  130. token = get_token()
  131. if not token:
  132. return []
  133. escaped_token = token.replace("'", "''")
  134. escaped_endpoint = constants.ENDPOINT.replace("'", "''")
  135. return [
  136. f"CREATE OR REPLACE SECRET hf_hub_token (TYPE HTTP, BEARER_TOKEN '{escaped_token}', SCOPE '{escaped_endpoint}')",
  137. f"CREATE OR REPLACE SECRET hf_token (TYPE HUGGINGFACE, TOKEN '{escaped_token}')",
  138. ]