Skip to content

Commit 8b2321d

Browse files
TYP: Type MaskedArray.reshape (#29404)
Co-authored-by: Joren Hammudoglu <jhammudoglu@gmail.com>
1 parent e766381 commit 8b2321d

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed

numpy/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2468,6 +2468,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
24682468
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], _DTypeT_co]: ...
24692469
def ravel(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], _DTypeT_co]: ...
24702470

2471+
# Keep in sync with `MaskedArray.reshape`
24712472
# NOTE: reshape also accepts negative integers, so we can't use integer literals
24722473
@overload # (None)
24732474
def reshape(self, shape: None, /, *, order: _OrderACF = "C", copy: builtins.bool | None = None) -> Self: ...

numpy/ma/core.pyi

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ from typing_extensions import TypeIs, TypeVar
1717

1818
import numpy as np
1919
from numpy import (
20+
_AnyShapeT,
2021
_HasDType,
2122
_HasDTypeWithRealAndImag,
2223
_ModeKind,
24+
_OrderACF,
2325
_OrderKACF,
2426
_PartitionKind,
2527
_SortKind,
@@ -1126,7 +1128,90 @@ class MaskedArray(ndarray[_ShapeT_co, _DTypeT_co]):
11261128
def count(self, axis: _ShapeLike | None, keepdims: Literal[True]) -> NDArray[int_]: ...
11271129

11281130
def ravel(self, order: _OrderKACF = "C") -> MaskedArray[tuple[int], _DTypeT_co]: ...
1129-
def reshape(self, *s, **kwargs): ...
1131+
1132+
# Keep in sync with `ndarray.reshape`
1133+
# NOTE: reshape also accepts negative integers, so we can't use integer literals
1134+
@overload # (None)
1135+
def reshape(self, shape: None, /, *, order: _OrderACF = "C", copy: bool | None = None) -> Self: ...
1136+
@overload # (empty_sequence)
1137+
def reshape( # type: ignore[overload-overlap] # mypy false positive
1138+
self,
1139+
shape: Sequence[Never],
1140+
/,
1141+
*,
1142+
order: _OrderACF = "C",
1143+
copy: bool | None = None,
1144+
) -> MaskedArray[tuple[()], _DTypeT_co]: ...
1145+
@overload # (() | (int) | (int, int) | ....) # up to 8-d
1146+
def reshape(
1147+
self,
1148+
shape: _AnyShapeT,
1149+
/,
1150+
*,
1151+
order: _OrderACF = "C",
1152+
copy: bool | None = None,
1153+
) -> MaskedArray[_AnyShapeT, _DTypeT_co]: ...
1154+
@overload # (index)
1155+
def reshape(
1156+
self,
1157+
size1: SupportsIndex,
1158+
/,
1159+
*,
1160+
order: _OrderACF = "C",
1161+
copy: bool | None = None,
1162+
) -> MaskedArray[tuple[int], _DTypeT_co]: ...
1163+
@overload # (index, index)
1164+
def reshape(
1165+
self,
1166+
size1: SupportsIndex,
1167+
size2: SupportsIndex,
1168+
/,
1169+
*,
1170+
order: _OrderACF = "C",
1171+
copy: bool | None = None,
1172+
) -> MaskedArray[tuple[int, int], _DTypeT_co]: ...
1173+
@overload # (index, index, index)
1174+
def reshape(
1175+
self,
1176+
size1: SupportsIndex,
1177+
size2: SupportsIndex,
1178+
size3: SupportsIndex,
1179+
/,
1180+
*,
1181+
order: _OrderACF = "C",
1182+
copy: bool | None = None,
1183+
) -> MaskedArray[tuple[int, int, int], _DTypeT_co]: ...
1184+
@overload # (index, index, index, index)
1185+
def reshape(
1186+
self,
1187+
size1: SupportsIndex,
1188+
size2: SupportsIndex,
1189+
size3: SupportsIndex,
1190+
size4: SupportsIndex,
1191+
/,
1192+
*,
1193+
order: _OrderACF = "C",
1194+
copy: bool | None = None,
1195+
) -> MaskedArray[tuple[int, int, int, int], _DTypeT_co]: ...
1196+
@overload # (int, *(index, ...))
1197+
def reshape(
1198+
self,
1199+
size0: SupportsIndex,
1200+
/,
1201+
*shape: SupportsIndex,
1202+
order: _OrderACF = "C",
1203+
copy: bool | None = None,
1204+
) -> MaskedArray[_AnyShape, _DTypeT_co]: ...
1205+
@overload # (sequence[index])
1206+
def reshape(
1207+
self,
1208+
shape: Sequence[SupportsIndex],
1209+
/,
1210+
*,
1211+
order: _OrderACF = "C",
1212+
copy: bool | None = None,
1213+
) -> MaskedArray[_AnyShape, _DTypeT_co]: ...
1214+
11301215
def resize(self, newshape: Never, refcheck: bool = True, order: bool = False) -> NoReturn: ...
11311216
def put(self, indices: _ArrayLikeInt_co, values: ArrayLike, mode: _ModeKind = "raise") -> None: ...
11321217
def ids(self) -> tuple[int, int]: ...

numpy/typing/tests/data/reveal/ma.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ AR_LIKE_dt64: list[np.datetime64]
2727
AR_LIKE_o: list[np.object_]
2828
AR_number: NDArray[np.number]
2929

30+
MAR_c8: MaskedArray[np.complex64]
3031
MAR_c16: MaskedArray[np.complex128]
3132
MAR_b: MaskedArray[np.bool]
3233
MAR_f4: MaskedArray[np.float32]
@@ -399,6 +400,13 @@ assert_type(MAR_f8.trace(out=MAR_subclass, dtype=None), MaskedArraySubclass)
399400
assert_type(MAR_f8.round(), MaskedArray[np.float64])
400401
assert_type(MAR_f8.round(out=MAR_subclass), MaskedArraySubclass)
401402

403+
assert_type(MAR_i8.reshape(None), MaskedArray[np.int64])
404+
assert_type(MAR_f8.reshape(-1), np.ma.MaskedArray[tuple[int], np.dtype[np.float64]])
405+
assert_type(MAR_c8.reshape(2, 3, 4, 5), np.ma.MaskedArray[tuple[int, int, int, int], np.dtype[np.complex64]])
406+
assert_type(MAR_td64.reshape(()), np.ma.MaskedArray[tuple[()], np.dtype[np.timedelta64]])
407+
assert_type(MAR_s.reshape([]), np.ma.MaskedArray[tuple[()], np.dtype[np.str_]])
408+
assert_type(MAR_V.reshape((480, 720, 4)), np.ma.MaskedArray[tuple[int, int, int], np.dtype[np.void]])
409+
402410
assert_type(MAR_f8.cumprod(), MaskedArray[Any])
403411
assert_type(MAR_f8.cumprod(out=MAR_subclass), MaskedArraySubclass)
404412

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy