from __future__ import annotations
from dataclasses import dataclass, fields, is_dataclass
from enum import StrEnum
from multiprocessing import Pipe
from multiprocessing.connection import Connection
import time
from typing import Iterable, Mapping, TYPE_CHECKING, Type
from magscope.ipc_commands import Command
if TYPE_CHECKING:
from multiprocessing.synchronize import Event as EventType
from magscope.processes import ManagerProcessBase
[docs]
class Delivery(StrEnum):
[docs]
BROADCAST = "broadcast"
[docs]
MAG_SCOPE = "mag_scope"
[docs]
class CommandRegistrationError(RuntimeError):
"""Base error for command registration problems."""
[docs]
class CommandConflictError(CommandRegistrationError):
"""Raised when a command is registered more than once with incompatible metadata."""
[docs]
class MissingCommandHandlerError(CommandRegistrationError):
"""Raised when a handler is missing for a registered command."""
[docs]
class UnknownCommandError(RuntimeError):
"""Raised when dispatch is attempted for an unknown command."""
@dataclass(frozen=True)
[docs]
class CommandSpec:
[docs]
command_type: type[Command]
@dataclass(frozen=True)
[docs]
class HandlerRegistration:
[docs]
command_type: type[Command]
[docs]
target_override: str | None = None
[docs]
def register_ipc_command(
command_type: type[Command],
*,
delivery: Delivery = Delivery.DIRECT,
target: str | None = None,
):
"""Decorator to associate an IPC command type with a handler method."""
def decorator(func):
func._ipc_command = command_type
func._ipc_delivery = delivery
func._ipc_target_override = target
return func
return decorator
[docs]
def _collect_handler_registrations(cls: Type) -> Iterable[HandlerRegistration]:
"""Yield command registrations declared on ``cls`` and its bases."""
seen: set[str] = set()
for base in cls.mro():
for name, func in base.__dict__.items():
command_type = getattr(func, "_ipc_command", None)
if command_type is None or name in seen:
continue
seen.add(name)
delivery = getattr(func, "_ipc_delivery", Delivery.DIRECT)
target_override = getattr(func, "_ipc_target_override", None)
yield HandlerRegistration(
command_type=command_type,
handler=name,
delivery=delivery,
target_override=target_override,
)
[docs]
def command_kwargs(command: Command) -> dict[str, object]:
"""Return the payload of a command as keyword arguments."""
return {field.name: getattr(command, field.name) for field in fields(command)}
[docs]
class CommandRegistry:
"""Registry mapping IPC command types to their handlers and destinations."""
def __init__(self):
[docs]
self._specs: dict[type[Command], CommandSpec] = {}
[docs]
self._handler_index: dict[tuple[str, str], type[Command]] = {}
[docs]
def register(
self,
*,
command_type: type[Command],
handler: str,
owner: Type,
delivery: Delivery,
target: str,
) -> None:
"""Register a command handler."""
if not is_dataclass(command_type):
raise TypeError(f"{command_type.__name__} must be a dataclass")
if not issubclass(command_type, Command):
raise TypeError(f"{command_type.__name__} must subclass Command")
if not hasattr(owner, handler):
raise MissingCommandHandlerError(
f"Owner {owner.__name__} missing handler {handler} for {command_type.__name__}"
)
if not target:
raise ValueError("Target cannot be empty")
spec = CommandSpec(
command_type=command_type,
handler=handler,
target=target,
delivery=delivery,
)
self._specs[command_type] = spec
handler_key = (owner.__name__, handler)
mapped_command = self._handler_index.get(handler_key)
if mapped_command is not None and mapped_command is not command_type:
raise CommandConflictError(
f"Handler {owner.__name__}.{handler} already mapped to {mapped_command.__name__}"
)
self._handler_index[handler_key] = command_type
target_key = (target, handler)
mapped_target_command = self._handler_index.get(target_key)
if mapped_target_command is not None and mapped_target_command is not command_type:
raise CommandConflictError(
f"Handler {target}.{handler} already mapped to {mapped_target_command.__name__}"
)
self._handler_index[target_key] = command_type
[docs]
def register_manager(self, manager: "ManagerProcessBase") -> None:
"""Register all decorated command handlers on ``manager``."""
target = getattr(manager, "name", type(manager).__name__)
for registration in _collect_handler_registrations(type(manager)):
target_name = registration.target_override or target
self.register(
command_type=registration.command_type,
handler=registration.handler,
owner=type(manager),
delivery=registration.delivery,
target=target_name,
)
[docs]
def register_object(self, obj: object, *, target: str | None = None) -> None:
"""Register decorated handlers on arbitrary objects (e.g., MagScope)."""
target_name = target or type(obj).__name__
for registration in _collect_handler_registrations(type(obj)):
self.register(
command_type=registration.command_type,
handler=registration.handler,
owner=type(obj),
delivery=registration.delivery,
target=registration.target_override or target_name,
)
[docs]
def route_for(self, command: Command) -> CommandSpec:
"""Return the route information for ``command``."""
spec = self._specs.get(type(command))
if spec is None:
raise UnknownCommandError(f"Command {type(command).__name__} is not registered")
return spec
[docs]
def handlers_for_target(self, target: str) -> dict[type[Command], CommandSpec]:
"""Return handler specs applicable to ``target``."""
handlers: dict[type[Command], CommandSpec] = {}
for command_type, spec in self._specs.items():
if spec.delivery == Delivery.BROADCAST or spec.target == target:
handlers[command_type] = spec
return handlers
[docs]
def validate_targets(self, processes: Mapping[str, "ManagerProcessBase"]) -> None:
"""Ensure every command has a reachable target and handler."""
for spec in self._specs.values():
if spec.delivery == Delivery.MAG_SCOPE:
continue
if spec.delivery == Delivery.DIRECT:
process = processes.get(spec.target)
if process is None:
raise MissingCommandHandlerError(
f"Command {spec.command_type.__name__} targets unknown process {spec.target}"
)
if not hasattr(process, spec.handler):
raise MissingCommandHandlerError(
f"Process {spec.target} missing handler {spec.handler} "
f"for command {spec.command_type.__name__}"
)
if spec.delivery == Delivery.BROADCAST:
missing = [
name
for name, proc in processes.items()
if not hasattr(proc, spec.handler)
]
if missing:
raise MissingCommandHandlerError(
f"Command {spec.command_type.__name__} has no handler "
f"{spec.handler} on processes: {', '.join(sorted(missing))}"
)
[docs]
def command_for_handler(self, owner: str, handler: str) -> type[Command]:
"""Return the command type bound to ``owner.handler``."""
key = (owner, handler)
command_type = self._handler_index.get(key)
if command_type is None:
raise UnknownCommandError(f"No command registered for {owner}.{handler}")
return command_type
[docs]
def create_pipes(
processes: Mapping[str, "ManagerProcessBase"],
) -> tuple[dict[str, Connection], dict[str, Connection]]:
"""Create duplex pipes for each managed process.
Returns a pair of dictionaries mapping process names to the parent and child
pipe ends, respectively. The parent ends are intended to be owned by the
coordinating ``MagScope`` instance while the child ends are passed to
individual manager processes.
"""
parent_ends: dict[str, Connection] = {}
child_ends: dict[str, Connection] = {}
for name in processes:
parent_end, child_end = Pipe()
parent_ends[name] = parent_end
child_ends[name] = child_end
return parent_ends, child_ends
[docs]
def broadcast_command(
command: Command,
*,
pipes: Mapping[str, Connection],
processes: Mapping[str, "ManagerProcessBase"],
quitting_events: Mapping[str, "EventType"],
) -> None:
"""Send a command to all running, non-quitting processes."""
for name, pipe in pipes.items():
if processes[name].is_alive() and not quitting_events[name].is_set():
pipe.send(command)
[docs]
def drain_pipe_until_quit(
pipe: Connection,
quitting_event: "EventType",
*,
poll_interval: float | None = 0.0,
) -> None:
"""Drain a pipe until the paired quit event is set."""
while not quitting_event.is_set():
if pipe.poll():
pipe.recv()
elif poll_interval:
time.sleep(poll_interval)