5c370c0b2a
GitOrigin-RevId: 33d1e753c82ffc557b4a585c77de43d4c922ebb5
14 lines
573 B
Diff
14 lines
573 B
Diff
diff --git a/objax/util/util.py b/objax/util/util.py
|
|
index c31a356..344cf9a 100644
|
|
--- a/objax/util/util.py
|
|
+++ b/objax/util/util.py
|
|
@@ -117,7 +117,8 @@ def get_local_devices():
|
|
if _local_devices is None:
|
|
x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32)
|
|
sharded_x = map_to_device(x)
|
|
- _local_devices = [b.device() for b in sharded_x.device_buffers]
|
|
+ device_buffers = [buf.data for buf in sharded_x.addressable_shards]
|
|
+ _local_devices = [list(b.devices())[0] for b in device_buffers]
|
|
return _local_devices
|
|
|
|
|