mirror of
https://github.com/NixOS/nixpkgs.git
synced 2026-07-01 06:18:08 +00:00
69 lines
2.2 KiB
Diff
69 lines
2.2 KiB
Diff
diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py
|
|
index f5fcb0e..87b198f 100644
|
|
--- a/brax/training/agents/apg/train.py
|
|
+++ b/brax/training/agents/apg/train.py
|
|
@@ -310,7 +310,7 @@ def train(
|
|
specs.Array((env.observation_size,), jnp.dtype(dtype))
|
|
),
|
|
)
|
|
- training_state = jax.device_put_replicated(
|
|
+ training_state = pmap.device_put_replicated(
|
|
training_state, jax.local_devices()[:local_devices_to_use]
|
|
)
|
|
|
|
diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py
|
|
index 9aec960..6624733 100644
|
|
--- a/brax/training/agents/ppo/train.py
|
|
+++ b/brax/training/agents/ppo/train.py
|
|
@@ -753,7 +753,7 @@ def train(
|
|
{},
|
|
)
|
|
|
|
- training_state = jax.device_put_replicated(
|
|
+ training_state = pmap.device_put_replicated(
|
|
training_state, jax.local_devices()[:local_devices_to_use]
|
|
)
|
|
|
|
diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py
|
|
index be716e9..8dcf3bf 100644
|
|
--- a/brax/training/agents/sac/train.py
|
|
+++ b/brax/training/agents/sac/train.py
|
|
@@ -108,7 +108,7 @@ def _init_training_state(
|
|
alpha_params=log_alpha,
|
|
normalizer_params=normalizer_params,
|
|
)
|
|
- return jax.device_put_replicated(
|
|
+ return pmap.device_put_replicated(
|
|
training_state, jax.local_devices()[:local_devices_to_use]
|
|
)
|
|
|
|
diff --git a/brax/training/pmap.py b/brax/training/pmap.py
|
|
index 82760fc..af62ef8 100644
|
|
--- a/brax/training/pmap.py
|
|
+++ b/brax/training/pmap.py
|
|
@@ -19,12 +19,23 @@ from typing import Any
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
+from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
|
+import numpy as np
|
|
+
|
|
+
|
|
+def device_put_replicated(x, devices):
|
|
+ """Drop-in replacement for jax.device_put_replicated supporting pytrees."""
|
|
+ mesh = Mesh(np.array(devices), ('x',))
|
|
+ sharding = NamedSharding(mesh, P('x'))
|
|
+ return jax.tree.map(
|
|
+ lambda y: jax.device_put(jnp.stack([y] * len(devices)), sharding), x
|
|
+ )
|
|
|
|
|
|
def bcast_local_devices(value, local_devices_to_use=1):
|
|
"""Broadcasts an object to all local devices."""
|
|
devices = jax.local_devices()[:local_devices_to_use]
|
|
- return jax.device_put_replicated(value, devices)
|
|
+ return device_put_replicated(value, devices)
|
|
|
|
|
|
def synchronize_hosts():
|