| 12345678910111213141516171819202122232425 |
- import threading
- from typing import Any, Optional
- import torch._C._lazy
- class DeviceContext:
- _CONTEXTS: dict[str, Any] = {}
- _CONTEXTS_LOCK = threading.Lock()
- def __init__(self, device: str) -> None:
- self.device = device
- def get_device_context(device: Optional[str] = None) -> DeviceContext:
- if device is None:
- device = torch._C._lazy._get_default_device_type()
- else:
- device = str(device)
- with DeviceContext._CONTEXTS_LOCK:
- devctx = DeviceContext._CONTEXTS.get(device, None)
- if devctx is None:
- devctx = DeviceContext(device)
- DeviceContext._CONTEXTS[device] = devctx
- return devctx
|