diff --git a/dm_env/specs.py b/dm_env/specs.py index 0dc989a..1d08d98 100644 --- a/dm_env/specs.py +++ b/dm_env/specs.py @@ -289,7 +289,7 @@ class DiscreteArray(BoundedArray): __slots__ = ('_num_values',) - def __init__(self, num_values, dtype=np.int32, name=None): + def __init__(self, num_values, dtype=np.int32, name=None, shape=()): """Initializes a new `DiscreteArray` spec. Args: @@ -316,7 +316,7 @@ def __init__(self, num_values, dtype=np.int32, name=None): raise ValueError(_DTYPE_OVERFLOW.format(dtype, num_values)) super(DiscreteArray, self).__init__( - shape=(), + shape=shape, dtype=dtype, minimum=0, maximum=maximum,