@@ -439,6 +439,40 @@ def _get_cache_file(self, provider: str) -> Path:
439439 fname = f"{ provider } _models.json"
440440 return self .cache_dir / fname
441441
442+ def _normalize_models_payload (self , provider : str , payload : Dict ) -> Dict :
443+ """Normalize provider payloads into an OpenAI-style `{data: [{id: ...}]}`."""
444+ if not isinstance (payload , dict ):
445+ return {}
446+ if "data" in payload and isinstance (payload .get ("data" ), list ):
447+ return payload
448+ # Fireworks returns `{models: [...], nextPageToken: ..., totalSize: ...}`
449+ models = payload .get ("models" )
450+ if isinstance (models , list ):
451+ normalized = []
452+ for item in models :
453+ if not isinstance (item , dict ):
454+ continue
455+ model_id = item .get ("name" ) or item .get ("id" )
456+ if not model_id :
457+ continue
458+ record = {"id" : model_id }
459+ for key in (
460+ "max_input_tokens" ,
461+ "max_output_tokens" ,
462+ "max_tokens" ,
463+ "context_length" ,
464+ "context_window" ,
465+ "mode" ,
466+ "pricing" ,
467+ "input_cost_per_token" ,
468+ "output_cost_per_token" ,
469+ ):
470+ if key in item and item [key ] is not None :
471+ record [key ] = item [key ]
472+ normalized .append (record )
473+ return {"data" : normalized }
474+ return {}
475+
442476 def _load_cache (self , provider : str ) -> None :
443477 if self ._cache_loaded .get (provider ):
444478 return
@@ -460,9 +494,10 @@ def _update_cache(self, provider: str) -> None:
460494 payload = self ._fetch_provider_models (provider )
461495 cache_file = self ._get_cache_file (provider )
462496 if payload :
463- self ._provider_cache [provider ] = payload
497+ normalized = self ._normalize_models_payload (provider , payload )
498+ self ._provider_cache [provider ] = normalized
464499 try :
465- cache_file .write_text (json .dumps (payload , indent = 2 ))
500+ cache_file .write_text (json .dumps (normalized , indent = 2 ))
466501 except OSError :
467502 pass
468503 return
0 commit comments