From 953271f09f4eeb8af8068ecfc4c5d922ff6951d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Wed, 11 Mar 2026 00:38:30 +0100 Subject: [PATCH] add a load_string function --- rbms/bernoulli_gaussian/classes.py | 1 - rbms/custom_fn.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/rbms/bernoulli_gaussian/classes.py b/rbms/bernoulli_gaussian/classes.py index 4c375dc..e634f58 100644 --- a/rbms/bernoulli_gaussian/classes.py +++ b/rbms/bernoulli_gaussian/classes.py @@ -1,5 +1,4 @@ from __future__ import annotations -from botocore.vendored.six import u import numpy as np import torch diff --git a/rbms/custom_fn.py b/rbms/custom_fn.py index eabb39e..7bc1fcf 100644 --- a/rbms/custom_fn.py +++ b/rbms/custom_fn.py @@ -1,3 +1,5 @@ +import h5py +import numpy as np import torch from torch import Tensor @@ -47,3 +49,13 @@ def check_keys_dict(d: dict, names: list[str]): raise ValueError( f"""Dictionary params missing key '{k}'\n Provided keys : {d.keys()}\n Expected keys: {names}""" ) + + +def load_string(f: h5py.Dataset, k: str | bytes) -> str: + # Fix 1: Ensure key is a string + # key = k.decode("utf-8") if isinstance(k, bytes) else k + val = np.asarray(f[k]) + # Fix 2: Ensure string values (like 'Reservoir') are strings, not bytes + if val.dtype.kind in ["S", "V", "O"]: # Bytes, Void, or Object (StringDType) + val = val.astype(str) + return str(val)