Skip to content

Commit 5ee3f62

Browse files
committed
ENH: Extend numpy.pad to handle pad_width dictionary argument.
1 parent b89320a commit 5ee3f62

File tree

4 files changed

+53
-5
lines changed

4 files changed

+53
-5
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Extend ``numpy.pad`` to accept a dictionary for the ``pad_width`` argument.

numpy/lib/_arraypad_impl.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
of an n-dimensional array.
44
55
"""
6+
import typing
7+
68
import numpy as np
79
from numpy._core.overrides import array_function_dispatch
810
from numpy.lib._index_tricks_impl import ndindex
@@ -550,14 +552,16 @@ def pad(array, pad_width, mode='constant', **kwargs):
550552
----------
551553
array : array_like of rank N
552554
The array to pad.
553-
pad_width : {sequence, array_like, int}
555+
pad_width : {sequence, array_like, int, dict}
554556
Number of values padded to the edges of each axis.
555557
``((before_1, after_1), ... (before_N, after_N))`` unique pad widths
556558
for each axis.
557559
``(before, after)`` or ``((before, after),)`` yields same before
558560
and after pad for each axis.
559561
``(pad,)`` or ``int`` is a shortcut for before = after = pad width
560562
for all axes.
563+
If a ``dict``, each key is an axis and its corresponding value is an ``int`` or
564+
``int`` pair describing the padding width for that axis.
561565
mode : str or function, optional
562566
One of the following string values or a user supplied function.
563567
@@ -745,8 +749,39 @@ def pad(array, pad_width, mode='constant', **kwargs):
745749
[100, 100, 3, 4, 5, 100, 100],
746750
[100, 100, 100, 100, 100, 100, 100],
747751
[100, 100, 100, 100, 100, 100, 100]])
752+
753+
>>> a = np.arange(2 * 3).reshape(2, 3)
754+
>>> np.pad(a, {1: (1, 2)})
755+
array([[0, 0, 1, 2, 0, 0],
756+
[0, 3, 4, 5, 0, 0]])
757+
>>> np.pad(a, {-1: 2})
758+
array([[0, 0, 0, 1, 2, 0, 0],
759+
[0, 0, 3, 4, 5, 0, 0]])
760+
>>> np.pad(a, {0: (3, 0)})
761+
array([[0, 0, 0],
762+
[0, 0, 0],
763+
[0, 0, 0],
764+
[0, 1, 2],
765+
[3, 4, 5]])
766+
>>> np.pad(a, {0: (3, 0), 1: 2})
767+
array([[0, 0, 0, 0, 0, 0, 0],
768+
[0, 0, 0, 0, 0, 0, 0],
769+
[0, 0, 0, 0, 0, 0, 0],
770+
[0, 0, 0, 1, 2, 0, 0],
771+
[0, 0, 3, 4, 5, 0, 0]])
748772
"""
749773
array = np.asarray(array)
774+
if isinstance(pad_width, dict):
775+
seq = [(0, 0)] * array.ndim
776+
for axis, width in pad_width.items():
777+
match width:
778+
case int(both):
779+
seq[axis] = both, both
780+
case tuple((int(before), int(after))):
781+
seq[axis] = before, after
782+
case _ as invalid:
783+
typing.assert_never(invalid)
784+
pad_width = seq
750785
pad_width = np.asarray(pad_width)
751786

752787
if not pad_width.dtype.kind == 'i':

numpy/lib/_arraypad_impl.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ _ModeKind: TypeAlias = L[
4747
@overload
4848
def pad(
4949
array: _ArrayLike[_ScalarT],
50-
pad_width: _ArrayLikeInt,
50+
pad_width: _ArrayLikeInt | dict,
5151
mode: _ModeKind = ...,
5252
*,
5353
stat_length: _ArrayLikeInt | None = ...,
@@ -58,7 +58,7 @@ def pad(
5858
@overload
5959
def pad(
6060
array: ArrayLike,
61-
pad_width: _ArrayLikeInt,
61+
pad_width: _ArrayLikeInt | dict,
6262
mode: _ModeKind = ...,
6363
*,
6464
stat_length: _ArrayLikeInt | None = ...,
@@ -69,14 +69,14 @@ def pad(
6969
@overload
7070
def pad(
7171
array: _ArrayLike[_ScalarT],
72-
pad_width: _ArrayLikeInt,
72+
pad_width: _ArrayLikeInt | dict,
7373
mode: _ModeFunc,
7474
**kwargs: Any,
7575
) -> NDArray[_ScalarT]: ...
7676
@overload
7777
def pad(
7878
array: ArrayLike,
79-
pad_width: _ArrayLikeInt,
79+
pad_width: _ArrayLikeInt | dict,
8080
mode: _ModeFunc,
8181
**kwargs: Any,
8282
) -> NDArray[Any]: ...

numpy/lib/tests/test_arraypad.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,3 +1413,15 @@ def test_dtype_persistence(dtype, mode):
14131413
arr = np.zeros((3, 2, 1), dtype=dtype)
14141414
result = np.pad(arr, 1, mode=mode)
14151415
assert result.dtype == dtype
1416+
1417+
1418+
@pytest.mark.parametrize("input_shape, pad_width, expected_shape", [
1419+
((3, 4, 5), {-2: (1, 3)}, (3, 4 + 1 + 3, 5)),
1420+
((3, 4, 5), {0: (5, 2)}, (3 + 5 + 2, 4, 5)),
1421+
((3, 4, 5), {0: (5, 2), -1: (3, 4)}, (3 + 5 + 2, 4, 5 + 3 + 4)),
1422+
((3, 4, 5), {1: 5}, (3, 4 + 2 * 5, 5)),
1423+
])
1424+
def test_pad_dict_pad_width(input_shape, pad_width, expected_shape):
1425+
a = np.zeros(input_shape)
1426+
result = np.pad(a, pad_width)
1427+
assert result.shape == expected_shape

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy