From 898854baafc16620444e164fb83405927d813cbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Mon, 15 Apr 2024 01:13:15 +0100 Subject: [PATCH 1/3] getattr tag --- skbase/base/_base.py | 28 ++++++++++++++++++++++++++++ skbase/tests/test_base.py | 22 ++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 3b0ca27c..cf84e94f 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -104,6 +104,34 @@ def __eq__(self, other): return deep_equals(self_params, other_params) + def __getattr__(self, attr): + """Get attribute dunder, defaults to object tags if no attribute found. + + In tag names, the following characters are replaced: + + * colon by double underscore, i.e., ":": "__" + * dash by single underscore, i.e., "-": "_" + """ + tag_dict = self.get_tags() + + # if attribute is in tag_dict, return tag value + if attr in tag_dict: + return tag_dict[attr] + + # not found, now try normalized keys + + def norm_key(k): + """Replace colon by double underscore, dash by single underscore.""" + return k.replace(":", "__", ).replace("-", "_") + + tag_dict_norm = {norm_key(k): v for k, v in tag_dict.items()} + + if attr in tag_dict_norm: + return tag_dict_norm[attr] + + # otherwise raise the default AttributeError + return object.__getattribute__(self, attr) + def reset(self): """Reset the object to a clean post-init state. diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index d3ed95b8..fe810c95 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -365,6 +365,28 @@ def test_get_tag_raises(fixture_tag_class_object: Child): fixture_tag_class_object.get_tag("bar") +def test_get_tag_attr( + fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any] +): + """Test get_tag mapping on get_attr. + + Raises + ------ + AssertError if inheritance logic in get_tag is incorrect + AssertError if default override logic in get_tag is incorrect + """ + object_tags = {} + object_tags_keys = fixture_object_tags.keys() + + for key in object_tags_keys: + object_tags[key] = getattr(fixture_tag_class_object, key) + + msg = "Inheritance logic in BaseObject.get_tag is incorrect" + + for key in object_tags_keys: + assert object_tags[key] == fixture_object_tags[key], msg + + def test_set_tags( fixture_object_instance_set_tags: Any, fixture_object_set_tags: Dict[str, Any], From 22cf6641619443a39cc230ca8900da7ba4edd22d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Apr 2024 00:15:53 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- skbase/base/_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index cf84e94f..d0c71735 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -122,7 +122,10 @@ def __getattr__(self, attr): def norm_key(k): """Replace colon by double underscore, dash by single underscore.""" - return k.replace(":", "__", ).replace("-", "_") + return k.replace( + ":", + "__", + ).replace("-", "_") tag_dict_norm = {norm_key(k): v for k, v in tag_dict.items()} From 8d858e0c33ba2b5f0dd562005de4886d05b222d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Mon, 15 Apr 2024 01:26:09 +0100 Subject: [PATCH 3/3] Update _base.py --- skbase/base/_base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index cf84e94f..7dbc3202 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -112,6 +112,12 @@ def __getattr__(self, attr): * colon by double underscore, i.e., ":": "__" * dash by single underscore, i.e., "-": "_" """ + # early stop for reserved attributes to avoid infinite recursion + reserved_attr = attr.endswith("_dynamic") + if reserved_attr: + return object.__getattribute__(self, attr) + + # get tags and normalized keys tag_dict = self.get_tags() # if attribute is in tag_dict, return tag value