From 52fedb8fce6b8ed4f3b0006e7473f6b19f9ebfb3 Mon Sep 17 00:00:00 2001 From: Alexander Cai Date: Tue, 7 Oct 2025 21:25:59 -0400 Subject: [PATCH] Add optional `shape` parameter to `DiscreteArray` --- dm_env/specs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dm_env/specs.py b/dm_env/specs.py index 0dc989a..1d08d98 100644 --- a/dm_env/specs.py +++ b/dm_env/specs.py @@ -289,7 +289,7 @@ class DiscreteArray(BoundedArray): __slots__ = ('_num_values',) - def __init__(self, num_values, dtype=np.int32, name=None): + def __init__(self, num_values, dtype=np.int32, name=None, shape=()): """Initializes a new `DiscreteArray` spec. Args: @@ -316,7 +316,7 @@ def __init__(self, num_values, dtype=np.int32, name=None): raise ValueError(_DTYPE_OVERFLOW.format(dtype, num_values)) super(DiscreteArray, self).__init__( - shape=(), + shape=shape, dtype=dtype, minimum=0, maximum=maximum,