diff --git a/rbms/bernoulli_gaussian/classes.py b/rbms/bernoulli_gaussian/classes.py index 4c375dc..e634f58 100644 --- a/rbms/bernoulli_gaussian/classes.py +++ b/rbms/bernoulli_gaussian/classes.py @@ -1,5 +1,4 @@ from __future__ import annotations -from botocore.vendored.six import u import numpy as np import torch diff --git a/rbms/custom_fn.py b/rbms/custom_fn.py index eabb39e..7bc1fcf 100644 --- a/rbms/custom_fn.py +++ b/rbms/custom_fn.py @@ -1,3 +1,5 @@ +import h5py +import numpy as np import torch from torch import Tensor @@ -47,3 +49,13 @@ def check_keys_dict(d: dict, names: list[str]): raise ValueError( f"""Dictionary params missing key '{k}'\n Provided keys : {d.keys()}\n Expected keys: {names}""" ) + + +def load_string(f: h5py.Dataset, k: str | bytes) -> str: + # Fix 1: Ensure key is a string + # key = k.decode("utf-8") if isinstance(k, bytes) else k + val = np.asarray(f[k]) + # Fix 2: Ensure string values (like 'Reservoir') are strings, not bytes + if val.dtype.kind in ["S", "V", "O"]: # Bytes, Void, or Object (StringDType) + val = val.astype(str) + return str(val)