Skip to content

Commit 6eeaae2

Browse files
authored
TYP: fix overloads where out: _ArrayT was typed as being the default (#29278)
1 parent bd68380 commit 6eeaae2

File tree

3 files changed

+88
-10
lines changed

3 files changed

+88
-10
lines changed

numpy/__init__.pyi

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,7 +2401,17 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
24012401
axis1: SupportsIndex = ...,
24022402
axis2: SupportsIndex = ...,
24032403
dtype: DTypeLike = ...,
2404-
out: _ArrayT = ...,
2404+
*,
2405+
out: _ArrayT,
2406+
) -> _ArrayT: ...
2407+
@overload
2408+
def trace(
2409+
self, # >= 2D array
2410+
offset: SupportsIndex,
2411+
axis1: SupportsIndex,
2412+
axis2: SupportsIndex,
2413+
dtype: DTypeLike,
2414+
out: _ArrayT,
24052415
) -> _ArrayT: ...
24062416

24072417
@overload
@@ -2425,7 +2435,16 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
24252435
self,
24262436
indices: _ArrayLikeInt_co,
24272437
axis: SupportsIndex | None = ...,
2428-
out: _ArrayT = ...,
2438+
*,
2439+
out: _ArrayT,
2440+
mode: _ModeKind = ...,
2441+
) -> _ArrayT: ...
2442+
@overload
2443+
def take(
2444+
self,
2445+
indices: _ArrayLikeInt_co,
2446+
axis: SupportsIndex | None,
2447+
out: _ArrayT,
24292448
mode: _ModeKind = ...,
24302449
) -> _ArrayT: ...
24312450

@@ -3655,7 +3674,16 @@ class generic(_ArrayOrScalarCommon, Generic[_ItemT_co]):
36553674
self,
36563675
indices: _ArrayLikeInt_co,
36573676
axis: SupportsIndex | None = ...,
3658-
out: _ArrayT = ...,
3677+
*,
3678+
out: _ArrayT,
3679+
mode: _ModeKind = ...,
3680+
) -> _ArrayT: ...
3681+
@overload
3682+
def take(
3683+
self,
3684+
indices: _ArrayLikeInt_co,
3685+
axis: SupportsIndex | None,
3686+
out: _ArrayT,
36593687
mode: _ModeKind = ...,
36603688
) -> _ArrayT: ...
36613689

numpy/_core/multiarray.pyi

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,6 @@ def concatenate( # type: ignore[misc]
528528
casting: _CastingKind | None = ...
529529
) -> NDArray[_ScalarT]: ...
530530
@overload
531-
@overload
532531
def concatenate( # type: ignore[misc]
533532
arrays: SupportsLenAndGetItem[ArrayLike],
534533
/,
@@ -553,7 +552,17 @@ def concatenate(
553552
arrays: SupportsLenAndGetItem[ArrayLike],
554553
/,
555554
axis: SupportsIndex | None = ...,
556-
out: _ArrayT = ...,
555+
*,
556+
out: _ArrayT,
557+
dtype: DTypeLike = ...,
558+
casting: _CastingKind | None = ...
559+
) -> _ArrayT: ...
560+
@overload
561+
def concatenate(
562+
arrays: SupportsLenAndGetItem[ArrayLike],
563+
/,
564+
axis: SupportsIndex | None,
565+
out: _ArrayT,
557566
*,
558567
dtype: DTypeLike = ...,
559568
casting: _CastingKind | None = ...
@@ -1094,7 +1103,17 @@ def busday_count(
10941103
weekmask: ArrayLike = ...,
10951104
holidays: ArrayLike | dt.date | _NestedSequence[dt.date] | None = ...,
10961105
busdaycal: busdaycalendar | None = ...,
1097-
out: _ArrayT = ...,
1106+
*,
1107+
out: _ArrayT,
1108+
) -> _ArrayT: ...
1109+
@overload
1110+
def busday_count(
1111+
begindates: ArrayLike | dt.date | _NestedSequence[dt.date],
1112+
enddates: ArrayLike | dt.date | _NestedSequence[dt.date],
1113+
weekmask: ArrayLike,
1114+
holidays: ArrayLike | dt.date | _NestedSequence[dt.date] | None,
1115+
busdaycal: busdaycalendar | None,
1116+
out: _ArrayT,
10981117
) -> _ArrayT: ...
10991118

11001119
# `roll="raise"` is (more or less?) equivalent to `casting="safe"`
@@ -1126,7 +1145,18 @@ def busday_offset( # type: ignore[misc]
11261145
weekmask: ArrayLike = ...,
11271146
holidays: ArrayLike | dt.date | _NestedSequence[dt.date] | None = ...,
11281147
busdaycal: busdaycalendar | None = ...,
1129-
out: _ArrayT = ...,
1148+
*,
1149+
out: _ArrayT,
1150+
) -> _ArrayT: ...
1151+
@overload
1152+
def busday_offset( # type: ignore[misc]
1153+
dates: _ArrayLike[datetime64] | dt.date | _NestedSequence[dt.date],
1154+
offsets: _ArrayLikeTD64_co | dt.timedelta | _NestedSequence[dt.timedelta],
1155+
roll: L["raise"],
1156+
weekmask: ArrayLike,
1157+
holidays: ArrayLike | dt.date | _NestedSequence[dt.date] | None,
1158+
busdaycal: busdaycalendar | None,
1159+
out: _ArrayT,
11301160
) -> _ArrayT: ...
11311161
@overload
11321162
def busday_offset( # type: ignore[misc]
@@ -1156,7 +1186,18 @@ def busday_offset(
11561186
weekmask: ArrayLike = ...,
11571187
holidays: ArrayLike | dt.date | _NestedSequence[dt.date] | None = ...,
11581188
busdaycal: busdaycalendar | None = ...,
1159-
out: _ArrayT = ...,
1189+
*,
1190+
out: _ArrayT,
1191+
) -> _ArrayT: ...
1192+
@overload
1193+
def busday_offset(
1194+
dates: ArrayLike | dt.date | _NestedSequence[dt.date],
1195+
offsets: ArrayLike | dt.timedelta | _NestedSequence[dt.timedelta],
1196+
roll: _RollKind,
1197+
weekmask: ArrayLike,
1198+
holidays: ArrayLike | dt.date | _NestedSequence[dt.date] | None,
1199+
busdaycal: busdaycalendar | None,
1200+
out: _ArrayT,
11601201
) -> _ArrayT: ...
11611202

11621203
@overload
@@ -1181,7 +1222,16 @@ def is_busday(
11811222
weekmask: ArrayLike = ...,
11821223
holidays: ArrayLike | dt.date | _NestedSequence[dt.date] | None = ...,
11831224
busdaycal: busdaycalendar | None = ...,
1184-
out: _ArrayT = ...,
1225+
*,
1226+
out: _ArrayT,
1227+
) -> _ArrayT: ...
1228+
@overload
1229+
def is_busday(
1230+
dates: ArrayLike | _NestedSequence[dt.date],
1231+
weekmask: ArrayLike,
1232+
holidays: ArrayLike | dt.date | _NestedSequence[dt.date] | None,
1233+
busdaycal: busdaycalendar | None,
1234+
out: _ArrayT,
11851235
) -> _ArrayT: ...
11861236

11871237
@overload

numpy/typing/tests/data/reveal/array_constructors.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ assert_type(np.empty([1, 5, 6], dtype='c16'), npt.NDArray[Any])
5353
assert_type(np.empty(mixed_shape), npt.NDArray[np.float64])
5454

5555
assert_type(np.concatenate(A), npt.NDArray[np.float64])
56-
assert_type(np.concatenate([A, A]), Any) # pyright correctly infers this as NDArray[float64]
56+
assert_type(np.concatenate([A, A]), npt.NDArray[Any]) # pyright correctly infers this as NDArray[float64]
5757
assert_type(np.concatenate([[1], A]), npt.NDArray[Any])
5858
assert_type(np.concatenate([[1], [1]]), npt.NDArray[Any])
5959
assert_type(np.concatenate((A, A)), npt.NDArray[np.float64])

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy