diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index caf38d16..b78e43e1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,11 +6,6 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/asottile/setup-cfg-fmt - rev: v2.4.0 - hooks: - - id: setup-cfg-fmt - - repo: https://github.com/psf/black rev: 23.7.0 hooks: @@ -22,14 +17,14 @@ repos: - id: napari-plugin-checks - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.4.1 + rev: v1.5.1 hooks: - id: mypy additional_dependencies: [numpy, matplotlib] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.0.280' + rev: 'v0.0.285' hooks: - id: ruff diff --git a/baseline/test_feature_histogram2.png b/baseline/test_feature_histogram2.png new file mode 100644 index 00000000..b7bb19e0 Binary files /dev/null and b/baseline/test_feature_histogram2.png differ diff --git a/docs/changelog.rst b/docs/changelog.rst index 103cd80e..6f77e0c3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,5 +1,23 @@ Changelog ========= +1.1.0 +----- +Additions +~~~~~~~~~ +- Added a widget to draw a histogram of features. + +Changes +~~~~~~~ +- The slice widget is now limited to slicing along the x/y dimensions. Support + for slicing along z has been removed for now to make the code simpler. +- The slice widget now uses a slider to select the slice value. + +Bug fixes +~~~~~~~~~ +- Fixed creating 1D slices of 2D images. +- Removed the limitation that only the first 99 indices could be sliced using + the slice widget. + 1.0.2 ----- Bug fixes diff --git a/docs/user_guide.rst b/docs/user_guide.rst index 0872e540..fbd48db1 100644 --- a/docs/user_guide.rst +++ b/docs/user_guide.rst @@ -30,6 +30,7 @@ These widgets plot the data stored in the ``.features`` attribute of individual Currently available are: - 2D scatter plots of two features against each other. +- Histograms of individual features. To use these: diff --git a/setup.cfg b/setup.cfg index 229b5777..e1fc9e73 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,7 +55,7 @@ docs = sphinx-automodapi sphinx-gallery testing = - napari[pyqt6-experimental]>=0.4.18 + napari[pyqt6_experimental]>=0.4.18 pooch pyqt6 pytest diff --git a/src/napari_matplotlib/features.py b/src/napari_matplotlib/features.py new file mode 100644 index 00000000..34abf104 --- /dev/null +++ b/src/napari_matplotlib/features.py @@ -0,0 +1,9 @@ +from napari.layers import Labels, Points, Shapes, Tracks, Vectors + +FEATURES_LAYER_TYPES = ( + Labels, + Points, + Shapes, + Tracks, + Vectors, +) diff --git a/src/napari_matplotlib/histogram.py b/src/napari_matplotlib/histogram.py index ab098f38..66aa7acc 100644 --- a/src/napari_matplotlib/histogram.py +++ b/src/napari_matplotlib/histogram.py @@ -1,13 +1,15 @@ -from typing import Optional +from typing import Any, List, Optional, Tuple import napari import numpy as np -from qtpy.QtWidgets import QWidget +import numpy.typing as npt +from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget from .base import SingleAxesWidget +from .features import FEATURES_LAYER_TYPES from .util import Interval -__all__ = ["HistogramWidget"] +__all__ = ["HistogramWidget", "FeaturesHistogramWidget"] _COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"} @@ -61,3 +63,112 @@ def draw(self) -> None: self.axes.hist(data.ravel(), bins=bins, label=layer.name) self.axes.legend() + + +class FeaturesHistogramWidget(SingleAxesWidget): + """ + Display a histogram of selected feature attached to selected layer. + """ + + n_layers_input = Interval(1, 1) + # All layers that have a .features attributes + input_layer_types = FEATURES_LAYER_TYPES + + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(napari_viewer, parent=parent) + + self.layout().addLayout(QVBoxLayout()) + self._key_selection_widget = QComboBox() + self.layout().addWidget(QLabel("Key:")) + self.layout().addWidget(self._key_selection_widget) + + self._key_selection_widget.currentTextChanged.connect( + self._set_axis_keys + ) + + self._update_layers(None) + + @property + def x_axis_key(self) -> Optional[str]: + """Key to access x axis data from the FeaturesTable""" + return self._x_axis_key + + @x_axis_key.setter + def x_axis_key(self, key: Optional[str]) -> None: + self._x_axis_key = key + self._draw() + + def _set_axis_keys(self, x_axis_key: str) -> None: + """Set both axis keys and then redraw the plot""" + self._x_axis_key = x_axis_key + self._draw() + + def _get_valid_axis_keys(self) -> List[str]: + """ + Get the valid axis keys from the layer FeatureTable. + + Returns + ------- + axis_keys : List[str] + The valid axis keys in the FeatureTable. If the table is empty + or there isn't a table, returns an empty list. + """ + if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")): + return [] + else: + return self.layers[0].features.keys() + + def _get_data(self) -> Tuple[Optional[npt.NDArray[Any]], str]: + """Get the plot data. + + Returns + ------- + data : List[np.ndarray] + List contains X and Y columns from the FeatureTable. Returns + an empty array if nothing to plot. + x_axis_name : str + The title to display on the x axis. Returns + an empty string if nothing to plot. + """ + if not hasattr(self.layers[0], "features"): + # if the selected layer doesn't have a featuretable, + # skip draw + return None, "" + + feature_table = self.layers[0].features + + if (len(feature_table) == 0) or (self.x_axis_key is None): + return None, "" + + data = feature_table[self.x_axis_key] + x_axis_name = self.x_axis_key.replace("_", " ") + + return data, x_axis_name + + def on_update_layers(self) -> None: + """ + Called when the layer selection changes by ``self.update_layers()``. + """ + # reset the axis keys + self._x_axis_key = None + + # Clear combobox + self._key_selection_widget.clear() + self._key_selection_widget.addItems(self._get_valid_axis_keys()) + + def draw(self) -> None: + """Clear the axes and histogram the currently selected layer/slice.""" + data, x_axis_name = self._get_data() + + if data is None: + return + + self.axes.hist(data, bins=50, edgecolor="white", linewidth=0.3) + + # set ax labels + self.axes.set_xlabel(x_axis_name) + self.axes.set_ylabel("Counts [#]") diff --git a/src/napari_matplotlib/napari.yaml b/src/napari_matplotlib/napari.yaml index b736592b..71af0ca6 100644 --- a/src/napari_matplotlib/napari.yaml +++ b/src/napari_matplotlib/napari.yaml @@ -14,6 +14,10 @@ contributions: python_name: napari_matplotlib:FeaturesScatterWidget title: Make a scatter plot of layer features + - id: napari-matplotlib.features_histogram + python_name: napari_matplotlib:FeaturesHistogramWidget + title: Plot feature histograms + - id: napari-matplotlib.slice python_name: napari_matplotlib:SliceWidget title: Plot a 1D slice @@ -28,5 +32,8 @@ contributions: - command: napari-matplotlib.features_scatter display_name: FeaturesScatter + - command: napari-matplotlib.features_histogram + display_name: FeaturesHistogram + - command: napari-matplotlib.slice display_name: 1D slice diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index db86c7f3..a4148bd2 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -5,6 +5,7 @@ from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget from .base import SingleAxesWidget +from .features import FEATURES_LAYER_TYPES from .util import Interval __all__ = ["ScatterBaseWidget", "ScatterWidget", "FeaturesScatterWidget"] @@ -94,13 +95,7 @@ class FeaturesScatterWidget(ScatterBaseWidget): n_layers_input = Interval(1, 1) # All layers that have a .features attributes - input_layer_types = ( - napari.layers.Labels, - napari.layers.Points, - napari.layers.Shapes, - napari.layers.Tracks, - napari.layers.Vectors, - ) + input_layer_types = FEATURES_LAYER_TYPES def __init__( self, diff --git a/src/napari_matplotlib/slice.py b/src/napari_matplotlib/slice.py index e3aa80b2..393f2e45 100644 --- a/src/napari_matplotlib/slice.py +++ b/src/napari_matplotlib/slice.py @@ -1,19 +1,23 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, List, Optional, Tuple import matplotlib.ticker as mticker import napari import numpy as np import numpy.typing as npt -from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox, QWidget +from qtpy.QtCore import Qt +from qtpy.QtWidgets import ( + QComboBox, + QLabel, + QSlider, + QVBoxLayout, + QWidget, +) from .base import SingleAxesWidget from .util import Interval __all__ = ["SliceWidget"] -_dims_sel = ["x", "y"] -_dims = ["x", "y", "z"] - class SliceWidget(SingleAxesWidget): """ @@ -31,28 +35,46 @@ def __init__( # Setup figure/axes super().__init__(napari_viewer, parent=parent) - button_layout = QHBoxLayout() - self.layout().addLayout(button_layout) - self.dim_selector = QComboBox() + self.dim_selector.addItems(["x", "y"]) + + self.slice_selector = QSlider(orientation=Qt.Orientation.Horizontal) + + # Create widget layout + button_layout = QVBoxLayout() button_layout.addWidget(QLabel("Slice axis:")) button_layout.addWidget(self.dim_selector) - self.dim_selector.addItems(_dims) - - self.slice_selectors = {} - for d in _dims_sel: - self.slice_selectors[d] = QSpinBox() - button_layout.addWidget(QLabel(f"{d}:")) - button_layout.addWidget(self.slice_selectors[d]) + button_layout.addWidget(self.slice_selector) + self.layout().addLayout(button_layout) # Setup callbacks - # Re-draw when any of the combon/spin boxes are updated + # Re-draw when any of the combo/slider is updated self.dim_selector.currentTextChanged.connect(self._draw) - for d in _dims_sel: - self.slice_selectors[d].textChanged.connect(self._draw) + self.slice_selector.valueChanged.connect(self._draw) self._update_layers(None) + def on_update_layers(self) -> None: + """ + Called when layer selection is updated. + """ + if not len(self.layers): + return + if self.current_dim_name == "x": + max = self._layer.data.shape[-2] + elif self.current_dim_name == "y": + max = self._layer.data.shape[-1] + else: + raise RuntimeError("dim name must be x or y") + self.slice_selector.setRange(0, max - 1) + + @property + def _slice_width(self) -> int: + """ + Width of the slice being plotted. + """ + return self._layer.data.shape[self.current_dim_index] + @property def _layer(self) -> napari.layers.Layer: """ @@ -61,7 +83,7 @@ def _layer(self) -> napari.layers.Layer: return self.layers[0] @property - def current_dim(self) -> str: + def current_dim_name(self) -> str: """ Currently selected slice dimension. """ @@ -74,36 +96,40 @@ def current_dim_index(self) -> int: """ # Note the reversed list because in napari the z-axis is the first # numpy axis - return _dims[::-1].index(self.current_dim) + return self._dim_names.index(self.current_dim_name) @property - def _selector_values(self) -> Dict[str, int]: + def _dim_names(self) -> List[str]: """ - Values of the slice selectors. + List of dimension names. This is a property as it varies depending on the + dimensionality of the currently selected data. """ - return {d: self.slice_selectors[d].value() for d in _dims_sel} + if self._layer.data.ndim == 2: + return ["y", "x"] + elif self._layer.data.ndim == 3: + return ["z", "y", "x"] + else: + raise RuntimeError("Don't know how to handle ndim != 2 or 3") def _get_xy(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any]]: """ Get data for plotting. """ - x = np.arange(self._layer.data.shape[self.current_dim_index]) - - vals = self._selector_values - vals.update({"z": self.current_z}) + val = self.slice_selector.value() slices = [] - for d in _dims: - if d == self.current_dim: + for dim_name in self._dim_names: + if dim_name == self.current_dim_name: # Select all data along this axis slices.append(slice(None)) + elif dim_name == "z": + # Only select the currently viewed z-index + slices.append(slice(self.current_z, self.current_z + 1)) else: # Select specific index - val = vals[d] slices.append(slice(val, val + 1)) - # Reverse since z is the first axis in napari - slices = slices[::-1] + x = np.arange(self._slice_width) y = self._layer.data[tuple(slices)].ravel() return x, y @@ -115,7 +141,7 @@ def draw(self) -> None: x, y = self._get_xy() self.axes.plot(x, y) - self.axes.set_xlabel(self.current_dim) + self.axes.set_xlabel(self.current_dim_name) self.axes.set_title(self._layer.name) # Make sure all ticks lie on integer values self.axes.xaxis.set_major_locator( diff --git a/src/napari_matplotlib/tests/baseline/test_feature_histogram.png b/src/napari_matplotlib/tests/baseline/test_feature_histogram.png new file mode 100644 index 00000000..1892af44 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_feature_histogram.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png b/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png new file mode 100644 index 00000000..b7bb19e0 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_slice_2D.png b/src/napari_matplotlib/tests/baseline/test_slice_2D.png index 639f25b8..ee3ce3b6 100644 Binary files a/src/napari_matplotlib/tests/baseline/test_slice_2D.png and b/src/napari_matplotlib/tests/baseline/test_slice_2D.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_slice_3D.png b/src/napari_matplotlib/tests/baseline/test_slice_3D.png index 43c8c3b6..c30211da 100644 Binary files a/src/napari_matplotlib/tests/baseline/test_slice_3D.png and b/src/napari_matplotlib/tests/baseline/test_slice_3D.png differ diff --git a/src/napari_matplotlib/tests/test_histogram.py b/src/napari_matplotlib/tests/test_histogram.py index 4d170014..006c042f 100644 --- a/src/napari_matplotlib/tests/test_histogram.py +++ b/src/napari_matplotlib/tests/test_histogram.py @@ -1,8 +1,13 @@ from copy import deepcopy +import numpy as np import pytest -from napari_matplotlib import HistogramWidget +from napari_matplotlib import FeaturesHistogramWidget, HistogramWidget +from napari_matplotlib.tests.helpers import ( + assert_figures_equal, + assert_figures_not_equal, +) @pytest.mark.mpl_image_compare @@ -28,3 +33,90 @@ def test_histogram_3D(make_napari_viewer, brain_data): # Need to return a copy, as original figure is too eagerley garbage # collected by the widget return deepcopy(fig) + + +def test_feature_histogram(make_napari_viewer): + n_points = 1000 + random_points = np.random.random((n_points, 3)) * 10 + feature1 = np.random.random(n_points) + feature2 = np.random.normal(size=n_points) + + viewer = make_napari_viewer() + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points1", + ) + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points2", + ) + + widget = FeaturesHistogramWidget(viewer) + viewer.window.add_dock_widget(widget) + + # Check whether changing the selected key changes the plot + widget._set_axis_keys("feature1") + fig1 = deepcopy(widget.figure) + + widget._set_axis_keys("feature2") + assert_figures_not_equal(widget.figure, fig1) + + # check whether selecting a different layer produces the same plot + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[1]) + assert_figures_equal(widget.figure, fig1) + + +@pytest.mark.mpl_image_compare +def test_feature_histogram2(make_napari_viewer): + import numpy as np + + np.random.seed(0) + n_points = 1000 + random_points = np.random.random((n_points, 3)) * 10 + feature1 = np.random.random(n_points) + feature2 = np.random.normal(size=n_points) + + viewer = make_napari_viewer() + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points1", + ) + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points2", + ) + + widget = FeaturesHistogramWidget(viewer) + viewer.window.add_dock_widget(widget) + widget._set_axis_keys("feature1") + + fig = FeaturesHistogramWidget(viewer).figure + return deepcopy(fig) + + +def test_change_layer(make_napari_viewer, brain_data, astronaut_data): + viewer = make_napari_viewer() + widget = HistogramWidget(viewer) + + viewer.add_image(brain_data[0], **brain_data[1]) + viewer.add_image(astronaut_data[0], **astronaut_data[1]) + + # Select first layer + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[0]) + fig1 = deepcopy(widget.figure) + + # Re-selecting first layer should produce identical plot + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[0]) + assert_figures_equal(widget.figure, fig1) + + # Plotting the second layer should produce a different plot + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[1]) + assert_figures_not_equal(widget.figure, fig1) diff --git a/src/napari_matplotlib/tests/test_slice.py b/src/napari_matplotlib/tests/test_slice.py index 412e71c3..368a7ded 100644 --- a/src/napari_matplotlib/tests/test_slice.py +++ b/src/napari_matplotlib/tests/test_slice.py @@ -9,9 +9,13 @@ def test_slice_3D(make_napari_viewer, brain_data): viewer = make_napari_viewer() viewer.theme = "light" - viewer.add_image(brain_data[0], **brain_data[1]) + + data = brain_data[0] + assert data.ndim == 3, data.shape + viewer.add_image(data, **brain_data[1]) + axis = viewer.dims.last_used - slice_no = brain_data[0].shape[0] - 1 + slice_no = data.shape[0] - 1 viewer.dims.set_current_step(axis, slice_no) fig = SliceWidget(viewer).figure # Need to return a copy, as original figure is too eagerley garbage @@ -23,8 +27,37 @@ def test_slice_3D(make_napari_viewer, brain_data): def test_slice_2D(make_napari_viewer, astronaut_data): viewer = make_napari_viewer() viewer.theme = "light" - viewer.add_image(astronaut_data[0], **astronaut_data[1]) + + # Take first RGB channel + data = astronaut_data[0][:, :, 0] + assert data.ndim == 2, data.shape + viewer.add_image(data) + fig = SliceWidget(viewer).figure # Need to return a copy, as original figure is too eagerley garbage # collected by the widget return deepcopy(fig) + + +def test_slice_axes(make_napari_viewer, astronaut_data): + viewer = make_napari_viewer() + viewer.theme = "light" + + # Take first RGB channel + data = astronaut_data[0][:256, :, 0] + # Shape: + # x: 0 > 512 + # y: 0 > 256 + assert data.ndim == 2, data.shape + # Make sure data isn't square for later tests + assert data.shape[0] != data.shape[1] + viewer.add_image(data) + + widget = SliceWidget(viewer) + assert widget._dim_names == ["y", "x"] + assert widget.current_dim_name == "x" + assert widget.slice_selector.value() == 0 + assert widget.slice_selector.minimum() == 0 + assert widget.slice_selector.maximum() == data.shape[0] - 1 + # x/y are flipped in napari + assert widget._slice_width == data.shape[1] pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy