timed_input.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. """timed_input: add a timeout to standard input.
  2. Approach was inspired by: https://github.com/johejo/inputimeout
  3. """
  4. import sys
  5. import threading
  6. import wandb
  7. SP = " "
  8. CR = "\r"
  9. LF = "\n"
  10. CRLF = CR + LF
  11. def _echo(prompt: str, *, err: bool) -> None:
  12. stream = sys.stderr if err else sys.stdout
  13. stream.write(prompt)
  14. stream.flush()
  15. def _posix_timed_input(prompt: str, timeout: float, err: bool) -> str:
  16. _echo(prompt, err=err)
  17. sel = selectors.DefaultSelector()
  18. sel.register(sys.stdin, selectors.EVENT_READ, data=sys.stdin.readline)
  19. events = sel.select(timeout=timeout)
  20. for key, _ in events:
  21. input_callback = key.data
  22. input_data: str = input_callback()
  23. if not input_data: # end-of-file - treat as timeout
  24. raise TimeoutError
  25. return input_data.rstrip(LF)
  26. _echo(LF, err=err)
  27. termios.tcflush(sys.stdin, termios.TCIFLUSH)
  28. raise TimeoutError
  29. def _windows_timed_input(prompt: str, timeout: float, err: bool) -> str:
  30. interval = 0.1
  31. _echo(prompt, err=err)
  32. begin = time.monotonic()
  33. end = begin + timeout
  34. line = ""
  35. while time.monotonic() < end:
  36. if msvcrt.kbhit(): # type: ignore[attr-defined]
  37. c = msvcrt.getwche() # type: ignore[attr-defined]
  38. if c in (CR, LF):
  39. _echo(CRLF, err=err)
  40. return line
  41. if c == "\003":
  42. raise KeyboardInterrupt
  43. if c == "\b":
  44. line = line[:-1]
  45. cover = SP * len(prompt + line + SP)
  46. _echo("".join([CR, cover, CR, prompt, line]), err=err)
  47. else:
  48. line += c
  49. time.sleep(interval)
  50. _echo(CRLF, err=err)
  51. raise TimeoutError
  52. def _jupyter_timed_input(prompt: str, timeout: float, err: bool) -> str:
  53. clear = True
  54. try:
  55. from IPython.core.display import clear_output # type: ignore
  56. except ImportError:
  57. clear = False
  58. wandb.termwarn(
  59. "Unable to clear output, can't import clear_output from ipython.core"
  60. )
  61. _echo(prompt, err=err)
  62. user_inp = None
  63. event = threading.Event()
  64. def get_input() -> None:
  65. nonlocal user_inp
  66. raw = input()
  67. if event.is_set():
  68. return
  69. user_inp = raw
  70. t = threading.Thread(target=get_input)
  71. t.start()
  72. t.join(timeout)
  73. event.set()
  74. if user_inp:
  75. return user_inp
  76. if clear:
  77. clear_output()
  78. raise TimeoutError
  79. def timed_input(
  80. prompt: str,
  81. timeout: float,
  82. show_timeout: bool = True,
  83. jupyter: bool = False,
  84. err: bool = False,
  85. ) -> str:
  86. """Behaves like builtin `input()` but adds timeout.
  87. Args:
  88. prompt: Prompt to output to stdout.
  89. timeout: Timeout to wait for input.
  90. show_timeout: Show timeout in prompt
  91. jupyter: If True, use jupyter specific code.
  92. err: If True, use stderr instead of stdout.
  93. Raises:
  94. TimeoutError: If a timeout occurred.
  95. KeyboardInterrupt: If the user aborted by pressing Ctrl+C.
  96. """
  97. if show_timeout:
  98. prompt = f"{prompt}({timeout:.0f} second timeout) "
  99. if jupyter:
  100. return _jupyter_timed_input(prompt=prompt, timeout=timeout, err=err)
  101. return _timed_input(prompt=prompt, timeout=timeout, err=err)
  102. try:
  103. import msvcrt
  104. except ImportError:
  105. import selectors
  106. import termios
  107. _timed_input = _posix_timed_input
  108. else:
  109. import time
  110. _timed_input = _windows_timed_input