Skip to content

ENH/API: xp-bound namespaces, array-api-compat #6

@lucascolley

Description

@lucascolley

Currently, functions of this package require passing a standard-compatible namespace as xp=xp. This works fine, but there have been suggestions that it might be nice to avoid this requirement. There are at least a few ways we could go about this:

(1) xpx.bind_namespace

Usage:

import array_api_strict as xpx
...
xp = array_namespace(x)
xpx = xpx.bind_namespace(xp)
x = xpx.atleast_nd(x, ndim=2)
y = xp.sum(x)
z = xpx.some_func(y)

A potential implementation:

extra_funcs = {'atleast_nd': atleast_nd, ...}

def bind_namespace(xp: ModuleType) -> ModuleType:
    class BoundNamespace:
        def __getattr__(self, name: str):
            if name in extra_funcs:
                return functools.partial(extra_funcs[name], xp=xp)
            else:
               return AttributeError(...)

    return BoundNamespace(xp)

I like this idea. If we encounter use cases where a library wants to use multiple xpx functions in the same local scope and finds the xp=xp pattern too cumbersome, I think we should add this. I think we can leave it out for now until that situation arises.

(2) xpx.extra_namespace

Usage:

import array_api_strict as xpx
...
xp = array_namespace(x)
xpx = xpx.extra_namespace(xp)
x = xpx.atleast_nd(x, ndim=2)
y = xpx.sum(x)  # XXX: xpx instead of xp
z = xpx.some_func(y)

A potential implementation:

extra_funcs = {'atleast_nd': atleast_nd, ...}

def extra_namespace(xp: ModuleType) -> ModuleType:
    class ExtraNamespace:
        def __getattr__(self, name: str):
            if name in extra_funcs:
                return functools.partial(extra_funcs[name], xp=xp)
            else:
               return getattr(xp, name)  # XXX: delegate to xp instead of error

    return ExtraNamespace(xp)

I would not want to add this yet. I think we should keep separation between the standard namespace and the 'extra' namespace, at least until this library matures.

(3) Use array_api_compat.array_namespace internally

This would provide the most flexible API and be the least LOC to use. One could use xpx functions on standard-incompatible arrays, and let array-api-compat handle the compatibility, without having to pass an xp argument.

We don't yet have a use case where it is clearly beneficial to be able to pass standard-incompatible arrays. Consumer libraries using array-api-extra would already be computing with standard-compatible arrays internally. I don't see the need to support the following use case:

import torch
import array_api_strict as xpx
...
x = torch.asarray([1, 2, 3])
xpx.some_func(x)             # works
torch.some_standard_func(x)  # does not work

Another complication is that consumer libraries like SciPy wrap array_namespace to provide custom behaviour for scalars and other types. We would want the internal array_namespace to be the consumer library's wrapped version rather than the base one from array-api-compat.

I'm also not sure that the 1 LOC save over option (1) of this post for standard-compatible arrays is worth introducing a dependency on array-api-compat.

Overall, this would complicate things a lot with situations of co-vendoring array-api-compat and array-api-extra, which is the primary use-case for the library right now. This might be a better idea in the future if a need for handling standard-incompatible arrays arises (for example, if one wants to use functions from xpx with just a single library).

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy