diff --git a/docs/api.rst b/docs/api.rst index dba583af..ae0d78d2 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -8,3 +8,5 @@ plots though the napari user interface. .. automodapi:: napari_matplotlib .. automodapi:: napari_matplotlib.base + +.. automodapi:: napari_matplotlib.features diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index 8c717d6a..5687895e 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -281,7 +281,9 @@ def __init__( napari_viewer: napari.viewer.Viewer, parent: Optional[QWidget] = None, ): - super().__init__(napari_viewer=napari_viewer, parent=parent) + NapariMPLWidget.__init__( + self, napari_viewer=napari_viewer, parent=parent + ) self.add_single_axes() def clear(self) -> None: diff --git a/src/napari_matplotlib/features.py b/src/napari_matplotlib/features.py new file mode 100644 index 00000000..3e1eb9ba --- /dev/null +++ b/src/napari_matplotlib/features.py @@ -0,0 +1,153 @@ +from typing import Any, Dict, List, Optional, Tuple + +import napari +import napari.layers +import numpy as np +import numpy.typing as npt +import pandas as pd +from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout + +from napari_matplotlib.base import NapariMPLWidget +from napari_matplotlib.util import Interval + +__all__ = ["FeaturesMixin"] + + +class FeaturesMixin(NapariMPLWidget): + """ + Mixin to help with widgets that plot data from a features table stored + in a single napari layer. + + This provides: + + - Setup for one or two combo boxes to select features to be plotted. + - An ``on_update_layers()`` callback that updates the combo box options + when the napari layer selection changes. + """ + + n_layers_input = Interval(1, 1) + # All layers that have a .features attributes + input_layer_types = ( + napari.layers.Labels, + napari.layers.Points, + napari.layers.Shapes, + napari.layers.Tracks, + napari.layers.Vectors, + ) + + def __init__(self, *, ndim: int) -> None: + """ + Parameters + ---------- + ndim : int + Number of dimensions that are plotted by the widget. + Must be 1 or 2. + """ + assert ndim in [1, 2] + self.dims = ["x", "y"][:ndim] + # Set up selection boxes + self.layout().addLayout(QVBoxLayout()) + + self._selectors: Dict[str, QComboBox] = {} + for dim in self.dims: + self._selectors[dim] = QComboBox() + # Re-draw when combo boxes are updated + self._selectors[dim].currentTextChanged.connect(self._draw) + + self.layout().addWidget(QLabel(f"{dim}-axis:")) + self.layout().addWidget(self._selectors[dim]) + + def get_key(self, dim: str) -> Optional[str]: + """ + Get key for a given dimension. + + Parameters + ---------- + dim : str + "x" or "y" + """ + if self._selectors[dim].count() == 0: + return None + else: + return self._selectors[dim].currentText() + + def set_key(self, dim: str, value: str) -> None: + """ + Set key for a given dimension. + + Parameters + ---------- + dim : str + "x" or "y" + value : str + Value to set. + """ + assert value in self._get_valid_axis_keys(), ( + "value must be on of the columns " + "of the feature table on the currently seleted layer" + ) + self._selectors[dim].setCurrentText(value) + self._draw() + + def _get_valid_axis_keys(self) -> List[str]: + """ + Get the valid axis keys from the features table column names. + + Returns + ------- + axis_keys : List[str] + The valid axis keys in the FeatureTable. If the table is empty + or there isn't a table, returns an empty list. + """ + if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")): + return [] + else: + return self.layers[0].features.keys() + + def _ready_to_plot(self) -> bool: + """ + Return True if selected layer has a feature table we can plot with, + and the columns to plot have been selected. + """ + if not hasattr(self.layers[0], "features"): + return False + + feature_table = self.layers[0].features + valid_keys = self._get_valid_axis_keys() + return ( + feature_table is not None + and len(feature_table) > 0 + and all([self.get_key(dim) in valid_keys for dim in self.dims]) + ) + + def _get_data_names( + self, + ) -> Tuple[List[npt.NDArray[Any]], List[str]]: + """ + Get the plot data from the ``features`` attribute of the first + selected layer. + + Returns + ------- + data : List[np.ndarray] + List contains X and Y columns from the FeatureTable. Returns + an empty array if nothing to plot. + names : List[str] + Names for each axis. + """ + feature_table: pd.DataFrame = self.layers[0].features + + names = [str(self.get_key(dim)) for dim in self.dims] + data = [np.array(feature_table[key]) for key in names] + return data, names + + def on_update_layers(self) -> None: + """ + Called when the layer selection changes by ``self.update_layers()``. + """ + # Clear combobox + for dim in self.dims: + while self._selectors[dim].count() > 0: + self._selectors[dim].removeItem(0) + # Add keys for newly selected layer + self._selectors[dim].addItems(self._get_valid_axis_keys()) diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 334f941c..4fa45798 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -1,10 +1,11 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Tuple import napari import numpy.typing as npt -from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget +from qtpy.QtWidgets import QWidget from .base import SingleAxesWidget +from .features import FeaturesMixin from .util import Interval __all__ = ["ScatterBaseWidget", "ScatterWidget", "FeaturesScatterWidget"] @@ -85,144 +86,27 @@ def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: return x, y, x_axis_name, y_axis_name -class FeaturesScatterWidget(ScatterBaseWidget): +class FeaturesScatterWidget(ScatterBaseWidget, FeaturesMixin): """ Widget to scatter data stored in two layer feature attributes. """ - n_layers_input = Interval(1, 1) - # All layers that have a .features attributes - input_layer_types = ( - napari.layers.Labels, - napari.layers.Points, - napari.layers.Shapes, - napari.layers.Tracks, - napari.layers.Vectors, - ) - def __init__( self, napari_viewer: napari.viewer.Viewer, parent: Optional[QWidget] = None, ): - super().__init__(napari_viewer, parent=parent) - - self.layout().addLayout(QVBoxLayout()) - - self._selectors: Dict[str, QComboBox] = {} - for dim in ["x", "y"]: - self._selectors[dim] = QComboBox() - # Re-draw when combo boxes are updated - self._selectors[dim].currentTextChanged.connect(self._draw) - - self.layout().addWidget(QLabel(f"{dim}-axis:")) - self.layout().addWidget(self._selectors[dim]) - + ScatterBaseWidget.__init__(self, napari_viewer, parent=parent) + FeaturesMixin.__init__(self, ndim=2) self._update_layers(None) - @property - def x_axis_key(self) -> Union[str, None]: - """ - Key for the x-axis data. - """ - if self._selectors["x"].count() == 0: - return None - else: - return self._selectors["x"].currentText() - - @x_axis_key.setter - def x_axis_key(self, key: str) -> None: - self._selectors["x"].setCurrentText(key) - self._draw() - - @property - def y_axis_key(self) -> Union[str, None]: - """ - Key for the y-axis data. - """ - if self._selectors["y"].count() == 0: - return None - else: - return self._selectors["y"].currentText() - - @y_axis_key.setter - def y_axis_key(self, key: str) -> None: - self._selectors["y"].setCurrentText(key) - self._draw() - - def _get_valid_axis_keys(self) -> List[str]: - """ - Get the valid axis keys from the layer FeatureTable. - - Returns - ------- - axis_keys : List[str] - The valid axis keys in the FeatureTable. If the table is empty - or there isn't a table, returns an empty list. - """ - if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")): - return [] - else: - return self.layers[0].features.keys() - - def _ready_to_scatter(self) -> bool: - """ - Return True if selected layer has a feature table we can scatter with, - and the two columns to be scatterd have been selected. - """ - if not hasattr(self.layers[0], "features"): - return False - - feature_table = self.layers[0].features - valid_keys = self._get_valid_axis_keys() - return ( - feature_table is not None - and len(feature_table) > 0 - and self.x_axis_key in valid_keys - and self.y_axis_key in valid_keys - ) - def draw(self) -> None: """ Scatter two features from the currently selected layer. """ - if self._ready_to_scatter(): + if self._ready_to_plot(): super().draw() def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: - """ - Get the plot data from the ``features`` attribute of the first - selected layer. - - Returns - ------- - data : List[np.ndarray] - List contains X and Y columns from the FeatureTable. Returns - an empty array if nothing to plot. - x_axis_name : str - The title to display on the x axis. Returns - an empty string if nothing to plot. - y_axis_name: str - The title to display on the y axis. Returns - an empty string if nothing to plot. - """ - feature_table = self.layers[0].features - - x = feature_table[self.x_axis_key] - y = feature_table[self.y_axis_key] - - x_axis_name = str(self.x_axis_key) - y_axis_name = str(self.y_axis_key) - - return x, y, x_axis_name, y_axis_name - - def on_update_layers(self) -> None: - """ - Called when the layer selection changes by ``self.update_layers()``. - """ - # Clear combobox - for dim in ["x", "y"]: - while self._selectors[dim].count() > 0: - self._selectors[dim].removeItem(0) - # Add keys for newly selected layer - self._selectors[dim].addItems(self._get_valid_axis_keys()) + data, names = self._get_data_names() + return data[0], data[1], names[0], names[1] diff --git a/src/napari_matplotlib/tests/scatter/test_scatter_features.py b/src/napari_matplotlib/tests/scatter/test_scatter_features.py index c211a064..0b3f7638 100644 --- a/src/napari_matplotlib/tests/scatter/test_scatter_features.py +++ b/src/napari_matplotlib/tests/scatter/test_scatter_features.py @@ -25,8 +25,8 @@ def test_features_scatter_widget_2D( # Select points data and chosen features viewer.layers.selection.add(viewer.layers[0]) # images need to be selected - widget.x_axis_key = "feature_0" - widget.y_axis_key = "feature_1" + widget.set_key("x", "feature_0") + widget.set_key("y", "feature_1") fig = widget.figure @@ -64,9 +64,9 @@ def test_features_scatter_get_data(make_napari_viewer): viewer.layers.selection = [labels_layer] x_column = "feature_0" - scatter_widget.x_axis_key = x_column y_column = "feature_2" - scatter_widget.y_axis_key = y_column + scatter_widget.set_key("x", x_column) + scatter_widget.set_key("y", y_column) x, y, x_axis_name, y_axis_name = scatter_widget._get_data() np.testing.assert_allclose(x, feature_table[x_column]) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy