fsm.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. #!/usr/bin/env python
  2. """Finite state machine.
  3. Simple FSM implementation.
  4. Usage:
  5. ```python
  6. class A:
  7. def on_output(self, inputs) -> None:
  8. pass
  9. class B:
  10. def on_output(self, inputs) -> None:
  11. pass
  12. def to_b(inputs) -> bool:
  13. return True
  14. def to_a(inputs) -> bool:
  15. return True
  16. f = Fsm(states=[A(), B()], table={A: [(to_b, B)], B: [(to_a, A)]})
  17. f.run({"input1": 1, "input2": 2})
  18. ```
  19. """
  20. from __future__ import annotations
  21. from abc import abstractmethod
  22. from collections.abc import Sequence
  23. from dataclasses import dataclass
  24. from typing import Callable, Generic, Union
  25. from typing_extensions import Protocol, TypeAlias, TypeVar, runtime_checkable
  26. T_FsmInputs = TypeVar("T_FsmInputs", contravariant=True)
  27. T_FsmContext = TypeVar("T_FsmContext")
  28. T_FsmContext_cov = TypeVar("T_FsmContext_cov", covariant=True)
  29. T_FsmContext_contra = TypeVar("T_FsmContext_contra", contravariant=True)
  30. @runtime_checkable
  31. class FsmStateCheck(Protocol[T_FsmInputs]):
  32. @abstractmethod
  33. def on_check(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
  34. @runtime_checkable
  35. class FsmStateOutput(Protocol[T_FsmInputs]):
  36. @abstractmethod
  37. def on_state(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
  38. @runtime_checkable
  39. class FsmStateEnter(Protocol[T_FsmInputs]):
  40. @abstractmethod
  41. def on_enter(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
  42. @runtime_checkable
  43. class FsmStateEnterWithContext(Protocol[T_FsmInputs, T_FsmContext_contra]):
  44. @abstractmethod
  45. def on_enter(
  46. self, inputs: T_FsmInputs, context: T_FsmContext_contra
  47. ) -> None: ... # pragma: no cover
  48. @runtime_checkable
  49. class FsmStateStay(Protocol[T_FsmInputs]):
  50. @abstractmethod
  51. def on_stay(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
  52. @runtime_checkable
  53. class FsmStateExit(Protocol[T_FsmInputs, T_FsmContext_cov]):
  54. @abstractmethod
  55. def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov: ... # pragma: no cover
  56. # It would be nice if python provided optional protocol members, but it does not as described here:
  57. # https://peps.python.org/pep-0544/#support-optional-protocol-members
  58. # Until then, we can only enforce that a state at least supports one protocol interface. This
  59. # unfortunately will not check the signature of other potential protocols.
  60. FsmState: TypeAlias = Union[
  61. FsmStateCheck[T_FsmInputs],
  62. FsmStateOutput[T_FsmInputs],
  63. FsmStateEnter[T_FsmInputs],
  64. FsmStateEnterWithContext[T_FsmInputs, T_FsmContext],
  65. FsmStateStay[T_FsmInputs],
  66. FsmStateExit[T_FsmInputs, T_FsmContext],
  67. ]
  68. @dataclass
  69. class FsmEntry(Generic[T_FsmInputs, T_FsmContext]):
  70. condition: Callable[[T_FsmInputs], bool]
  71. target_state: type[FsmState[T_FsmInputs, T_FsmContext]]
  72. action: Callable[[T_FsmInputs], None] | None = None
  73. FsmTableWithContext: TypeAlias = dict[
  74. type[FsmState[T_FsmInputs, T_FsmContext]],
  75. Sequence[FsmEntry[T_FsmInputs, T_FsmContext]],
  76. ]
  77. FsmTable: TypeAlias = FsmTableWithContext[T_FsmInputs, None]
  78. class FsmWithContext(Generic[T_FsmInputs, T_FsmContext]):
  79. _state_dict: dict[type[FsmState], FsmState]
  80. _table: FsmTableWithContext[T_FsmInputs, T_FsmContext]
  81. _state: FsmState[T_FsmInputs, T_FsmContext]
  82. _states: Sequence[FsmState]
  83. def __init__(
  84. self,
  85. states: Sequence[FsmState],
  86. table: FsmTableWithContext[T_FsmInputs, T_FsmContext],
  87. ) -> None:
  88. self._states = states
  89. self._table = table
  90. self._state_dict = {type(s): s for s in states}
  91. self._state = self._state_dict[type(states[0])]
  92. def _transition(
  93. self,
  94. inputs: T_FsmInputs,
  95. new_state: type[FsmState[T_FsmInputs, T_FsmContext]],
  96. action: Callable[[T_FsmInputs], None] | None,
  97. ) -> None:
  98. if action:
  99. action(inputs)
  100. context = None
  101. if isinstance(self._state, FsmStateExit):
  102. context = self._state.on_exit(inputs)
  103. prev_state = type(self._state)
  104. if prev_state == new_state:
  105. if isinstance(self._state, FsmStateStay):
  106. self._state.on_stay(inputs)
  107. else:
  108. self._state = self._state_dict[new_state]
  109. if context and isinstance(self._state, FsmStateEnterWithContext):
  110. self._state.on_enter(inputs, context=context)
  111. elif isinstance(self._state, FsmStateEnter):
  112. self._state.on_enter(inputs)
  113. def _check_transitions(self, inputs: T_FsmInputs) -> None:
  114. for entry in self._table[type(self._state)]:
  115. if entry.condition(inputs):
  116. self._transition(inputs, entry.target_state, entry.action)
  117. return
  118. def input(self, inputs: T_FsmInputs) -> None:
  119. if isinstance(self._state, FsmStateCheck):
  120. self._state.on_check(inputs)
  121. self._check_transitions(inputs)
  122. if isinstance(self._state, FsmStateOutput):
  123. self._state.on_state(inputs)
  124. Fsm: TypeAlias = FsmWithContext[T_FsmInputs, None]