Skip to content
This repository was archived by the owner on Aug 29, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions evojax/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .ars import ARS
from .simple_ga import SimpleGA
from .open_es import OpenES
from .cma_evosax import CMA_ES
from .sep_cma_es import Sep_CMA_ES

Strategies = {
Expand All @@ -26,6 +27,7 @@
"SimpleGA": SimpleGA,
"ARS": ARS,
"OpenES": OpenES,
"CMA_ES": CMA_ES,
"Sep_CMA_ES": Sep_CMA_ES,
}

Expand All @@ -36,6 +38,7 @@
"ARS",
"SimpleGA",
"OpenES",
"CMA_ES",
"Sep_CMA_ES",
"Strategies",
]
107 changes: 107 additions & 0 deletions evojax/algo/cma_evosax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import sys

import logging
from typing import Union
import numpy as np
import jax
import jax.numpy as jnp

from evojax.algo.base import NEAlgorithm
from evojax.util import create_logger


class CMA_ES(NEAlgorithm):
"""A wrapper around evosax's CMA-ES.
Implementation: https://github.com/RobertTLange/evosax/blob/main/evosax/strategies/cma_es.py
Reference: Hansen & Ostermeier (2008) - http://www.cmap.polytechnique.fr/~nikolaus.hansen/cmaartic.pdf
"""

def __init__(
self,
param_size: int,
pop_size: int,
elite_ratio: float = 0.5,
init_stdev: float = 0.1,
w_decay: float = 0.0,
seed: int = 0,
logger: logging.Logger = None,
):
"""Initialization function.

Args:
param_size - Parameter size.
pop_size - Population size.
elite_ratio - Population elite fraction used for gradient estimate.
init_stdev - Initial scale of istropic part of covariance.
w_decay - L2 weight regularization coefficient.
seed - Random seed for parameters sampling.
logger - Logger.
"""

# Delayed importing of evosax

if sys.version_info.minor < 7:
print("evosax, which is needed by CMA-ES, requires python>=3.7")
print(" please consider upgrading your Python version.")
sys.exit(1)

try:
import evosax
except ModuleNotFoundError:
print("You need to install evosax for its CMA-ES:")
print(" pip install evosax")
sys.exit(1)

# Set up object variables.

if logger is None:
self.logger = create_logger(name="CMA_ES")
else:
self.logger = logger

self.param_size = param_size
self.pop_size = abs(pop_size)
self.elite_ratio = elite_ratio
self.rand_key = jax.random.PRNGKey(seed=seed)

# Instantiate evosax's ARS strategy
self.es = evosax.CMA_ES(
popsize=pop_size,
num_dims=param_size,
elite_ratio=elite_ratio,
)

# Set hyperparameters according to provided inputs
self.es_params = self.es.default_params
self.es_params["sigma_init"] = init_stdev

# Initialize the evolution strategy state
self.rand_key, init_key = jax.random.split(self.rand_key)
self.es_state = self.es.initialize(init_key, self.es_params)

# By default evojax assumes maximization of fitness score!
# Evosax, on the other hand, minimizes!
self.fit_shaper = evosax.FitnessShaper(w_decay=w_decay, maximize=True)

def ask(self) -> jnp.ndarray:
self.rand_key, ask_key = jax.random.split(self.rand_key)
self.params, self.es_state = self.es.ask(
ask_key, self.es_state, self.es_params
)
return self.params

def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None:
# Reshape fitness to conform with evosax minimization
fit_re = self.fit_shaper.apply(self.params, fitness)
self.es_state = self.es.tell(
self.params, fit_re, self.es_state, self.es_params
)

@property
def best_params(self) -> jnp.ndarray:
return jnp.array(self.es_state["mean"], copy=True)

@best_params.setter
def best_params(self, params: Union[np.ndarray, jnp.ndarray]) -> None:
self.es_state["best_member"] = jnp.array(params, copy=True)
self.es_state["mean"] = jnp.array(params, copy=True)
13 changes: 11 additions & 2 deletions scripts/benchmarks/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,18 @@ This will sequentially execute 25 ARS-MNIST evolution runs for a grid of differe

## Benchmark Results

### Sep-CMA-ES
### CMA-ES

| | Benchmarks | Parameters | Results (Avg) |
|---|---|---|---|
CartPole (easy) | 900 (max_iter=1000)|[Link](configs/CMA_ES/cartpole_easy.yaml)| 927.3208 |
CartPole (hard) | 600 (max_iter=1000)|[Link](configs/CMA_ES/cartpole_hard.yaml)| 625.9829 |
MNIST | 90.0 (max_iter=2000) | [Link](configs/CMA_ES/mnist.yaml)| 0.9581 |
Brax Ant | 3000 (max_iter=1200) |[Link](configs/CMA_ES/brax_ant.yaml)| 3174.0608 |
Waterworld | 6 (max_iter=500) | [Link](configs/CMA_ES/waterworld.yaml)| 9.44 |
Waterworld (MA) | 2 (max_iter=2000) | [Link](configs/CMA_ES/waterworld_ma.yaml) | 0.5625 |

### Sep-CMA-ES

| | Benchmarks | Parameters | Results (Avg) |
|---|---|---|---|
Expand All @@ -74,7 +84,6 @@ Waterworld (MA) | 2 (max_iter=2000) | [Link](configs/Sep_CMA_ES/waterworld_ma.ya

### PGPE


| | Benchmarks | Parameters | Results (Avg) |
|---|---|---|---|
CartPole (easy) | 900 (max_iter=1000)|[Link](configs/PGPE/cartpole_easy.yaml)| 935.4268 |
Expand Down
16 changes: 16 additions & 0 deletions scripts/benchmarks/configs/CMA_ES/brax_ant.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
es_name: "CMA_ES"
problem_type: "brax"
env_name: "ant"
normalize: true
es_config:
pop_size: 512
elite_ratio: 0.5
init_stdev: 0.1
num_tests: 128
n_repeats: 16
max_iter: 600
test_interval: 100
log_interval: 20
seed: 42
gpu_id: [0, 1, 2, 3]
debug: false
16 changes: 16 additions & 0 deletions scripts/benchmarks/configs/CMA_ES/cartpole_easy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
es_name: "CMA_ES"
problem_type: "cartpole_easy"
normalize: false
es_config:
pop_size: 100
elite_ratio: 0.5
init_stdev: 0.15
hidden_size: 64
num_tests: 100
n_repeats: 16
max_iter: 1000
test_interval: 100
log_interval: 50
seed: 42
gpu_id: 0
debug: false
16 changes: 16 additions & 0 deletions scripts/benchmarks/configs/CMA_ES/cartpole_hard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
es_name: "CMA_ES"
problem_type: "cartpole_hard"
normalize: false
es_config:
pop_size: 100
elite_ratio: 0.5
init_stdev: 0.15
hidden_size: 64
num_tests: 100
n_repeats: 16
max_iter: 1000
test_interval: 100
log_interval: 50
seed: 42
gpu_id: 0
debug: false
17 changes: 17 additions & 0 deletions scripts/benchmarks/configs/CMA_ES/mnist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
es_name: "CMA_ES"
problem_type: "mnist"
normalize: false
es_config:
pop_size: 100
elite_ratio: 0.5
init_stdev: 0.065
hidden_size: 100
batch_size: 1024
max_iter: 2000
test_interval: 500
log_interval: 100
num_tests: 1
n_repeats: 1
seed: 42
gpu_id: [0, 1, 2, 3]
debug: false
16 changes: 16 additions & 0 deletions scripts/benchmarks/configs/CMA_ES/waterworld.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
es_name: "CMA_ES"
problem_type: "waterworld"
normalize: false
es_config:
pop_size: 256
elite_ratio: 0.5
init_stdev: 0.065
hidden_size: 100
num_tests: 100
n_repeats: 32
max_iter: 1000
test_interval: 50
log_interval: 10
seed: 42
gpu_id: [0, 1, 2, 3]
debug: false
16 changes: 16 additions & 0 deletions scripts/benchmarks/configs/CMA_ES/waterworld_ma.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
es_name: "CMA_ES"
problem_type: "waterworld_ma"
normalize: false
es_config:
pop_size: 16
elite_ratio: 0.5
init_stdev: 0.065
hidden_size: 100
num_tests: 16
n_repeats: 64
max_iter: 2000
test_interval: 100
log_interval: 10
seed: 42
gpu_id: 0
debug: false