Skip to content

Commit 06d86c9

Browse files
authored
Merge pull request #267 from tdegeus/qad
Adding possibility to 'cast' or copy to `xt::xarray` etc
2 parents 43b244e + af91def commit 06d86c9

File tree

9 files changed

+276
-18
lines changed

9 files changed

+276
-18
lines changed

.azure-pipelines/unix-build.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ steps:
4545
displayName: Example - readme 1
4646
workingDirectory: $(Build.SourcesDirectory)/docs/source/examples/readme_example_1
4747
48+
- script: |
49+
source activate xtensor-python
50+
cmake -Bbuild -DPython_EXECUTABLE=`which python`
51+
cd build
52+
cmake --build .
53+
cp ../example.py .
54+
python example.py
55+
cd ..
56+
displayName: Example - Copy 'cast'
57+
workingDirectory: $(Build.SourcesDirectory)/docs/source/examples/copy_cast
58+
4859
- script: |
4960
source activate xtensor-python
5061
cmake -Bbuild -DPython_EXECUTABLE=`which python`

docs/source/examples.rst

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,54 @@ Then we can test the module:
143143
Since we did not install the module,
144144
we should compile and run the example from the same folder.
145145
To install, please consult
146-
`this *pybind11* / *CMake* example <https://github.com/pybind/cmake_example>`_.
146+
`this pybind11 / CMake example <https://github.com/pybind/cmake_example>`_.
147147
**Tip**: take care to modify that example with the correct *CMake* case ``Python_EXECUTABLE``.
148+
149+
Fall-back cast
150+
==============
151+
152+
The previous example showed you how to design your module to be flexible in accepting data.
153+
From C++ we used ``xt::xarray<double>``,
154+
whereas for the Python API we used ``xt::pyarray<double>`` to operate directly on the memory
155+
of a NumPy array from Python (without copying the data).
156+
157+
Sometimes, you might not have the flexibility to design your module's methods
158+
with template parameters.
159+
This might occur when you want to ``override`` functions
160+
(though it is recommended to use CRTP to still use templates).
161+
In this case we can still bind the module in Python using *xtensor-python*,
162+
however, we have to copy the data from a (NumPy) array.
163+
This means that although the following signatures are quite different when used from C++,
164+
as follows:
165+
166+
1. *Constant reference*: read from the data, without copying it.
167+
168+
.. code-block:: cpp
169+
170+
void foo(const xt::xarray<double>& a);
171+
172+
2. *Reference*: read from and/or write to the data, without copying it.
173+
174+
.. code-block:: cpp
175+
176+
void foo(xt::xarray<double>& a);
177+
178+
3. *Copy*: copy the data.
179+
180+
.. code-block:: cpp
181+
182+
void foo(xt::xarray<double> a);
183+
184+
The Python will all cases result in a copy to a temporary variable
185+
(though the last signature will lead to a copy to a temporary variable, and another copy to ``a``).
186+
On the one hand, this is more costly than when using ``xt::pyarray`` and ``xt::pyxtensor``,
187+
on the other hand, it means that all changes you make to a reference, are made to the temporary
188+
copy, and are thus lost.
189+
190+
Still, it might be a convenient way to create Python bindings, using a minimal effort.
191+
Consider this example:
192+
193+
:download:`main.cpp <examples/copy_cast/main.cpp>`
194+
195+
.. literalinclude:: examples/copy_cast/main.cpp
196+
:language: cpp
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
cmake_minimum_required(VERSION 3.1..3.19)
2+
3+
project(mymodule)
4+
5+
find_package(pybind11 CONFIG REQUIRED)
6+
find_package(xtensor REQUIRED)
7+
find_package(xtensor-python REQUIRED)
8+
find_package(Python REQUIRED COMPONENTS NumPy)
9+
10+
pybind11_add_module(mymodule main.cpp)
11+
target_link_libraries(mymodule PUBLIC pybind11::module xtensor-python Python::NumPy)
12+
13+
target_compile_definitions(mymodule PRIVATE VERSION_INFO=0.1.0)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import mymodule
2+
import numpy as np
3+
4+
c = np.array([[1, 2, 3], [4, 5, 6]])
5+
assert np.isclose(np.sum(np.sin(c)), mymodule.sum_of_sines(c))
6+
assert np.isclose(np.sum(np.cos(c)), mymodule.sum_of_cosines(c))
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include <numeric>
2+
#include <xtensor.hpp>
3+
#include <pybind11/pybind11.h>
4+
#define FORCE_IMPORT_ARRAY
5+
#include <xtensor-python/pyarray.hpp>
6+
7+
template <class T>
8+
double sum_of_sines(T& m)
9+
{
10+
auto sines = xt::sin(m); // sines does not actually hold values.
11+
return std::accumulate(sines.begin(), sines.end(), 0.0);
12+
}
13+
14+
// In the Python API this a reference to a temporary variable
15+
double sum_of_cosines(const xt::xarray<double>& m)
16+
{
17+
auto cosines = xt::cos(m); // cosines does not actually hold values.
18+
return std::accumulate(cosines.begin(), cosines.end(), 0.0);
19+
}
20+
21+
PYBIND11_MODULE(mymodule, m)
22+
{
23+
xt::import_numpy();
24+
m.doc() = "Test module for xtensor python bindings";
25+
m.def("sum_of_sines", sum_of_sines<xt::pyarray<double>>, "Sum the sines of the input values");
26+
m.def("sum_of_cosines", sum_of_cosines, "Sum the cosines of the input values");
27+
}

include/xtensor-python/pynative_casters.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#include "xtensor_type_caster_base.hpp"
1414

15-
1615
namespace pybind11
1716
{
1817
namespace detail

include/xtensor-python/xtensor_type_caster_base.hpp

Lines changed: 121 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,97 @@ namespace pybind11
2323
{
2424
namespace detail
2525
{
26+
template <typename T, xt::layout_type L>
27+
struct pybind_array_getter_impl
28+
{
29+
static auto run(handle src)
30+
{
31+
return array_t<T, array::c_style | array::forcecast>::ensure(src);
32+
}
33+
};
34+
35+
template <typename T>
36+
struct pybind_array_getter_impl<T, xt::layout_type::column_major>
37+
{
38+
static auto run(handle src)
39+
{
40+
return array_t<T, array::f_style | array::forcecast>::ensure(src);
41+
}
42+
};
43+
44+
template <class T>
45+
struct pybind_array_getter
46+
{
47+
};
48+
49+
template <class T, xt::layout_type L>
50+
struct pybind_array_getter<xt::xarray<T, L>>
51+
{
52+
static auto run(handle src)
53+
{
54+
return pybind_array_getter_impl<T, L>::run(src);
55+
}
56+
};
57+
58+
template <class T, std::size_t N, xt::layout_type L>
59+
struct pybind_array_getter<xt::xtensor<T, N, L>>
60+
{
61+
static auto run(handle src)
62+
{
63+
return pybind_array_getter_impl<T, L>::run(src);
64+
}
65+
};
66+
67+
template <class CT, class S, xt::layout_type L, class FST>
68+
struct pybind_array_getter<xt::xstrided_view<CT, S, L, FST>>
69+
{
70+
static auto run(handle /*src*/)
71+
{
72+
return false;
73+
}
74+
};
75+
76+
template <class EC, xt::layout_type L, class SC, class Tag>
77+
struct pybind_array_getter<xt::xarray_adaptor<EC, L, SC, Tag>>
78+
{
79+
static auto run(handle src)
80+
{
81+
auto buf = pybind_array_getter_impl<EC, L>::run(src);
82+
return buf;
83+
}
84+
};
85+
86+
template <class EC, std::size_t N, xt::layout_type L, class Tag>
87+
struct pybind_array_getter<xt::xtensor_adaptor<EC, N, L, Tag>>
88+
{
89+
static auto run(handle /*src*/)
90+
{
91+
return false;
92+
}
93+
};
94+
95+
96+
template <class T>
97+
struct pybind_array_dim_checker
98+
{
99+
template <class B>
100+
static bool run(const B& buf)
101+
{
102+
return true;
103+
}
104+
};
105+
106+
template <class T, std::size_t N, xt::layout_type L>
107+
struct pybind_array_dim_checker<xt::xtensor<T, N, L>>
108+
{
109+
template <class B>
110+
static bool run(const B& buf)
111+
{
112+
return buf.ndim() == N;
113+
}
114+
};
115+
116+
26117
// Casts a strided expression type to numpy array.If given a base,
27118
// the numpy array references the src data, otherwise it'll make a copy.
28119
// The writeable attributes lets you specify writeable flag for the array.
@@ -74,10 +165,6 @@ namespace pybind11
74165
template <class Type>
75166
struct xtensor_type_caster_base
76167
{
77-
bool load(handle /*src*/, bool)
78-
{
79-
return false;
80-
}
81168

82169
private:
83170

@@ -106,6 +193,36 @@ namespace pybind11
106193

107194
public:
108195

196+
PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<typename Type::value_type>::name + _("]"));
197+
198+
bool load(handle src, bool convert)
199+
{
200+
using T = typename Type::value_type;
201+
202+
if (!convert && !array_t<T>::check_(src))
203+
{
204+
return false;
205+
}
206+
207+
auto buf = pybind_array_getter<Type>::run(src);
208+
209+
if (!buf)
210+
{
211+
return false;
212+
}
213+
if (!pybind_array_dim_checker<Type>::run(buf))
214+
{
215+
return false;
216+
}
217+
218+
std::vector<size_t> shape(buf.ndim());
219+
std::copy(buf.shape(), buf.shape() + buf.ndim(), shape.begin());
220+
value = Type::from_shape(shape);
221+
std::copy(buf.data(), buf.data() + buf.size(), value.data());
222+
223+
return true;
224+
}
225+
109226
// Normal returned non-reference, non-const value:
110227
static handle cast(Type&& src, return_value_policy /* policy */, handle parent)
111228
{
@@ -151,18 +268,6 @@ namespace pybind11
151268
{
152269
return cast_impl(src, policy, parent);
153270
}
154-
155-
#ifdef PYBIND11_DESCR // The macro is removed from pybind11 since 2.3
156-
static PYBIND11_DESCR name()
157-
{
158-
return _("xt::xtensor");
159-
}
160-
#else
161-
static constexpr auto name = _("xt::xtensor");
162-
#endif
163-
164-
template <typename T>
165-
using cast_op_type = cast_op_type<T>;
166271
};
167272
}
168273
}

test_python/main.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,33 @@ xt::pyarray<double> example2(xt::pyarray<double>& m)
3333
return m + 2;
3434
}
3535

36+
xt::xarray<int> example3_xarray(const xt::xarray<int>& m)
37+
{
38+
return xt::transpose(m) + 2;
39+
}
40+
41+
xt::xarray<int, xt::layout_type::column_major> example3_xarray_colmajor(
42+
const xt::xarray<int, xt::layout_type::column_major>& m)
43+
{
44+
return xt::transpose(m) + 2;
45+
}
46+
47+
xt::xtensor<int, 3> example3_xtensor3(const xt::xtensor<int, 3>& m)
48+
{
49+
return xt::transpose(m) + 2;
50+
}
51+
52+
xt::xtensor<int, 2> example3_xtensor2(const xt::xtensor<int, 2>& m)
53+
{
54+
return xt::transpose(m) + 2;
55+
}
56+
57+
xt::xtensor<int, 2, xt::layout_type::column_major> example3_xtensor2_colmajor(
58+
const xt::xtensor<int, 2, xt::layout_type::column_major>& m)
59+
{
60+
return xt::transpose(m) + 2;
61+
}
62+
3663
// Readme Examples
3764

3865
double readme_example1(xt::pyarray<double>& m)
@@ -249,6 +276,11 @@ PYBIND11_MODULE(xtensor_python_test, m)
249276

250277
m.def("example1", example1);
251278
m.def("example2", example2);
279+
m.def("example3_xarray", example3_xarray);
280+
m.def("example3_xarray_colmajor", example3_xarray_colmajor);
281+
m.def("example3_xtensor3", example3_xtensor3);
282+
m.def("example3_xtensor2", example3_xtensor2);
283+
m.def("example3_xtensor2_colmajor", example3_xtensor2_colmajor);
252284

253285
m.def("complex_overload", no_complex_overload);
254286
m.def("complex_overload", complex_overload);

test_python/test_pyarray.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ def test_example2(self):
3636
y = xt.example2(x)
3737
np.testing.assert_allclose(y, res, 1e-12)
3838

39+
def test_example3(self):
40+
x = np.arange(2 * 3).reshape(2, 3)
41+
xc = np.asfortranarray(x)
42+
y = np.arange(2 * 3 * 4).reshape(2, 3, 4)
43+
v = y[1:, 1:, 0]
44+
z = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
45+
np.testing.assert_array_equal(xt.example3_xarray(x), x.T + 2)
46+
np.testing.assert_array_equal(xt.example3_xarray_colmajor(xc), xc.T + 2)
47+
np.testing.assert_array_equal(xt.example3_xtensor3(y), y.T + 2)
48+
np.testing.assert_array_equal(xt.example3_xtensor2(x), x.T + 2)
49+
np.testing.assert_array_equal(xt.example3_xtensor2(y[1:, 1:, 0]), v.T + 2)
50+
np.testing.assert_array_equal(xt.example3_xtensor2_colmajor(xc), xc.T + 2)
51+
52+
with self.assertRaises(TypeError):
53+
xt.example3_xtensor3(x)
54+
3955
def test_vectorize(self):
4056
x1 = np.array([[0, 1], [2, 3]])
4157
x2 = np.array([0, 1])

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy