diff --git a/CHANGELOG.md b/CHANGELOG.md index 5581e232..a697219c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ Changelog follow https://keepachangelog.com/ format. Windows nor when `tf.io.gfile` is used. * `etree`: * Fix `etree.map` for `collections.defaultdict` +* `internal`: + * Add a `unwrap_on_reload` to save/restore original function after a + module is reloaded (e.g. on colab) ## [1.3.0] - 2023-05-12 diff --git a/etils/epy/_internal.py b/etils/epy/_internal.py index f137ae66..dc085775 100644 --- a/etils/epy/_internal.py +++ b/etils/epy/_internal.py @@ -15,10 +15,12 @@ """`etils` internal utils.""" import contextlib -from typing import Iterator +from typing import Iterator, TypeVar from etils.epy import reraise_utils +_FnT = TypeVar('_FnT') + @contextlib.contextmanager def check_missing_deps() -> Iterator[None]: @@ -48,3 +50,13 @@ def check_missing_deps() -> Iterator[None]: '(e.g. `from etils import ecolab` -> `pip install etils[ecolab]`)' ), ) + + +def unwrap_on_reload(fn: _FnT) -> _FnT: + """Unwrap the function to support colab module reload.""" + if hasattr(fn, '__original_fn__'): + fn = fn.__original_fn__ + + # Save the original function (to support reload) + fn.__original_fn__ = fn + return fn