Skip to content

Commit 0d470da

Browse files
authored
docs: move notes below the fold and highlight RFC 2119 keywords
PR-URL: #955 Ref: #397
1 parent a52285e commit 0d470da

File tree

1 file changed

+62
-56
lines changed

1 file changed

+62
-56
lines changed

src/array_api_stubs/_draft/linear_algebra_functions.py

Lines changed: 62 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,47 +8,57 @@ def matmul(x1: array, x2: array, /) -> array:
88
"""
99
Computes the matrix product.
1010
11-
.. note::
12-
The ``matmul`` function must implement the same semantics as the built-in ``@`` operator (see `PEP 465 <https://www.python.org/dev/peps/pep-0465>`_).
13-
1411
Parameters
1512
----------
1613
x1: array
17-
first input array. Should have a numeric data type. Must have at least one dimension. If ``x1`` is one-dimensional having shape ``(M,)`` and ``x2`` has more than one dimension, ``x1`` must be promoted to a two-dimensional array by prepending ``1`` to its dimensions (i.e., must have shape ``(1, M)``). After matrix multiplication, the prepended dimensions in the returned array must be removed. If ``x1`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x1)[:-2]`` must be compatible with ``shape(x2)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`). If ``x1`` has shape ``(..., M, K)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
18-
x2: array
19-
second input array. Should have a numeric data type. Must have at least one dimension. If ``x2`` is one-dimensional having shape ``(N,)`` and ``x1`` has more than one dimension, ``x2`` must be promoted to a two-dimensional array by appending ``1`` to its dimensions (i.e., must have shape ``(N, 1)``). After matrix multiplication, the appended dimensions in the returned array must be removed. If ``x2`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x2)[:-2]`` must be compatible with ``shape(x1)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`). If ``x2`` has shape ``(..., K, N)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
14+
first input array. **Should** have a numeric data type. **Must** have at least one dimension.
15+
16+
- If ``x1`` is a one-dimensional array having shape ``(M,)`` and ``x2`` has more than one dimension, ``x1`` **must** be promoted to a two-dimensional array by prepending ``1`` to its dimensions (i.e., **must** have shape ``(1, M)``). After matrix multiplication, the prepended dimensions in the returned array **must** be removed.
17+
- If ``x1`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x1)[:-2]`` **must** be compatible with ``shape(x2)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`).
18+
- If ``x1`` has shape ``(..., M, K)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
2019
20+
x2: array
21+
second input array. **Should** have a numeric data type. **Must** have at least one dimension.
2122
22-
.. note::
23-
If either ``x1`` or ``x2`` has a complex floating-point data type, neither argument must be complex-conjugated or transposed. If conjugation and/or transposition is desired, these operations should be explicitly performed prior to computing the matrix product.
23+
- If ``x2`` is one-dimensional array having shape ``(N,)`` and ``x1`` has more than one dimension, ``x2`` **must** be promoted to a two-dimensional array by appending ``1`` to its dimensions (i.e., **must** have shape ``(N, 1)``). After matrix multiplication, the appended dimensions in the returned array **must** be removed.
24+
- If ``x2`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x2)[:-2]`` **must** be compatible with ``shape(x1)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`).
25+
- If ``x2`` has shape ``(..., K, N)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
2426
2527
Returns
2628
-------
2729
out: array
28-
- if both ``x1`` and ``x2`` are one-dimensional arrays having shape ``(N,)``, a zero-dimensional array containing the inner product as its only element.
29-
- if ``x1`` is a two-dimensional array having shape ``(M, K)`` and ``x2`` is a two-dimensional array having shape ``(K, N)``, a two-dimensional array containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ and having shape ``(M, N)``.
30-
- if ``x1`` is a one-dimensional array having shape ``(K,)`` and ``x2`` is an array having shape ``(..., K, N)``, an array having shape ``(..., N)`` (i.e., prepended dimensions during vector-to-matrix promotion must be removed) and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
31-
- if ``x1`` is an array having shape ``(..., M, K)`` and ``x2`` is a one-dimensional array having shape ``(K,)``, an array having shape ``(..., M)`` (i.e., appended dimensions during vector-to-matrix promotion must be removed) and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
32-
- if ``x1`` is a two-dimensional array having shape ``(M, K)`` and ``x2`` is an array having shape ``(..., K, N)``, an array having shape ``(..., M, N)`` and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
33-
- if ``x1`` is an array having shape ``(..., M, K)`` and ``x2`` is a two-dimensional array having shape ``(K, N)``, an array having shape ``(..., M, N)`` and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
34-
- if either ``x1`` or ``x2`` has more than two dimensions, an array having a shape determined by :ref:`broadcasting` ``shape(x1)[:-2]`` against ``shape(x2)[:-2]`` and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
30+
output array.
31+
32+
- If both ``x1`` and ``x2`` are one-dimensional arrays having shape ``(N,)``, the returned array **must** be a zero-dimensional array and **must** contain the inner product as its only element.
33+
- If ``x1`` is a two-dimensional array having shape ``(M, K)`` and ``x2`` is a two-dimensional array having shape ``(K, N)``, the returned array **must** be a two-dimensional array and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ and having shape ``(M, N)``.
34+
- If ``x1`` is a one-dimensional array having shape ``(K,)`` and ``x2`` is an array having shape ``(..., K, N)``, the returned array **must** be an array having shape ``(..., N)`` (i.e., prepended dimensions during vector-to-matrix promotion **must** be removed) and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
35+
- If ``x1`` is an array having shape ``(..., M, K)`` and ``x2`` is a one-dimensional array having shape ``(K,)``, the returned array **must** be an array having shape ``(..., M)`` (i.e., appended dimensions during vector-to-matrix promotion **must** be removed) and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
36+
- If ``x1`` is a two-dimensional array having shape ``(M, K)`` and ``x2`` is an array having shape ``(..., K, N)``, the returned array **must** be an array having shape ``(..., M, N)`` and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
37+
- If ``x1`` is an array having shape ``(..., M, K)`` and ``x2`` is a two-dimensional array having shape ``(K, N)``, the returned array **must** be an array having shape ``(..., M, N)`` and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
38+
- If either ``x1`` or ``x2`` has more than two dimensions, the returned array **must** be an array having a shape determined by :ref:`broadcasting` ``shape(x1)[:-2]`` against ``shape(x2)[:-2]`` and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
39+
40+
The returned array **must** have a data type determined by :ref:`type-promotion`.
3541
36-
The returned array must have a data type determined by :ref:`type-promotion`.
42+
Raises
43+
------
44+
Exception
45+
an exception **should** be raised in the following circumstances:
46+
47+
- if either ``x1`` or ``x2`` is a zero-dimensional array.
48+
- if ``x1`` is a one-dimensional array having shape ``(K,)``, ``x2`` is a one-dimensional array having shape ``(L,)``, and ``K != L``.
49+
- if ``x1`` is a one-dimensional array having shape ``(K,)``, ``x2`` is an array having shape ``(..., L, N)``, and ``K != L``.
50+
- if ``x1`` is an array having shape ``(..., M, K)``, ``x2`` is a one-dimensional array having shape ``(L,)``, and ``K != L``.
51+
- if ``x1`` is an array having shape ``(..., M, K)``, ``x2`` is an array having shape ``(..., L, N)``, and ``K != L``.
3752
3853
Notes
3954
-----
4055
41-
.. versionchanged:: 2022.12
42-
Added complex data type support.
56+
- The ``matmul`` function **must** implement the same semantics as the built-in ``@`` operator (see `PEP 465 <https://www.python.org/dev/peps/pep-0465>`_).
4357
44-
**Raises**
45-
46-
- if either ``x1`` or ``x2`` is a zero-dimensional array.
47-
- if ``x1`` is a one-dimensional array having shape ``(K,)``, ``x2`` is a one-dimensional array having shape ``(L,)``, and ``K != L``.
48-
- if ``x1`` is a one-dimensional array having shape ``(K,)``, ``x2`` is an array having shape ``(..., L, N)``, and ``K != L``.
49-
- if ``x1`` is an array having shape ``(..., M, K)``, ``x2`` is a one-dimensional array having shape ``(L,)``, and ``K != L``.
50-
- if ``x1`` is an array having shape ``(..., M, K)``, ``x2`` is an array having shape ``(..., L, N)``, and ``K != L``.
58+
- If either ``x1`` or ``x2`` has a complex floating-point data type, the function **must not** complex-conjugate or tranpose either argument. If conjugation and/or transposition is desired, a user can explicitly perform these operations prior to computing the matrix product.
5159
60+
.. versionchanged:: 2022.12
61+
Added complex data type support.
5262
"""
5363

5464

@@ -64,7 +74,7 @@ def matrix_transpose(x: array, /) -> array:
6474
Returns
6575
-------
6676
out: array
67-
an array containing the transpose for each matrix and having shape ``(..., N, M)``. The returned array must have the same data type as ``x``.
77+
an array containing the transpose for each matrix. The returned array **must** have shape ``(..., N, M)``. The returned array **must** have the same data type as ``x``.
6878
"""
6979

7080

@@ -78,42 +88,37 @@ def tensordot(
7888
"""
7989
Returns a tensor contraction of ``x1`` and ``x2`` over specific axes.
8090
81-
.. note::
82-
The ``tensordot`` function corresponds to the generalized matrix product.
83-
8491
Parameters
8592
----------
8693
x1: array
87-
first input array. Should have a numeric data type.
94+
first input array. **Should** have a numeric data type.
8895
x2: array
89-
second input array. Should have a numeric data type. Corresponding contracted axes of ``x1`` and ``x2`` must be equal.
90-
91-
.. note::
92-
Contracted axes (dimensions) must not be broadcasted.
96+
second input array. **Should** have a numeric data type. Corresponding contracted axes of ``x1`` and ``x2`` **must** be equal.
9397
9498
axes: Union[int, Tuple[Sequence[int], Sequence[int]]]
95-
number of axes (dimensions) to contract or explicit sequences of axis (dimension) indices for ``x1`` and ``x2``, respectively.
99+
number of axes to contract or explicit sequences of axis indices for ``x1`` and ``x2``, respectively.
96100
97-
If ``axes`` is an ``int`` equal to ``N``, then contraction must be performed over the last ``N`` axes of ``x1`` and the first ``N`` axes of ``x2`` in order. The size of each corresponding axis (dimension) must match. Must be nonnegative.
101+
If ``axes`` is an ``int`` equal to ``N``, then contraction **must** be performed over the last ``N`` axes of ``x1`` and the first ``N`` axes of ``x2`` in order. The size of each corresponding axis **must** match. An integer ``axes`` value **must** be nonnegative.
98102
99-
- If ``N`` equals ``0``, the result is the tensor (outer) product.
100-
- If ``N`` equals ``1``, the result is the tensor dot product.
101-
- If ``N`` equals ``2``, the result is the tensor double contraction (default).
103+
- If ``N`` equals ``0``, the result **must** be the tensor (outer) product.
104+
- If ``N`` equals ``1``, the result **must** be the tensor dot product.
105+
- If ``N`` equals ``2``, the result **must** be the tensor double contraction (default).
102106
103-
If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence must apply to ``x1`` and the second sequence to ``x2``. Both sequences must have the same length. Each axis (dimension) ``x1_axes[i]`` for ``x1`` must have the same size as the respective axis (dimension) ``x2_axes[i]`` for ``x2``. Each index referred to in a sequence must be unique. If ``x1`` has rank (i.e, number of dimensions) ``N``, a valid ``x1`` axis must reside on the half-open interval ``[-N, N)``. If ``x2`` has rank ``M``, a valid ``x2`` axis must reside on the half-open interval ``[-M, M)``.
107+
If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence **must** apply to ``x1`` and the second sequence **must** apply to ``x2``. Both sequences **must** have the same length. Each axis ``x1_axes[i]`` for ``x1`` **must** have the same size as the respective axis ``x2_axes[i]`` for ``x2``. Each index referred to in a sequence **must** be unique. A valid axis **must** be an integer on the interval ``[-S, S)``, where ``S`` is the number of axes in respective array. Hence, if ``x1`` has ``N`` axes, a valid ``x1`` axes **must** be an integer on the interval ``[-N, N)``. If ``x2`` has ``M`` axes, a valid ``x2`` axes **must** be an integer on the interval ``[-M, M)``. If an axis is specified as a negative integer, the function **must** determine the axis along which to perform the operation by counting backward from the last axis (where ``-1`` refers to the last axis). If provided an invalid axis, the function **must** raise an exception.
104108
105109
106-
.. note::
107-
If either ``x1`` or ``x2`` has a complex floating-point data type, neither argument must be complex-conjugated or transposed. If conjugation and/or transposition is desired, these operations should be explicitly performed prior to computing the generalized matrix product.
108-
109110
Returns
110111
-------
111112
out: array
112-
an array containing the tensor contraction whose shape consists of the non-contracted axes (dimensions) of the first array ``x1``, followed by the non-contracted axes (dimensions) of the second array ``x2``. The returned array must have a data type determined by :ref:`type-promotion`.
113+
an array containing the tensor contraction. The returned array **must** have a shape which consists of the non-contracted axes of the first array ``x1``, followed by the non-contracted axes of the second array ``x2``. The returned array **must** have a data type determined by :ref:`type-promotion`.
113114
114115
Notes
115116
-----
116117
118+
- The ``tensordot`` function corresponds to the generalized matrix product.
119+
- Contracted axes **must** not be broadcasted.
120+
- If either ``x1`` or ``x2`` has a complex floating-point data type, the function **must not** complex-conjugate or transpose either argument. If conjugation and/or transposition is desired, a user can explicitly perform these operations prior to computing the generalized matrix product.
121+
117122
.. versionchanged:: 2022.12
118123
Added complex data type support.
119124
@@ -131,32 +136,33 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
131136
.. math::
132137
\mathbf{a} \cdot \mathbf{b} = \sum_{i=0}^{n-1} \overline{a_i}b_i
133138
134-
over the dimension specified by ``axis`` and where :math:`n` is the dimension size and :math:`\overline{a_i}` denotes the complex conjugate if :math:`a_i` is complex and the identity if :math:`a_i` is real-valued.
139+
over the axis specified by ``axis`` and where :math:`n` is the axis size and :math:`\overline{a_i}` denotes the complex conjugate if :math:`a_i` is complex and the identity if :math:`a_i` is real-valued.
135140
136141
Parameters
137142
----------
138143
x1: array
139-
first input array. Should have a floating-point data type.
144+
first input array. **Should** have a floating-point data type.
140145
x2: array
141-
second input array. Must be compatible with ``x1`` for all non-contracted axes (see :ref:`broadcasting`). The size of the axis over which to compute the dot product must be the same size as the respective axis in ``x1``. Should have a floating-point data type.
142-
143-
.. note::
144-
The contracted axis (dimension) must not be broadcasted.
145-
146+
second input array. **Must** be compatible with ``x1`` for all non-contracted axes (see :ref:`broadcasting`). The size of the axis over which to compute the dot product **must** be the same size as the respective axis in ``x1``. **Should** have a floating-point data type.
146147
axis: int
147-
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the dot product. Should be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``.
148+
axis of ``x1`` and ``x2`` containing the vectors for which to compute the dot product. **Should** be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function **must** determine the axis along which to perform the operation by counting backward from the last axis (where ``-1`` refers to the last axis). By default, the function **must** compute the dot product over the last axis. Default: ``-1``.
148149
149150
Returns
150151
-------
151152
out: array
152-
if ``x1`` and ``x2`` are both one-dimensional arrays, a zero-dimensional containing the dot product; otherwise, a non-zero-dimensional array containing the dot products and having rank ``N-1``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting` along the non-contracted axes. The returned array must have a data type determined by :ref:`type-promotion`.
153+
if ``x1`` and ``x2`` are both one-dimensional arrays, a zero-dimensional containing the dot product; otherwise, a non-zero-dimensional array containing the dot products and having ``N-1`` axes, where ``N`` is number of axes in the shape determined according to :ref:`broadcasting` along the non-contracted axes. The returned array **must** have a data type determined by :ref:`type-promotion`.
154+
155+
Raises
156+
------
157+
Exception
158+
an exception **should** be raised in the following circumstances:
159+
160+
- if the size of the axis over which to compute the dot product is not the same (before broadcasting) for both ``x1`` and ``x2``.
153161
154162
Notes
155163
-----
156164
157-
**Raises**
158-
159-
- if the size of the axis over which to compute the dot product is not the same (before broadcasting) for both ``x1`` and ``x2``.
165+
- The contracted axis **must** not be broadcasted.
160166
161167
.. versionchanged:: 2022.12
162168
Added complex data type support.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy