Skip to content

Add async oindex and vindex methods to AsyncArray #3083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
4f51d23
add oindex method to AsyncArray
TomNicholas May 23, 2025
535ebaa
fix
TomNicholas May 23, 2025
6f25f82
add support for async vindex
TomNicholas May 29, 2025
e595f76
Merge branch 'main' into async_oindex
TomNicholas Jul 22, 2025
bdbdd61
remove outdated @_deprecate_positional_args
TomNicholas Jul 22, 2025
fec243d
correct return type hint
TomNicholas Jul 22, 2025
320e6d2
Merge branch 'main' into async_oindex
TomNicholas Jul 22, 2025
ea0f657
add type parameter to generic
TomNicholas Jul 22, 2025
870b6b6
actually import type
TomNicholas Jul 22, 2025
a7e9e43
add hypothesis tests for supported async indexing
TomNicholas Jul 22, 2025
102e411
make tests of async into async functions
TomNicholas Jul 22, 2025
0cd96aa
release notes
TomNicholas Jul 22, 2025
b503969
linting
TomNicholas Jul 23, 2025
9b8ebde
fix type hint issues with T_ArrayMetadata and Generics
TomNicholas Jul 23, 2025
125ebdf
copy docstring for async vindex
TomNicholas Jul 23, 2025
b6d5b6d
broaden return type to include scalar
TomNicholas Jul 23, 2025
e7cbaef
satisfied mypy by adding get_mask_selection to AsyncArray
TomNicholas Jul 23, 2025
d5d5494
move T_ArrayMetadata import back inside type checking block
TomNicholas Jul 23, 2025
c0026e9
resolve circular import by moving ceildiv to common.py
TomNicholas Jul 23, 2025
b9197e5
merge sync and async tests into one function
TomNicholas Jul 24, 2025
8c13259
Merge branch 'main' into async_oindex
d-v-b Jul 25, 2025
9e60062
sketch out async oindex test
TomNicholas Jul 28, 2025
18ea042
add ellipsis tests
TomNicholas Jul 28, 2025
7fe1ffd
add tests for arrays of ints
TomNicholas Jul 28, 2025
b0af4a7
all async oindex tests
TomNicholas Jul 28, 2025
79f78cc
add vindex test
TomNicholas Jul 28, 2025
da37026
Merge branch 'async_oindex' of https://github.com/TomNicholas/zarr-py…
TomNicholas Jul 28, 2025
4a1ca09
Merge branch 'main' into async_oindex
d-v-b Jul 28, 2025
3b62dfa
satisfy mypy
TomNicholas Jul 28, 2025
b8b7c09
Merge branch 'async_oindex' of https://github.com/TomNicholas/zarr-py…
TomNicholas Jul 28, 2025
6fa9f37
linting
TomNicholas Jul 28, 2025
01ac722
add test for indexing with zarr array
TomNicholas Jul 29, 2025
c7a1000
add test case for masked boolean vectorized indexing
TomNicholas Jul 30, 2025
7e9681d
add test to cover invalid indexer passed to vindex
TomNicholas Jul 30, 2025
1469093
also cover invalid indexer to oindex
TomNicholas Jul 30, 2025
6fbb6b1
test vindexing with zarr array
TomNicholas Jul 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3083.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for async vectorized and orthogonal indexing.
67 changes: 66 additions & 1 deletion src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
ZarrFormat,
_default_zarr_format,
_warn_order_kwarg,
ceildiv,
concurrent_map,
parse_shapelike,
product,
Expand All @@ -76,6 +77,8 @@
)
from zarr.core.dtype.common import HasEndianness, HasItemSize, HasObjectCodec
from zarr.core.indexing import (
AsyncOIndex,
AsyncVIndex,
BasicIndexer,
BasicSelection,
BlockIndex,
Expand All @@ -92,7 +95,6 @@
Selection,
VIndex,
_iter_grid,
ceildiv,
check_fields,
check_no_multi_fields,
is_pure_fancy_indexing,
Expand Down Expand Up @@ -1425,6 +1427,56 @@ async def getitem(
)
return await self._get_selection(indexer, prototype=prototype)

async def get_orthogonal_selection(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_basic_selection also doesn't exist on AsyncArray - should I add that too?

self,
selection: OrthogonalSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
return await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)

async def get_mask_selection(
self,
mask: MaskSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
return await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)

async def get_coordinate_selection(
self,
selection: CoordinateSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
out_array = await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)

if hasattr(out_array, "shape"):
# restore shape
out_array = np.array(out_array).reshape(indexer.sel_shape)
return out_array

async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
"""
Asynchronously save the array metadata.
Expand Down Expand Up @@ -1556,6 +1608,19 @@ async def setitem(
)
return await self._set_selection(indexer, value, prototype=prototype)

@property
def oindex(self) -> AsyncOIndex[T_ArrayMetadata]:
Comment on lines +1611 to +1612
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose this API to try to follow this pattern:

  • Array.__getitem__ (exists)
  • Array.oindex.__getitem__ (exists)
  • Array.vindex.__getitem__ (exists)
  • AsyncArray.getitem (exists)
  • AsyncArray.oindex.getitem (new)
  • AsyncArray.vindex.getitem (new)

because python doesn't let you make an async version of the __getitem__ magic method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the sync versions to use sync(self._async_array.oindex.getitem) instead of going to _get_selection directly? That will get us some test coverage

Copy link
Member Author

@TomNicholas TomNicholas Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds very sensible, but when I try to refactor to do that I find that it causes a problem where sync() gets called from within an async coroutine. I wouldn't have thought this should happen, but here's claude's explanation of why it does:

> Explain how exactly sync is being called from within an async coroutineLooking at the error and the code, here's exactly how sync() is being called from within an async coroutine:

  The Call Chain

  1. Test calls zarr array indexing: za[zix] in test_indexing_with_zarr_array
  2. Array.getitem (src/zarr/core/array.py:2566):
  return self.vindex[cast("CoordinateSelection | MaskSelection", selection)]
  3. VIndex.getitem (src/zarr/core/indexing.py:1248):
  return sync(self.array._async_array.vindex.getitem(selection))
  4. sync() tries to run the coroutine but detects it's already in a running loop

  The Problem

  The issue is in the __array__ method call chain:

  1. VIndex.__getitem__ calls sync() with self.array._async_array.vindex.getitem(selection)
  2. This eventually leads to AsyncArray._get_selection() being called
  3. But somewhere in the process, Array.__array__() gets called (src/zarr/core/array.py:2413)
  4. __array__() calls self[...] which goes back to Array.__getitem__
  5. This creates a nested call where sync() is called while already inside an async context

  The Root Cause

  The function _zarr_array_to_int_or_bool_array() at line 85 calls np.asarray(arr), which triggers the __array__ protocol on the zarr array. This causes:

  def _zarr_array_to_int_or_bool_array(arr: Array) -> npt.NDArray[np.intp] | npt.NDArray[np.bool_]:
      if arr.dtype.kind in ("i", "b"):
          return np.asarray(arr)  # <-- This calls arr.__array__()

  When np.asarray() is called on a zarr Array, it calls Array.__array__(), which calls self[...], which eventually calls sync() again - but we're already
  inside a sync() call from the VIndex, creating the nested async context error.

  The original code before the changes avoided this by handling the zarr array conversion within the sync methods directly, rather than delegating to async
  methods that would create this nested sync situation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess for indexing with a Zarr array, we should convert to numpy array before the sync call

"""Shortcut for orthogonal (outer) indexing, see :func:`get_orthogonal_selection` and
:func:`set_orthogonal_selection` for documentation and examples."""
return AsyncOIndex(self)

@property
def vindex(self) -> AsyncVIndex[T_ArrayMetadata]:
"""Shortcut for vectorized (inner) indexing, see :func:`get_coordinate_selection`,
:func:`set_coordinate_selection`, :func:`get_mask_selection` and
:func:`set_mask_selection` for documentation and examples."""
return AsyncVIndex(self)

async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) -> None:
"""
Asynchronously resize the array to a new shape.
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/core/chunk_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
ChunkCoords,
ChunkCoordsLike,
ShapeLike,
ceildiv,
parse_named_configuration,
parse_shapelike,
)
from zarr.core.indexing import ceildiv

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down
7 changes: 7 additions & 0 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import functools
import math
import operator
import warnings
from collections.abc import Iterable, Mapping, Sequence
Expand Down Expand Up @@ -69,6 +70,12 @@ def product(tup: ChunkCoords) -> int:
return functools.reduce(operator.mul, tup, 1)


def ceildiv(a: float, b: float) -> int:
if a == 0:
return 0
return math.ceil(a / b)


T = TypeVar("T", bound=tuple[Any, ...])
V = TypeVar("V")

Expand Down
58 changes: 50 additions & 8 deletions src/zarr/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
NamedTuple,
Protocol,
Expand All @@ -25,14 +26,16 @@
import numpy as np
import numpy.typing as npt

from zarr.core.common import product
from zarr.core.common import ceildiv, product
from zarr.core.metadata import T_ArrayMetadata

if TYPE_CHECKING:
from zarr.core.array import Array
from zarr.core.array import Array, AsyncArray
from zarr.core.buffer import NDArrayLikeOrScalar
from zarr.core.chunk_grids import ChunkGrid
from zarr.core.common import ChunkCoords


IntSequence = list[int] | npt.NDArray[np.intp]
ArrayOfIntOrBool = npt.NDArray[np.intp] | npt.NDArray[np.bool_]
BasicSelector = int | slice | EllipsisType
Expand Down Expand Up @@ -93,12 +96,6 @@ class Indexer(Protocol):
def __iter__(self) -> Iterator[ChunkProjection]: ...


def ceildiv(a: float, b: float) -> int:
if a == 0:
return 0
return math.ceil(a / b)


_ArrayIndexingOrder: TypeAlias = Literal["lexicographic"]


Expand Down Expand Up @@ -960,6 +957,25 @@ def __setitem__(self, selection: OrthogonalSelection, value: npt.ArrayLike) -> N
)


@dataclass(frozen=True)
class AsyncOIndex(Generic[T_ArrayMetadata]):
array: AsyncArray[T_ArrayMetadata]

async def getitem(self, selection: OrthogonalSelection | Array) -> NDArrayLikeOrScalar:
from zarr.core.array import Array

# if input is a Zarr array, we materialize it now.
if isinstance(selection, Array):
selection = _zarr_array_to_int_or_bool_array(selection)

fields, new_selection = pop_fields(selection)
new_selection = ensure_tuple(new_selection)
new_selection = replace_lists(new_selection)
return await self.array.get_orthogonal_selection(
cast(OrthogonalSelection, new_selection), fields=fields
)


@dataclass(frozen=True)
class BlockIndexer(Indexer):
dim_indexers: list[SliceDimIndexer]
Expand Down Expand Up @@ -1268,6 +1284,32 @@ def __setitem__(
raise VindexInvalidSelectionError(new_selection)


@dataclass(frozen=True)
class AsyncVIndex(Generic[T_ArrayMetadata]):
array: AsyncArray[T_ArrayMetadata]

# TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool
async def getitem(
self, selection: CoordinateSelection | MaskSelection | Array
) -> NDArrayLikeOrScalar:
# TODO deduplicate these internals with the sync version of getitem
# TODO requires solving this circular sync issue: https://github.com/zarr-developers/zarr-python/pull/3083#discussion_r2230737448
from zarr.core.array import Array

# if input is a Zarr array, we materialize it now.
if isinstance(selection, Array):
selection = _zarr_array_to_int_or_bool_array(selection)
fields, new_selection = pop_fields(selection)
new_selection = ensure_tuple(new_selection)
new_selection = replace_lists(new_selection)
if is_coordinate_selection(new_selection, self.array.shape):
return await self.array.get_coordinate_selection(new_selection, fields=fields)
elif is_mask_selection(new_selection, self.array.shape):
return await self.array.get_mask_selection(new_selection, fields=fields)
else:
raise VindexInvalidSelectionError(new_selection)


def check_fields(fields: Fields | None, dtype: np.dtype[Any]) -> np.dtype[Any]:
# early out
if fields is None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from zarr.core.buffer import NDArrayLike, NDArrayLikeOrScalar, default_buffer_prototype
from zarr.core.chunk_grids import _auto_partition
from zarr.core.chunk_key_encodings import ChunkKeyEncodingParams
from zarr.core.common import JSON, ZarrFormat
from zarr.core.common import JSON, ZarrFormat, ceildiv
from zarr.core.dtype import (
DateTime64,
Float32,
Expand All @@ -59,7 +59,7 @@
from zarr.core.dtype.npy.common import NUMPY_ENDIANNESS_STR, endianness_from_numpy_str
from zarr.core.dtype.npy.string import UTF8Base
from zarr.core.group import AsyncGroup
from zarr.core.indexing import BasicIndexer, ceildiv
from zarr.core.indexing import BasicIndexer
from zarr.core.metadata.v2 import ArrayV2Metadata
from zarr.core.metadata.v3 import ArrayV3Metadata
from zarr.core.sync import sync
Expand Down
107 changes: 107 additions & 0 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1994,3 +1994,110 @@ def test_iter_chunk_regions():
assert_array_equal(a[region], np.ones_like(a[region]))
a[region] = 0
assert_array_equal(a[region], np.zeros_like(a[region]))


class TestAsync:
@pytest.mark.parametrize(
("indexer", "expected"),
[
# int
((0,), np.array([1, 2])),
((1,), np.array([3, 4])),
((0, 1), np.array(2)),
# slice
((slice(None),), np.array([[1, 2], [3, 4]])),
((slice(0, 1),), np.array([[1, 2]])),
((slice(1, 2),), np.array([[3, 4]])),
((slice(0, 2),), np.array([[1, 2], [3, 4]])),
((slice(0, 0),), np.empty(shape=(0, 2), dtype="i8")),
# ellipsis
((...,), np.array([[1, 2], [3, 4]])),
((0, ...), np.array([1, 2])),
((..., 0), np.array([1, 3])),
((0, 1, ...), np.array(2)),
# combined
((0, slice(None)), np.array([1, 2])),
((slice(None), 0), np.array([1, 3])),
((slice(None), slice(None)), np.array([[1, 2], [3, 4]])),
# array of ints
(([0]), np.array([[1, 2]])),
(([1]), np.array([[3, 4]])),
(([0], [1]), np.array(2)),
(([0, 1], [0]), np.array([[1], [3]])),
(([0, 1], [0, 1]), np.array([[1, 2], [3, 4]])),
# boolean array
(np.array([True, True]), np.array([[1, 2], [3, 4]])),
(np.array([True, False]), np.array([[1, 2]])),
(np.array([False, True]), np.array([[3, 4]])),
(np.array([False, False]), np.empty(shape=(0, 2), dtype="i8")),
],
)
@pytest.mark.asyncio
async def test_async_oindex(self, store, indexer, expected):
z = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
z[...] = np.array([[1, 2], [3, 4]])
async_zarr = z._async_array

result = await async_zarr.oindex.getitem(indexer)
assert_array_equal(result, expected)

@pytest.mark.asyncio
async def test_async_oindex_with_zarr_array(self, store):
z1 = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
z1[...] = np.array([[1, 2], [3, 4]])
async_zarr = z1._async_array

# create boolean zarr array to index with
z2 = zarr.create_array(
store=store, name="z2", shape=(2,), chunks=(1,), zarr_format=3, dtype="?"
)
z2[...] = np.array([True, False])

result = await async_zarr.oindex.getitem(z2)
expected = np.array([[1, 2]])
assert_array_equal(result, expected)

@pytest.mark.parametrize(
("indexer", "expected"),
[
(([0], [0]), np.array(1)),
(([0, 1], [0, 1]), np.array([1, 4])),
(np.array([[False, True], [False, True]]), np.array([2, 4])),
],
)
@pytest.mark.asyncio
async def test_async_vindex(self, store, indexer, expected):
z = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
z[...] = np.array([[1, 2], [3, 4]])
async_zarr = z._async_array

result = await async_zarr.vindex.getitem(indexer)
assert_array_equal(result, expected)

@pytest.mark.asyncio
async def test_async_vindex_with_zarr_array(self, store):
z1 = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
z1[...] = np.array([[1, 2], [3, 4]])
async_zarr = z1._async_array

# create boolean zarr array to index with
z2 = zarr.create_array(
store=store, name="z2", shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="?"
)
z2[...] = np.array([[False, True], [False, True]])

result = await async_zarr.vindex.getitem(z2)
expected = np.array([2, 4])
assert_array_equal(result, expected)

@pytest.mark.asyncio
async def test_async_invalid_indexer(self, store):
z = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
z[...] = np.array([[1, 2], [3, 4]])
async_zarr = z._async_array

with pytest.raises(IndexError):
await async_zarr.vindex.getitem("invalid_indexer")

with pytest.raises(IndexError):
await async_zarr.oindex.getitem("invalid_indexer")
Loading
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy