diff --git a/evojax/algo/__init__.py b/evojax/algo/__init__.py index 14d8ee6b..43c9e7d2 100644 --- a/evojax/algo/__init__.py +++ b/evojax/algo/__init__.py @@ -18,6 +18,7 @@ from .ars import ARS from .simple_ga import SimpleGA from .open_es import OpenES +from .cma_evosax import CMA_ES from .sep_cma_es import Sep_CMA_ES Strategies = { @@ -26,6 +27,7 @@ "SimpleGA": SimpleGA, "ARS": ARS, "OpenES": OpenES, + "CMA_ES": CMA_ES, "Sep_CMA_ES": Sep_CMA_ES, } @@ -36,6 +38,7 @@ "ARS", "SimpleGA", "OpenES", + "CMA_ES", "Sep_CMA_ES", "Strategies", ] diff --git a/evojax/algo/cma_evosax.py b/evojax/algo/cma_evosax.py new file mode 100644 index 00000000..1ba44716 --- /dev/null +++ b/evojax/algo/cma_evosax.py @@ -0,0 +1,107 @@ +import sys + +import logging +from typing import Union +import numpy as np +import jax +import jax.numpy as jnp + +from evojax.algo.base import NEAlgorithm +from evojax.util import create_logger + + +class CMA_ES(NEAlgorithm): + """A wrapper around evosax's CMA-ES. + Implementation: https://github.com/RobertTLange/evosax/blob/main/evosax/strategies/cma_es.py + Reference: Hansen & Ostermeier (2008) - http://www.cmap.polytechnique.fr/~nikolaus.hansen/cmaartic.pdf + """ + + def __init__( + self, + param_size: int, + pop_size: int, + elite_ratio: float = 0.5, + init_stdev: float = 0.1, + w_decay: float = 0.0, + seed: int = 0, + logger: logging.Logger = None, + ): + """Initialization function. + + Args: + param_size - Parameter size. + pop_size - Population size. + elite_ratio - Population elite fraction used for gradient estimate. + init_stdev - Initial scale of istropic part of covariance. + w_decay - L2 weight regularization coefficient. + seed - Random seed for parameters sampling. + logger - Logger. + """ + + # Delayed importing of evosax + + if sys.version_info.minor < 7: + print("evosax, which is needed by CMA-ES, requires python>=3.7") + print(" please consider upgrading your Python version.") + sys.exit(1) + + try: + import evosax + except ModuleNotFoundError: + print("You need to install evosax for its CMA-ES:") + print(" pip install evosax") + sys.exit(1) + + # Set up object variables. + + if logger is None: + self.logger = create_logger(name="CMA_ES") + else: + self.logger = logger + + self.param_size = param_size + self.pop_size = abs(pop_size) + self.elite_ratio = elite_ratio + self.rand_key = jax.random.PRNGKey(seed=seed) + + # Instantiate evosax's ARS strategy + self.es = evosax.CMA_ES( + popsize=pop_size, + num_dims=param_size, + elite_ratio=elite_ratio, + ) + + # Set hyperparameters according to provided inputs + self.es_params = self.es.default_params + self.es_params["sigma_init"] = init_stdev + + # Initialize the evolution strategy state + self.rand_key, init_key = jax.random.split(self.rand_key) + self.es_state = self.es.initialize(init_key, self.es_params) + + # By default evojax assumes maximization of fitness score! + # Evosax, on the other hand, minimizes! + self.fit_shaper = evosax.FitnessShaper(w_decay=w_decay, maximize=True) + + def ask(self) -> jnp.ndarray: + self.rand_key, ask_key = jax.random.split(self.rand_key) + self.params, self.es_state = self.es.ask( + ask_key, self.es_state, self.es_params + ) + return self.params + + def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None: + # Reshape fitness to conform with evosax minimization + fit_re = self.fit_shaper.apply(self.params, fitness) + self.es_state = self.es.tell( + self.params, fit_re, self.es_state, self.es_params + ) + + @property + def best_params(self) -> jnp.ndarray: + return jnp.array(self.es_state["mean"], copy=True) + + @best_params.setter + def best_params(self, params: Union[np.ndarray, jnp.ndarray]) -> None: + self.es_state["best_member"] = jnp.array(params, copy=True) + self.es_state["mean"] = jnp.array(params, copy=True) diff --git a/scripts/benchmarks/Readme.md b/scripts/benchmarks/Readme.md index 8f15f502..027d4d22 100644 --- a/scripts/benchmarks/Readme.md +++ b/scripts/benchmarks/Readme.md @@ -59,8 +59,18 @@ This will sequentially execute 25 ARS-MNIST evolution runs for a grid of differe ## Benchmark Results -### Sep-CMA-ES +### CMA-ES + +| | Benchmarks | Parameters | Results (Avg) | +|---|---|---|---| +CartPole (easy) | 900 (max_iter=1000)|[Link](configs/CMA_ES/cartpole_easy.yaml)| 927.3208 | +CartPole (hard) | 600 (max_iter=1000)|[Link](configs/CMA_ES/cartpole_hard.yaml)| 625.9829 | +MNIST | 90.0 (max_iter=2000) | [Link](configs/CMA_ES/mnist.yaml)| 0.9581 | +Brax Ant | 3000 (max_iter=1200) |[Link](configs/CMA_ES/brax_ant.yaml)| 3174.0608 | +Waterworld | 6 (max_iter=500) | [Link](configs/CMA_ES/waterworld.yaml)| 9.44 | +Waterworld (MA) | 2 (max_iter=2000) | [Link](configs/CMA_ES/waterworld_ma.yaml) | 0.5625 | +### Sep-CMA-ES | | Benchmarks | Parameters | Results (Avg) | |---|---|---|---| @@ -74,7 +84,6 @@ Waterworld (MA) | 2 (max_iter=2000) | [Link](configs/Sep_CMA_ES/waterworld_ma.ya ### PGPE - | | Benchmarks | Parameters | Results (Avg) | |---|---|---|---| CartPole (easy) | 900 (max_iter=1000)|[Link](configs/PGPE/cartpole_easy.yaml)| 935.4268 | diff --git a/scripts/benchmarks/configs/CMA_ES/brax_ant.yaml b/scripts/benchmarks/configs/CMA_ES/brax_ant.yaml new file mode 100644 index 00000000..2922bf59 --- /dev/null +++ b/scripts/benchmarks/configs/CMA_ES/brax_ant.yaml @@ -0,0 +1,16 @@ +es_name: "CMA_ES" +problem_type: "brax" +env_name: "ant" +normalize: true +es_config: + pop_size: 512 + elite_ratio: 0.5 + init_stdev: 0.1 +num_tests: 128 +n_repeats: 16 +max_iter: 600 +test_interval: 100 +log_interval: 20 +seed: 42 +gpu_id: [0, 1, 2, 3] +debug: false \ No newline at end of file diff --git a/scripts/benchmarks/configs/CMA_ES/cartpole_easy.yaml b/scripts/benchmarks/configs/CMA_ES/cartpole_easy.yaml new file mode 100644 index 00000000..ff34be97 --- /dev/null +++ b/scripts/benchmarks/configs/CMA_ES/cartpole_easy.yaml @@ -0,0 +1,16 @@ +es_name: "CMA_ES" +problem_type: "cartpole_easy" +normalize: false +es_config: + pop_size: 100 + elite_ratio: 0.5 + init_stdev: 0.15 +hidden_size: 64 +num_tests: 100 +n_repeats: 16 +max_iter: 1000 +test_interval: 100 +log_interval: 50 +seed: 42 +gpu_id: 0 +debug: false \ No newline at end of file diff --git a/scripts/benchmarks/configs/CMA_ES/cartpole_hard.yaml b/scripts/benchmarks/configs/CMA_ES/cartpole_hard.yaml new file mode 100644 index 00000000..0b0077a7 --- /dev/null +++ b/scripts/benchmarks/configs/CMA_ES/cartpole_hard.yaml @@ -0,0 +1,16 @@ +es_name: "CMA_ES" +problem_type: "cartpole_hard" +normalize: false +es_config: + pop_size: 100 + elite_ratio: 0.5 + init_stdev: 0.15 +hidden_size: 64 +num_tests: 100 +n_repeats: 16 +max_iter: 1000 +test_interval: 100 +log_interval: 50 +seed: 42 +gpu_id: 0 +debug: false \ No newline at end of file diff --git a/scripts/benchmarks/configs/CMA_ES/mnist.yaml b/scripts/benchmarks/configs/CMA_ES/mnist.yaml new file mode 100644 index 00000000..761aeb27 --- /dev/null +++ b/scripts/benchmarks/configs/CMA_ES/mnist.yaml @@ -0,0 +1,17 @@ +es_name: "CMA_ES" +problem_type: "mnist" +normalize: false +es_config: + pop_size: 100 + elite_ratio: 0.5 + init_stdev: 0.065 +hidden_size: 100 +batch_size: 1024 +max_iter: 2000 +test_interval: 500 +log_interval: 100 +num_tests: 1 +n_repeats: 1 +seed: 42 +gpu_id: [0, 1, 2, 3] +debug: false diff --git a/scripts/benchmarks/configs/CMA_ES/waterworld.yaml b/scripts/benchmarks/configs/CMA_ES/waterworld.yaml new file mode 100644 index 00000000..d6a31694 --- /dev/null +++ b/scripts/benchmarks/configs/CMA_ES/waterworld.yaml @@ -0,0 +1,16 @@ +es_name: "CMA_ES" +problem_type: "waterworld" +normalize: false +es_config: + pop_size: 256 + elite_ratio: 0.5 + init_stdev: 0.065 +hidden_size: 100 +num_tests: 100 +n_repeats: 32 +max_iter: 1000 +test_interval: 50 +log_interval: 10 +seed: 42 +gpu_id: [0, 1, 2, 3] +debug: false \ No newline at end of file diff --git a/scripts/benchmarks/configs/CMA_ES/waterworld_ma.yaml b/scripts/benchmarks/configs/CMA_ES/waterworld_ma.yaml new file mode 100644 index 00000000..69e90b0d --- /dev/null +++ b/scripts/benchmarks/configs/CMA_ES/waterworld_ma.yaml @@ -0,0 +1,16 @@ +es_name: "CMA_ES" +problem_type: "waterworld_ma" +normalize: false +es_config: + pop_size: 16 + elite_ratio: 0.5 + init_stdev: 0.065 +hidden_size: 100 +num_tests: 16 +n_repeats: 64 +max_iter: 2000 +test_interval: 100 +log_interval: 10 +seed: 42 +gpu_id: 0 +debug: false \ No newline at end of file