from __future__ import annotations
from abc import ABCMeta, abstractmethod
from datetime import datetime
from time import sleep, time
from typing import TYPE_CHECKING
import warnings
import matplotlib
import matplotlib.dates as mdates
import matplotlib.style as mplstyle
import matplotlib.ticker as mticker
import numpy as np
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from PyQt6.QtCore import QMutex, QObject, pyqtSignal
from PyQt6.QtGui import QImage
from magscope.datatypes import MatrixBuffer
from magscope.ui.theme import PANEL_BACKGROUND_COLOR
if TYPE_CHECKING:
from multiprocessing.synchronize import Lock as LockType
matplotlib.use('QtAgg')
mplstyle.use('dark_background')
mplstyle.use('fast')
[docs]
class PlotWorker(QObject):
[docs]
image_signal = pyqtSignal(QImage)
[docs]
limits_signal = pyqtSignal(object)
[docs]
selected_bead_signal = pyqtSignal(int)
[docs]
reference_bead_signal = pyqtSignal(int)
[docs]
stop_signal = pyqtSignal()
[docs]
time_mode_signal = pyqtSignal(str)
[docs]
relative_window_signal = pyqtSignal(object)
def __init__(self):
""" Called before the parent process is started """
super().__init__()
[docs]
self.axes: matplotlib.axes.Axes
[docs]
self.locks: dict[str, LockType]
[docs]
self.canvas: FigureCanvas
[docs]
self._is_running: bool = False
[docs]
self.limits: dict[str, tuple[float, float]] = {}
[docs]
self.selected_bead: int | None = 0
[docs]
self.reference_bead: int | None = None
[docs]
self.update_on: bool = True
[docs]
self._update_last_time: float
[docs]
self.device_pixel_ratio = 1.0
[docs]
self.time_mode = "absolute"
[docs]
self.relative_window_seconds: float | None = 300
[docs]
self._tracks_snapshot: np.ndarray | None = None
# Connect internal signal to slot
self.limits_signal.connect(self._set_limits)
self.selected_bead_signal.connect(self._set_selected_bead)
self.reference_bead_signal.connect(self._set_reference_bead)
self.stop_signal.connect(self._stop)
self.figure_size_signal.connect(self._update_figure_size)
self.time_mode_signal.connect(self._set_time_mode)
self.relative_window_signal.connect(self._set_relative_window)
# Thread safety
# Add plots for bead tracks
self.add_plot(TracksTimeSeriesPlot('X'))
self.add_plot(TracksTimeSeriesPlot('Y'))
self.add_plot(TracksTimeSeriesPlot('Z'))
[docs]
def setup(self):
self.n_plots = len(self.plots)
self.mutex = QMutex()
# Create figure and axes
self.figure = Figure(
figsize=(self.fig_width, self.fig_height),
dpi=self.dpi,
facecolor=PANEL_BACKGROUND_COLOR,
constrained_layout=True,
)
self.figure.set_constrained_layout_pads(
w_pad=0.02,
h_pad=0.0,
hspace=0.0,
wspace=0.0,
)
self.canvas = FigureCanvas(self.figure)
self.axes = self.figure.subplots(nrows=self.n_plots, ncols=1, sharex=True, sharey=False)
# Formatting to make it look good
for ax in self.axes:
ax.set_facecolor(PANEL_BACKGROUND_COLOR) # Set background color
ax.margins(x=0) # Set margins
for ax in self.axes[:-1]:
ax.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
self.axes[-1].set_xlabel('Time (h:m:s)')
self.axes[-1].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
# Pass complex objects to each plot (self, axes, ect)
for plot, ax in zip(self.plots, self.axes):
plot.set_axes(ax)
for plot in self.plots:
plot.set_parent(self)
for plot in self.plots:
plot.setup()
self._apply_time_axis_format()
[docs]
def run(self):
self._is_running = True
self._update_last_time = time()
while self._is_running:
self.do_main_loop()
[docs]
def do_main_loop(self):
# Is plotting enabled?
if not self.update_on:
return
# Wait for timer
duration = time() - self._update_last_time
sleep(10*duration)
self._update_last_time = time()
# Check if we need to recreate the figure
self._recreate_figure_if_needed()
self._tracks_snapshot = None
for plot in self.plots:
if isinstance(plot, TracksTimeSeriesPlot):
self._tracks_snapshot = plot.buffer.peak_unsorted()
break
# Update plots
for plot in self.plots:
plot.update()
self._tracks_snapshot = None
# Render figure to buffer
self.canvas.draw()
w, h = self.canvas.get_width_height()
buf = np.frombuffer(self.canvas.buffer_rgba(), dtype=np.uint8).reshape((h, w, 4))
# Convert numpy RGBA -> QImage
img = QImage(buf.data, w, h, QImage.Format.Format_RGBA8888).copy()
img.setDevicePixelRatio(self.device_pixel_ratio)
# Emit figure as a buffer to the main GUI
self.image_signal.emit(img)
[docs]
def add_plot(self, plot: TimeSeriesPlotBase):
""" Used to add plots before the process has started """
self.plots.append(plot)
[docs]
def _set_limits(self, limits: dict[str, list[float, float]]):
self.limits = limits
[docs]
def _set_selected_bead(self, bead: int):
self.selected_bead = bead
[docs]
def _set_reference_bead(self, bead: int | None):
self.reference_bead = bead
[docs]
def set_locks(self, locks: dict[str, LockType]):
self.locks = locks
[docs]
def _stop(self):
self._is_running = False
[docs]
def dispose(self) -> None:
self._is_running = False
canvas = getattr(self, 'canvas', None)
figure = getattr(self, 'figure', None)
if canvas is not None:
try:
canvas.hide()
except RuntimeError:
pass
try:
canvas.setParent(None)
except RuntimeError:
pass
if figure is not None:
try:
figure.clear()
except Exception:
pass
if canvas is not None:
try:
canvas.close()
except RuntimeError:
pass
try:
canvas.deleteLater()
except RuntimeError:
pass
self.axes = None
self.canvas = None
self.figure = None
self._tracks_snapshot = None
self.plots = []
[docs]
def _set_time_mode(self, time_mode: str):
self.time_mode = time_mode
self._apply_time_axis_format()
[docs]
def _set_relative_window(self, window_seconds: float | None):
self.relative_window_seconds = window_seconds
[docs]
class TimeSeriesPlotBase(metaclass=ABCMeta):
def __init__(self, buffer_name: str, ylabel: str):
[docs]
self.buffer: MatrixBuffer
[docs]
self.buffer_name = buffer_name
[docs]
self.parent: PlotWorker
[docs]
self.axes: matplotlib.axes.Axes
[docs]
def setup(self):
""" Called after the parent process is started """
# Buffer
self.buffer = MatrixBuffer(
create=False,
name=self.buffer_name,
locks=self.parent.locks
)
# Format plot
self.axes.set_ylabel(self.ylabel)
[docs]
def set_parent(self, parent: PlotWorker):
self.parent = parent
[docs]
def set_axes(self, axes: matplotlib.axes.Axes):
self.axes = axes
@abstractmethod
[docs]
class TracksTimeSeriesPlot(TimeSeriesPlotBase):
def __init__(self, axis_name: str):
super().__init__('TracksBuffer', ylabel=axis_name+' (nm)')
[docs]
self.axis_name = axis_name
[docs]
self.axis_index = ['X', 'Y', 'Z'].index(axis_name) + 1
[docs]
self.line: matplotlib.lines.Line2D
[docs]
def setup(self):
super().setup()
self.line, = self.axes.plot([], [], 'r')
[docs]
def update(self):
# Get selected and reference bead
sel = self.parent.selected_bead
ref = self.parent.reference_bead
if ref == -1:
ref = None
# Get data from buffer
data = self.parent._tracks_snapshot
if data is None:
data = self.buffer.peak_unsorted()
data = data[np.argsort(data[:, 0], kind='stable')]
t = data[:, 0]
b = data[:, 4]
v = data[:, self.axis_index]
# Get selected bead values
selection = b == sel
t_sel = t[selection]
v_sel = v[selection]
# Subtract reference bead values
if ref is None:
t = t_sel
v = v_sel
else:
# Get reference bead values
selection = b == ref
t_ref = t[selection]
v_ref = v[selection]
if np.unique(t_sel).size != t_sel.size or np.unique(t_ref).size != t_ref.size:
warnings.warn(
'Duplicate timestamps detected while plotting referenced bead tracks.',
RuntimeWarning,
stacklevel=2,
)
try:
# Get values where selected bead and reference bead share the same timepoints.
t, index_sel, index_ref = np.intersect1d(
t_sel,
t_ref,
assume_unique=True,
return_indices=True,
)
v = v_sel[index_sel] - v_ref[index_ref]
except Exception as exc:
warnings.warn(
f'Skipping referenced bead track plot update: {exc}',
RuntimeWarning,
stacklevel=2,
)
self.line.set_xdata([])
self.line.set_ydata([])
self.axes.relim()
self.axes.autoscale_view()
return
# Correct for ZLUT upsidedown order
if self.axis_name == 'Z':
v *= -1
# Remove nan/inf
selection = np.isfinite(t)
t = t[selection]
v = v[selection]
ymin = self.parent.limits.get(self.ylabel, (None, None))[0]
ymax = self.parent.limits.get(self.ylabel, (None, None))[1]
ymin_limit = ymin if ymin is not None else -np.inf
ymax_limit = ymax if ymax is not None else np.inf
if self.parent.time_mode == "relative":
if t.size == 0:
self.line.set_xdata([])
self.line.set_ydata([])
self.axes.relim()
self.axes.autoscale_view()
return
window = self.parent.relative_window_seconds
t_max = np.max(t)
xmin_value = t_max - window if window else np.min(t)
selection = t >= xmin_value
t = t[selection]
v = v[selection]
selection = (ymin_limit <= v) & (v <= ymax_limit)
t = t[selection]
v = v[selection]
t_relative = t - xmin_value
xmin = 0
xmax = window if window else None
xdata = t_relative
else:
xmin = self.parent.limits.get('Time', (None, None))[0]
xmax = self.parent.limits.get('Time', (None, None))[1]
xmin_limit = xmin if xmin is not None else -np.inf
xmax_limit = xmax if xmax is not None else np.inf
selection = (xmin_limit <= t) & (t <= xmax_limit)
selection &= (ymin_limit <= v) & (v <= ymax_limit)
t = t[selection]
v = v[selection]
xdata = [datetime.fromtimestamp(t_) for t_ in t]
self.line.set_xdata(xdata)
self.line.set_ydata(v)
if xmin is not None and xmin == xmax:
xmax = xmin + 1
if ymin is not None and ymin == ymax:
ymax = ymin + 1
if xmin is None or xmax is None:
self.axes.xaxis.set_inverted(False)
if ymin is None or ymax is None:
self.axes.yaxis.set_inverted(False)
if self.parent.time_mode == "absolute":
xmin, xmax = [datetime.fromtimestamp(t_) if t_ else None for t_ in (xmin, xmax)]
self.axes.autoscale()
self.axes.autoscale_view()
self.axes.set_xlim(xmin=xmin, xmax=xmax)
self.axes.set_ylim(ymin=ymin, ymax=ymax)
self.axes.relim()