-
Notifications
You must be signed in to change notification settings - Fork 52
Description
This RFC proposes to add an API to the standard for returning the list of devices. Currently, the only standardized means for an array API consumer to access a device object is by creating an array and accessing the .device
attribute.
>>> x = xp.zeros((2, 2))
>>> d = x.device
More generally, the standard currently lacks APIs for device inspection, and device objects are unspecified, apart from required support for equality comparison. To aid in introspection and scripting/REPL envs, having a means for getting the list of devices would be useful, especially when probing the capabilities of an unknown platform.
Prior art
JAX
jax.devices(backend=None) -> List[Device]
Returns the list of devices from the default backend (e.g., gpu
, tpu
, or cpu
).
jax.local_devices(process_index=None, backend=None, host_id=None)
Returns a list of devices local to a given process.
Proposal
devices() -> List[device]
The proposed API would be a nullary function which returns a list of device objects.
Notes
- Currently, when not provided a
backend
kwarg, JAX returns the list of devices for the default backend. I am not sure whether this is preferred compared to returning a list of devices across all backends.