Skip to content

Commit c548724

Browse files
committed
Make prompt registry a static singleton instance
1 parent b43df0d commit c548724

3 files changed

Lines changed: 187 additions & 113 deletions

File tree

cecli/coders/base_coder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from cecli.utils import format_tokens, is_image_file
6262

6363
from ..dump import dump # noqa: F401
64-
from ..prompts.utils.registry import registry
64+
from ..prompts.utils.registry import PromptRegistry
6565
from .chat_chunks import ChatChunks
6666

6767

@@ -600,7 +600,7 @@ def gpt_prompts(self):
600600
return Coder._prompt_cache[prompt_name]
601601

602602
# Get prompts from registry
603-
prompts = registry.get_prompt(prompt_name)
603+
prompts = PromptRegistry.get_prompt(prompt_name)
604604

605605
# Create a simple object that allows attribute access
606606
class PromptObject:

cecli/prompts/utils/registry.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,12 @@
1919
class PromptRegistry:
2020
"""Central registry for loading and managing prompts from YAML files."""
2121

22-
_instance = None
22+
# Class-level state for singleton pattern
2323
_prompts_cache: Dict[str, Dict[str, Any]] = {}
2424
_base_prompts: Optional[Dict[str, Any]] = None
2525

26-
def __new__(cls):
27-
if cls._instance is None:
28-
cls._instance = super(PromptRegistry, cls).__new__(cls)
29-
return cls._instance
30-
31-
def __init__(self):
32-
if not hasattr(self, "_initialized"):
33-
self._initialized = True
34-
35-
def _load_yaml_file(self, file_name: str) -> Dict[str, Any]:
26+
@staticmethod
27+
def _load_yaml_file(file_name: str) -> Dict[str, Any]:
3628
"""Load a YAML file and return its contents."""
3729
try:
3830
# Use importlib_resources to access package files
@@ -43,30 +35,46 @@ def _load_yaml_file(self, file_name: str) -> Dict[str, Any]:
4335
)
4436
return yaml.safe_load(file_content) or {}
4537
except FileNotFoundError:
46-
return {}
38+
# If not found via importlib_resources, try local file system
39+
# Treat file_name as absolute path relative to current working directory
40+
try:
41+
import os
42+
43+
file_path = os.path.abspath(file_name)
44+
if os.path.exists(file_path):
45+
with open(file_path, "r", encoding="utf-8") as f:
46+
file_content = f.read()
47+
return yaml.safe_load(file_content) or {}
48+
else:
49+
raise ValueError(f"Prompt YAML file not found {file_name}")
50+
except (FileNotFoundError, OSError) as e:
51+
raise ValueError(f"Error parsing YAML file {file_name}: {e}")
4752
except yaml.YAMLError as e:
4853
raise ValueError(f"Error parsing YAML file {file_name}: {e}")
4954

50-
def _get_base_prompts(self) -> Dict[str, Any]:
55+
@classmethod
56+
def _get_base_prompts(cls) -> Dict[str, Any]:
5157
"""Load and cache base.yml prompts."""
52-
if self._base_prompts is None:
53-
self._base_prompts = self._load_yaml_file("base.yml")
54-
return self._base_prompts
58+
if cls._base_prompts is None:
59+
cls._base_prompts = cls._load_yaml_file("base.yml")
60+
return cls._base_prompts
5561

56-
def _merge_prompts(self, base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
62+
@staticmethod
63+
def _merge_prompts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
5764
"""Recursively merge override dict into base dict."""
5865
result = base.copy()
5966

6067
for key, value in override.items():
6168
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
62-
result[key] = self._merge_prompts(result[key], value)
69+
result[key] = PromptRegistry._merge_prompts(result[key], value)
6370
else:
6471
result[key] = value
6572

6673
return result
6774

75+
@classmethod
6876
def _resolve_inheritance_chain(
69-
self, prompt_name: str, visited: Optional[set] = None
77+
cls, prompt_name: str, visited: Optional[set] = None
7078
) -> List[str]:
7179
"""
7280
Resolve the full inheritance chain for a prompt type.
@@ -100,13 +108,13 @@ def _resolve_inheritance_chain(
100108
except FileNotFoundError:
101109
raise FileNotFoundError(f"Prompt file not found: {prompt_file_name}")
102110

103-
prompt_data = self._load_yaml_file(prompt_file_name)
111+
prompt_data = cls._load_yaml_file(prompt_file_name)
104112
inherits = prompt_data.get("_inherits", [])
105113

106114
# Resolve inheritance chain recursively
107115
inheritance_chain = []
108116
for parent in inherits:
109-
parent_chain = self._resolve_inheritance_chain(parent, visited.copy())
117+
parent_chain = cls._resolve_inheritance_chain(parent, visited.copy())
110118
# Add parent chain, avoiding duplicates while preserving order
111119
for item in parent_chain:
112120
if item not in inheritance_chain:
@@ -118,7 +126,8 @@ def _resolve_inheritance_chain(
118126

119127
return inheritance_chain
120128

121-
def get_prompt(self, prompt_name: str) -> Dict[str, Any]:
129+
@classmethod
130+
def get_prompt(cls, prompt_name: str) -> Dict[str, Any]:
122131
"""
123132
Get prompts for a specific prompt type.
124133
@@ -128,40 +137,43 @@ def get_prompt(self, prompt_name: str) -> Dict[str, Any]:
128137
Returns:
129138
Dictionary containing all prompt attributes for the specified type
130139
"""
140+
prompt_name = prompt_name.replace(".yml", "")
131141
# Check cache first
132-
if prompt_name in self._prompts_cache:
133-
return self._prompts_cache[prompt_name]
142+
if prompt_name in cls._prompts_cache:
143+
return cls._prompts_cache[prompt_name]
134144

135145
# Resolve inheritance chain
136-
inheritance_chain = self._resolve_inheritance_chain(prompt_name)
146+
inheritance_chain = cls._resolve_inheritance_chain(prompt_name)
137147

138148
# Start with empty dict and merge in inheritance order
139149
merged_prompts: Dict[str, Any] = {}
140150

141151
for current_name in inheritance_chain:
142152
# Load prompts for this level
143153
if current_name == "base":
144-
current_prompts = self._get_base_prompts()
154+
current_prompts = cls._get_base_prompts()
145155
else:
146-
current_prompts = self._load_yaml_file(f"{current_name}.yml")
156+
current_prompts = cls._load_yaml_file(f"{current_name}.yml")
147157

148158
# Merge current prompts into accumulated result
149-
merged_prompts = self._merge_prompts(merged_prompts, current_prompts)
159+
merged_prompts = cls._merge_prompts(merged_prompts, current_prompts)
150160

151161
# Remove _inherits key from final result (it's metadata, not a prompt)
152162
merged_prompts.pop("_inherits", None)
153163

154164
# Cache the result
155-
self._prompts_cache[prompt_name] = merged_prompts
165+
cls._prompts_cache[prompt_name] = merged_prompts
156166

157167
return merged_prompts
158168

159-
def reload_prompts(self):
169+
@classmethod
170+
def reload_prompts(cls):
160171
"""Clear cache and reload all prompts from disk."""
161-
self._prompts_cache.clear()
162-
self._base_prompts = None
172+
cls._prompts_cache.clear()
173+
cls._base_prompts = None
163174

164-
def list_available_prompts(self) -> list[str]:
175+
@staticmethod
176+
def list_available_prompts() -> list[str]:
165177
"""List all available prompt types."""
166178
prompts = []
167179
for path in importlib_resources.files("cecli.prompts").iterdir():
@@ -170,5 +182,5 @@ def list_available_prompts(self) -> list[str]:
170182
return sorted(prompts)
171183

172184

173-
# Global instance for easy access
174-
registry = PromptRegistry()
185+
# All methods are static/class methods, so no instance is needed
186+
# Use PromptRegistry.get_prompt() directly

0 commit comments

Comments
 (0)