#!/usr/bin/env python """Finite state machine. Simple FSM implementation. Usage: ```python class A: def on_output(self, inputs) -> None: pass class B: def on_output(self, inputs) -> None: pass def to_b(inputs) -> bool: return True def to_a(inputs) -> bool: return True f = Fsm(states=[A(), B()], table={A: [(to_b, B)], B: [(to_a, A)]}) f.run({"input1": 1, "input2": 2}) ``` """ from __future__ import annotations from abc import abstractmethod from collections.abc import Sequence from dataclasses import dataclass from typing import Callable, Generic, Union from typing_extensions import Protocol, TypeAlias, TypeVar, runtime_checkable T_FsmInputs = TypeVar("T_FsmInputs", contravariant=True) T_FsmContext = TypeVar("T_FsmContext") T_FsmContext_cov = TypeVar("T_FsmContext_cov", covariant=True) T_FsmContext_contra = TypeVar("T_FsmContext_contra", contravariant=True) @runtime_checkable class FsmStateCheck(Protocol[T_FsmInputs]): @abstractmethod def on_check(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover @runtime_checkable class FsmStateOutput(Protocol[T_FsmInputs]): @abstractmethod def on_state(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover @runtime_checkable class FsmStateEnter(Protocol[T_FsmInputs]): @abstractmethod def on_enter(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover @runtime_checkable class FsmStateEnterWithContext(Protocol[T_FsmInputs, T_FsmContext_contra]): @abstractmethod def on_enter( self, inputs: T_FsmInputs, context: T_FsmContext_contra ) -> None: ... # pragma: no cover @runtime_checkable class FsmStateStay(Protocol[T_FsmInputs]): @abstractmethod def on_stay(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover @runtime_checkable class FsmStateExit(Protocol[T_FsmInputs, T_FsmContext_cov]): @abstractmethod def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov: ... # pragma: no cover # It would be nice if python provided optional protocol members, but it does not as described here: # https://peps.python.org/pep-0544/#support-optional-protocol-members # Until then, we can only enforce that a state at least supports one protocol interface. This # unfortunately will not check the signature of other potential protocols. FsmState: TypeAlias = Union[ FsmStateCheck[T_FsmInputs], FsmStateOutput[T_FsmInputs], FsmStateEnter[T_FsmInputs], FsmStateEnterWithContext[T_FsmInputs, T_FsmContext], FsmStateStay[T_FsmInputs], FsmStateExit[T_FsmInputs, T_FsmContext], ] @dataclass class FsmEntry(Generic[T_FsmInputs, T_FsmContext]): condition: Callable[[T_FsmInputs], bool] target_state: type[FsmState[T_FsmInputs, T_FsmContext]] action: Callable[[T_FsmInputs], None] | None = None FsmTableWithContext: TypeAlias = dict[ type[FsmState[T_FsmInputs, T_FsmContext]], Sequence[FsmEntry[T_FsmInputs, T_FsmContext]], ] FsmTable: TypeAlias = FsmTableWithContext[T_FsmInputs, None] class FsmWithContext(Generic[T_FsmInputs, T_FsmContext]): _state_dict: dict[type[FsmState], FsmState] _table: FsmTableWithContext[T_FsmInputs, T_FsmContext] _state: FsmState[T_FsmInputs, T_FsmContext] _states: Sequence[FsmState] def __init__( self, states: Sequence[FsmState], table: FsmTableWithContext[T_FsmInputs, T_FsmContext], ) -> None: self._states = states self._table = table self._state_dict = {type(s): s for s in states} self._state = self._state_dict[type(states[0])] def _transition( self, inputs: T_FsmInputs, new_state: type[FsmState[T_FsmInputs, T_FsmContext]], action: Callable[[T_FsmInputs], None] | None, ) -> None: if action: action(inputs) context = None if isinstance(self._state, FsmStateExit): context = self._state.on_exit(inputs) prev_state = type(self._state) if prev_state == new_state: if isinstance(self._state, FsmStateStay): self._state.on_stay(inputs) else: self._state = self._state_dict[new_state] if context and isinstance(self._state, FsmStateEnterWithContext): self._state.on_enter(inputs, context=context) elif isinstance(self._state, FsmStateEnter): self._state.on_enter(inputs) def _check_transitions(self, inputs: T_FsmInputs) -> None: for entry in self._table[type(self._state)]: if entry.condition(inputs): self._transition(inputs, entry.target_state, entry.action) return def input(self, inputs: T_FsmInputs) -> None: if isinstance(self._state, FsmStateCheck): self._state.on_check(inputs) self._check_transitions(inputs) if isinstance(self._state, FsmStateOutput): self._state.on_state(inputs) Fsm: TypeAlias = FsmWithContext[T_FsmInputs, None]