Skip to content

Commit f228b58

Browse files
committed
Add fft support for numpy and cupy
This is based off of numpy/numpy#25317
1 parent d235910 commit f228b58

File tree

5 files changed

+245
-0
lines changed

5 files changed

+245
-0
lines changed

array_api_compat/common/_fft.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Union, Optional, Literal
4+
5+
if TYPE_CHECKING:
6+
from ._typing import Device, ndarray
7+
from collections.abc import Sequence
8+
9+
# Note: NumPy fft functions improperly upcast float32 and complex64 to
10+
# complex128, which is why we require wrapping them all here.
11+
12+
def fft(
13+
x: ndarray,
14+
/,
15+
xp,
16+
*,
17+
n: Optional[int] = None,
18+
axis: int = -1,
19+
norm: Literal["backward", "ortho", "forward"] = "backward",
20+
) -> ndarray:
21+
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
22+
if x.dtype in [xp.float32, xp.complex64]:
23+
return res.astype(xp.complex64)
24+
return res
25+
26+
def ifft(
27+
x: ndarray,
28+
/,
29+
xp,
30+
*,
31+
n: Optional[int] = None,
32+
axis: int = -1,
33+
norm: Literal["backward", "ortho", "forward"] = "backward",
34+
) -> ndarray:
35+
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
36+
if x.dtype in [xp.float32, xp.complex64]:
37+
return res.astype(xp.complex64)
38+
return res
39+
40+
def fftn(
41+
x: ndarray,
42+
/,
43+
xp,
44+
*,
45+
s: Sequence[int] = None,
46+
axes: Sequence[int] = None,
47+
norm: Literal["backward", "ortho", "forward"] = "backward",
48+
) -> ndarray:
49+
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
50+
if x.dtype in [xp.float32, xp.complex64]:
51+
return res.astype(xp.complex64)
52+
return res
53+
54+
def ifftn(
55+
x: ndarray,
56+
/,
57+
xp,
58+
*,
59+
s: Sequence[int] = None,
60+
axes: Sequence[int] = None,
61+
norm: Literal["backward", "ortho", "forward"] = "backward",
62+
) -> ndarray:
63+
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
64+
if x.dtype in [xp.float32, xp.complex64]:
65+
return res.astype(xp.complex64)
66+
return res
67+
68+
def rfft(
69+
x: ndarray,
70+
/,
71+
xp,
72+
*,
73+
n: Optional[int] = None,
74+
axis: int = -1,
75+
norm: Literal["backward", "ortho", "forward"] = "backward",
76+
) -> ndarray:
77+
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
78+
if x.dtype == xp.float32:
79+
return res.astype(xp.complex64)
80+
return res
81+
82+
def irfft(
83+
x: ndarray,
84+
/,
85+
xp,
86+
*,
87+
n: Optional[int] = None,
88+
axis: int = -1,
89+
norm: Literal["backward", "ortho", "forward"] = "backward",
90+
) -> ndarray:
91+
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
92+
if x.dtype == xp.complex64:
93+
return res.astype(xp.float32)
94+
return res
95+
96+
def rfftn(
97+
x: ndarray,
98+
/,
99+
xp,
100+
*,
101+
s: Sequence[int] = None,
102+
axes: Sequence[int] = None,
103+
norm: Literal["backward", "ortho", "forward"] = "backward",
104+
) -> ndarray:
105+
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
106+
if x.dtype == xp.float32:
107+
return res.astype(xp.complex64)
108+
return res
109+
110+
def irfftn(
111+
x: ndarray,
112+
/,
113+
xp,
114+
*,
115+
s: Sequence[int] = None,
116+
axes: Sequence[int] = None,
117+
norm: Literal["backward", "ortho", "forward"] = "backward",
118+
) -> ndarray:
119+
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
120+
if x.dtype == xp.complex64:
121+
return res.astype(xp.float32)
122+
return res
123+
124+
def hfft(
125+
x: ndarray,
126+
/,
127+
xp,
128+
*,
129+
n: Optional[int] = None,
130+
axis: int = -1,
131+
norm: Literal["backward", "ortho", "forward"] = "backward",
132+
) -> ndarray:
133+
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
134+
if x.dtype in [xp.float32, xp.complex64]:
135+
return res.astype(xp.complex64)
136+
return res
137+
138+
def ihfft(
139+
x: ndarray,
140+
/,
141+
xp,
142+
*,
143+
n: Optional[int] = None,
144+
axis: int = -1,
145+
norm: Literal["backward", "ortho", "forward"] = "backward",
146+
) -> ndarray:
147+
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
148+
if x.dtype in [xp.float32, xp.complex64]:
149+
return res.astype(xp.complex64)
150+
return res
151+
152+
def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
153+
if device not in ["cpu", None]:
154+
raise ValueError(f"Unsupported device {device!r}")
155+
return xp.fft.fftfreq(n, d=d)
156+
157+
def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
158+
if device not in ["cpu", None]:
159+
raise ValueError(f"Unsupported device {device!r}")
160+
return xp.fft.rfftfreq(n, d=d)
161+
162+
def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
163+
return xp.fft.fftshift(x, axes=axes)
164+
165+
def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
166+
return xp.fft.ifftshift(x, axes=axes)
167+
168+
__all__ = [
169+
"fft",
170+
"ifft",
171+
"fftn",
172+
"ifftn",
173+
"rfft",
174+
"irfft",
175+
"rfftn",
176+
"irfftn",
177+
"hfft",
178+
"ihfft",
179+
"fftfreq",
180+
"rfftfreq",
181+
"fftshift",
182+
"ifftshift",
183+
]

array_api_compat/cupy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# See the comment in the numpy __init__.py
1010
__import__(__package__ + '.linalg')
1111

12+
__import__(__package__ + '.fft')
13+
1214
from .linalg import matrix_transpose, vecdot
1315

1416
from ..common._helpers import *

array_api_compat/cupy/fft.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from cupy.fft import *
2+
from cupy.fft import __all__ as fft_all
3+
4+
from ..common import _fft
5+
from .._internal import get_xp
6+
7+
import cupy as cp
8+
9+
fft = get_xp(cp)(_fft.fft),
10+
ifft = get_xp(cp)(_fft.ifft),
11+
fftn = get_xp(cp)(_fft.fftn),
12+
ifftn = get_xp(cp)(_fft.ifftn),
13+
rfft = get_xp(cp)(_fft.rfft),
14+
irfft = get_xp(cp)(_fft.irfft),
15+
rfftn = get_xp(cp)(_fft.rfftn),
16+
irfftn = get_xp(cp)(_fft.irfftn),
17+
hfft = get_xp(cp)(_fft.hfft),
18+
ihfft = get_xp(cp)(_fft.ihfft),
19+
fftfreq = get_xp(cp)(_fft.fftfreq),
20+
rfftfreq = get_xp(cp)(_fft.rfftfreq),
21+
fftshift = get_xp(cp)(_fft.fftshift),
22+
ifftshift = get_xp(cp)(_fft.ifftshift),
23+
24+
__all__ = fft_all + _fft.__all__
25+
26+
del get_xp
27+
del cp
28+
del fft_all
29+
del _fft

array_api_compat/numpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# dynamically so that the library can be vendored.
1616
__import__(__package__ + '.linalg')
1717

18+
__import__(__package__ + '.fft')
19+
1820
from .linalg import matrix_transpose, vecdot
1921

2022
from ..common._helpers import *

array_api_compat/numpy/fft.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from numpy.fft import *
2+
from numpy.fft import __all__ as fft_all
3+
4+
from ..common import _fft
5+
from .._internal import get_xp
6+
7+
import numpy as np
8+
9+
fft = get_xp(np)(_fft.fft)
10+
ifft = get_xp(np)(_fft.ifft)
11+
fftn = get_xp(np)(_fft.fftn)
12+
ifftn = get_xp(np)(_fft.ifftn)
13+
rfft = get_xp(np)(_fft.rfft)
14+
irfft = get_xp(np)(_fft.irfft)
15+
rfftn = get_xp(np)(_fft.rfftn)
16+
irfftn = get_xp(np)(_fft.irfftn)
17+
hfft = get_xp(np)(_fft.hfft)
18+
ihfft = get_xp(np)(_fft.ihfft)
19+
fftfreq = get_xp(np)(_fft.fftfreq)
20+
rfftfreq = get_xp(np)(_fft.rfftfreq)
21+
fftshift = get_xp(np)(_fft.fftshift)
22+
ifftshift = get_xp(np)(_fft.ifftshift)
23+
24+
__all__ = fft_all + _fft.__all__
25+
26+
del get_xp
27+
del np
28+
del fft_all
29+
del _fft

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy