From c8818ed749a86a38be47ac05c2027865a57a1e85 Mon Sep 17 00:00:00 2001 From: Deep-Axe <133337403+Deep-Axe@users.noreply.github.com> Date: Sat, 2 May 2026 16:06:06 +0530 Subject: [PATCH 1/2] [ENH] respect explicit default in get_tag to match getattr semantics, addresses #291 --- .all-contributorsrc | 10 +++++++ skbase/base/_base.py | 58 +++++++++++++++++++++++--------------- skbase/base/_tagmanager.py | 36 +++++++++++++++-------- skbase/tests/test_base.py | 21 ++++++++++++++ 4 files changed, 92 insertions(+), 33 deletions(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index 31755537..17527a87 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -181,6 +181,16 @@ "contributions": [ "infra" ] + }, + { + "login": "Deep-Axe", + "name": "Deep-Axe", + "avatar_url": "https://avatars.githubusercontent.com/u/152912444?v=4", + "profile": "https://github.com/Deep-Axe", + "contributions": [ + "code", + "test" + ] } ], "projectName": "skbase", diff --git a/skbase/base/_base.py b/skbase/base/_base.py index c6a0d16a..7d4f07ae 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -63,9 +63,9 @@ class name: BaseEstimator from skbase._exceptions import NotFittedError from skbase.base._clone_base import _check_clone, _clone from skbase.base._pretty_printing._object_html_repr import _object_html_repr -from skbase.base._tagmanager import _FlagManager +from skbase.base._tagmanager import _DEFAULT_SENTINEL, _FlagManager -__author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"] +__author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos", "Deep-Axe"] __all__: List[str] = ["BaseEstimator", "BaseObject"] @@ -601,7 +601,9 @@ class attribute via nested inheritance and then any overrides """ return self._get_flags(flag_attr_name="_tags") - def get_tag(self, tag_name, tag_value_default=None, raise_error=True): + def get_tag( + self, tag_name, tag_value_default=_DEFAULT_SENTINEL, raise_error=True + ): """Get tag value from instance, with tag level inheritance and overrides. Every ``scikit-base`` compatible object has a dictionary of tags. @@ -627,23 +629,28 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True): ---------- tag_name : str Name of tag to be retrieved - tag_value_default : any type, optional; default=None - Default/fallback value if tag is not found + tag_value_default : any type, optional + Default/fallback value if tag is not found. When provided (including + ``None``), it is returned on a missing tag regardless of ``raise_error``. + When omitted, behaviour is controlled by ``raise_error``. raise_error : bool - whether a ``ValueError`` is raised when the tag is not found + Whether a ``ValueError`` is raised when the tag is not found. + Ignored when ``tag_value_default`` is explicitly provided. Returns ------- tag_value : Any - Value of the ``tag_name`` tag in ``self``. - If not found, raises an error if - ``raise_error`` is True, otherwise it returns ``tag_value_default``. + Value of the ``tag_name`` tag in ``self``. If not found: + - ``tag_value_default`` is returned when it was explicitly provided, + - a ``ValueError`` is raised when ``raise_error=True`` and no default + was given, + - ``None`` is returned otherwise. Raises ------ - ValueError, if ``raise_error`` is ``True``. - The ``ValueError`` is then raised if ``tag_name`` is - not in ``self.get_tags().keys()``. + ValueError + If ``tag_name`` is not in ``self.get_tags().keys()``, no default was + supplied, and ``raise_error`` is ``True``. """ return self._get_flag( flag_name=tag_name, @@ -1405,7 +1412,9 @@ class attribute via nested inheritance and then any overrides collected_tags = self._complete_dict(collected_tags) return collected_tags - def get_tag(self, tag_name, tag_value_default=None, raise_error=True): + def get_tag( + self, tag_name, tag_value_default=_DEFAULT_SENTINEL, raise_error=True + ): """Get tag value from instance, with tag level inheritance and overrides. Every ``scikit-base`` compatible object has a dictionary of tags. @@ -1431,23 +1440,28 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True): ---------- tag_name : str Name of tag to be retrieved - tag_value_default : any type, optional; default=None - Default/fallback value if tag is not found + tag_value_default : any type, optional + Default/fallback value if tag is not found. When provided (including + ``None``), it is returned on a missing tag regardless of ``raise_error``. + When omitted, behaviour is controlled by ``raise_error``. raise_error : bool - whether a ``ValueError`` is raised when the tag is not found + Whether a ``ValueError`` is raised when the tag is not found. + Ignored when ``tag_value_default`` is explicitly provided. Returns ------- tag_value : Any - Value of the ``tag_name`` tag in ``self``. - If not found, raises an error if - ``raise_error`` is True, otherwise it returns ``tag_value_default``. + Value of the ``tag_name`` tag in ``self``. If not found: + - ``tag_value_default`` is returned when it was explicitly provided, + - a ``ValueError`` is raised when ``raise_error=True`` and no default + was given, + - ``None`` is returned otherwise. Raises ------ - ValueError, if ``raise_error`` is ``True``. - The ``ValueError`` is then raised if ``tag_name`` is - not in ``self.get_tags().keys()``. + ValueError + If ``tag_name`` is not in ``self.get_tags().keys()``, no default was + supplied, and ``raise_error`` is ``True``. """ self._deprecate_tag_warn([tag_name]) alias_dict = self.alias_dict diff --git a/skbase/base/_tagmanager.py b/skbase/base/_tagmanager.py index 087c925b..c0507a86 100644 --- a/skbase/base/_tagmanager.py +++ b/skbase/base/_tagmanager.py @@ -10,6 +10,9 @@ import inspect from copy import deepcopy +# Sentinel distinguishes "caller passed no default" from "caller passed None". +_DEFAULT_SENTINEL = object() + class _FlagManager: """Mixin class for flag and configuration settings management.""" @@ -111,7 +114,7 @@ class attribute via nested inheritance and then any overrides def _get_flag( self, flag_name, - flag_value_default=None, + flag_value_default=_DEFAULT_SENTINEL, raise_error=True, flag_attr_name="_flags", ): @@ -121,33 +124,44 @@ def _get_flag( ---------- flag_name : str Name of flag to be retrieved. - flag_value_default : any type, default=None - Default/fallback value if flag is not found + flag_value_default : any type, default=_DEFAULT_SENTINEL + Default/fallback value if flag is not found. If provided (including + ``None``), it is returned when the flag is missing regardless of + ``raise_error``. raise_error : bool - Whether a `ValueError` is raised when the flag is not found. + Whether a ``ValueError`` is raised when the flag is not found. + Ignored if ``flag_value_default`` was explicitly supplied. flag_attr_name : str, default = "_flags" Name of the flag attribute that is read. Returns ------- flag_value : - Value of the `flag_name` flag in self. If not found, returns an error if - raise_error is True, otherwise it returns `flag_value_default`. + Value of the ``flag_name`` flag in self. If not found: + - returns ``flag_value_default`` when it was explicitly provided, or + - raises ``ValueError`` when ``raise_error=True`` and no default given, + - returns ``None`` otherwise. Raises ------ ValueError - if `raise_error` is `True`, i.e., - if `flag_name` is not in `self.get_flags().keys()` + if ``flag_name`` is not found, no default was supplied, and + ``raise_error`` is ``True``. """ collected_flags = self._get_flags(flag_attr_name=flag_attr_name) - flag_value = collected_flags.get(flag_name, flag_value_default) + if flag_name in collected_flags: + return collected_flags[flag_name] + + # Flag not found: determine what to return. + if flag_value_default is not _DEFAULT_SENTINEL: + # Caller explicitly supplied a default; respect it unconditionally. + return flag_value_default - if raise_error and flag_name not in collected_flags.keys(): + if raise_error: raise ValueError(f"Tag with name {flag_name} could not be found.") - return flag_value + return None def _set_flags(self, flag_attr_name="_flags", **flag_dict): """Set dynamic flags to given values. diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 83b2739a..466fb2ab 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -26,6 +26,7 @@ "test_get_tags", "test_get_tag", "test_get_tag_raises", + "test_get_tag_default_bypasses_raise_error", "test_set_tags", "test_set_tags_works_with_missing_tags_dynamic_attribute", "test_clone_tags", @@ -363,6 +364,26 @@ def test_get_tag_raises(fixture_tag_class_object: Child): fixture_tag_class_object.get_tag("bar") +def test_get_tag_default_bypasses_raise_error(fixture_tag_class_object: Child): + """Test that supplying tag_value_default bypasses raise_error. + + Providing a default should return it for missing tags even when + raise_error=True (the default), matching dict.get() / getattr() semantics. + + Raises + ------ + AssertError if explicit default does not suppress ValueError. + AssertError if explicit None default is not returned. + """ + # Explicit string default, no raise_error=False needed + val = fixture_tag_class_object.get_tag("bar", "fallback") + assert val == "fallback", "Expected 'fallback' when default supplied without raise_error=False" + + # Explicitly passing None as default should return None, not raise + val_none = fixture_tag_class_object.get_tag("bar", None) + assert val_none is None, "Expected None when None is explicitly supplied as default" + + def test_set_tags( fixture_object_instance_set_tags: Any, fixture_object_set_tags: Dict[str, Any], From 141c322e127f1bb3f29ee37ddb29d7f5a970da93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 2 May 2026 11:18:26 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- skbase/base/_base.py | 8 ++------ skbase/tests/test_base.py | 4 +++- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 7d4f07ae..84a84712 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -601,9 +601,7 @@ class attribute via nested inheritance and then any overrides """ return self._get_flags(flag_attr_name="_tags") - def get_tag( - self, tag_name, tag_value_default=_DEFAULT_SENTINEL, raise_error=True - ): + def get_tag(self, tag_name, tag_value_default=_DEFAULT_SENTINEL, raise_error=True): """Get tag value from instance, with tag level inheritance and overrides. Every ``scikit-base`` compatible object has a dictionary of tags. @@ -1412,9 +1410,7 @@ class attribute via nested inheritance and then any overrides collected_tags = self._complete_dict(collected_tags) return collected_tags - def get_tag( - self, tag_name, tag_value_default=_DEFAULT_SENTINEL, raise_error=True - ): + def get_tag(self, tag_name, tag_value_default=_DEFAULT_SENTINEL, raise_error=True): """Get tag value from instance, with tag level inheritance and overrides. Every ``scikit-base`` compatible object has a dictionary of tags. diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 466fb2ab..d27cfd1d 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -377,7 +377,9 @@ def test_get_tag_default_bypasses_raise_error(fixture_tag_class_object: Child): """ # Explicit string default, no raise_error=False needed val = fixture_tag_class_object.get_tag("bar", "fallback") - assert val == "fallback", "Expected 'fallback' when default supplied without raise_error=False" + assert ( + val == "fallback" + ), "Expected 'fallback' when default supplied without raise_error=False" # Explicitly passing None as default should return None, not raise val_none = fixture_tag_class_object.get_tag("bar", None)