Skip to content

Commit cbb42d1

Browse files
committed
Simplify scatter code
1 parent 58223ca commit cbb42d1

File tree

3 files changed

+66
-69
lines changed

3 files changed

+66
-69
lines changed

CHANGELOG.rst

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
0.0.2
1+
0.4.0
22
=====
33

4-
New features
5-
------------
6-
- `HistogramWidget` now shows individual histograms for RGB channels when
7-
present.
8-
9-
10-
Bug fixes
11-
---------
12-
- `HistogramWidget` now works properly with 2D images.
4+
Changes
5+
-------
6+
- The scatter widgets no longer use a LogNorm() for 2D histogram scaling.
7+
This is to move the widget in line with the philosophy of using Matplotlib default
8+
settings throughout ``napari-matplotlib``. This still leaves open the option of
9+
adding the option to change the normalization in the future. If this is something
10+
you would be interested in please open an issue at https://github.com/matplotlib/napari-matplotlib.
11+
- Labels plotting with the features scatter widget no longer have underscores
12+
replaced with spaces.

src/napari_matplotlib/scatter.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, List, Optional, Tuple
22

3-
import matplotlib.colors as mcolor
43
import napari
54
import numpy.typing as npt
65
from magicgui import magicgui
@@ -17,15 +16,8 @@ class ScatterBaseWidget(NapariMPLWidget):
1716
Base class for widgets that scatter two datasets against each other.
1817
"""
1918

20-
# opacity value for the markers
21-
_marker_alpha = 0.5
22-
23-
# flag set to True if histogram should be used
24-
# for plotting large points
25-
_histogram_for_large_data = True
26-
2719
# if the number of points is greater than this value,
28-
# the scatter is plotted as a 2dhist
20+
# the scatter is plotted as a 2D histogram
2921
_threshold_to_switch_to_histogram = 500
3022

3123
def __init__(self, napari_viewer: napari.viewer.Viewer):
@@ -44,40 +36,32 @@ def draw(self) -> None:
4436
"""
4537
Scatter the currently selected layers.
4638
"""
47-
data, x_axis_name, y_axis_name = self._get_data()
48-
49-
if len(data) == 0:
50-
# don't plot if there isn't data
51-
return
39+
x, y, x_axis_name, y_axis_name = self._get_data()
5240

53-
if self._histogram_for_large_data and (
54-
data[0].size > self._threshold_to_switch_to_histogram
55-
):
41+
if x.size > self._threshold_to_switch_to_histogram:
5642
self.axes.hist2d(
57-
data[0].ravel(),
58-
data[1].ravel(),
43+
x.ravel(),
44+
y.ravel(),
5945
bins=100,
60-
norm=mcolor.LogNorm(),
6146
)
6247
else:
63-
self.axes.scatter(data[0], data[1], alpha=self._marker_alpha)
48+
self.axes.scatter(x, y, alpha=0.5)
6449

6550
self.axes.set_xlabel(x_axis_name)
6651
self.axes.set_ylabel(y_axis_name)
6752

68-
def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
69-
"""Get the plot data.
53+
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
54+
"""
55+
Get the plot data.
7056
7157
This must be implemented on the subclass.
7258
7359
Returns
7460
-------
75-
data : np.ndarray
76-
The list containing the scatter plot data.
77-
x_axis_name : str
78-
The label to display on the x axis
79-
y_axis_name: str
80-
The label to display on the y axis
61+
x, y : np.ndarray
62+
x and y values of plot data.
63+
x_axis_name, y_axis_name : str
64+
Label to display on the x/y axis
8165
"""
8266
raise NotImplementedError
8367

@@ -93,7 +77,7 @@ class ScatterWidget(ScatterBaseWidget):
9377
n_layers_input = Interval(2, 2)
9478
input_layer_types = (napari.layers.Image,)
9579

96-
def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
80+
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
9781
"""
9882
Get the plot data.
9983
@@ -106,11 +90,12 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
10690
y_axis_name: str
10791
The title to display on the y axis
10892
"""
109-
data = [layer.data[self.current_z] for layer in self.layers]
93+
x = self.layers[0].data[self.current_z]
94+
y = self.layers[1].data[self.current_z]
11095
x_axis_name = self.layers[0].name
11196
y_axis_name = self.layers[1].name
11297

113-
return data, x_axis_name, y_axis_name
98+
return x, y, x_axis_name, y_axis_name
11499

115100

116101
class FeaturesScatterWidget(ScatterBaseWidget):
@@ -191,9 +176,33 @@ def _get_valid_axis_keys(
191176
else:
192177
return self.layers[0].features.keys()
193178

194-
def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
179+
def _ready_to_scatter(self) -> bool:
195180
"""
196-
Get the plot data.
181+
Return True if selected layer has a feature table we can scatter with,
182+
and the two columns to be scatterd have been selected.
183+
"""
184+
if not hasattr(self.layers[0], "features"):
185+
return False
186+
187+
feature_table = self.layers[0].features
188+
return (
189+
feature_table is not None
190+
and len(feature_table) > 0
191+
and self.x_axis_key is not None
192+
and self.y_axis_key is not None
193+
)
194+
195+
def draw(self) -> None:
196+
"""
197+
Scatter two features from the currently selected layer.
198+
"""
199+
if self._ready_to_scatter():
200+
super().draw()
201+
202+
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
203+
"""
204+
Get the plot data from the ``features`` attribute of the first
205+
selected layer.
197206
198207
Returns
199208
-------
@@ -207,28 +216,15 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
207216
The title to display on the y axis. Returns
208217
an empty string if nothing to plot.
209218
"""
210-
if not hasattr(self.layers[0], "features"):
211-
# if the selected layer doesn't have a featuretable,
212-
# skip draw
213-
return [], "", ""
214-
215219
feature_table = self.layers[0].features
216220

217-
if (
218-
(len(feature_table) == 0)
219-
or (self.x_axis_key is None)
220-
or (self.y_axis_key is None)
221-
):
222-
return [], "", ""
223-
224-
data_x = feature_table[self.x_axis_key]
225-
data_y = feature_table[self.y_axis_key]
226-
data = [data_x, data_y]
221+
x = feature_table[self.x_axis_key]
222+
y = feature_table[self.y_axis_key]
227223

228-
x_axis_name = self.x_axis_key.replace("_", " ")
229-
y_axis_name = self.y_axis_key.replace("_", " ")
224+
x_axis_name = str(self.x_axis_key)
225+
y_axis_name = str(self.y_axis_key)
230226

231-
return data, x_axis_name, y_axis_name
227+
return x, y, x_axis_name, y_axis_name
232228

233229
def _on_update_layers(self) -> None:
234230
"""

src/napari_matplotlib/tests/test_scatter.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def make_labels_layer_with_features() -> (
3939

4040

4141
def test_features_scatter_get_data(make_napari_viewer):
42-
"""Test the get data method"""
42+
"""
43+
Test the get data method.
44+
"""
4345
# make the label image
4446
label_image, feature_table = make_labels_layer_with_features()
4547

@@ -55,17 +57,16 @@ def test_features_scatter_get_data(make_napari_viewer):
5557
y_column = "feature_2"
5658
scatter_widget.y_axis_key = y_column
5759

58-
data, x_axis_name, y_axis_name = scatter_widget._get_data()
59-
np.testing.assert_allclose(
60-
data, np.stack((feature_table[x_column], feature_table[y_column]))
61-
)
60+
x, y, x_axis_name, y_axis_name = scatter_widget._get_data()
61+
np.testing.assert_allclose(x, feature_table[x_column])
62+
np.testing.assert_allclose(y, np.stack(feature_table[y_column]))
6263
assert x_axis_name == x_column.replace("_", " ")
6364
assert y_axis_name == y_column.replace("_", " ")
6465

6566

6667
def test_get_valid_axis_keys(make_napari_viewer):
67-
"""Test the values returned from
68-
FeaturesScatterWidget._get_valid_keys() when there
68+
"""
69+
Test the values returned from _get_valid_keys() when there
6970
are valid keys.
7071
"""
7172
# make the label image

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy