Skip to content

Commit 62a96ff

Browse files
committed
TST: Calculate RMS and diff image in C++
The current implementation is not slow, but uses a lot of memory per image. In `compare_images`, we have: - one actual and one expected image as uint8 (2×image) - both converted to int16 (though original is thrown away) (4×) which adds up to 4× the image allocated in this function. Then it calls `calculate_rms`, which has: - a difference between them as int16 (2×) - the difference cast to 64-bit float (8×) - the square of the difference as 64-bit float (though possibly the original difference was thrown away) (8×) which at its peak has 16× the image allocated in parallel. If the RMS is over the desired tolerance, then `save_diff_image` is called, which: - loads the actual and expected images _again_ as uint8 (2× image) - converts both to 64-bit float (throwing away the original) (16×) - calculates the difference (8×) - calculates the absolute value (8×) - multiples that by 10 (in-place, so no allocation) - clips to 0-255 (8×) - casts to uint8 (1×) which at peak uses 32× the image. So at their peak, `compare_images`→`calculate_rms` will have 20× the image allocated, and then `compare_images`→`save_diff_image` will have 36× the image allocated. This is generally not a problem, but on resource-constrained places like WASM, it can sometimes run out of memory just in `calculate_rms`. This implementation in C++ always allocates the diff image, even when not needed, but doesn't have all the temporaries, so it's a maximum of 3× the image size (plus a few scalar temporaries).
1 parent 38a8e15 commit 62a96ff

File tree

2 files changed

+79
-9
lines changed

2 files changed

+79
-9
lines changed

lib/matplotlib/testing/compare.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from PIL import Image
2020

2121
import matplotlib as mpl
22-
from matplotlib import cbook
22+
from matplotlib import cbook, _image
2323
from matplotlib.testing.exceptions import ImageComparisonFailure
2424

2525
_log = logging.getLogger(__name__)
@@ -398,7 +398,7 @@ def compare_images(expected, actual, tol, in_decorator=False):
398398
399399
The two given filenames may point to files which are convertible to
400400
PNG via the `.converter` dictionary. The underlying RMS is calculated
401-
with the `.calculate_rms` function.
401+
in a similar way to the `.calculate_rms` function.
402402
403403
Parameters
404404
----------
@@ -469,17 +469,12 @@ def compare_images(expected, actual, tol, in_decorator=False):
469469
if np.array_equal(expected_image, actual_image):
470470
return None
471471

472-
# convert to signed integers, so that the images can be subtracted without
473-
# overflow
474-
expected_image = expected_image.astype(np.int16)
475-
actual_image = actual_image.astype(np.int16)
476-
477-
rms = calculate_rms(expected_image, actual_image)
472+
rms, abs_diff = _image.calculate_rms_and_diff(expected_image, actual_image)
478473

479474
if rms <= tol:
480475
return None
481476

482-
save_diff_image(expected, actual, diff_image)
477+
Image.fromarray(abs_diff).save(diff_image, format="png")
483478

484479
results = dict(rms=rms, expected=str(expected),
485480
actual=str(actual), diff=str(diff_image), tol=tol)

src/_image_wrapper.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <pybind11/pybind11.h>
22
#include <pybind11/numpy.h>
33

4+
#include <algorithm>
5+
46
#include "_image_resample.h"
57
#include "py_converters.h"
68

@@ -200,6 +202,76 @@ image_resample(py::array input_array,
200202
}
201203

202204

205+
// This is used by matplotlib.testing.compare to calculate RMS and a difference image.
206+
static py::tuple
207+
calculate_rms_and_diff(py::array_t<unsigned char> expected_image,
208+
py::array_t<unsigned char> actual_image)
209+
{
210+
if (expected_image.ndim() != 3) {
211+
auto exceptions = py::module_::import("matplotlib.testing.exceptions");
212+
auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure");
213+
py::set_error(
214+
ImageComparisonFailure,
215+
"Expected image must be 3-dimensional, but is {ndim}-dimensional"_s.format(
216+
"ndim"_a=expected_image.ndim()));
217+
throw py::error_already_set();
218+
}
219+
220+
if (actual_image.ndim() != 3) {
221+
auto exceptions = py::module_::import("matplotlib.testing.exceptions");
222+
auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure");
223+
py::set_error(
224+
ImageComparisonFailure,
225+
"Actual image must be 3-dimensional, but is {ndim}-dimensional"_s.format(
226+
"ndim"_a=actual_image.ndim()));
227+
throw py::error_already_set();
228+
}
229+
230+
auto height = expected_image.shape(0);
231+
auto width = expected_image.shape(1);
232+
auto depth = expected_image.shape(2);
233+
234+
if (height != actual_image.shape(0) || width != actual_image.shape(1) ||
235+
depth != actual_image.shape(2)) {
236+
auto exceptions = py::module_::import("matplotlib.testing.exceptions");
237+
auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure");
238+
py::set_error(
239+
ImageComparisonFailure,
240+
"Image sizes do not match expected size: {expected_image.shape} "_s
241+
"actual size {actual_image.shape}"_s.format(
242+
"expected_image"_a=expected_image, "actual_image"_a=actual_image));
243+
throw py::error_already_set();
244+
}
245+
auto expected = expected_image.unchecked<3>();
246+
auto actual = actual_image.unchecked<3>();
247+
248+
py::ssize_t diff_dims[3] = {height, width, 3};
249+
py::array_t<unsigned char> diff_image(diff_dims);
250+
auto diff = diff_image.mutable_unchecked<3>();
251+
252+
double total = 0.0;
253+
for (auto i = 0; i < height; i++) {
254+
for (auto j = 0; j < width; j++) {
255+
for (auto k = 0; k < depth; k++) {
256+
auto pixel_diff = static_cast<double>(expected(i, j, k)) -
257+
static_cast<double>(actual(i, j, k));
258+
259+
total += pixel_diff*pixel_diff;
260+
261+
if (k != 3) { // Hard-code a fully solid alpha channel by omitting it.
262+
diff(i, j, k) = static_cast<unsigned char>(std::clamp(
263+
abs(pixel_diff) * 10, // Expand differences in luminance domain.
264+
0.0, 255.0));
265+
}
266+
}
267+
}
268+
}
269+
total = total / (width * height * depth);
270+
271+
return py::make_tuple(sqrt(total), diff_image);
272+
}
273+
274+
203275
PYBIND11_MODULE(_image, m, py::mod_gil_not_used())
204276
{
205277
py::enum_<interpolation_e>(m, "_InterpolationType")
@@ -232,4 +304,7 @@ PYBIND11_MODULE(_image, m, py::mod_gil_not_used())
232304
"norm"_a = false,
233305
"radius"_a = 1,
234306
image_resample__doc__);
307+
308+
m.def("calculate_rms_and_diff", &calculate_rms_and_diff,
309+
"expected_image"_a, "actual_image"_a);
235310
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy