From 00cffb990f1432c048e6e8aa6c41f581b27702e2 Mon Sep 17 00:00:00 2001 From: David Stansby Date: Mon, 15 May 2023 11:34:34 +0100 Subject: [PATCH] Simplify scatter code --- CHANGELOG.rst | 20 ++-- src/napari_matplotlib/scatter.py | 100 ++++++++++---------- src/napari_matplotlib/tests/test_scatter.py | 19 ++-- 3 files changed, 68 insertions(+), 71 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index fafd2e0b..f6199857 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,12 +1,12 @@ -0.0.2 +0.4.0 ===== -New features ------------- -- `HistogramWidget` now shows individual histograms for RGB channels when - present. - - -Bug fixes ---------- -- `HistogramWidget` now works properly with 2D images. +Changes +------- +- The scatter widgets no longer use a LogNorm() for 2D histogram scaling. + This is to move the widget in line with the philosophy of using Matplotlib default + settings throughout ``napari-matplotlib``. This still leaves open the option of + adding the option to change the normalization in the future. If this is something + you would be interested in please open an issue at https://github.com/matplotlib/napari-matplotlib. +- Labels plotting with the features scatter widget no longer have underscores + replaced with spaces. diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index cb1e8498..405b7b09 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -1,6 +1,5 @@ from typing import Any, List, Optional, Tuple -import matplotlib.colors as mcolor import napari import numpy.typing as npt from magicgui import magicgui @@ -17,15 +16,8 @@ class ScatterBaseWidget(NapariMPLWidget): Base class for widgets that scatter two datasets against each other. """ - # opacity value for the markers - _marker_alpha = 0.5 - - # flag set to True if histogram should be used - # for plotting large points - _histogram_for_large_data = True - # if the number of points is greater than this value, - # the scatter is plotted as a 2dhist + # the scatter is plotted as a 2D histogram _threshold_to_switch_to_histogram = 500 def __init__(self, napari_viewer: napari.viewer.Viewer): @@ -44,40 +36,32 @@ def draw(self) -> None: """ Scatter the currently selected layers. """ - data, x_axis_name, y_axis_name = self._get_data() - - if len(data) == 0: - # don't plot if there isn't data - return + x, y, x_axis_name, y_axis_name = self._get_data() - if self._histogram_for_large_data and ( - data[0].size > self._threshold_to_switch_to_histogram - ): + if x.size > self._threshold_to_switch_to_histogram: self.axes.hist2d( - data[0].ravel(), - data[1].ravel(), + x.ravel(), + y.ravel(), bins=100, - norm=mcolor.LogNorm(), ) else: - self.axes.scatter(data[0], data[1], alpha=self._marker_alpha) + self.axes.scatter(x, y, alpha=0.5) self.axes.set_xlabel(x_axis_name) self.axes.set_ylabel(y_axis_name) - def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: - """Get the plot data. + def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: + """ + Get the plot data. This must be implemented on the subclass. Returns ------- - data : np.ndarray - The list containing the scatter plot data. - x_axis_name : str - The label to display on the x axis - y_axis_name: str - The label to display on the y axis + x, y : np.ndarray + x and y values of plot data. + x_axis_name, y_axis_name : str + Label to display on the x/y axis """ raise NotImplementedError @@ -93,7 +77,7 @@ class ScatterWidget(ScatterBaseWidget): n_layers_input = Interval(2, 2) input_layer_types = (napari.layers.Image,) - def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: + def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: """ Get the plot data. @@ -106,11 +90,12 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: y_axis_name: str The title to display on the y axis """ - data = [layer.data[self.current_z] for layer in self.layers] + x = self.layers[0].data[self.current_z] + y = self.layers[1].data[self.current_z] x_axis_name = self.layers[0].name y_axis_name = self.layers[1].name - return data, x_axis_name, y_axis_name + return x, y, x_axis_name, y_axis_name class FeaturesScatterWidget(ScatterBaseWidget): @@ -191,9 +176,33 @@ def _get_valid_axis_keys( else: return self.layers[0].features.keys() - def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: + def _ready_to_scatter(self) -> bool: """ - Get the plot data. + Return True if selected layer has a feature table we can scatter with, + and the two columns to be scatterd have been selected. + """ + if not hasattr(self.layers[0], "features"): + return False + + feature_table = self.layers[0].features + return ( + feature_table is not None + and len(feature_table) > 0 + and self.x_axis_key is not None + and self.y_axis_key is not None + ) + + def draw(self) -> None: + """ + Scatter two features from the currently selected layer. + """ + if self._ready_to_scatter(): + super().draw() + + def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: + """ + Get the plot data from the ``features`` attribute of the first + selected layer. Returns ------- @@ -207,28 +216,15 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: The title to display on the y axis. Returns an empty string if nothing to plot. """ - if not hasattr(self.layers[0], "features"): - # if the selected layer doesn't have a featuretable, - # skip draw - return [], "", "" - feature_table = self.layers[0].features - if ( - (len(feature_table) == 0) - or (self.x_axis_key is None) - or (self.y_axis_key is None) - ): - return [], "", "" - - data_x = feature_table[self.x_axis_key] - data_y = feature_table[self.y_axis_key] - data = [data_x, data_y] + x = feature_table[self.x_axis_key] + y = feature_table[self.y_axis_key] - x_axis_name = self.x_axis_key.replace("_", " ") - y_axis_name = self.y_axis_key.replace("_", " ") + x_axis_name = str(self.x_axis_key) + y_axis_name = str(self.y_axis_key) - return data, x_axis_name, y_axis_name + return x, y, x_axis_name, y_axis_name def _on_update_layers(self) -> None: """ diff --git a/src/napari_matplotlib/tests/test_scatter.py b/src/napari_matplotlib/tests/test_scatter.py index fe07655d..88e0584c 100644 --- a/src/napari_matplotlib/tests/test_scatter.py +++ b/src/napari_matplotlib/tests/test_scatter.py @@ -39,7 +39,9 @@ def make_labels_layer_with_features() -> ( def test_features_scatter_get_data(make_napari_viewer): - """Test the get data method""" + """ + Test the get data method. + """ # make the label image label_image, feature_table = make_labels_layer_with_features() @@ -55,17 +57,16 @@ def test_features_scatter_get_data(make_napari_viewer): y_column = "feature_2" scatter_widget.y_axis_key = y_column - data, x_axis_name, y_axis_name = scatter_widget._get_data() - np.testing.assert_allclose( - data, np.stack((feature_table[x_column], feature_table[y_column])) - ) - assert x_axis_name == x_column.replace("_", " ") - assert y_axis_name == y_column.replace("_", " ") + x, y, x_axis_name, y_axis_name = scatter_widget._get_data() + np.testing.assert_allclose(x, feature_table[x_column]) + np.testing.assert_allclose(y, np.stack(feature_table[y_column])) + assert x_axis_name == x_column + assert y_axis_name == y_column def test_get_valid_axis_keys(make_napari_viewer): - """Test the values returned from - FeaturesScatterWidget._get_valid_keys() when there + """ + Test the values returned from _get_valid_keys() when there are valid keys. """ # make the label image pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy