@@ -20,6 +20,7 @@ def main():
2020 parser .add_argument ("--completion-params" , required = True , help = "JSON completion params (includes model)" )
2121 parser .add_argument ("--metadata" , required = True , help = "JSON serialized metadata object" )
2222 parser .add_argument ("--model-base-url" , required = True , help = "Base URL for the model API" )
23+ parser .add_argument ("--api-key" , required = True , help = "API key for the model API" )
2324
2425 args = parser .parse_args ()
2526
@@ -44,6 +45,8 @@ def main():
4445 rollout_id = metadata ["rollout_id" ]
4546 row_id = metadata ["row_id" ]
4647
48+ api_key = args .api_key
49+
4750 print (f"🚀 Starting rollout { rollout_id } " )
4851 print (f" Model: { model } " )
4952 print (f" Row ID: { row_id } " )
@@ -63,7 +66,7 @@ def main():
6366 # Build completion kwargs from completion_params
6467 completion_kwargs = {"messages" : messages , ** completion_params }
6568
66- client = OpenAI (base_url = args .model_base_url , api_key = os . environ . get ( "FIREWORKS_API_KEY" ) )
69+ client = OpenAI (base_url = args .model_base_url , api_key = api_key )
6770
6871 print ("📡 Calling OpenAI completion..." )
6972 print (f" Completion kwargs: { completion_kwargs } " )
0 commit comments