Skip to content

Commit 05d58c1

Browse files
committed
added utils_cache
1 parent c11f150 commit 05d58c1

1 file changed

Lines changed: 120 additions & 0 deletions

File tree

src/py_utils/utils_cache.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import json
2+
import pickle
3+
from functools import wraps
4+
from pathlib import Path
5+
6+
import numpy as np
7+
8+
9+
####################
10+
# Cache Primitives #
11+
####################
12+
13+
14+
def save_cache(data, cache_path, serializer="pickle"):
15+
"""Save data to a cache file. Creates parent directories as needed."""
16+
17+
cache_path = Path(cache_path)
18+
cache_path.parent.mkdir(parents=True, exist_ok=True)
19+
20+
if serializer == "pickle":
21+
with open(cache_path, "wb") as f:
22+
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
23+
24+
elif serializer == "numpy":
25+
np.save(cache_path, data)
26+
27+
elif serializer == "json":
28+
with open(cache_path, "w") as f:
29+
json.dump(data, f, indent=4)
30+
31+
else:
32+
raise ValueError(f"Unknown serializer: {serializer}")
33+
34+
35+
def load_cache(cache_path, serializer="pickle"):
36+
"""Load data from a cache file."""
37+
38+
cache_path = Path(cache_path)
39+
40+
if serializer == "pickle":
41+
with open(cache_path, "rb") as f:
42+
return pickle.load(f)
43+
44+
elif serializer == "numpy":
45+
result = np.load(cache_path)
46+
return result
47+
48+
elif serializer == "json":
49+
with open(cache_path, "r") as f:
50+
return json.load(f)
51+
52+
else:
53+
raise ValueError(f"Unknown serializer: {serializer}")
54+
55+
56+
def clear_cache(cache_path):
57+
"""Remove a cache file if it exists."""
58+
59+
cache_path = Path(cache_path)
60+
if cache_path.exists():
61+
cache_path.unlink()
62+
print(f"Cleared cache: {cache_path}")
63+
else:
64+
print(f"Cache not found: {cache_path}")
65+
66+
67+
#####################
68+
# cache_to_file #
69+
#####################
70+
71+
72+
def cache_to_file(
73+
directory,
74+
filename,
75+
serializer="pickle",
76+
):
77+
"""Decorator that caches a function's return value to a file.
78+
79+
Example — on a method::
80+
81+
@cache_to_file(
82+
lambda self, idx: self.cache_root,
83+
lambda self, idx: f"result.{idx}.npy",
84+
serializer="numpy",
85+
)
86+
def get_result(self, idx):
87+
...
88+
89+
Example — on a plain function::
90+
91+
@cache_to_file("/tmp/cache", "result.pkl")
92+
def expensive():
93+
...
94+
"""
95+
96+
def decorator(func):
97+
98+
@wraps(func)
99+
def wrapper(*args, **kwargs):
100+
101+
dir_path = directory
102+
if callable(directory):
103+
dir_path = directory(*args, **kwargs)
104+
105+
file_name = filename
106+
if callable(filename):
107+
file_name = filename(*args, **kwargs)
108+
109+
cache_path = Path(dir_path) / file_name
110+
111+
if cache_path.exists():
112+
return load_cache(cache_path, serializer)
113+
114+
result = func(*args, **kwargs)
115+
save_cache(result, cache_path, serializer)
116+
return result
117+
118+
return wrapper
119+
120+
return decorator

0 commit comments

Comments
 (0)