diff --git a/internal/provider/openai/cost.go b/internal/provider/openai/cost.go index c59fdf0..0a0932c 100644 --- a/internal/provider/openai/cost.go +++ b/internal/provider/openai/cost.go @@ -219,11 +219,23 @@ var OpenAiPerThousandCallsToolCost = map[string]float64{ "web_search": 10.0, "web_search_preview": 25.0, "web_search_preview_reasoning": 10.0, + "file_search": 2.5, +} + +var OpenAiCodeInterpreterContainerCost = map[string]float64{ + "1g": 0.03, + "4g": 0.12, + "16g": 0.48, + "64g": 1.92, } var AllowedTools = []string{ "web_search", "web_search_preview", + + "code_interpreter", + + "file_search", } type tokenCounter interface { @@ -571,6 +583,9 @@ func (ce *CostEstimator) EstimateResponseApiToolCallsCost(tools []responsesOpena totalCost := 0.0 for _, tool := range tools { toolType := tool.Type + if toolType == "code_interpreter" { + continue + } cost, ok := OpenAiPerThousandCallsToolCost[extendedToolType(toolType, model)] if !ok { return 0, fmt.Errorf("tool type %s is not present in the tool cost map provided", toolType) @@ -580,6 +595,26 @@ func (ce *CostEstimator) EstimateResponseApiToolCallsCost(tools []responsesOpena return totalCost / 1000, nil } +func (ce *CostEstimator) EstimateResponseApiToolCreateContainerCost(req *ResponseRequest) (float64, error) { + if req == nil { + return 0, nil + } + totalCost := 0.0 + for _, tool := range req.Tools { + c := tool.GetContainerAsResponseRequestToolContainer() + if c == nil { + continue + } + limit := c.GetMemoryLimit() + cost, ok := OpenAiCodeInterpreterContainerCost[limit] + if !ok { + return 0, fmt.Errorf("container with memory limit %s is not present in the code interpreter container cost map", limit) + } + totalCost += cost + } + return totalCost, nil +} + var reasoningModelPrefix = []string{"gpt-5", "o1", "o2", "o3"} func extendedToolType(toolType, model string) string { diff --git a/internal/provider/openai/types.go b/internal/provider/openai/types.go index 5d1111b..b898e75 100644 --- a/internal/provider/openai/types.go +++ b/internal/provider/openai/types.go @@ -30,6 +30,45 @@ type ResponseRequest struct { //User *string `json:"user,omitzero"` //Deprecated } -type ResponseRequestToolUnion struct { +type ResponseRequestToolContainer struct { Type string `json:"type"` + // memory_limit + MemoryLimit *string `json:"memory_limit,omitzero"` +} + +func (c *ResponseRequestToolContainer) GetMemoryLimit() string { + if c.MemoryLimit != nil { + return *c.MemoryLimit + } + return "1g" +} + +type ResponseRequestToolUnion struct { + Type string `json:"type"` + Container any `json:"container"` +} + +func (u *ResponseRequestToolUnion) GetContainerAsResponseRequestToolContainer() *ResponseRequestToolContainer { + if container, ok := u.Container.(map[string]interface{}); ok { + cType := "auto" + rawType, exists := container["type"] + if !exists { + cType = "auto" + } + if typeStr, ok := rawType.(string); ok { + cType = typeStr + } + toolContainer := &ResponseRequestToolContainer{ + Type: cType, + MemoryLimit: nil, + } + + if memoryLimit, exists := container["memory_limit"]; exists { + if memoryLimitStr, ok := memoryLimit.(string); ok { + toolContainer.MemoryLimit = &memoryLimitStr + } + } + return toolContainer + } + return nil } diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go index c69691c..e841dc3 100644 --- a/internal/server/web/proxy/middleware.go +++ b/internal/server/web/proxy/middleware.go @@ -60,6 +60,7 @@ type estimator interface { EstimateChatCompletionPromptTokenCounts(model string, r *goopenai.ChatCompletionRequest) (int, error) EstimateResponseApiTotalCost(model string, usage responsesOpenai.ResponseUsage) (float64, error) EstimateResponseApiToolCallsCost(tools []responsesOpenai.ToolUnion, model string) (float64, error) + EstimateResponseApiToolCreateContainerCost(req *openai.ResponseRequest) (float64, error) } type azureEstimator interface { @@ -808,6 +809,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag return } + ginCtxSetResponsesRequest(c, responsesReq) + if gopointer.ToValueOrDefault(responsesReq.Background, false) { telemetry.Incr("bricksllm.proxy.get_middleware.background_not_allowed", nil, 1) JSON(c, http.StatusForbidden, "[BricksLLM] background is not allowed") @@ -830,6 +833,25 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag return } + isCreateContainerTool := false + var containerMemLimit string + for _, tool := range responsesReq.Tools { + if tool.GetContainerAsResponseRequestToolContainer() != nil { + isCreateContainerTool = true + containerMemLimit = tool.GetContainerAsResponseRequestToolContainer().GetMemoryLimit() + break + } + } + if isCreateContainerTool { + _, ok := openai.OpenAiCodeInterpreterContainerCost[containerMemLimit] + if !ok { + telemetry.Incr("bricksllm.proxy.get_middleware.container_memory_limit_not_allowed", nil, 1) + JSON(c, http.StatusForbidden, "[BricksLLM] container memory limit is not allowed") + c.Abort() + return + } + } + userId = gopointer.ToValueOrDefault(responsesReq.SafetyIdentifier, "") enrichedEvent.Request = responsesReq c.Set("model", gopointer.ToValueOrDefault(responsesReq.Model, "")) diff --git a/internal/server/web/proxy/responses.go b/internal/server/web/proxy/responses.go index 0860a65..6fd79b1 100644 --- a/internal/server/web/proxy/responses.go +++ b/internal/server/web/proxy/responses.go @@ -11,6 +11,7 @@ import ( "net/http" "time" + "github.com/bricks-cloud/bricksllm/internal/provider/openai" "github.com/bricks-cloud/bricksllm/internal/telemetry" "github.com/bricks-cloud/bricksllm/internal/util" "github.com/gin-gonic/gin" @@ -98,6 +99,13 @@ func getResponsesHandler(prod, private bool, client http.Client, e estimator) gi telemetry.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_total_cost_error", nil, 1) logError(log, "error when estimating openai cost", prod, err) } + reqResp, _ := ginCtxGetResponsesRequest(c) + containerCost, err := e.EstimateResponseApiToolCreateContainerCost(reqResp) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_tool_container_cost_error", nil, 1) + logError(log, "error when estimating openai tool container cost", prod, err) + } + cost += containerCost toolsCost, err := e.EstimateResponseApiToolCallsCost(resp.Tools, model) if err != nil { telemetry.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_tool_calls_cost_error", nil, 1) @@ -237,6 +245,13 @@ func getResponsesHandler(prod, private bool, client http.Client, e estimator) gi telemetry.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_total_cost_error", nil, 1) logError(log, "error when estimating openai cost", prod, err) } + reqResp, _ := ginCtxGetResponsesRequest(c) + containerCost, err := e.EstimateResponseApiToolCreateContainerCost(reqResp) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_tool_container_cost_error", nil, 1) + logError(log, "error when estimating openai tool container cost", prod, err) + } + streamCost += containerCost toolsCost, err := e.EstimateResponseApiToolCallsCost(responsesStreamResp.Response.Tools, model) if err != nil { telemetry.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_tool_calls_cost_error", nil, 1) @@ -268,3 +283,21 @@ func int64ToInt(src int64) (int, error) { } return int(src), nil } + +func ginCtxSetResponsesRequest(c *gin.Context, req *openai.ResponseRequest) { + c.Set("responses_request", req) +} + +func ginCtxGetResponsesRequest(c *gin.Context) (*openai.ResponseRequest, error) { + reqAny, exists := c.Get("responses_request") + if !exists { + return nil, errors.New("responses request not found in gin context") + } + + req, ok := reqAny.(*openai.ResponseRequest) + if !ok { + return nil, errors.New("responses request in gin context has invalid type") + } + + return req, nil +}