client.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright 2026 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. from collections import deque
  16. from collections.abc import Iterator
  17. from typing import Literal, TypedDict
  18. import httpx
  19. from ..utils._headers import build_hf_headers
  20. from ..utils._http import hf_raise_for_status
  21. from .sse_client import SSEClient
  22. from .types import ApiGetReloadEventSourceData, ApiGetReloadRequest
  23. HOT_RELOADING_PORT = 7887
  24. class MultiReplicaStreamEvent(TypedDict):
  25. kind: Literal["event"]
  26. event: ApiGetReloadEventSourceData
  27. class MultiReplicaStreamReplicaHash(TypedDict):
  28. kind: Literal["replicaHash"]
  29. hash: str
  30. class MultiReplicaStreamFullMatch(TypedDict):
  31. kind: Literal["fullMatch"]
  32. class ReloadClient:
  33. def __init__(
  34. self,
  35. *,
  36. host: str,
  37. subdomain: str,
  38. replica_hash: str,
  39. token: str | None,
  40. ):
  41. base_host = host.replace(subdomain, f"{subdomain}--{HOT_RELOADING_PORT}")
  42. self.replica_hash = replica_hash
  43. self.client = httpx.Client(
  44. base_url=f"{base_host}/--replicas/+{replica_hash}",
  45. headers=build_hf_headers(token=token),
  46. )
  47. def get_reload(self, reload_id: str) -> Iterator[ApiGetReloadEventSourceData]:
  48. req = ApiGetReloadRequest(reloadId=reload_id)
  49. with self.client.stream("POST", "/get-reload", json=req) as res:
  50. hf_raise_for_status(res)
  51. for event in SSEClient(res.iter_bytes()).events():
  52. if event.event == "message":
  53. yield json.loads(event.data)
  54. def multi_replica_reload_events(
  55. commit_sha: str,
  56. host: str,
  57. subdomain: str,
  58. replica_hashes: list[str],
  59. token: str | None,
  60. ) -> Iterator[MultiReplicaStreamEvent | MultiReplicaStreamReplicaHash | MultiReplicaStreamFullMatch]:
  61. clients = [
  62. ReloadClient(
  63. host=host,
  64. subdomain=subdomain,
  65. replica_hash=hash,
  66. token=token,
  67. )
  68. for hash in replica_hashes
  69. ]
  70. first_client_events: dict[int, ApiGetReloadEventSourceData] = {}
  71. for client_index, client in enumerate(clients):
  72. if len(clients) > 1:
  73. yield {"kind": "replicaHash", "hash": client.replica_hash}
  74. full_match = True
  75. replay: deque[ApiGetReloadEventSourceData] = deque()
  76. for event_index, event in enumerate(client.get_reload(commit_sha)):
  77. if client_index == 0:
  78. first_client_events[event_index] = event
  79. elif full_match := full_match and first_client_events.get(event_index) == event:
  80. replay.append(event)
  81. continue
  82. while replay:
  83. yield {"kind": "event", "event": replay.popleft()}
  84. yield {"kind": "event", "event": event}
  85. if client_index > 0 and full_match:
  86. yield {"kind": "fullMatch"}