Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,16 @@
"contributions": [
"infra"
]
},
{
"login": "Deep-Axe",
"name": "Deep-Axe",
"avatar_url": "https://avatars.githubusercontent.com/u/152912444?v=4",
"profile": "https://github.com/Deep-Axe",
"contributions": [
"code",
"test"
]
}
],
"projectName": "skbase",
Expand Down
54 changes: 32 additions & 22 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class name: BaseEstimator
from skbase._exceptions import NotFittedError
from skbase.base._clone_base import _check_clone, _clone
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
from skbase.base._tagmanager import _FlagManager
from skbase.base._tagmanager import _DEFAULT_SENTINEL, _FlagManager

__author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"]
__author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos", "Deep-Axe"]
__all__: List[str] = ["BaseEstimator", "BaseObject"]


Expand Down Expand Up @@ -601,7 +601,7 @@ class attribute via nested inheritance and then any overrides
"""
return self._get_flags(flag_attr_name="_tags")

def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
def get_tag(self, tag_name, tag_value_default=_DEFAULT_SENTINEL, raise_error=True):
"""Get tag value from instance, with tag level inheritance and overrides.

Every ``scikit-base`` compatible object has a dictionary of tags.
Expand All @@ -627,23 +627,28 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
----------
tag_name : str
Name of tag to be retrieved
tag_value_default : any type, optional; default=None
Default/fallback value if tag is not found
tag_value_default : any type, optional
Default/fallback value if tag is not found. When provided (including
``None``), it is returned on a missing tag regardless of ``raise_error``.
When omitted, behaviour is controlled by ``raise_error``.
raise_error : bool
whether a ``ValueError`` is raised when the tag is not found
Whether a ``ValueError`` is raised when the tag is not found.
Ignored when ``tag_value_default`` is explicitly provided.

Returns
-------
tag_value : Any
Value of the ``tag_name`` tag in ``self``.
If not found, raises an error if
``raise_error`` is True, otherwise it returns ``tag_value_default``.
Value of the ``tag_name`` tag in ``self``. If not found:
- ``tag_value_default`` is returned when it was explicitly provided,
- a ``ValueError`` is raised when ``raise_error=True`` and no default
was given,
- ``None`` is returned otherwise.

Raises
------
ValueError, if ``raise_error`` is ``True``.
The ``ValueError`` is then raised if ``tag_name`` is
not in ``self.get_tags().keys()``.
ValueError
If ``tag_name`` is not in ``self.get_tags().keys()``, no default was
supplied, and ``raise_error`` is ``True``.
"""
return self._get_flag(
flag_name=tag_name,
Expand Down Expand Up @@ -1405,7 +1410,7 @@ class attribute via nested inheritance and then any overrides
collected_tags = self._complete_dict(collected_tags)
return collected_tags

def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
def get_tag(self, tag_name, tag_value_default=_DEFAULT_SENTINEL, raise_error=True):
"""Get tag value from instance, with tag level inheritance and overrides.

Every ``scikit-base`` compatible object has a dictionary of tags.
Expand All @@ -1431,23 +1436,28 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
----------
tag_name : str
Name of tag to be retrieved
tag_value_default : any type, optional; default=None
Default/fallback value if tag is not found
tag_value_default : any type, optional
Default/fallback value if tag is not found. When provided (including
``None``), it is returned on a missing tag regardless of ``raise_error``.
When omitted, behaviour is controlled by ``raise_error``.
raise_error : bool
whether a ``ValueError`` is raised when the tag is not found
Whether a ``ValueError`` is raised when the tag is not found.
Ignored when ``tag_value_default`` is explicitly provided.

Returns
-------
tag_value : Any
Value of the ``tag_name`` tag in ``self``.
If not found, raises an error if
``raise_error`` is True, otherwise it returns ``tag_value_default``.
Value of the ``tag_name`` tag in ``self``. If not found:
- ``tag_value_default`` is returned when it was explicitly provided,
- a ``ValueError`` is raised when ``raise_error=True`` and no default
was given,
- ``None`` is returned otherwise.

Raises
------
ValueError, if ``raise_error`` is ``True``.
The ``ValueError`` is then raised if ``tag_name`` is
not in ``self.get_tags().keys()``.
ValueError
If ``tag_name`` is not in ``self.get_tags().keys()``, no default was
supplied, and ``raise_error`` is ``True``.
"""
self._deprecate_tag_warn([tag_name])
alias_dict = self.alias_dict
Expand Down
36 changes: 25 additions & 11 deletions skbase/base/_tagmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import inspect
from copy import deepcopy

# Sentinel distinguishes "caller passed no default" from "caller passed None".
_DEFAULT_SENTINEL = object()


class _FlagManager:
"""Mixin class for flag and configuration settings management."""
Expand Down Expand Up @@ -111,7 +114,7 @@ class attribute via nested inheritance and then any overrides
def _get_flag(
self,
flag_name,
flag_value_default=None,
flag_value_default=_DEFAULT_SENTINEL,
raise_error=True,
flag_attr_name="_flags",
):
Expand All @@ -121,33 +124,44 @@ def _get_flag(
----------
flag_name : str
Name of flag to be retrieved.
flag_value_default : any type, default=None
Default/fallback value if flag is not found
flag_value_default : any type, default=_DEFAULT_SENTINEL
Default/fallback value if flag is not found. If provided (including
``None``), it is returned when the flag is missing regardless of
``raise_error``.
raise_error : bool
Whether a `ValueError` is raised when the flag is not found.
Whether a ``ValueError`` is raised when the flag is not found.
Ignored if ``flag_value_default`` was explicitly supplied.
flag_attr_name : str, default = "_flags"
Name of the flag attribute that is read.

Returns
-------
flag_value :
Value of the `flag_name` flag in self. If not found, returns an error if
raise_error is True, otherwise it returns `flag_value_default`.
Value of the ``flag_name`` flag in self. If not found:
- returns ``flag_value_default`` when it was explicitly provided, or
- raises ``ValueError`` when ``raise_error=True`` and no default given,
- returns ``None`` otherwise.

Raises
------
ValueError
if `raise_error` is `True`, i.e.,
if `flag_name` is not in `self.get_flags().keys()`
if ``flag_name`` is not found, no default was supplied, and
``raise_error`` is ``True``.
"""
collected_flags = self._get_flags(flag_attr_name=flag_attr_name)

flag_value = collected_flags.get(flag_name, flag_value_default)
if flag_name in collected_flags:
return collected_flags[flag_name]

# Flag not found: determine what to return.
if flag_value_default is not _DEFAULT_SENTINEL:
# Caller explicitly supplied a default; respect it unconditionally.
return flag_value_default

if raise_error and flag_name not in collected_flags.keys():
if raise_error:
raise ValueError(f"Tag with name {flag_name} could not be found.")

return flag_value
return None

def _set_flags(self, flag_attr_name="_flags", **flag_dict):
"""Set dynamic flags to given values.
Expand Down
23 changes: 23 additions & 0 deletions skbase/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"test_get_tags",
"test_get_tag",
"test_get_tag_raises",
"test_get_tag_default_bypasses_raise_error",
"test_set_tags",
"test_set_tags_works_with_missing_tags_dynamic_attribute",
"test_clone_tags",
Expand Down Expand Up @@ -363,6 +364,28 @@ def test_get_tag_raises(fixture_tag_class_object: Child):
fixture_tag_class_object.get_tag("bar")


def test_get_tag_default_bypasses_raise_error(fixture_tag_class_object: Child):
"""Test that supplying tag_value_default bypasses raise_error.

Providing a default should return it for missing tags even when
raise_error=True (the default), matching dict.get() / getattr() semantics.

Raises
------
AssertError if explicit default does not suppress ValueError.
AssertError if explicit None default is not returned.
"""
# Explicit string default, no raise_error=False needed
val = fixture_tag_class_object.get_tag("bar", "fallback")
assert (
val == "fallback"
), "Expected 'fallback' when default supplied without raise_error=False"

# Explicitly passing None as default should return None, not raise
val_none = fixture_tag_class_object.get_tag("bar", None)
assert val_none is None, "Expected None when None is explicitly supplied as default"


def test_set_tags(
fixture_object_instance_set_tags: Any,
fixture_object_set_tags: Dict[str, Any],
Expand Down