diff --git a/.vscode/launch.json b/.vscode/launch.json index 308fbea..d8d9f0e 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -37,11 +37,11 @@ "asRoot": true, }, { - "name": "Run MCP Server", + "name": "Websocket Server", "type": "go", "request": "launch", "mode": "auto", - "program": "${workspaceFolder}/examples/mcp/", + "program": "${workspaceFolder}/examples/websocket/", "asRoot": true, }, ] diff --git a/examples/websocket/configs/config.yaml b/examples/websocket/configs/config.yaml new file mode 100644 index 0000000..7b65a32 --- /dev/null +++ b/examples/websocket/configs/config.yaml @@ -0,0 +1,30 @@ +service: + http: + h3: + enabled: true + address: + ip: "" + port: 4431 + h1: + enabled: true + address: + ip: "" + port: 4431 + h1_ssl: + enabled: true + address: + ip: "" + port: 4430 + tls: + generate_if_missing: true + certificate: + raw: "" + path: cert.pem + key: + raw: "" + path: key.pem + mcp: + enabled: false + address: + ip: "" + port: 4432 diff --git a/examples/websocket/main.go b/examples/websocket/main.go new file mode 100644 index 0000000..8a1363d --- /dev/null +++ b/examples/websocket/main.go @@ -0,0 +1,42 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/ayushanand18/crazyhttp/internal/constants" + crazyserver "github.com/ayushanand18/crazyhttp/pkg/server" + "github.com/ayushanand18/crazyhttp/pkg/types" +) + +func main() { + ctx := context.Background() + + server := crazyserver.NewHttpServer(ctx) + if err := server.Initialize(ctx); err != nil { + log.Fatalf("Server failed to Initialize: %v", err) + } + + server.WebSocket("/ws-test"). + WithOptions(types.WebSocketOption{ + AllowedOrigins: []string{"*"}, + }). + Serve(func(ctx context.Context) error { + reqChanel := ctx.Value(constants.WebsocketRequestChannel).(chan types.WebsocketStreamChunk) + respChanel := ctx.Value(constants.WebsocketResponseChannel).(chan types.WebsocketStreamChunk) + + for chunk := range reqChanel { + fmt.Printf("Received chunk: ID=%d, Type=%d, Data=%s\n", chunk.Id, chunk.MessageType, string(chunk.Data)) + respChanel <- types.WebsocketStreamChunk{ + Data: []byte(fmt.Sprintf("Echo: %s", string(chunk.Data))), + } + } + + return nil + }) + + if err := server.ListenAndServe(ctx); err != nil { + log.Fatalf("Server failed to start: %v", err) + } +} diff --git a/examples/websocket/main_test.go b/examples/websocket/main_test.go new file mode 100644 index 0000000..9ed481d --- /dev/null +++ b/examples/websocket/main_test.go @@ -0,0 +1,91 @@ +package main_test + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/ayushanand18/crazyhttp/internal/constants" + crazyserver "github.com/ayushanand18/crazyhttp/pkg/server" + "github.com/ayushanand18/crazyhttp/pkg/types" + "github.com/gorilla/websocket" +) + +// waitForServer waits until TCP port is accepting connections +func waitForServer(addr string) error { + for i := 0; i < 20; i++ { + conn, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err == nil { + conn.Close() + return nil + } + time.Sleep(50 * time.Millisecond) + } + return fmt.Errorf("server not ready on %s", addr) +} + +func TestUserRoute_WebsocketRequest(t *testing.T) { + ctx := context.Background() + addr := "localhost:4431" + + server := crazyserver.NewHttpServer(ctx) + if err := server.Initialize(ctx); err != nil { + t.Fatalf("Initialize failed: %v", err) + } + + // WebSocket endpoint + server.WebSocket("/ws-test"). + WithOptions(types.WebSocketOption{AllowedOrigins: []string{"*"}}). + Serve(func(ctx context.Context) error { + reqCh := ctx.Value(constants.WebsocketRequestChannel).(chan types.WebsocketStreamChunk) + respCh := ctx.Value(constants.WebsocketResponseChannel).(chan types.WebsocketStreamChunk) + + // Keep the handler alive to echo messages + for chunk := range reqCh { + respCh <- types.WebsocketStreamChunk{ + Data: []byte(fmt.Sprintf("Echo: %s", chunk.Data)), + } + } + return nil + }) + + // Start server in background + go func() { + if err := server.ListenAndServe(ctx); err != nil { + t.Logf("Server stopped: %v", err) + } + }() + + // Wait for server to be ready + if err := waitForServer(addr); err != nil { + t.Fatalf("Server not ready: %v", err) + } + + // Connect WebSocket client + wsURL := fmt.Sprintf("ws://%s/ws-test", addr) + dialer := websocket.Dialer{} + conn, _, err := dialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("WebSocket dial failed: %v", err) + } + defer conn.Close() + + // Send a message + msg := "hello" + if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil { + t.Fatalf("WriteMessage failed: %v", err) + } + + // Read the echo + _, p, err := conn.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage failed: %v", err) + } + + want := fmt.Sprintf("Echo: %s", msg) + if string(p) != want { + t.Errorf("Expected %q, got %q", want, p) + } +} diff --git a/go.mod b/go.mod index bbc4138..3b7f820 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,10 @@ toolchain go1.23.11 require ( github.com/gorilla/mux v1.8.1 + github.com/gorilla/websocket v1.5.3 github.com/pkg/errors v0.8.1 github.com/quic-go/quic-go v0.52.0 + golang.org/x/net v0.28.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -20,7 +22,6 @@ require ( go.uber.org/mock v0.5.0 // indirect golang.org/x/crypto v0.26.0 // indirect golang.org/x/mod v0.18.0 // indirect - golang.org/x/net v0.28.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.23.0 // indirect golang.org/x/text v0.17.0 // indirect diff --git a/go.sum b/go.sum index 228a6df..6261633 100644 --- a/go.sum +++ b/go.sum @@ -56,6 +56,8 @@ github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE0 github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= diff --git a/internal/constants/http.go b/internal/constants/http.go index 8fab05a..64a1105 100644 --- a/internal/constants/http.go +++ b/internal/constants/http.go @@ -30,4 +30,8 @@ const ( HttpRequestURLParams ContextKeys = "request_url_params" HttpRequestPathValues ContextKeys = "request_path_values" RateLimitCustomKey ContextKeys = "rate_limit_custom_key" + + // websocket specific context keys + WebsocketRequestChannel ContextKeys = "websocket_request_channel" + WebsocketResponseChannel ContextKeys = "websocket_response_channel" ) diff --git a/internal/http/encoder.go b/internal/http/encoder.go index 61d5814..7b7c530 100644 --- a/internal/http/encoder.go +++ b/internal/http/encoder.go @@ -11,16 +11,9 @@ func DefaultHttpEncode(ctx context.Context, response interface{}) (headers map[s "Content-Type": {"application/json; charset=utf-8"}, } - switch v := response.(type) { - case string: - body = []byte(v) - case []byte: - body = v - default: - body, err = json.Marshal(v) - if err != nil { - return headers, nil, err - } + body, err = GetDefaultSerialization(response) + if err != nil { + return headers, body, err } return headers, body, nil @@ -33,3 +26,19 @@ func DefaultHttpDecode(ctx context.Context, r *http.Request) (outgoingRequest in return outgoingRequest, nil } + +func GetDefaultSerialization(req interface{}) (body []byte, err error) { + switch v := req.(type) { + case string: + body = []byte(v) + case []byte: + body = v + default: + body, err = json.Marshal(v) + if err != nil { + return nil, err + } + } + + return body, nil +} diff --git a/internal/http/headers.go b/internal/http/headers.go index af4ae0c..9853709 100644 --- a/internal/http/headers.go +++ b/internal/http/headers.go @@ -1,19 +1,22 @@ package http -import "context" +import ( + "context" + "net/http" +) -func PopulateDefaultServerHeaders(ctx context.Context, headers map[string][]string) map[string][]string { +func PopulateDefaultServerHeaders(ctx context.Context, r *http.Request, headers map[string][]string) map[string][]string { if headers == nil { headers = make(map[string][]string) } headers["X-Server"] = []string{"crazyhttp"} - headers["Access-Control-Allow-Origin"] = []string{"*"} + // relay the origin back since we check for allowed origins, earlier + headers["Access-Control-Allow-Origin"] = []string{r.Header.Get("Origin")} headers["Access-Control-Allow-Methods"] = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"} headers["Access-Control-Allow-Headers"] = []string{"Content-Type", "Authorization"} headers["Access-Control-Allow-Credentials"] = []string{"true"} headers["Access-Control-Max-Age"] = []string{"86400"} - headers["Content-Type"] = []string{"text/event-stream"} return headers } diff --git a/internal/http/origin.go b/internal/http/origin.go new file mode 100644 index 0000000..813383a --- /dev/null +++ b/internal/http/origin.go @@ -0,0 +1,33 @@ +package http + +import ( + "regexp" + "strings" +) + +func IsOriginAllowed(origin string, patterns []string) bool { + for _, pattern := range patterns { + switch { + case len(pattern) > 2 && pattern[0] == '/' && pattern[len(pattern)-1] == '/': + // Treat as raw regex (/^https:\/\/foo\.com$/) + if matched, _ := regexp.MatchString(pattern[1:len(pattern)-1], origin); matched { + return true + } + + case strings.Contains(pattern, "*"): + // Convert glob-style wildcard (*) to regex + re := "^" + regexp.QuoteMeta(pattern) + re = strings.ReplaceAll(re, `\*`, ".*") + "$" + if matched, _ := regexp.MatchString(re, origin); matched { + return true + } + + default: + // Exact match + if origin == pattern { + return true + } + } + } + return false +} diff --git a/pkg/server/handler.go b/pkg/server/handler.go index 126feaa..338bf25 100644 --- a/pkg/server/handler.go +++ b/pkg/server/handler.go @@ -31,6 +31,12 @@ func httpDefaultHandler( return } + if len(m.options.AllowedOrigins) > 0 && !ashttp.IsOriginAllowed(r.Header.Get("Origin"), m.options.AllowedOrigins) { + w.WriteHeader(http.StatusForbidden) + slog.ErrorContext(ctx, "origin not allowed", "origin", r.Header.Get("Origin")) + return + } + if decoder != nil { request, err = decoder(ctx, r) if err != nil { @@ -84,7 +90,7 @@ func httpDefaultHandler( } } - headers = ashttp.PopulateDefaultServerHeaders(ctx, headers) + headers = ashttp.PopulateDefaultServerHeaders(ctx, r, headers) for key, value := range headers { w.Header().Del(key) diff --git a/pkg/server/interface.go b/pkg/server/interface.go index 6fd0a0c..94bc8cd 100644 --- a/pkg/server/interface.go +++ b/pkg/server/interface.go @@ -35,6 +35,9 @@ type HttpServer interface { OPTIONS(string) Method CONNECT(string) Method TRACE(string) Method + + // Websocket + WebSocket(string) WebSocket } func NewHttpServer(ctx context.Context) HttpServer { diff --git a/pkg/server/server.go b/pkg/server/server.go index a7bc678..d31ecf2 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -124,6 +124,10 @@ func (s *server) TRACE(url string) Method { return NewMethod(constants.HttpMethodTrace, url, s) } +func (s *server) WebSocket(url string) WebSocket { + return NewWebsocket(url, s) +} + // serve the HTTP request, and provide a response func (h *rootHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer func() { diff --git a/pkg/server/streaming_handler.go b/pkg/server/streaming_handler.go index e82b9b4..e290b9f 100644 --- a/pkg/server/streaming_handler.go +++ b/pkg/server/streaming_handler.go @@ -5,6 +5,7 @@ import ( "log/slog" "net/http" "strings" + "sync" "github.com/ayushanand18/crazyhttp/internal/constants" ashttp "github.com/ayushanand18/crazyhttp/internal/http" @@ -19,24 +20,48 @@ func streamingDefaultHandler( decoder types.HttpDecoder, encoder types.HttpEncoder, r *http.Request, - m *method) { - + m *method, +) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("X-Accel-Buffering", "no") w.WriteHeader(http.StatusOK) + if len(m.options.AllowedOrigins) > 0 && + !ashttp.IsOriginAllowed(r.Header.Get("Origin"), m.options.AllowedOrigins) { + w.WriteHeader(http.StatusForbidden) + slog.ErrorContext(ctx, "origin not allowed", "origin", r.Header.Get("Origin")) + return + } + flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming unsupported by this server!", http.StatusInternalServerError) return } + // master cancel context + ctx, cancel := context.WithCancel(ctx) + defer cancel() + ch := make(chan types.StreamChunk) + + // safe channel close + var once sync.Once + closeAll := func() { + once.Do(func() { close(ch) }) + } + + // expose channel in ctx ctx = context.WithValue(ctx, constants.StreamingResponseChannelContextKey, ch) + // worker goroutine: decode, rate-limit, call handler go func() { - defer close(ch) + defer func() { + cancel() + closeAll() + }() + var request interface{} var err error @@ -53,13 +78,13 @@ func streamingDefaultHandler( if key == nil || key == "" { key = strings.Split(r.RemoteAddr, ":")[0] } - _, ok := key.(string) + k, ok := key.(string) if !ok { w.WriteHeader(http.StatusInternalServerError) slog.ErrorContext(ctx, "rate limit key is not a string", "key:=", key) return } - m.rateLimiter.Allow(key.(string)) + m.rateLimiter.Allow(k) } _, err = handler(ctx, request) @@ -67,39 +92,48 @@ func streamingDefaultHandler( w.WriteHeader(errors.DecodeErrorToHttpErrorStatus(err)) return } - }() - for chunk := range ch { - var headers map[string][]string - var encoded []byte - var err error + // writer loop (main goroutine) + for { + select { + case chunk, ok := <-ch: + if !ok { + return + } - if encoder != nil { - headers, encoded, err = encoder(ctx, chunk.Data) - if err != nil { - w.WriteHeader(errors.DecodeErrorToHttpErrorStatus(err)) - break + var headers map[string][]string + var encoded []byte + var err error + + if encoder != nil { + headers, encoded, err = encoder(ctx, chunk.Data) + } else { + headers, encoded, err = ashttp.DefaultHttpEncode(ctx, chunk.Data) } - } else { - headers, encoded, err = ashttp.DefaultHttpEncode(ctx, chunk.Data) if err != nil { w.WriteHeader(errors.DecodeErrorToHttpErrorStatus(err)) - break + cancel() + return } - } - for key, value := range headers { - w.Header().Del(key) - for _, v := range value { - w.Header().Add(key, v) + for key, value := range headers { + w.Header().Del(key) + for _, v := range value { + w.Header().Add(key, v) + } } - } - if _, err := w.Write(encoded); err != nil { - break - } + if _, err := w.Write(encoded); err != nil { + cancel() + return + } - flusher.Flush() + flusher.Flush() + + case <-ctx.Done(): + closeAll() + return + } } } diff --git a/pkg/server/utils.go b/pkg/server/utils.go index 7c93eef..805dc2a 100644 --- a/pkg/server/utils.go +++ b/pkg/server/utils.go @@ -11,6 +11,9 @@ import ( "os" "github.com/ayushanand18/crazyhttp/internal/config" + internalhttp "github.com/ayushanand18/crazyhttp/internal/http" + "github.com/ayushanand18/crazyhttp/pkg/types" + gws "github.com/gorilla/websocket" ) func checkIfTlsCertificateIsMissing(ctx context.Context) bool { @@ -78,3 +81,27 @@ func DumpRequest(req *http.Request) { // Restore the body again to ensure downstream handlers can read it req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) } + +// GetWebSocketHandlerFunc wraps a method onto websocket handler func +func (ws *websocket) GetWebSocketHandlerFunc(handler types.WebsocketHandlerFunc) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + upgrader := gws.Upgrader{ + CheckOrigin: func(req *http.Request) bool { + if len(ws.options.AllowedOrigins) > 0 && + !internalhttp.IsOriginAllowed(r.Header.Get("Origin"), ws.options.AllowedOrigins) { + return false + } + return true + }, + } + + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + slog.Error("Error handling Upgrading websocket", "error", err) + return + } + defer c.Close() + + websocketHandler(r.Context(), c, w, r, ws, handler) + } +} diff --git a/pkg/server/websocket.go b/pkg/server/websocket.go new file mode 100644 index 0000000..062c129 --- /dev/null +++ b/pkg/server/websocket.go @@ -0,0 +1,105 @@ +package server + +import ( + "net/http" + "time" + + "github.com/ayushanand18/crazyhttp/internal/ratelimiter" + "github.com/ayushanand18/crazyhttp/pkg/types" +) + +type websocket struct { + Url string + s *server + + rateLimiter *ratelimiter.RateLimiter + + decoder types.HttpDecoder + encoder types.HttpEncoder + beforeServeMiddleware types.HttpRequestMiddleware + afterServeMiddleware types.HttpResponseMiddleware + + options types.WebSocketOption + + description string + name string +} + +type WebSocket interface { + Serve(types.WebsocketHandlerFunc) + + // Decoder for every message received + WithDecoder(decoder types.HttpDecoder) WebSocket + // Encoder for every message sent + WithEncoder(encoder types.HttpEncoder) WebSocket + // Middleware to run before every message is served + WithBeforeServe(middleware types.HttpRequestMiddleware) WebSocket + // Middleware to run after every message is sent + WithAfterServe(middleware types.HttpResponseMiddleware) WebSocket + // Name of the websocket endpoint - for Swagger API documentation + WithName(name string) WebSocket + // Description of the websocket endpoint - for Swagger API documentation + WithDescription(desc string) WebSocket + // WithOptions to add serve options + WithOptions(options types.WebSocketOption) WebSocket + // WithRateLimiter to add rate limiting + // rate limit will be applied on each message received + // key in context with which rate limiting will be done can be set using RateLimitOptions.ContextKey + WithRateLimit(options types.RateLimitOptions) WebSocket + // HandleHandshake to handle custom handshake + HandleHandshake(types.WebSocketHandshakeFunc) WebSocket +} + +func NewWebsocket(url string, s *server) WebSocket { + return &websocket{Url: url, s: s} +} + +func (ws *websocket) Serve(handler types.WebsocketHandlerFunc) { + fun := ws.GetWebSocketHandlerFunc(handler) + ws.s.mux.HandleFunc(ws.Url, http.HandlerFunc(fun)) +} + +func (ws *websocket) WithDecoder(decoder types.HttpDecoder) WebSocket { + ws.decoder = decoder + return ws +} + +func (ws *websocket) WithEncoder(encoder types.HttpEncoder) WebSocket { + ws.encoder = encoder + return ws +} + +func (ws *websocket) WithBeforeServe(middleware types.HttpRequestMiddleware) WebSocket { + ws.beforeServeMiddleware = middleware + return ws +} + +func (ws *websocket) WithAfterServe(middleware types.HttpResponseMiddleware) WebSocket { + ws.afterServeMiddleware = middleware + return ws +} + +func (ws *websocket) WithName(name string) WebSocket { + ws.name = name + return ws +} + +func (ws *websocket) WithDescription(desc string) WebSocket { + ws.description = desc + return ws +} + +func (ws *websocket) HandleHandshake(fn types.WebSocketHandshakeFunc) WebSocket { + return ws +} + +func (ws *websocket) WithOptions(options types.WebSocketOption) WebSocket { + ws.options = options + return ws +} + +func (ws *websocket) WithRateLimit(options types.RateLimitOptions) WebSocket { + ws.rateLimiter = ratelimiter.NewRateLimiter(options.Limit, time.Duration(options.BucketDurationInSeconds)*time.Second) + + return ws +} diff --git a/pkg/server/websocket_handler.go b/pkg/server/websocket_handler.go new file mode 100644 index 0000000..7bc4d16 --- /dev/null +++ b/pkg/server/websocket_handler.go @@ -0,0 +1,145 @@ +package server + +import ( + "context" + "log/slog" + "net/http" + "strings" + "sync" + + "github.com/ayushanand18/crazyhttp/internal/constants" + ashttp "github.com/ayushanand18/crazyhttp/internal/http" + "github.com/ayushanand18/crazyhttp/pkg/errors" + "github.com/ayushanand18/crazyhttp/pkg/types" + gws "github.com/gorilla/websocket" +) + +func websocketHandler( + ctx context.Context, + conn *gws.Conn, + w http.ResponseWriter, + r *http.Request, + ws *websocket, + handler types.WebsocketHandlerFunc, +) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + requestChannel := make(chan types.WebsocketStreamChunk) + responseChannel := make(chan types.WebsocketStreamChunk) + + // helper to close all channels once + var once sync.Once + closeAll := func() { + once.Do(func() { + close(requestChannel) + close(responseChannel) + }) + } + + // attach to context + ctx = context.WithValue(ctx, constants.WebsocketRequestChannel, requestChannel) + ctx = context.WithValue(ctx, constants.WebsocketResponseChannel, responseChannel) + + // Reader goroutine + go func() { + defer func() { + cancel() + closeAll() + }() + + for { + mt, message, err := conn.ReadMessage() + if err != nil { + if gws.IsCloseError(err, + gws.CloseGoingAway, + gws.CloseNormalClosure) { + slog.Info("WebSocket connection closed by client") + } else { + slog.Error("Error receiving WebSocket message", "error", err) + } + return + } + + if ws.rateLimiter != nil { + key := ctx.Value(constants.RateLimitCustomKey) + if key == nil || key == "" { + key = strings.Split(r.RemoteAddr, ":")[0] + } + k, ok := key.(string) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + slog.ErrorContext(ctx, "rate limit key is not a string", "key:=", key) + return + } + ws.rateLimiter.Allow(k) + } + + msg, err := ashttp.GetDefaultSerialization(message) + if err != nil { + w.WriteHeader(errors.DecodeErrorToHttpErrorStatus(err)) + return + } + + select { + case requestChannel <- types.WebsocketStreamChunk{ + MessageType: types.WebsocketMessageType(mt), + Data: msg, + }: + case <-ctx.Done(): + return + } + } + }() + + // Handler goroutine + go func() { + defer cancel() + handler(ctx) + }() + + // Writer loop (main goroutine) + for { + select { + case chunk, ok := <-responseChannel: + if !ok { + return + } + + if chunk.MessageType == types.WebsocketUnspecifiedMessage { + chunk.MessageType = types.WebsocketTextMessage + } + + var headers map[string][]string + var encoded []byte + var err error + + if ws.encoder != nil { + headers, encoded, err = ws.encoder(ctx, chunk.Data) + } else { + headers, encoded, err = ashttp.DefaultHttpEncode(ctx, chunk.Data) + } + if err != nil { + w.WriteHeader(errors.DecodeErrorToHttpErrorStatus(err)) + return + } + + for key, value := range headers { + w.Header().Del(key) + for _, v := range value { + w.Header().Add(key, v) + } + } + + if err := conn.WriteMessage(chunk.MessageType.ToInt(), encoded); err != nil { + slog.Error("Error sending WebSocket message", "error", err) + return + } + + case <-ctx.Done(): + // someone canceled (reader, handler, or connection closed) + closeAll() + return + } + } +} diff --git a/pkg/types/http.go b/pkg/types/http.go index d00e1fc..6a71567 100644 --- a/pkg/types/http.go +++ b/pkg/types/http.go @@ -7,6 +7,7 @@ import ( type MethodOptions struct { IsStreamingResponse bool + AllowedOrigins []string } // HandlerFunc defines a function for serving HTTP requests. @@ -90,8 +91,43 @@ type HttpRequestMiddleware func(ctx context.Context, incomingRequest interface{} // should not be sent. type HttpResponseMiddleware func(ctx context.Context, incomingResponse interface{}) (outgoingResponse interface{}, err error) +// RateLimitOptions specifies configuration settings for applying rate limiting +// to HTTP requests. +// +// Fields +// +// Limit: Maximum number of requests allowed within the specified +// bucket duration. +// BucketDurationInSeconds: Length of the rate-limit window in seconds during which +// the Limit applies. +// ContextKey: Context key whose associated value is used to identify +// the client (e.g., user ID or IP address) for rate limiting. type RateLimitOptions struct { Limit int // number of requests allowed in the given duration BucketDurationInSeconds int64 // duration in seconds for which the limit is applicable ContextKey string // key in context which will be checked for rate limiting } + +// WebSocketOption defines configuration options for a WebSocket endpoint. +// +// Fields +// +// AllowedOrigins: A list of allowed origin URLs for incoming WebSocket +// upgrade requests. If empty, no origin restriction is applied. +type WebSocketOption struct { + AllowedOrigins []string +} + +// WebsocketHandlerFunc defines a function type for handling a WebSocket +// connection after it has been successfully upgraded. +// +// Parameters +// +// ctx: The request-scoped context carrying deadlines, cancellation signals, +// and other metadata for the lifetime of the WebSocket session. +// +// Returns +// +// err: A non-nil error if the WebSocket session encounters a failure or +// needs to be terminated. +type WebsocketHandlerFunc func(context.Context) error diff --git a/pkg/types/streaming.go b/pkg/types/streaming.go index f7ec0e9..21f9f0d 100644 --- a/pkg/types/streaming.go +++ b/pkg/types/streaming.go @@ -12,3 +12,4 @@ type StreamChunk struct { Id uint32 Data []byte } + diff --git a/pkg/types/websocket.go b/pkg/types/websocket.go new file mode 100644 index 0000000..45a52a6 --- /dev/null +++ b/pkg/types/websocket.go @@ -0,0 +1,46 @@ +package types + +type WebSocketHandshakeFunc func() + +// WebsocketStreamChunk represents a single chunk of data in a streaming HTTP response, +// such as Server-Sent Events (SSE) or other streaming protocols. +// +// Fields +// +// Id: A unique identifier for the chunk, typically used to track ordering +// or support reconnection/resume logic. +// MessageType: The type of message being sent (e.g., text, binary, close). +// Data: The raw byte payload of the chunk to be sent to the client. +type WebsocketStreamChunk struct { + Id uint32 + MessageType WebsocketMessageType + Data []byte +} + +// WebsocketMessageType represents the type of a WebSocket frame as defined by +// the WebSocket protocol specification (RFC 6455). It is used to indicate how +// the payload of a WebSocket message should be interpreted. +// +// # Constants +// +// WebsocketUnspecifiedMessage: An unspecified message type, typically unused +// +// WebsocketTextMessage: A UTF-8 encoded text message. +// WebsocketBinaryMessage: A binary data message. +// WebsocketCloseMessage: A control message to close the WebSocket connection. +// WebsocketPingMessage: A control message to check if the peer is alive. +// WebsocketPongMessage: A control message sent in response to a ping. +type WebsocketMessageType int + +const ( + WebsocketUnspecifiedMessage WebsocketMessageType = 0 + WebsocketTextMessage WebsocketMessageType = 1 + WebsocketBinaryMessage WebsocketMessageType = 2 + WebsocketCloseMessage WebsocketMessageType = 8 + WebsocketPingMessage WebsocketMessageType = 9 + WebsocketPongMessage WebsocketMessageType = 10 +) + +func (w WebsocketMessageType) ToInt() int { + return int(w) +}