diff --git a/server/actions_test.go b/server/actions_test.go new file mode 100644 index 0000000..227a818 --- /dev/null +++ b/server/actions_test.go @@ -0,0 +1,508 @@ +package main + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func setupPlugin() *Plugin { + p := &Plugin{ + botUserID: "bot_id", + bridgeClient: NewBridgeClient("http://localhost:3001", nil), + configuration: &configuration{ + BridgeServerURL: "http://localhost:3001", + }, + } + return p +} + +func setupKVMocks(api *plugintest.API) { + sessionData, _ := json.Marshal(&ChannelSession{SessionID: "session_123", UserID: "user_id"}) + api.On("KVGet", mock.Anything).Return(sessionData, nil).Maybe() + api.On("KVSet", mock.Anything, mock.Anything).Return(nil).Maybe() + api.On("KVDelete", mock.Anything).Return(nil).Maybe() +} + +func TestServeHTTP_Routes(t *testing.T) { + tests := []struct { + name string + path string + wantStatus int + }{ + {"not_found", "/api/unknown", http.StatusNotFound}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + req := httptest.NewRequest("POST", tt.path, nil) + w := httptest.NewRecorder() + + p.ServeHTTP(&plugin.Context{}, w, req) + + assert.Equal(t, tt.wantStatus, w.Code) + }) + } +} + +func TestHandleApprove_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + setupKVMocks(api) + + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/sessions/session_123/approve", r.URL.Path) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{}) + })) + defer mockServer.Close() + p.bridgeClient.baseURL = mockServer.URL + + api.On("GetUser", "user_id").Return(&model.User{ + Id: "user_id", + Username: "testuser", + }, nil) + + reqBody := model.PostActionIntegrationRequest{ + UserId: "user_id", + ChannelId: "channel_id", + Context: map[string]interface{}{ + "change_id": "change_123", + }, + } + body, _ := json.Marshal(reqBody) + + // Save session for channel + p.SaveSession("channel_id", &ChannelSession{ + SessionID: "session_123", + UserID: "user_id", + }) + + req := httptest.NewRequest("POST", "/api/action/approve", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleApprove(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.PostActionIntegrationResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Contains(t, response.Update.Message, "Changes approved") +} + +func TestHandleApprove_InvalidRequest(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + req := httptest.NewRequest("POST", "/api/action/approve", bytes.NewReader([]byte("invalid json"))) + w := httptest.NewRecorder() + + p.handleApprove(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) +} + +func TestHandleApprove_MissingChangeID(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + reqBody := model.PostActionIntegrationRequest{ + Context: map[string]interface{}{}, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/action/approve", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleApprove(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) +} + +func TestHandleApprove_NoActiveSession(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + // Mock no active session (KVGet returns nil) + api.On("KVGet", mock.Anything).Return(nil, nil) + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + reqBody := model.PostActionIntegrationRequest{ + ChannelId: "channel_id", + Context: map[string]interface{}{ + "change_id": "change_123", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/action/approve", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleApprove(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) +} + +func TestHandleReject_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + setupKVMocks(api) + + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{}) + })) + defer mockServer.Close() + p.bridgeClient.baseURL = mockServer.URL + + api.On("GetUser", "user_id").Return(&model.User{ + Id: "user_id", + Username: "testuser", + }, nil) + + reqBody := model.PostActionIntegrationRequest{ + UserId: "user_id", + ChannelId: "channel_id", + Context: map[string]interface{}{ + "change_id": "change_123", + }, + } + body, _ := json.Marshal(reqBody) + + p.SaveSession("channel_id", &ChannelSession{ + SessionID: "session_123", + UserID: "user_id", + }) + + req := httptest.NewRequest("POST", "/api/action/reject", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleReject(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.PostActionIntegrationResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Contains(t, response.Update.Message, "Changes rejected") +} + +func TestHandleModify_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("OpenInteractiveDialog", mock.Anything).Return(nil) + api.On("GetConfig", mock.Anything).Return(&model.Config{ + ServiceSettings: model.ServiceSettings{ + SiteURL: model.NewString("http://localhost:8065"), + }, + }) + + reqBody := model.PostActionIntegrationRequest{ + TriggerId: "trigger_123", + Context: map[string]interface{}{ + "change_id": "change_123", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/action/modify", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleModify(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestHandleContinue_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{}) + })) + defer mockServer.Close() + p.bridgeClient.baseURL = mockServer.URL + + reqBody := model.PostActionIntegrationRequest{ + Context: map[string]interface{}{ + "session_id": "session_123", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/action/continue", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleContinue(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestHandleExplain_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{}) + })) + defer mockServer.Close() + p.bridgeClient.baseURL = mockServer.URL + + reqBody := model.PostActionIntegrationRequest{ + Context: map[string]interface{}{ + "session_id": "session_123", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/action/explain", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleExplain(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestHandleUndo_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("OpenInteractiveDialog", mock.Anything).Return(nil) + api.On("GetConfig", mock.Anything).Return(&model.Config{ + ServiceSettings: model.ServiceSettings{ + SiteURL: model.NewString("http://localhost:8065"), + }, + }) + + reqBody := model.PostActionIntegrationRequest{ + TriggerId: "trigger_123", + Context: map[string]interface{}{ + "session_id": "session_123", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/action/undo", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleUndo(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestHandleApply_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + setupKVMocks(api) + + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{}) + })) + defer mockServer.Close() + p.bridgeClient.baseURL = mockServer.URL + + reqBody := model.PostActionIntegrationRequest{ + ChannelId: "channel_id", + Context: map[string]interface{}{ + "change_id": "change_123", + }, + } + body, _ := json.Marshal(reqBody) + + p.SaveSession("channel_id", &ChannelSession{ + SessionID: "session_123", + UserID: "user_id", + }) + + req := httptest.NewRequest("POST", "/api/action/apply", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleApply(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestHandleDiscard_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + setupKVMocks(api) + + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{}) + })) + defer mockServer.Close() + p.bridgeClient.baseURL = mockServer.URL + + reqBody := model.PostActionIntegrationRequest{ + ChannelId: "channel_id", + Context: map[string]interface{}{ + "change_id": "change_123", + }, + } + body, _ := json.Marshal(reqBody) + + p.SaveSession("channel_id", &ChannelSession{ + SessionID: "session_123", + UserID: "user_id", + }) + + req := httptest.NewRequest("POST", "/api/action/discard", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleDiscard(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestHandleView_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + setupKVMocks(api) + + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{ + "content": "file content", + }) + })) + defer mockServer.Close() + p.bridgeClient.baseURL = mockServer.URL + + api.On("SendEphemeralPost", "user_id", mock.Anything).Return(nil) + + reqBody := model.PostActionIntegrationRequest{ + UserId: "user_id", + ChannelId: "channel_id", + Context: map[string]interface{}{ + "filename": "test.go", + }, + } + body, _ := json.Marshal(reqBody) + + p.SaveSession("channel_id", &ChannelSession{ + SessionID: "session_123", + UserID: "user_id", + }) + + req := httptest.NewRequest("POST", "/api/action/view", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleView(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestHandleMenu_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{}) + })) + defer mockServer.Close() + p.bridgeClient.baseURL = mockServer.URL + + reqBody := model.PostActionIntegrationRequest{ + Context: map[string]interface{}{ + "session_id": "session_123", + "selected_option": "test option", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/action/menu", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleMenu(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestWriteError(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + w := httptest.NewRecorder() + p.writeError(w, assert.AnError) + + assert.Equal(t, http.StatusInternalServerError, w.Code) + + var response model.PostActionIntegrationResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Contains(t, response.EphemeralText, "Error:") +} diff --git a/server/bridge_client_test.go b/server/bridge_client_test.go new file mode 100644 index 0000000..e3dc449 --- /dev/null +++ b/server/bridge_client_test.go @@ -0,0 +1,336 @@ +package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" +) + +func TestNewBridgeClient(t *testing.T) { + api := &plugintest.API{} + client := NewBridgeClient("http://localhost:3002", api) + + assert.NotNil(t, client) + assert.Equal(t, "http://localhost:3002", client.baseURL) + assert.NotNil(t, client.httpClient) +} + +func TestCreateSession_Success(t *testing.T) { + // Create test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/sessions", r.URL.Path) + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + + // Verify request body + var reqBody CreateSessionRequest + json.NewDecoder(r.Body).Decode(&reqBody) + assert.Equal(t, "/test/project", reqBody.ProjectPath) + assert.Equal(t, "user123", reqBody.MattermostUserID) + assert.Equal(t, "channel123", reqBody.MattermostChannelID) + + // Send response + w.WriteHeader(http.StatusCreated) + response := map[string]interface{}{ + "session": map[string]interface{}{ + "id": "session123", + "projectPath": "/test/project", + "mattermostUserId": "user123", + "mattermostChannelId": "channel123", + "status": "active", + "createdAt": 1234567890, + "updatedAt": 1234567890, + }, + } + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + session, err := client.CreateSession("/test/project", "user123", "channel123") + + assert.NoError(t, err) + assert.NotNil(t, session) + assert.Equal(t, "session123", session.ID) + assert.Equal(t, "/test/project", session.ProjectPath) + assert.Equal(t, "user123", session.MattermostUserID) + assert.Equal(t, "channel123", session.MattermostChannelID) + assert.Equal(t, "active", session.Status) +} + +func TestCreateSession_Error(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Invalid project path")) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + session, err := client.CreateSession("", "user123", "channel123") + + assert.Error(t, err) + assert.Nil(t, session) + assert.Contains(t, err.Error(), "400") +} + +func TestSendMessage_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/sessions/session123/message", r.URL.Path) + assert.Equal(t, "POST", r.Method) + + var reqBody SendMessageRequest + json.NewDecoder(r.Body).Decode(&reqBody) + assert.Equal(t, "Hello Claude", reqBody.Message) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + err := client.SendMessage("session123", "Hello Claude") + + assert.NoError(t, err) +} + +func TestSendMessage_Error(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("Session not found")) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + err := client.SendMessage("invalid", "test") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "404") +} + +func TestGetMessages_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/sessions/session123/messages", r.URL.Path) + assert.Equal(t, "10", r.URL.Query().Get("limit")) + + w.WriteHeader(http.StatusOK) + response := map[string]interface{}{ + "messages": []map[string]interface{}{ + { + "id": 1, + "sessionId": "session123", + "role": "user", + "content": "Hello", + "timestamp": 1234567890, + }, + { + "id": 2, + "sessionId": "session123", + "role": "assistant", + "content": "Hi there!", + "timestamp": 1234567900, + }, + }, + } + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + messages, err := client.GetMessages("session123", 10) + + assert.NoError(t, err) + assert.Len(t, messages, 2) + assert.Equal(t, "Hello", messages[0].Content) + assert.Equal(t, "Hi there!", messages[1].Content) +} + +func TestGetSession_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/sessions/session123", r.URL.Path) + + pid := 12345 + w.WriteHeader(http.StatusOK) + response := map[string]interface{}{ + "session": map[string]interface{}{ + "id": "session123", + "projectPath": "/test/project", + "mattermostUserId": "user123", + "mattermostChannelId": "channel123", + "cliPid": pid, + "status": "active", + "createdAt": 1234567890, + "updatedAt": 1234567900, + }, + } + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + session, err := client.GetSession("session123") + + assert.NoError(t, err) + assert.NotNil(t, session) + assert.Equal(t, "session123", session.ID) + assert.Equal(t, "active", session.Status) + assert.NotNil(t, session.CLIPid) + assert.Equal(t, 12345, *session.CLIPid) +} + +func TestDeleteSession_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/sessions/session123", r.URL.Path) + assert.Equal(t, "DELETE", r.Method) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "deleted"}) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + err := client.DeleteSession("session123") + + assert.NoError(t, err) +} + +func TestSendContext_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/sessions/session123/context", r.URL.Path) + assert.Equal(t, "POST", r.Method) + + var reqBody ContextRequest + json.NewDecoder(r.Body).Decode(&reqBody) + assert.Equal(t, "thread", reqBody.Source) + assert.Equal(t, "Thread context content", reqBody.Content) + assert.Equal(t, "summarize", reqBody.Action) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + contextReq := &ContextRequest{ + Source: "thread", + Content: "Thread context content", + Action: "summarize", + Metadata: &ContextMetadata{ + ChannelName: "test-channel", + MessageCount: 5, + }, + } + + err := client.SendContext("session123", contextReq) + + assert.NoError(t, err) +} + +func TestApproveChange_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/sessions/session123/approve", r.URL.Path) + + var reqBody map[string]string + json.NewDecoder(r.Body).Decode(&reqBody) + assert.Equal(t, "change456", reqBody["changeId"]) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "approved"}) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + err := client.ApproveChange("session123", "change456") + + assert.NoError(t, err) +} + +func TestRejectChange_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/sessions/session123/reject", r.URL.Path) + + var reqBody map[string]string + json.NewDecoder(r.Body).Decode(&reqBody) + assert.Equal(t, "change456", reqBody["changeId"]) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "rejected"}) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + err := client.RejectChange("session123", "change456") + + assert.NoError(t, err) +} + +func TestModifyChange_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/sessions/session123/modify", r.URL.Path) + + var reqBody map[string]string + json.NewDecoder(r.Body).Decode(&reqBody) + assert.Equal(t, "change456", reqBody["changeId"]) + assert.Equal(t, "Add more tests", reqBody["instructions"]) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "modified"}) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + err := client.ModifyChange("session123", "change456", "Add more tests") + + assert.NoError(t, err) +} + +func TestGetFileContent_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/sessions/session123/file", r.URL.Path) + + var reqBody map[string]string + json.NewDecoder(r.Body).Decode(&reqBody) + assert.Equal(t, "src/main.go", reqBody["filename"]) + + w.WriteHeader(http.StatusOK) + response := map[string]string{ + "content": "package main\n\nfunc main() {}\n", + } + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + api := &plugintest.API{} + client := NewBridgeClient(server.URL, api) + + content, err := client.GetFileContent("session123", "src/main.go") + + assert.NoError(t, err) + assert.Contains(t, content, "package main") + assert.Contains(t, content, "func main") +} diff --git a/server/cli_process.go b/server/cli_process.go index de2c39f..0a533aa 100644 --- a/server/cli_process.go +++ b/server/cli_process.go @@ -346,7 +346,7 @@ func (pm *ProcessManager) GetRunningCount() int { // GetAllProcesses returns a slice of all running processes func (pm *ProcessManager) GetAllProcesses() []*CLIProcess { - var processes []*CLIProcess + processes := []*CLIProcess{} pm.processes.Range(func(key, value interface{}) bool { process := value.(*CLIProcess) select { diff --git a/server/cli_process_test.go b/server/cli_process_test.go new file mode 100644 index 0000000..016b8c3 --- /dev/null +++ b/server/cli_process_test.go @@ -0,0 +1,349 @@ +package main + +import ( + "fmt" + "testing" + "time" + + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" +) + +func TestNewProcessManager(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + + pm := NewProcessManager(plugin) + + assert.NotNil(t, pm) + assert.Equal(t, plugin, pm.plugin) +} + +func TestProcessManagerIsRunning(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + tests := []struct { + name string + sessionID string + want bool + }{ + { + name: "non-existent session", + sessionID: "nonexistent", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := pm.IsRunning(tt.sessionID) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestProcessManagerGetProcess(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + tests := []struct { + name string + sessionID string + want *CLIProcess + }{ + { + name: "non-existent session", + sessionID: "nonexistent", + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := pm.GetProcess(tt.sessionID) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestProcessManagerGetRunningCount(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + // Initially should be 0 + count := pm.GetRunningCount() + assert.Equal(t, 0, count) +} + +func TestProcessManagerGetAllProcesses(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + // Initially should be empty + processes := pm.GetAllProcesses() + assert.NotNil(t, processes) + assert.Len(t, processes, 0) +} + +func TestProcessManagerKillNonExistent(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + // Killing a non-existent process should not error + err := pm.Kill("nonexistent") + assert.NoError(t, err) +} + +func TestProcessManagerKillAll(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + // With no processes, should not panic + pm.KillAll() + + // Should still have 0 processes after + count := pm.GetRunningCount() + assert.Equal(t, 0, count) +} + +func TestCLIProcessStructure(t *testing.T) { + // Test that CLIProcess structure can be created + process := &CLIProcess{ + SessionID: "test-session", + StartTime: time.Now(), + ProjectPath: "/tmp/test", + ChannelID: "channel123", + UserID: "user123", + done: make(chan struct{}), + } + + assert.NotNil(t, process) + assert.Equal(t, "test-session", process.SessionID) + assert.Equal(t, "/tmp/test", process.ProjectPath) + assert.Equal(t, "channel123", process.ChannelID) + assert.Equal(t, "user123", process.UserID) + assert.NotNil(t, process.done) + assert.False(t, process.StartTime.IsZero()) +} + +func TestCLIProcessDoneChannel(t *testing.T) { + process := &CLIProcess{ + SessionID: "test-session", + done: make(chan struct{}), + } + + // Test that done channel is open initially + select { + case <-process.done: + t.Fatal("done channel should be open initially") + default: + // Expected: channel is open but not closed + } + + // Close the channel + close(process.done) + + // Test that done channel is now closed + select { + case <-process.done: + // Expected: channel is closed + default: + t.Fatal("done channel should be closed") + } +} + +// TestProcessManagerSpawnWithoutCLI tests Spawn when CLI is not available +func TestProcessManagerSpawnWithoutCLI(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.configuration = &configuration{ + ClaudeCodePath: "/nonexistent/claude", + } + + pm := NewProcessManager(plugin) + + err := pm.Spawn("session1", "/tmp/project", "channel1", "user1") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +// TestProcessManagerSpawnDuplicate tests that spawning duplicate session fails +func TestProcessManagerSpawnDuplicate(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + // Manually add a fake process to simulate existing session + process := &CLIProcess{ + SessionID: "session1", + done: make(chan struct{}), + } + pm.processes.Store("session1", process) + + // Try to spawn duplicate + err := pm.Spawn("session1", "/tmp/project", "channel1", "user1") + assert.Error(t, err) + assert.Contains(t, err.Error(), "already has a running process") + + // Clean up + pm.processes.Delete("session1") +} + +// TestProcessManagerSendInputNonExistent tests sending input to non-existent process +func TestProcessManagerSendInputNonExistent(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + err := pm.SendInput("nonexistent", "test input") + assert.Error(t, err) + assert.Contains(t, err.Error(), "no running process") +} + +// TestProcessManagerSendInputJSONNonExistent tests sending JSON to non-existent process +func TestProcessManagerSendInputJSONNonExistent(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + err := pm.SendInputJSON("nonexistent", map[string]string{"test": "data"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no running process") +} + +// TestProcessManagerSendInputJSONInvalidData tests sending invalid JSON +func TestProcessManagerSendInputJSONInvalidData(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + // Try to marshal invalid data (channels can't be marshaled) + invalidData := make(chan int) + err := pm.SendInputJSON("session1", invalidData) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to marshal") +} + +// TestProcessManagerMultipleSessions tests managing multiple sessions +func TestProcessManagerMultipleSessions(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + // Add multiple fake processes + sessions := []string{"session1", "session2", "session3"} + for _, sessionID := range sessions { + process := &CLIProcess{ + SessionID: sessionID, + done: make(chan struct{}), + } + pm.processes.Store(sessionID, process) + } + + // Test IsRunning for all + for _, sessionID := range sessions { + assert.True(t, pm.IsRunning(sessionID)) + } + + // Test GetRunningCount + count := pm.GetRunningCount() + assert.Equal(t, 3, count) + + // Test GetAllProcesses + processes := pm.GetAllProcesses() + assert.Len(t, processes, 3) + + // Test GetProcess for each + for _, sessionID := range sessions { + process := pm.GetProcess(sessionID) + assert.NotNil(t, process) + assert.Equal(t, sessionID, process.SessionID) + } + + // Close one process + processInterface, _ := pm.processes.Load("session1") + process1 := processInterface.(*CLIProcess) + close(process1.done) + + // Running count should now be 2 + count = pm.GetRunningCount() + assert.Equal(t, 2, count) + + // IsRunning should return false for closed process + assert.False(t, pm.IsRunning("session1")) + assert.True(t, pm.IsRunning("session2")) + assert.True(t, pm.IsRunning("session3")) + + // Clean up + pm.processes.Delete("session1") + pm.processes.Delete("session2") + pm.processes.Delete("session3") +} + +// TestProcessManagerConcurrentAccess tests thread-safe access +func TestProcessManagerConcurrentAccess(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + pm := NewProcessManager(plugin) + + // Add some processes + for i := 0; i < 10; i++ { + sessionID := fmt.Sprintf("session%d", i) + process := &CLIProcess{ + SessionID: sessionID, + done: make(chan struct{}), + } + pm.processes.Store(sessionID, process) + } + + // Concurrently access the processes + done := make(chan bool) + for i := 0; i < 5; i++ { + go func() { + for j := 0; j < 10; j++ { + sessionID := fmt.Sprintf("session%d", j) + pm.IsRunning(sessionID) + pm.GetProcess(sessionID) + } + pm.GetRunningCount() + pm.GetAllProcesses() + done <- true + }() + } + + // Wait for all goroutines + for i := 0; i < 5; i++ { + <-done + } + + // Should still have all processes + count := pm.GetRunningCount() + assert.Equal(t, 10, count) + + // Clean up + for i := 0; i < 10; i++ { + sessionID := fmt.Sprintf("session%d", i) + pm.processes.Delete(sessionID) + } +} diff --git a/server/commands_test.go b/server/commands_test.go new file mode 100644 index 0000000..6292a50 --- /dev/null +++ b/server/commands_test.go @@ -0,0 +1,455 @@ +package main + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestExecuteCommand_Help(t *testing.T) { + p := setupTestPlugin(t) + defer p.API.(*plugintest.API).AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude-help", + UserId: "user1", + ChannelId: "channel1", + } + + response, appErr := p.ExecuteCommand(nil, args) + if appErr != nil { + t.Fatalf("ExecuteCommand returned AppError: %v", appErr) + } + assert.NotNil(t, response) + assert.Contains(t, response.Text, "Claude Code - AI Coding Assistant") +} + +func TestExecuteCommand_StartWithoutPath(t *testing.T) { + p := setupTestPlugin(t) + defer p.API.(*plugintest.API).AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude-start", + UserId: "user1", + ChannelId: "channel1", + } + + response, appErr := p.ExecuteCommand(nil, args) + if appErr != nil { + t.Fatalf("ExecuteCommand returned AppError: %v", appErr) + } + assert.NotNil(t, response) + assert.Contains(t, response.Text, "Please provide a project path") +} + +func TestExecuteCommand_StopWithoutSession(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // No active session + api.On("KVGet", mock.AnythingOfType("string")).Return(nil, nil) + + defer api.AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude-stop", + UserId: "user1", + ChannelId: "channel1", + } + + response, appErr := p.ExecuteCommand(nil, args) + if appErr != nil { + t.Fatalf("ExecuteCommand returned AppError: %v", appErr) + } + assert.NotNil(t, response) + assert.Contains(t, response.Text, "No active session") +} + +func TestExecuteCommand_SendWithoutSession(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // No active session + api.On("KVGet", mock.AnythingOfType("string")).Return(nil, nil) + + defer api.AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude hello world", + UserId: "user1", + ChannelId: "channel1", + } + + response, appErr := p.ExecuteCommand(nil, args) + if appErr != nil { + t.Fatalf("ExecuteCommand returned AppError: %v", appErr) + } + assert.NotNil(t, response) + assert.Contains(t, response.Text, "No active session") +} + +func TestExecuteCommand_Status(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // No active session + api.On("KVGet", mock.AnythingOfType("string")).Return(nil, nil) + + defer api.AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude-status", + UserId: "user1", + ChannelId: "channel1", + } + + response, appErr := p.ExecuteCommand(nil, args) + if appErr != nil { + t.Fatalf("ExecuteCommand returned AppError: %v", appErr) + } + assert.NotNil(t, response) + // Should show no active session + assert.Contains(t, response.Text, "No active session") +} + +func TestExecuteCommand_FilesWithoutSession(t *testing.T) { + p := setupTestPlugin(t) + defer p.API.(*plugintest.API).AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude-files", + UserId: "user1", + ChannelId: "channel1", + } + + response, appErr := p.ExecuteCommand(nil, args) + if appErr != nil { + t.Fatalf("ExecuteCommand returned AppError: %v", appErr) + } + assert.NotNil(t, response) + // claude-files command doesn't exist, should return unknown command + assert.Contains(t, response.Text, "Unknown command") +} + +func TestExecuteCommand_ThreadWithoutSession(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // No active session + api.On("KVGet", mock.AnythingOfType("string")).Return(nil, nil) + + defer api.AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude-thread context", + UserId: "user1", + ChannelId: "channel1", + RootId: "root1", // In a thread + } + + response, appErr := p.ExecuteCommand(nil, args) + if appErr != nil { + t.Fatalf("ExecuteCommand returned AppError: %v", appErr) + } + assert.NotNil(t, response) + assert.Contains(t, response.Text, "No active Claude session") +} + +func TestExecuteCommand_ThreadNotInThread(t *testing.T) { + p := setupTestPlugin(t) + defer p.API.(*plugintest.API).AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude-thread context", + UserId: "user1", + ChannelId: "channel1", + RootId: "", // Not in a thread + } + + response, appErr := p.ExecuteCommand(nil, args) + if appErr != nil { + t.Fatalf("ExecuteCommand returned AppError: %v", appErr) + } + assert.NotNil(t, response) + assert.Contains(t, response.Text, "must be run in a thread") +} + +func TestExecuteCommand_InvalidCommand(t *testing.T) { + p := setupTestPlugin(t) + defer p.API.(*plugintest.API).AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude-invalid", + UserId: "user1", + ChannelId: "channel1", + } + + response, appErr := p.ExecuteCommand(nil, args) + if appErr != nil { + t.Fatalf("ExecuteCommand returned AppError: %v", appErr) + } + assert.NotNil(t, response) + assert.Contains(t, response.Text, "Unknown command") +} + +func TestFormatDuration(t *testing.T) { + // Test with a recent timestamp (less than a minute ago) + recentTimestamp := time.Now().Add(-30 * time.Second).Unix() + result := formatDuration(recentTimestamp) + assert.Contains(t, result, "seconds ago") + + // Test with a timestamp from a few minutes ago + minutesAgo := time.Now().Add(-5 * time.Minute).Unix() + result = formatDuration(minutesAgo) + assert.Contains(t, result, "minutes ago") + + // Test with a timestamp from a few hours ago + hoursAgo := time.Now().Add(-3 * time.Hour).Unix() + result = formatDuration(hoursAgo) + assert.Contains(t, result, "hours ago") + + // Test with a timestamp from days ago + daysAgo := time.Now().Add(-2 * 24 * time.Hour).Unix() + result = formatDuration(daysAgo) + assert.Contains(t, result, "days ago") +} + +func TestFormatPID(t *testing.T) { + // Test with a valid PID + pid := 12345 + result := formatPID(&pid) + assert.Equal(t, "PID 12345", result) + + // Test with nil PID + result = formatPID(nil) + assert.Equal(t, "Not running", result) + + // Test with zero PID + zeroPID := 0 + result = formatPID(&zeroPID) + assert.Equal(t, "PID 0", result) +} + +func TestExecuteClaudeStart_WithExistingSession(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock existing session + existingSession := &ChannelSession{ + SessionID: "session123", + ProjectPath: "/tmp/old", + UserID: "user1", + } + data, _ := json.Marshal(existingSession) + api.On("KVGet", "session_channel1").Return(data, nil) + + defer api.AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude-start /tmp/test", + UserId: "user1", + ChannelId: "channel1", + } + + response, appErr := p.ExecuteCommand(nil, args) + assert.Nil(t, appErr) + assert.NotNil(t, response) + assert.Contains(t, response.Text, "already has an active session") + assert.Contains(t, response.Text, "/tmp/old") +} + +func TestExecuteClaudeStatus_WithActiveSession(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock existing session + existingSession := &ChannelSession{ + SessionID: "session123", + ProjectPath: "/tmp/test", + UserID: "user1", + CreatedAt: 1678901234, + LastMessageAt: 1678901334, + } + data, _ := json.Marshal(existingSession) + api.On("KVGet", "session_channel1").Return(data, nil) + + // Mock bridge client GetSession call + // Note: This will fail without a working bridge, so we expect an error path + api.On("LogError", "Failed to get session from bridge", mock.Anything, mock.Anything).Return() + + defer api.AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude-status", + UserId: "user1", + ChannelId: "channel1", + } + + response, appErr := p.ExecuteCommand(nil, args) + assert.Nil(t, appErr) + assert.NotNil(t, response) + // Should show session status (even if bridge details fail) + assert.Contains(t, response.Text, "Session Status") +} + +func TestExecuteClaudeThread_InvalidAction(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock existing session + existingSession := &ChannelSession{ + SessionID: "session123", + ProjectPath: "/tmp/test", + UserID: "user1", + } + data, _ := json.Marshal(existingSession) + api.On("KVGet", "session_channel1").Return(data, nil) + + // Mock GetChannel (required by GetThreadContext) + channel := &model.Channel{ + Id: "channel1", + Name: "test-channel", + Type: model.ChannelTypeOpen, + } + api.On("GetChannel", "channel1").Return(channel, nil) + + // Mock GetPostThread (required by GetThreadContext) with at least one post + rootPost := &model.Post{ + Id: "root1", + UserId: "user1", + ChannelId: "channel1", + Message: "Root post", + CreateAt: 1678901234000, + } + postList := &model.PostList{ + Order: []string{"root1"}, + Posts: map[string]*model.Post{ + "root1": rootPost, + }, + } + api.On("GetPostThread", "root1").Return(postList, nil) + + // Mock GetUser for username lookup + user := &model.User{ + Id: "user1", + Username: "testuser", + } + api.On("GetUser", "user1").Return(user, nil) + + // Mock log calls for bridge connection failure and thread send failure + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return().Maybe() + api.On("LogWarn", mock.Anything, mock.Anything, mock.Anything).Return().Maybe() + + defer api.AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude-thread invalid", + UserId: "user1", + ChannelId: "channel1", + RootId: "root1", + } + + response, appErr := p.ExecuteCommand(nil, args) + assert.Nil(t, appErr) + assert.NotNil(t, response) + // With invalid action and bridge failure, we'll get an error message + assert.NotEmpty(t, response.Text) +} + +func TestExecuteClaude_EmptyMessage(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // No need to mock KVGet - empty message is checked before session retrieval + + defer api.AssertExpectations(t) + + args := &model.CommandArgs{ + Command: "/claude", + UserId: "user1", + ChannelId: "channel1", + } + + response, appErr := p.ExecuteCommand(nil, args) + assert.Nil(t, appErr) + assert.NotNil(t, response) + assert.Contains(t, response.Text, "Please provide a message") +} + +func TestRespondEphemeral(t *testing.T) { + response := respondEphemeral("Test message") + assert.NotNil(t, response) + assert.Equal(t, "Test message", response.Text) + assert.Equal(t, model.CommandResponseTypeEphemeral, response.ResponseType) +} + +func TestParseProjectPath(t *testing.T) { + tests := []struct { + name string + command string + expected string + }{ + { + name: "simple path", + command: "/claude-start /tmp/test", + expected: "/tmp/test", + }, + { + name: "path with spaces", + command: "/claude-start /tmp/test project", + expected: "/tmp/test", + }, + { + name: "quoted path", + command: "/claude-start \"/tmp/test project\"", + expected: "\"/tmp/test", + }, + { + name: "no path", + command: "/claude-start", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parts := strings.Fields(tt.command) + if len(parts) > 1 { + result := strings.Join(parts[1:], " ") + if tt.expected != "" { + assert.Contains(t, result, tt.expected) + } else { + // Can't assert empty because we're testing different input + } + } + }) + } +} + +// setupTestPlugin creates a plugin instance with mocked API for testing +func setupTestPlugin(t *testing.T) *Plugin { + api := &plugintest.API{} + + p := &Plugin{} + p.SetAPI(api) + p.botUserID = "bot123" + + // Initialize configuration + config := &configuration{ + BridgeServerURL: "http://localhost:3002", + ClaudeCodePath: "/usr/local/bin/claude-code", + EnableFileOperations: true, + } + p.setConfiguration(config) + + // Initialize bridge client + p.bridgeClient = NewBridgeClient("http://localhost:3002", api) + + return p +} diff --git a/server/configuration_test.go b/server/configuration_test.go new file mode 100644 index 0000000..9cf0750 --- /dev/null +++ b/server/configuration_test.go @@ -0,0 +1,186 @@ +package main + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestConfiguration_Clone(t *testing.T) { + original := &configuration{ + BridgeServerURL: "http://localhost:3001", + ClaudeCodePath: "/usr/local/bin/claude-code", + EnableFileOperations: true, + } + + cloned := original.Clone() + + // Should be equal + assert.Equal(t, original.BridgeServerURL, cloned.BridgeServerURL) + assert.Equal(t, original.ClaudeCodePath, cloned.ClaudeCodePath) + assert.Equal(t, original.EnableFileOperations, cloned.EnableFileOperations) + + // Should be different pointers + assert.NotSame(t, original, cloned) + + // Modifying clone should not affect original + cloned.BridgeServerURL = "http://localhost:3002" + assert.NotEqual(t, original.BridgeServerURL, cloned.BridgeServerURL) +} + +func TestGetConfiguration_Nil(t *testing.T) { + p := &Plugin{ + configuration: nil, + } + + config := p.getConfiguration() + + assert.NotNil(t, config) + assert.Equal(t, "", config.BridgeServerURL) + assert.Equal(t, "claude", config.ClaudeCodePath) + assert.Equal(t, false, config.EnableFileOperations) +} + +func TestGetConfiguration_Existing(t *testing.T) { + expected := &configuration{ + BridgeServerURL: "http://localhost:3001", + ClaudeCodePath: "/usr/local/bin/claude-code", + EnableFileOperations: true, + } + + p := &Plugin{ + configuration: expected, + } + + config := p.getConfiguration() + + assert.Equal(t, expected, config) + assert.Same(t, expected, config) +} + +func TestSetConfiguration_New(t *testing.T) { + p := &Plugin{} + + newConfig := &configuration{ + BridgeServerURL: "http://localhost:3001", + } + + p.setConfiguration(newConfig) + + assert.Equal(t, newConfig, p.configuration) +} + +func TestSetConfiguration_Different(t *testing.T) { + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: "http://localhost:3001", + }, + } + + newConfig := &configuration{ + BridgeServerURL: "http://localhost:3002", + } + + p.setConfiguration(newConfig) + + assert.Equal(t, newConfig, p.configuration) +} + +func TestSetConfiguration_SamePointer(t *testing.T) { + config := &configuration{ + BridgeServerURL: "http://localhost:3001", + } + + p := &Plugin{ + configuration: config, + } + + // Should panic when setting the same pointer + assert.Panics(t, func() { + p.setConfiguration(config) + }) +} + +func TestSetConfiguration_Nil(t *testing.T) { + p := &Plugin{ + configuration: nil, + } + + newConfig := &configuration{ + BridgeServerURL: "http://localhost:3001", + } + + // Should not panic when setting nil to a new config + assert.NotPanics(t, func() { + p.setConfiguration(newConfig) + }) + assert.Equal(t, newConfig, p.configuration) +} + +func TestOnConfigurationChange_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := &Plugin{} + p.SetAPI(api) + + api.On("LoadPluginConfiguration", mock.AnythingOfType("*main.configuration")). + Run(func(args mock.Arguments) { + config := args.Get(0).(*configuration) + config.BridgeServerURL = "http://localhost:3001" + config.ClaudeCodePath = "/usr/local/bin/claude-code" + config.EnableFileOperations = true + }). + Return(nil) + + err := p.OnConfigurationChange() + + assert.NoError(t, err) + assert.NotNil(t, p.configuration) + assert.Equal(t, "http://localhost:3001", p.configuration.BridgeServerURL) + assert.Equal(t, "/usr/local/bin/claude-code", p.configuration.ClaudeCodePath) + assert.Equal(t, true, p.configuration.EnableFileOperations) +} + +func TestOnConfigurationChange_LoadError(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := &Plugin{} + p.SetAPI(api) + + loadError := errors.New("load error") + api.On("LoadPluginConfiguration", mock.AnythingOfType("*main.configuration")). + Return(loadError) + + err := p.OnConfigurationChange() + + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to load plugin configuration") +} + +func TestConfiguration_ConcurrentAccess(t *testing.T) { + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: "http://localhost:3001", + }, + } + + // Start multiple goroutines reading configuration + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + config := p.getConfiguration() + assert.NotNil(t, config) + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } +} diff --git a/server/dialogs_test.go b/server/dialogs_test.go new file mode 100644 index 0000000..b6d0795 --- /dev/null +++ b/server/dialogs_test.go @@ -0,0 +1,362 @@ +package main + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestHandleModifyDialog_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + setupKVMocks(api) + + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.URL.Path, "/api/sessions/") + assert.Contains(t, r.URL.Path, "/modify") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{}) + })) + defer mockServer.Close() + p.bridgeClient.baseURL = mockServer.URL + + // Save session for channel + p.SaveSession("channel_id", &ChannelSession{ + SessionID: "session_123", + UserID: "user_id", + }) + + api.On("CreatePost", mock.AnythingOfType("*model.Post")).Return(&model.Post{}, nil) + + reqBody := model.SubmitDialogRequest{ + ChannelId: "channel_id", + Submission: map[string]interface{}{ + "instructions": "Make it faster", + "change_id": "change_123", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/dialog/modify-change", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleModifyDialog(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Empty(t, response.Error) +} + +func TestHandleModifyDialog_InvalidRequest(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + req := httptest.NewRequest("POST", "/api/dialog/modify-change", bytes.NewReader([]byte("invalid json"))) + w := httptest.NewRecorder() + + p.handleModifyDialog(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Equal(t, "Invalid request", response.Error) +} + +func TestHandleModifyDialog_MissingInstructions(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + reqBody := model.SubmitDialogRequest{ + Submission: map[string]interface{}{ + "change_id": "change_123", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/dialog/modify-change", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleModifyDialog(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Equal(t, "Please provide modification instructions", response.Error) +} + +func TestHandleModifyDialog_EmptyInstructions(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + reqBody := model.SubmitDialogRequest{ + Submission: map[string]interface{}{ + "instructions": "", + "change_id": "change_123", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/dialog/modify-change", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleModifyDialog(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Equal(t, "Please provide modification instructions", response.Error) +} + +func TestHandleModifyDialog_MissingChangeID(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + reqBody := model.SubmitDialogRequest{ + Submission: map[string]interface{}{ + "instructions": "Make it faster", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/dialog/modify-change", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleModifyDialog(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Equal(t, "Missing change ID", response.Error) +} + +func TestHandleModifyDialog_NoActiveSession(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + // Mock no active session (KVGet returns nil) + api.On("KVGet", mock.Anything).Return(nil, nil) + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + reqBody := model.SubmitDialogRequest{ + ChannelId: "channel_id", + Submission: map[string]interface{}{ + "instructions": "Make it faster", + "change_id": "change_123", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/dialog/modify-change", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleModifyDialog(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Equal(t, "No active session", response.Error) +} + +func TestHandleConfirmDialog_Success_Undo(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{}) + })) + defer mockServer.Close() + p.bridgeClient.baseURL = mockServer.URL + + api.On("CreatePost", mock.AnythingOfType("*model.Post")).Return(&model.Post{}, nil) + + reqBody := model.SubmitDialogRequest{ + ChannelId: "channel_id", + Submission: map[string]interface{}{ + "session_id": "session_123", + "action": "undo", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/dialog/confirm", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleConfirmDialog(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Empty(t, response.Error) +} + +func TestHandleConfirmDialog_InvalidRequest(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + req := httptest.NewRequest("POST", "/api/dialog/confirm", bytes.NewReader([]byte("invalid json"))) + w := httptest.NewRecorder() + + p.handleConfirmDialog(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Equal(t, "Invalid request", response.Error) +} + +func TestHandleConfirmDialog_MissingSessionID(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + reqBody := model.SubmitDialogRequest{ + Submission: map[string]interface{}{ + "action": "undo", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/dialog/confirm", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleConfirmDialog(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Equal(t, "Missing session ID", response.Error) +} + +func TestHandleConfirmDialog_MissingAction(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + reqBody := model.SubmitDialogRequest{ + Submission: map[string]interface{}{ + "session_id": "session_123", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/dialog/confirm", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleConfirmDialog(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Equal(t, "Missing action", response.Error) +} + +func TestHandleConfirmDialog_UnknownAction(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + reqBody := model.SubmitDialogRequest{ + Submission: map[string]interface{}{ + "session_id": "session_123", + "action": "unknown_action", + }, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest("POST", "/api/dialog/confirm", bytes.NewReader(body)) + w := httptest.NewRecorder() + + p.handleConfirmDialog(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Contains(t, response.Error, "Unknown action") +} + +func TestWriteDialogError(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + w := httptest.NewRecorder() + p.writeDialogError(w, "Test error message") + + assert.Equal(t, http.StatusOK, w.Code) + + var response model.SubmitDialogResponse + json.NewDecoder(w.Body).Decode(&response) + assert.Equal(t, "Test error message", response.Error) +} diff --git a/server/go.mod b/server/go.mod index 9a5bf96..c4e9b47 100644 --- a/server/go.mod +++ b/server/go.mod @@ -6,10 +6,12 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/mattermost/mattermost/server/public v0.1.1 github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.11.1 ) require ( github.com/blang/semver/v4 v4.0.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dyatlov/go-opengraph/opengraph v0.0.0-20220524092352-606d7b1e5f8a // indirect github.com/fatih/color v1.16.0 // indirect github.com/francoispqt/gojay v1.2.13 // indirect @@ -31,6 +33,8 @@ require ( github.com/pborman/uuid v1.2.1 // indirect github.com/pelletier/go-toml v1.9.5 // indirect github.com/philhofer/fwd v1.1.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tinylib/msgp v1.1.9 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect @@ -45,4 +49,5 @@ require ( google.golang.org/protobuf v1.32.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/server/go.sum b/server/go.sum index 3e6f312..2fd952a 100644 --- a/server/go.sum +++ b/server/go.sum @@ -163,11 +163,13 @@ github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5k github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tinylib/msgp v1.1.9 h1:SHf3yoO2sGA0veCJeCBYLHuttAVFHGm2RHgNodW7wQU= github.com/tinylib/msgp v1.1.9/go.mod h1:BCXGB54lDD8qUEPmiG0cQQUANC4IUQyB2ItS2UDlO/k= diff --git a/server/health_test.go b/server/health_test.go new file mode 100644 index 0000000..def39af --- /dev/null +++ b/server/health_test.go @@ -0,0 +1,291 @@ +package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestCheckBridgeHealth_Success(t *testing.T) { + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/health", r.URL.Path) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(BridgeHealthResponse{ + Status: "ok", + Version: "1.0.0", + Uptime: 3600, + Sessions: 5, + Timestamp: time.Now().UTC().Format(time.RFC3339), + }) + })) + defer mockServer.Close() + + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: mockServer.URL, + }, + } + + health, err := p.CheckBridgeHealth() + + assert.NoError(t, err) + assert.NotNil(t, health) + assert.Equal(t, "ok", health.Status) + assert.Equal(t, "1.0.0", health.Version) + assert.Equal(t, 3600, health.Uptime) + assert.Equal(t, 5, health.Sessions) +} + +func TestCheckBridgeHealth_NoURL(t *testing.T) { + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: "", + }, + } + + health, err := p.CheckBridgeHealth() + + assert.Error(t, err) + assert.Nil(t, health) + assert.Contains(t, err.Error(), "bridge server URL not configured") +} + +func TestCheckBridgeHealth_ConnectionError(t *testing.T) { + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: "http://localhost:99999", + }, + } + + health, err := p.CheckBridgeHealth() + + assert.Error(t, err) + assert.Nil(t, health) + assert.Contains(t, err.Error(), "failed to connect to bridge server") +} + +func TestCheckBridgeHealth_BadStatus(t *testing.T) { + // Setup mock bridge server returning 500 + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer mockServer.Close() + + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: mockServer.URL, + }, + } + + health, err := p.CheckBridgeHealth() + + assert.Error(t, err) + assert.Nil(t, health) + assert.Contains(t, err.Error(), "bridge server returned status") +} + +func TestCheckBridgeHealth_InvalidJSON(t *testing.T) { + // Setup mock bridge server returning invalid JSON + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("invalid json")) + })) + defer mockServer.Close() + + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: mockServer.URL, + }, + } + + health, err := p.CheckBridgeHealth() + + assert.Error(t, err) + assert.Nil(t, health) + assert.Contains(t, err.Error(), "failed to decode health response") +} + +func TestCheckBridgeHealth_Timeout(t *testing.T) { + // Setup mock bridge server with delay + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer mockServer.Close() + + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: mockServer.URL, + }, + } + + health, err := p.CheckBridgeHealth() + + assert.Error(t, err) + assert.Nil(t, health) + assert.Contains(t, err.Error(), "failed to connect to bridge server") +} + +func TestGetHealthStatus_BridgeHealthy(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(BridgeHealthResponse{ + Status: "ok", + Version: "1.0.0", + Uptime: 3600, + Sessions: 5, + Timestamp: time.Now().UTC().Format(time.RFC3339), + }) + })) + defer mockServer.Close() + + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: mockServer.URL, + }, + } + p.SetAPI(api) + + status := p.GetHealthStatus() + + assert.NotNil(t, status) + assert.Equal(t, "ok", status.Status) + assert.True(t, status.BridgeConnected) + assert.Equal(t, mockServer.URL, status.BridgeURL) + assert.Equal(t, 5, status.ActiveSessions) + assert.NotEmpty(t, status.Timestamp) +} + +func TestGetHealthStatus_BridgeUnhealthy(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + api.On("LogWarn", mock.Anything, mock.Anything, mock.Anything).Return() + + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: "http://localhost:99999", + }, + } + p.SetAPI(api) + + status := p.GetHealthStatus() + + assert.NotNil(t, status) + assert.Equal(t, "degraded", status.Status) + assert.False(t, status.BridgeConnected) + assert.Equal(t, "http://localhost:99999", status.BridgeURL) + assert.Equal(t, 0, status.ActiveSessions) + assert.NotEmpty(t, status.Timestamp) +} + +func TestIsBridgeHealthy_True(t *testing.T) { + // Setup mock bridge server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(BridgeHealthResponse{ + Status: "ok", + }) + })) + defer mockServer.Close() + + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: mockServer.URL, + }, + } + + healthy := p.IsBridgeHealthy() + + assert.True(t, healthy) +} + +func TestIsBridgeHealthy_False_ConnectionError(t *testing.T) { + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: "http://localhost:99999", + }, + } + + healthy := p.IsBridgeHealthy() + + assert.False(t, healthy) +} + +func TestIsBridgeHealthy_False_BadStatus(t *testing.T) { + // Setup mock bridge server returning degraded status + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(BridgeHealthResponse{ + Status: "degraded", + }) + })) + defer mockServer.Close() + + p := &Plugin{ + configuration: &configuration{ + BridgeServerURL: mockServer.URL, + }, + } + + healthy := p.IsBridgeHealthy() + + assert.False(t, healthy) +} + +func TestHealthStatus_JSONSerialization(t *testing.T) { + status := &HealthStatus{ + Status: "ok", + BridgeConnected: true, + BridgeURL: "http://localhost:3001", + ActiveSessions: 5, + Timestamp: "2024-01-01T00:00:00Z", + } + + data, err := json.Marshal(status) + assert.NoError(t, err) + + var decoded HealthStatus + err = json.Unmarshal(data, &decoded) + assert.NoError(t, err) + + assert.Equal(t, status.Status, decoded.Status) + assert.Equal(t, status.BridgeConnected, decoded.BridgeConnected) + assert.Equal(t, status.BridgeURL, decoded.BridgeURL) + assert.Equal(t, status.ActiveSessions, decoded.ActiveSessions) + assert.Equal(t, status.Timestamp, decoded.Timestamp) +} + +func TestBridgeHealthResponse_JSONSerialization(t *testing.T) { + response := &BridgeHealthResponse{ + Status: "ok", + Version: "1.0.0", + Uptime: 3600, + Sessions: 5, + Timestamp: "2024-01-01T00:00:00Z", + } + + data, err := json.Marshal(response) + assert.NoError(t, err) + + var decoded BridgeHealthResponse + err = json.Unmarshal(data, &decoded) + assert.NoError(t, err) + + assert.Equal(t, response.Status, decoded.Status) + assert.Equal(t, response.Version, decoded.Version) + assert.Equal(t, response.Uptime, decoded.Uptime) + assert.Equal(t, response.Sessions, decoded.Sessions) + assert.Equal(t, response.Timestamp, decoded.Timestamp) +} diff --git a/server/message_store_test.go b/server/message_store_test.go new file mode 100644 index 0000000..3c3db6f --- /dev/null +++ b/server/message_store_test.go @@ -0,0 +1,593 @@ +package main + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestNewMessageStore(t *testing.T) { + api := &plugintest.API{} + store := NewMessageStore(api) + + assert.NotNil(t, store) + assert.Equal(t, api, store.api) +} + +func TestMessageStoreKVKey(t *testing.T) { + api := &plugintest.API{} + store := NewMessageStore(api) + + tests := []struct { + name string + sessionID string + wantKey string + }{ + { + name: "normal session ID", + sessionID: "session123", + wantKey: "messages_session123", + }, + { + name: "empty session ID", + sessionID: "", + wantKey: "messages_", + }, + { + name: "session ID with special chars", + sessionID: "user_123_channel_456", + wantKey: "messages_user_123_channel_456", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := store.kvKey(tt.sessionID) + assert.Equal(t, tt.wantKey, key) + }) + } +} + +func TestMessageStoreGetMessages(t *testing.T) { + tests := []struct { + name string + sessionID string + kvData []byte + kvErr *model.AppError + want []StoredMessage + wantErr bool + }{ + { + name: "empty messages", + sessionID: "session1", + kvData: nil, + kvErr: nil, + want: []StoredMessage{}, + wantErr: false, + }, + { + name: "valid messages", + sessionID: "session2", + kvData: func() []byte { + sm := SessionMessages{ + SessionID: "session2", + Messages: []StoredMessage{ + { + ID: "session2_0", + SessionID: "session2", + Role: "user", + Content: "Hello", + Timestamp: time.Now(), + }, + { + ID: "session2_1", + SessionID: "session2", + Role: "assistant", + Content: "Hi there!", + Timestamp: time.Now(), + }, + }, + UpdatedAt: time.Now(), + } + data, _ := json.Marshal(sm) + return data + }(), + kvErr: nil, + want: []StoredMessage{{}, {}}, + wantErr: false, + }, + { + name: "KV get error", + sessionID: "session3", + kvData: nil, + kvErr: model.NewAppError("test", "test.error", nil, "", 500), + want: nil, + wantErr: true, + }, + { + name: "invalid JSON", + sessionID: "session4", + kvData: []byte("invalid json"), + kvErr: nil, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + api := &plugintest.API{} + store := NewMessageStore(api) + + api.On("KVGet", fmt.Sprintf("messages_%s", tt.sessionID)).Return(tt.kvData, tt.kvErr) + + got, err := store.GetMessages(tt.sessionID) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.want == nil { + assert.Nil(t, got) + } else { + assert.Len(t, got, len(tt.want)) + } + } + + api.AssertExpectations(t) + }) + } +} + +func TestMessageStoreAddMessage(t *testing.T) { + tests := []struct { + name string + sessionID string + role string + content string + existingData []byte + getErr *model.AppError + setErr *model.AppError + wantErr bool + wantMessageID string + wantRole string + wantContent string + }{ + { + name: "add first message", + sessionID: "session1", + role: "user", + content: "Hello", + existingData: nil, + getErr: nil, + setErr: nil, + wantErr: false, + wantMessageID: "session1_0", + wantRole: "user", + wantContent: "Hello", + }, + { + name: "add second message", + sessionID: "session2", + role: "assistant", + content: "Hi there!", + existingData: func() []byte { + sm := SessionMessages{ + SessionID: "session2", + Messages: []StoredMessage{ + { + ID: "session2_0", + SessionID: "session2", + Role: "user", + Content: "Hello", + Timestamp: time.Now(), + }, + }, + UpdatedAt: time.Now(), + } + data, _ := json.Marshal(sm) + return data + }(), + getErr: nil, + setErr: nil, + wantErr: false, + wantMessageID: "session2_1", + wantRole: "assistant", + wantContent: "Hi there!", + }, + { + name: "KV set error", + sessionID: "session3", + role: "user", + content: "Test", + existingData: nil, + getErr: nil, + setErr: model.NewAppError("test", "test.error", nil, "", 500), + wantErr: true, + }, + { + name: "get error but still works", + sessionID: "session4", + role: "user", + content: "Test", + existingData: nil, + getErr: model.NewAppError("test", "test.error", nil, "", 500), + setErr: nil, + wantErr: false, + wantMessageID: "session4_0", + wantRole: "user", + wantContent: "Test", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + api := &plugintest.API{} + store := NewMessageStore(api) + + kvKey := fmt.Sprintf("messages_%s", tt.sessionID) + api.On("KVGet", kvKey).Return(tt.existingData, tt.getErr) + if tt.setErr != nil { + api.On("KVSet", kvKey, mock.Anything).Return(tt.setErr) + } else { + api.On("KVSet", kvKey, mock.Anything).Return(nil) + } + + msg, err := store.AddMessage(tt.sessionID, tt.role, tt.content) + + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, msg) + } else { + assert.NoError(t, err) + assert.NotNil(t, msg) + if msg != nil { + assert.Equal(t, tt.wantMessageID, msg.ID) + assert.Equal(t, tt.sessionID, msg.SessionID) + assert.Equal(t, tt.wantRole, msg.Role) + assert.Equal(t, tt.wantContent, msg.Content) + assert.False(t, msg.Timestamp.IsZero()) + } + } + + api.AssertExpectations(t) + }) + } +} + +func TestMessageStoreDeleteSessionMessages(t *testing.T) { + tests := []struct { + name string + sessionID string + deleteErr *model.AppError + wantErr bool + }{ + { + name: "successful delete", + sessionID: "session1", + deleteErr: nil, + wantErr: false, + }, + { + name: "delete error", + sessionID: "session2", + deleteErr: model.NewAppError("test", "test.error", nil, "", 500), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + api := &plugintest.API{} + store := NewMessageStore(api) + + kvKey := fmt.Sprintf("messages_%s", tt.sessionID) + api.On("KVDelete", kvKey).Return(tt.deleteErr) + + err := store.DeleteSessionMessages(tt.sessionID) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + api.AssertExpectations(t) + }) + } +} + +func TestMessageStoreGetMessageCount(t *testing.T) { + tests := []struct { + name string + sessionID string + kvData []byte + kvErr *model.AppError + wantCount int + wantErr bool + }{ + { + name: "empty messages", + sessionID: "session1", + kvData: nil, + kvErr: nil, + wantCount: 0, + wantErr: false, + }, + { + name: "multiple messages", + sessionID: "session2", + kvData: func() []byte { + sm := SessionMessages{ + SessionID: "session2", + Messages: []StoredMessage{ + {ID: "1", Role: "user", Content: "Hello"}, + {ID: "2", Role: "assistant", Content: "Hi"}, + {ID: "3", Role: "user", Content: "How are you?"}, + }, + UpdatedAt: time.Now(), + } + data, _ := json.Marshal(sm) + return data + }(), + kvErr: nil, + wantCount: 3, + wantErr: false, + }, + { + name: "get error", + sessionID: "session3", + kvData: nil, + kvErr: model.NewAppError("test", "test.error", nil, "", 500), + wantCount: 0, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + api := &plugintest.API{} + store := NewMessageStore(api) + + api.On("KVGet", fmt.Sprintf("messages_%s", tt.sessionID)).Return(tt.kvData, tt.kvErr) + + count, err := store.GetMessageCount(tt.sessionID) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantCount, count) + } + + api.AssertExpectations(t) + }) + } +} + +func TestMessageStoreGetLastMessage(t *testing.T) { + tests := []struct { + name string + sessionID string + kvData []byte + kvErr *model.AppError + wantMsg *StoredMessage + wantErr bool + }{ + { + name: "empty messages", + sessionID: "session1", + kvData: nil, + kvErr: nil, + wantMsg: nil, + wantErr: false, + }, + { + name: "single message", + sessionID: "session2", + kvData: func() []byte { + sm := SessionMessages{ + SessionID: "session2", + Messages: []StoredMessage{ + {ID: "1", Role: "user", Content: "Hello"}, + }, + UpdatedAt: time.Now(), + } + data, _ := json.Marshal(sm) + return data + }(), + kvErr: nil, + wantMsg: &StoredMessage{ + ID: "1", + Role: "user", + Content: "Hello", + }, + wantErr: false, + }, + { + name: "multiple messages", + sessionID: "session3", + kvData: func() []byte { + sm := SessionMessages{ + SessionID: "session3", + Messages: []StoredMessage{ + {ID: "1", Role: "user", Content: "Hello"}, + {ID: "2", Role: "assistant", Content: "Hi"}, + {ID: "3", Role: "user", Content: "Last"}, + }, + UpdatedAt: time.Now(), + } + data, _ := json.Marshal(sm) + return data + }(), + kvErr: nil, + wantMsg: &StoredMessage{ + ID: "3", + Role: "user", + Content: "Last", + }, + wantErr: false, + }, + { + name: "get error", + sessionID: "session4", + kvData: nil, + kvErr: model.NewAppError("test", "test.error", nil, "", 500), + wantMsg: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + api := &plugintest.API{} + store := NewMessageStore(api) + + api.On("KVGet", fmt.Sprintf("messages_%s", tt.sessionID)).Return(tt.kvData, tt.kvErr) + + msg, err := store.GetLastMessage(tt.sessionID) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.wantMsg == nil { + assert.Nil(t, msg) + } else { + assert.NotNil(t, msg) + assert.Equal(t, tt.wantMsg.ID, msg.ID) + assert.Equal(t, tt.wantMsg.Role, msg.Role) + assert.Equal(t, tt.wantMsg.Content, msg.Content) + } + } + + api.AssertExpectations(t) + }) + } +} + +func TestMessageStoreGetMessagesByRole(t *testing.T) { + tests := []struct { + name string + sessionID string + role string + kvData []byte + kvErr *model.AppError + wantCount int + wantErr bool + }{ + { + name: "empty messages", + sessionID: "session1", + role: "user", + kvData: nil, + kvErr: nil, + wantCount: 0, + wantErr: false, + }, + { + name: "filter user messages", + sessionID: "session2", + role: "user", + kvData: func() []byte { + sm := SessionMessages{ + SessionID: "session2", + Messages: []StoredMessage{ + {ID: "1", Role: "user", Content: "Hello"}, + {ID: "2", Role: "assistant", Content: "Hi"}, + {ID: "3", Role: "user", Content: "How are you?"}, + {ID: "4", Role: "assistant", Content: "Good!"}, + {ID: "5", Role: "user", Content: "Great"}, + }, + UpdatedAt: time.Now(), + } + data, _ := json.Marshal(sm) + return data + }(), + kvErr: nil, + wantCount: 3, + wantErr: false, + }, + { + name: "filter assistant messages", + sessionID: "session3", + role: "assistant", + kvData: func() []byte { + sm := SessionMessages{ + SessionID: "session3", + Messages: []StoredMessage{ + {ID: "1", Role: "user", Content: "Hello"}, + {ID: "2", Role: "assistant", Content: "Hi"}, + {ID: "3", Role: "user", Content: "How are you?"}, + {ID: "4", Role: "assistant", Content: "Good!"}, + }, + UpdatedAt: time.Now(), + } + data, _ := json.Marshal(sm) + return data + }(), + kvErr: nil, + wantCount: 2, + wantErr: false, + }, + { + name: "no matching role", + sessionID: "session4", + role: "system", + kvData: func() []byte { + sm := SessionMessages{ + SessionID: "session4", + Messages: []StoredMessage{ + {ID: "1", Role: "user", Content: "Hello"}, + {ID: "2", Role: "assistant", Content: "Hi"}, + }, + UpdatedAt: time.Now(), + } + data, _ := json.Marshal(sm) + return data + }(), + kvErr: nil, + wantCount: 0, + wantErr: false, + }, + { + name: "get error", + sessionID: "session5", + role: "user", + kvData: nil, + kvErr: model.NewAppError("test", "test.error", nil, "", 500), + wantCount: 0, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + api := &plugintest.API{} + store := NewMessageStore(api) + + api.On("KVGet", fmt.Sprintf("messages_%s", tt.sessionID)).Return(tt.kvData, tt.kvErr) + + messages, err := store.GetMessagesByRole(tt.sessionID, tt.role) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Len(t, messages, tt.wantCount) + // Verify all returned messages have the correct role + for _, msg := range messages { + assert.Equal(t, tt.role, msg.Role) + } + } + + api.AssertExpectations(t) + }) + } +} diff --git a/server/output_handler_test.go b/server/output_handler_test.go new file mode 100644 index 0000000..f679cd7 --- /dev/null +++ b/server/output_handler_test.go @@ -0,0 +1,641 @@ +package main + +import ( + "encoding/json" + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestNewOutputHandler(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + + handler := NewOutputHandler(plugin) + + assert.NotNil(t, handler) + assert.Equal(t, plugin, handler.plugin) +} + +func TestOutputHandlerGetOrCreateBuffer(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + handler := NewOutputHandler(plugin) + + // First call should create new buffer + buf1 := handler.getOrCreateBuffer("session1", "channel1") + assert.NotNil(t, buf1) + assert.Equal(t, "session1", buf1.sessionID) + assert.Equal(t, "channel1", buf1.channelID) + + // Second call should return same buffer + buf2 := handler.getOrCreateBuffer("session1", "channel1") + assert.Equal(t, buf1, buf2) + + // Different session should create new buffer + buf3 := handler.getOrCreateBuffer("session2", "channel2") + assert.NotNil(t, buf3) + assert.NotEqual(t, buf1, buf3) + assert.Equal(t, "session2", buf3.sessionID) + assert.Equal(t, "channel2", buf3.channelID) +} + +func TestOutputHandlerHandleOutputInvalidJSON(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.botUserID = "bot123" + handler := NewOutputHandler(plugin) + + // Mock CreatePost for raw text + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == "channel1" && + post.UserId == "bot123" && + post.Message == "invalid json text" + })).Return(&model.Post{}, nil) + + handler.HandleOutput("session1", "channel1", "invalid json text") + + api.AssertExpectations(t) +} + +func TestOutputHandlerHandleOutputAssistantMessage(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.botUserID = "bot123" + handler := NewOutputHandler(plugin) + + tests := []struct { + name string + message CLIOutputMessage + wantMsg string + }{ + { + name: "simple text message", + message: CLIOutputMessage{ + Type: "assistant", + Message: "Hello, how can I help?", + }, + wantMsg: "Hello, how can I help?", + }, + { + name: "content blocks", + message: CLIOutputMessage{ + Type: "assistant", + ContentBlocks: []ContentBlock{ + {Type: "text", Text: "First block"}, + {Type: "text", Text: "Second block"}, + }, + }, + wantMsg: "First block\nSecond block", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(tt.message) + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == "channel1" && + post.UserId == "bot123" && + post.Message == tt.wantMsg + })).Return(&model.Post{}, nil).Once() + + handler.HandleOutput("session1", "channel1", string(data)) + + api.AssertExpectations(t) + }) + } +} + +func TestOutputHandlerHandleOutputSystemMessage(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.botUserID = "bot123" + handler := NewOutputHandler(plugin) + + message := CLIOutputMessage{ + Type: "system", + Message: "System is ready", + } + data, _ := json.Marshal(message) + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == "channel1" && + post.UserId == "bot123" && + post.Message == "_System is ready_" + })).Return(&model.Post{}, nil) + + handler.HandleOutput("session1", "channel1", string(data)) + + api.AssertExpectations(t) +} + +func TestOutputHandlerHandleOutputResultMessage(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.botUserID = "bot123" + handler := NewOutputHandler(plugin) + + tests := []struct { + name string + message CLIOutputMessage + wantMsg string + }{ + { + name: "result only", + message: CLIOutputMessage{ + Type: "result", + Result: "Task completed successfully", + }, + wantMsg: "Task completed successfully", + }, + { + name: "result with cost", + message: CLIOutputMessage{ + Type: "result", + Result: "Done", + TotalCost: "$0.50", + }, + wantMsg: "Done\n\n_(Cost: $0.50)_", + }, + { + name: "result with usage", + message: CLIOutputMessage{ + Type: "result", + Result: "Done", + TotalUsage: &Usage{ + InputTokens: 100, + OutputTokens: 50, + }, + }, + wantMsg: "Done\n\n_(Tokens: 100 in / 50 out)_", + }, + { + name: "result with cost and usage", + message: CLIOutputMessage{ + Type: "result", + Result: "Done", + TotalCost: "$0.50", + TotalUsage: &Usage{ + InputTokens: 100, + OutputTokens: 50, + }, + }, + wantMsg: "Done\n\n_(Cost: $0.50, Tokens: 100 in / 50 out)_", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(tt.message) + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == "channel1" && + post.UserId == "bot123" && + post.Message == tt.wantMsg + })).Return(&model.Post{}, nil).Once() + + handler.HandleOutput("session1", "channel1", string(data)) + + api.AssertExpectations(t) + }) + } +} + +func TestOutputHandlerHandleOutputToolUseMessage(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.botUserID = "bot123" + handler := NewOutputHandler(plugin) + + message := CLIOutputMessage{ + Type: "tool_use", + ToolName: "file_editor", + } + data, _ := json.Marshal(message) + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == "channel1" && + post.UserId == "bot123" && + post.Message == ":wrench: Using tool: **file_editor**" + })).Return(&model.Post{}, nil) + + handler.HandleOutput("session1", "channel1", string(data)) + + api.AssertExpectations(t) +} + +func TestOutputHandlerHandleOutputErrorMessage(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.botUserID = "bot123" + handler := NewOutputHandler(plugin) + + message := CLIOutputMessage{ + Type: "error", + Error: "Something went wrong", + } + data, _ := json.Marshal(message) + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == "channel1" && + post.UserId == "bot123" && + post.Message == ":warning: **Error**: Something went wrong" + })).Return(&model.Post{}, nil) + + handler.HandleOutput("session1", "channel1", string(data)) + + api.AssertExpectations(t) +} + +func TestOutputHandlerHandleError(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.botUserID = "bot123" + handler := NewOutputHandler(plugin) + + tests := []struct { + name string + errorMsg string + shouldPost bool + }{ + { + name: "normal error", + errorMsg: "File not found", + shouldPost: true, + }, + { + name: "empty error", + errorMsg: "", + shouldPost: false, + }, + { + name: "debugger message - should skip", + errorMsg: "Debugger listening on port 9229", + shouldPost: false, + }, + { + name: "debugger help message - should skip", + errorMsg: "For help, see: https://nodejs.org/en/docs/inspector", + shouldPost: false, + }, + { + name: "waiting for debugger - should skip", + errorMsg: "Waiting for the debugger to disconnect...", + shouldPost: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPost { + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == "channel1" && + post.UserId == "bot123" && + post.Message == ":warning: **Error**: "+tt.errorMsg + })).Return(&model.Post{}, nil).Once() + } + + handler.HandleError("session1", "channel1", tt.errorMsg) + + if tt.shouldPost { + api.AssertExpectations(t) + } + }) + } +} + +func TestOutputHandlerHandleExit(t *testing.T) { + tests := []struct { + name string + exitCode int + wantMsg string + }{ + { + name: "successful exit", + exitCode: 0, + wantMsg: ":white_check_mark: Claude Code session completed successfully.", + }, + { + name: "error exit", + exitCode: 1, + wantMsg: ":x: Claude Code session ended with exit code 1.", + }, + { + name: "signal exit", + exitCode: 130, + wantMsg: ":x: Claude Code session ended with exit code 130.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create fresh API mock for each subtest + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.botUserID = "bot123" + handler := NewOutputHandler(plugin) + + // Setup mocks + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == "channel1" && + post.UserId == "bot123" && + post.Message == tt.wantMsg + })).Return(&model.Post{}, nil).Once() + + // DeleteSession will be called - mock only KVDelete + api.On("KVDelete", "session_channel1").Return(nil).Once() + + handler.HandleExit("session1", "channel1", tt.exitCode) + + api.AssertExpectations(t) + }) + } +} + +func TestOutputHandlerPostBotMessage(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.botUserID = "bot123" + handler := NewOutputHandler(plugin) + + tests := []struct { + name string + content string + shouldPost bool + }{ + { + name: "normal message", + content: "Hello, world!", + shouldPost: true, + }, + { + name: "empty message", + content: "", + shouldPost: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPost { + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == "channel1" && + post.UserId == "bot123" && + post.Message == tt.content + })).Return(&model.Post{}, nil).Once() + } + + handler.postBotMessage("channel1", tt.content) + + if tt.shouldPost { + api.AssertExpectations(t) + } + }) + } +} + +func TestOutputHandlerPostRawMessage(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.botUserID = "bot123" + handler := NewOutputHandler(plugin) + + tests := []struct { + name string + data string + wantMsg string + shouldPost bool + }{ + { + name: "single line text", + data: "simple text", + wantMsg: "simple text", + shouldPost: true, + }, + { + name: "multi-line text", + data: "line 1\nline 2\nline 3", + wantMsg: "```\nline 1\nline 2\nline 3\n```", + shouldPost: true, + }, + { + name: "JSON-like text", + data: `{"key": "value"}`, + wantMsg: "```\n{\"key\": \"value\"}\n```", + shouldPost: true, + }, + { + name: "empty text", + data: "", + wantMsg: "", + shouldPost: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPost { + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == "channel1" && + post.UserId == "bot123" && + post.Message == tt.wantMsg + })).Return(&model.Post{}, nil).Once() + } + + handler.postRawMessage("session1", "channel1", tt.data) + + if tt.shouldPost { + api.AssertExpectations(t) + } + }) + } +} + +func TestOutputHandlerHandleFileChange(t *testing.T) { + api := &plugintest.API{} + plugin := &Plugin{} + plugin.SetAPI(api) + plugin.botUserID = "bot123" + handler := NewOutputHandler(plugin) + + tests := []struct { + name string + message CLIOutputMessage + wantEmoji string + wantAction string + }{ + { + name: "create file", + message: CLIOutputMessage{ + Type: "tool_result", + FilePath: "test.go", + ChangeType: "create", + }, + wantEmoji: ":new:", + wantAction: "create", + }, + { + name: "modify file", + message: CLIOutputMessage{ + Type: "tool_result", + FilePath: "test.go", + ChangeType: "modify", + }, + wantEmoji: ":pencil2:", + wantAction: "modify", + }, + { + name: "edit file", + message: CLIOutputMessage{ + Type: "tool_result", + FilePath: "test.go", + ChangeType: "edit", + }, + wantEmoji: ":pencil2:", + wantAction: "edit", + }, + { + name: "delete file", + message: CLIOutputMessage{ + Type: "tool_result", + FilePath: "test.go", + ChangeType: "delete", + }, + wantEmoji: ":wastebasket:", + wantAction: "delete", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock GetConfig for getPluginURL only if interactive buttons are added + // (not for delete operations) + if tt.wantAction != "delete" { + api.On("GetConfig").Return(&model.Config{ + ServiceSettings: model.ServiceSettings{ + SiteURL: model.NewString("http://localhost:8065"), + }, + }).Once() + } + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + hasEmoji := post.Message[:len(tt.wantEmoji)] == tt.wantEmoji + hasAction := containsString(post.Message, tt.wantAction) + hasFilePath := containsString(post.Message, tt.message.FilePath) + return post.ChannelId == "channel1" && + post.UserId == "bot123" && + hasEmoji && hasAction && hasFilePath + })).Return(&model.Post{}, nil).Once() + + handler.handleFileChange("session1", "channel1", &tt.message) + + api.AssertExpectations(t) + }) + } +} + +func TestCLIOutputMessageStructure(t *testing.T) { + // Test that CLIOutputMessage can be marshaled and unmarshaled + msg := CLIOutputMessage{ + Type: "assistant", + Subtype: "text", + Message: "Hello", + SessionID: "session1", + Timestamp: 1234567890, + ContentBlocks: []ContentBlock{ + {Type: "text", Text: "Content"}, + }, + ToolName: "tool1", + ToolResult: "result", + FilePath: "/path/to/file", + ChangeType: "modify", + Result: "success", + TotalCost: "$1.00", + TotalUsage: &Usage{InputTokens: 100, OutputTokens: 50}, + Error: "no error", + } + + // Marshal + data, err := json.Marshal(msg) + assert.NoError(t, err) + assert.NotNil(t, data) + + // Unmarshal + var msg2 CLIOutputMessage + err = json.Unmarshal(data, &msg2) + assert.NoError(t, err) + assert.Equal(t, msg.Type, msg2.Type) + assert.Equal(t, msg.Message, msg2.Message) + assert.Equal(t, msg.SessionID, msg2.SessionID) + assert.Equal(t, msg.ToolName, msg2.ToolName) + assert.Equal(t, msg.FilePath, msg2.FilePath) + assert.Equal(t, msg.Result, msg2.Result) + assert.Equal(t, msg.TotalCost, msg2.TotalCost) + assert.NotNil(t, msg2.TotalUsage) + assert.Equal(t, msg.TotalUsage.InputTokens, msg2.TotalUsage.InputTokens) + assert.Equal(t, msg.TotalUsage.OutputTokens, msg2.TotalUsage.OutputTokens) +} + +func TestUsageStructure(t *testing.T) { + usage := Usage{ + InputTokens: 150, + OutputTokens: 75, + } + + data, err := json.Marshal(usage) + assert.NoError(t, err) + + var usage2 Usage + err = json.Unmarshal(data, &usage2) + assert.NoError(t, err) + assert.Equal(t, usage.InputTokens, usage2.InputTokens) + assert.Equal(t, usage.OutputTokens, usage2.OutputTokens) +} + +func TestContentBlockStructure(t *testing.T) { + block := ContentBlock{ + Type: "text", + Text: "Sample text", + Name: "block1", + } + + data, err := json.Marshal(block) + assert.NoError(t, err) + + var block2 ContentBlock + err = json.Unmarshal(data, &block2) + assert.NoError(t, err) + assert.Equal(t, block.Type, block2.Type) + assert.Equal(t, block.Text, block2.Text) + assert.Equal(t, block.Name, block2.Name) +} + +// Helper function to check if a string contains a substring +func containsString(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && stringContains(s, substr)) +} + +func stringContains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/server/plugin_test.go b/server/plugin_test.go new file mode 100644 index 0000000..894716c --- /dev/null +++ b/server/plugin_test.go @@ -0,0 +1,163 @@ +package main + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestOnActivate_NewBot(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + plugin := &Plugin{ + configuration: &configuration{ + BridgeServerURL: "http://localhost:3001", + }, + } + plugin.SetAPI(api) + + // Mock bot creation (success case) + api.On("CreateBot", mock.AnythingOfType("*model.Bot")).Return(&model.Bot{ + UserId: "bot_user_id", + }, nil) + + // Mock command registration + api.On("RegisterCommand", mock.AnythingOfType("*model.Command")).Return(nil) + + // Mock log messages (variadic arguments) + api.On("LogInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return() + api.On("LogWarn", mock.Anything, mock.Anything, mock.Anything).Maybe().Return() + + err := plugin.OnActivate() + assert.NoError(t, err) + assert.Equal(t, "bot_user_id", plugin.botUserID) + assert.NotNil(t, plugin.bridgeClient) + assert.NotNil(t, plugin.wsClient) +} + +func TestOnActivate_ExistingBot(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + plugin := &Plugin{ + configuration: &configuration{ + BridgeServerURL: "http://localhost:3001", + }, + } + plugin.SetAPI(api) + + // Mock bot creation failure (bot already exists) + api.On("CreateBot", mock.AnythingOfType("*model.Bot")).Return(nil, model.NewAppError("CreateBot", "app.bot.create.error", nil, "already exists", 400)) + + // Mock getting existing bot user + api.On("GetUserByUsername", "claude-code").Return(&model.User{ + Id: "existing_bot_id", + Username: "claude-code", + }, nil) + + // Mock command registration + api.On("RegisterCommand", mock.AnythingOfType("*model.Command")).Return(nil) + + // Mock log messages (variadic arguments) + api.On("LogInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return() + api.On("LogWarn", mock.Anything, mock.Anything, mock.Anything).Maybe().Return() + + err := plugin.OnActivate() + assert.NoError(t, err) + assert.Equal(t, "existing_bot_id", plugin.botUserID) + assert.NotNil(t, plugin.bridgeClient) + assert.NotNil(t, plugin.wsClient) +} + +func TestOnActivate_BotCreationFailure(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + plugin := &Plugin{ + configuration: &configuration{ + BridgeServerURL: "http://localhost:3001", + }, + } + plugin.SetAPI(api) + + // Mock bot creation failure + api.On("CreateBot", mock.AnythingOfType("*model.Bot")).Return(nil, model.NewAppError("CreateBot", "app.bot.create.error", nil, "error", 500)) + + // Mock getting bot user also fails + api.On("GetUserByUsername", "claude-code").Return(nil, model.NewAppError("GetUserByUsername", "app.user.get.error", nil, "not found", 404)) + + err := plugin.OnActivate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to ensure bot user exists") +} + +func TestOnActivate_CommandRegistrationFailure(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + plugin := &Plugin{ + configuration: &configuration{ + BridgeServerURL: "http://localhost:3001", + }, + } + plugin.SetAPI(api) + + // Mock bot creation + api.On("CreateBot", mock.AnythingOfType("*model.Bot")).Return(&model.Bot{ + UserId: "bot_user_id", + }, nil) + + // Mock command registration failure + api.On("RegisterCommand", mock.AnythingOfType("*model.Command")).Return(model.NewAppError("RegisterCommand", "app.command.register.error", nil, "error", 500)) + + // Mock log messages (variadic arguments) + api.On("LogInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return() + api.On("LogWarn", mock.Anything, mock.Anything, mock.Anything).Maybe().Return() + + err := plugin.OnActivate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to register commands") +} + +func TestOnDeactivate(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + // Create a mock WebSocket client + wsClient := &WebSocketClient{ + baseURL: "http://localhost:3001", + subscriptions: make(map[string]string), + stopChan: make(chan struct{}), + } + + plugin := &Plugin{ + wsClient: wsClient, + } + plugin.SetAPI(api) + + // Mock log message (variadic) + api.On("LogInfo", mock.Anything).Maybe().Return() + + err := plugin.OnDeactivate() + assert.NoError(t, err) +} + +func TestOnDeactivate_NoWebSocket(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + plugin := &Plugin{ + wsClient: nil, + } + plugin.SetAPI(api) + + // Mock log message (variadic) + api.On("LogInfo", mock.Anything).Maybe().Return() + + err := plugin.OnDeactivate() + assert.NoError(t, err) +} diff --git a/server/post_utils_test.go b/server/post_utils_test.go new file mode 100644 index 0000000..0b2f033 --- /dev/null +++ b/server/post_utils_test.go @@ -0,0 +1,148 @@ +package main + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestPostChangeProposal(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock site URL config + config := &model.Config{} + siteURL := "http://localhost:8065" + config.ServiceSettings.SiteURL = &siteURL + api.On("GetConfig").Return(config) + + // Mock post creation + api.On("CreatePost", mock.AnythingOfType("*model.Post")).Return(&model.Post{Id: "post123"}, nil) + + defer api.AssertExpectations(t) + + _ = p.postChangeProposal("channel1", "Would you like to apply this change?", "change123") + // Mock expectations will catch any errors +} + +func TestPostWithQuickActions(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock site URL config + config := &model.Config{} + siteURL := "http://localhost:8065" + config.ServiceSettings.SiteURL = &siteURL + api.On("GetConfig").Return(config) + + // Mock post creation + createdPost := &model.Post{Id: "post123"} + api.On("CreatePost", mock.AnythingOfType("*model.Post")).Return(createdPost, nil) + + defer api.AssertExpectations(t) + + postID, err := p.postWithQuickActions("channel1", "Here's the response", "session123") + if err != nil { + t.Fatalf("postWithQuickActions returned error: %v", err) + } + assert.Equal(t, "post123", postID) +} + +func TestPostCodeChange(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock site URL config + config := &model.Config{} + siteURL := "http://localhost:8065" + config.ServiceSettings.SiteURL = &siteURL + api.On("GetConfig").Return(config) + + // Mock post creation + api.On("CreatePost", mock.AnythingOfType("*model.Post")).Return(&model.Post{Id: "post123"}, nil) + + defer api.AssertExpectations(t) + + diff := "+function hello() {\n- console.log('old');\n+ console.log('new');\n+}" + _ = p.postCodeChange("channel1", "src/main.js", diff, "change123") + // Mock expectations will catch any errors +} + +func TestPostWithMenu(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock site URL config + config := &model.Config{} + siteURL := "http://localhost:8065" + config.ServiceSettings.SiteURL = &siteURL + api.On("GetConfig").Return(config) + + // Mock post creation + api.On("CreatePost", mock.AnythingOfType("*model.Post")).Return(&model.Post{Id: "post123"}, nil) + + defer api.AssertExpectations(t) + + options := []ActionOption{ + {Label: "Option 1", Value: "opt1"}, + {Label: "Option 2", Value: "opt2"}, + } + _ = p.postWithMenu("channel1", "Choose an action:", options, "session123") + // Mock expectations will catch any errors +} + +func TestUpdatePostWithProgress(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + existingPost := &model.Post{ + Id: "post123", + Message: "Old message", + } + + // Mock getting and updating the post + api.On("GetPost", "post123").Return(existingPost, nil) + api.On("UpdatePost", mock.AnythingOfType("*model.Post")).Return(&model.Post{}, nil) + + defer api.AssertExpectations(t) + + _ = p.updatePostWithProgress("post123", "Processing...") + // Mock expectations will catch any errors +} + +func TestUpdatePostMessage(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + existingPost := &model.Post{ + Id: "post123", + Message: "Old message", + } + + // Mock getting and updating the post + api.On("GetPost", "post123").Return(existingPost, nil) + api.On("UpdatePost", mock.AnythingOfType("*model.Post")).Return(&model.Post{}, nil) + + defer api.AssertExpectations(t) + + _ = p.updatePostMessage("post123", "New message") + // Mock expectations will catch any errors +} + +func TestGetPluginURL(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + config := &model.Config{} + siteURL := "http://localhost:8065" + config.ServiceSettings.SiteURL = &siteURL + api.On("GetConfig").Return(config) + + defer api.AssertExpectations(t) + + url := p.getPluginURL() + assert.Equal(t, "http://localhost:8065/plugins/co.appsome.claudecode", url) +} diff --git a/server/session_manager_test.go b/server/session_manager_test.go new file mode 100644 index 0000000..6136d9c --- /dev/null +++ b/server/session_manager_test.go @@ -0,0 +1,155 @@ +package main + +import ( + "encoding/json" + "testing" + + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestGetActiveSession_NoSession(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Return nil for KVGet (no session) + api.On("KVGet", "session_channel1").Return(nil, nil) + + defer api.AssertExpectations(t) + + session, err := p.GetActiveSession("channel1") + assert.NoError(t, err) + assert.Nil(t, session) +} + +func TestGetActiveSession_ExistingSession(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Create a session object + expectedSession := &ChannelSession{ + SessionID: "session123", + ProjectPath: "/tmp/test", + UserID: "user1", + CreatedAt: 1000000, + LastMessageAt: 1000000, + } + + // Marshal it to JSON + data, _ := json.Marshal(expectedSession) + + // Mock KVGet to return the session + api.On("KVGet", "session_channel1").Return(data, nil) + + defer api.AssertExpectations(t) + + session, err := p.GetActiveSession("channel1") + assert.NoError(t, err) + assert.NotNil(t, session) + assert.Equal(t, "session123", session.SessionID) + assert.Equal(t, "/tmp/test", session.ProjectPath) + assert.Equal(t, "user1", session.UserID) +} + +func TestSaveSession(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + session := &ChannelSession{ + SessionID: "session123", + ProjectPath: "/tmp/test", + UserID: "user1", + CreatedAt: 1000000, + LastMessageAt: 1000000, + } + + // Mock KVSet + api.On("KVSet", "session_channel1", mock.Anything).Return(nil) + + defer api.AssertExpectations(t) + + err := p.SaveSession("channel1", session) + assert.NoError(t, err) +} + +func TestDeleteSession(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock KVDelete + api.On("KVDelete", "session_channel1").Return(nil) + + defer api.AssertExpectations(t) + + err := p.DeleteSession("channel1") + assert.NoError(t, err) +} + +func TestUpdateSessionLastMessage(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + session := &ChannelSession{ + SessionID: "session123", + ProjectPath: "/tmp/test", + UserID: "user1", + CreatedAt: 1000000, + LastMessageAt: 1000000, + } + + data, _ := json.Marshal(session) + + // Mock KVGet to return existing session + api.On("KVGet", "session_channel1").Return(data, nil) + + // Mock KVSet to save updated session + api.On("KVSet", "session_channel1", mock.Anything).Return(nil) + + defer api.AssertExpectations(t) + + err := p.UpdateSessionLastMessage("channel1") + assert.NoError(t, err) +} + +func TestGetSessionForChannel_NoSession(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // No session + api.On("KVGet", "session_channel1").Return(nil, nil) + + defer api.AssertExpectations(t) + + sessionID := p.GetSessionForChannel("channel1") + assert.Empty(t, sessionID) +} + +func TestGetSessionForChannel_ExistingSession(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + session := &ChannelSession{ + SessionID: "session123", + ProjectPath: "/tmp/test", + UserID: "user1", + } + + data, _ := json.Marshal(session) + api.On("KVGet", "session_channel1").Return(data, nil) + + defer api.AssertExpectations(t) + + sessionID := p.GetSessionForChannel("channel1") + assert.Equal(t, "session123", sessionID) +} + +func TestSessionManager_CreateSession_SkippedForNow(t *testing.T) { + // Skipped: CreateSession and StopSession tests require interface refactoring + // Plugin.bridgeClient needs to be an interface to allow mocking + t.Skip("Requires refactoring Plugin.bridgeClient to use an interface") +} + +// CreateSession and StopSession tests skipped - require interface refactoring +// These functions call bridgeClient methods which cannot be easily mocked without +// refactoring Plugin.bridgeClient to use an interface instead of a concrete type diff --git a/server/thread_context_test.go b/server/thread_context_test.go new file mode 100644 index 0000000..c689bb7 --- /dev/null +++ b/server/thread_context_test.go @@ -0,0 +1,238 @@ +package main + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" +) + +func TestGetThreadContext_EmptyThread(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock GetChannel + channel := &model.Channel{ + Id: "channel1", + Name: "test-channel", + } + api.On("GetChannel", "channel1").Return(channel, nil) + + // Mock GetPostThread to return truly empty thread (no posts) + postList := &model.PostList{ + Order: []string{}, + Posts: map[string]*model.Post{}, + } + api.On("GetPostThread", "root123").Return(postList, nil) + + defer api.AssertExpectations(t) + + _, err := p.GetThreadContext("root123", "channel1", 50) + // Should error because thread is empty + assert.Error(t, err) + assert.Contains(t, err.Error(), "thread is empty") +} + +func TestGetThreadContext_SinglePost(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock GetChannel + channel := &model.Channel{ + Id: "channel1", + Name: "test-channel", + } + api.On("GetChannel", "channel1").Return(channel, nil) + + // Create a single post + post := &model.Post{ + Id: "post1", + UserId: "user1", + ChannelId: "channel1", + Message: "Hello world", + CreateAt: 1000000000, + } + + postList := &model.PostList{ + Order: []string{"post1"}, + Posts: map[string]*model.Post{ + "post1": post, + }, + } + + // Mock user lookup + user := &model.User{ + Id: "user1", + Username: "testuser", + } + + api.On("GetPostThread", "post1").Return(postList, nil) + api.On("GetUser", "user1").Return(user, nil) + + defer api.AssertExpectations(t) + + context, err := p.GetThreadContext("post1", "channel1", 50) + assert.NoError(t, err) + assert.NotNil(t, context) + assert.Contains(t, context.Content, "testuser") + assert.Contains(t, context.Content, "Hello world") + assert.Equal(t, 1, context.MessageCount) + assert.Contains(t, context.Participants, "@testuser") +} + +func TestGetThreadContext_MultipleMessages(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock GetChannel + channel := &model.Channel{ + Id: "channel1", + Name: "test-channel", + } + api.On("GetChannel", "channel1").Return(channel, nil) + + // Create multiple posts + post1 := &model.Post{ + Id: "root123", + UserId: "user1", + ChannelId: "channel1", + Message: "First message", + CreateAt: 1000000000, + } + + post2 := &model.Post{ + Id: "post2", + UserId: "user2", + ChannelId: "channel1", + Message: "Second message", + CreateAt: 2000000000, + } + + postList := &model.PostList{ + Order: []string{"root123", "post2"}, + Posts: map[string]*model.Post{ + "root123": post1, + "post2": post2, + }, + } + + user1 := &model.User{Id: "user1", Username: "alice"} + user2 := &model.User{Id: "user2", Username: "bob"} + + api.On("GetPostThread", "root123").Return(postList, nil) + api.On("GetUser", "user1").Return(user1, nil) + api.On("GetUser", "user2").Return(user2, nil) + + defer api.AssertExpectations(t) + + context, err := p.GetThreadContext("root123", "channel1", 50) + assert.NoError(t, err) + assert.NotNil(t, context) + assert.Contains(t, context.Content, "alice") + assert.Contains(t, context.Content, "First message") + assert.Contains(t, context.Content, "bob") + assert.Contains(t, context.Content, "Second message") + assert.Equal(t, 2, context.MessageCount) + assert.Len(t, context.Participants, 2) +} + +func TestGetThreadContext_MaxMessagesLimit(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock GetChannel + channel := &model.Channel{ + Id: "channel1", + Name: "test-channel", + } + api.On("GetChannel", "channel1").Return(channel, nil) + + // Create root post + 9 more posts (10 total) + posts := make(map[string]*model.Post) + order := make([]string, 10) + + // Root post + posts["root123"] = &model.Post{ + Id: "root123", + UserId: "user1", + ChannelId: "channel1", + Message: "Root message", + CreateAt: 1000000000, + } + order[0] = "root123" + + // Add 9 more posts + for i := 1; i < 10; i++ { + postID := model.NewId() + posts[postID] = &model.Post{ + Id: postID, + UserId: "user1", + ChannelId: "channel1", + Message: "Message " + string(rune('0'+i)), + CreateAt: int64(1000000000 + i*1000), + } + order[i] = postID + } + + postList := &model.PostList{ + Order: order, + Posts: posts, + } + + user := &model.User{Id: "user1", Username: "testuser"} + + api.On("GetPostThread", "root123").Return(postList, nil) + // GetUser will be called once for each of the last 5 messages (all same user) + // But since they're all the same user, it will still be called 5 times + api.On("GetUser", "user1").Return(user, nil).Times(5) + + defer api.AssertExpectations(t) + + // Limit to 5 messages + context, err := p.GetThreadContext("root123", "channel1", 5) + assert.NoError(t, err) + assert.NotNil(t, context) + assert.Equal(t, 5, context.MessageCount) +} + +func TestGetThreadContext_WithFileAttachments(t *testing.T) { + p := setupTestPlugin(t) + api := p.API.(*plugintest.API) + + // Mock GetChannel + channel := &model.Channel{ + Id: "channel1", + Name: "test-channel", + } + api.On("GetChannel", "channel1").Return(channel, nil) + + // Create post with file attachments + post := &model.Post{ + Id: "post1", + UserId: "user1", + ChannelId: "channel1", + Message: "Check out these files", + CreateAt: 1000000000, + FileIds: []string{"file1", "file2"}, + } + + postList := &model.PostList{ + Order: []string{"post1"}, + Posts: map[string]*model.Post{ + "post1": post, + }, + } + + user := &model.User{Id: "user1", Username: "testuser"} + + api.On("GetPostThread", "post1").Return(postList, nil) + api.On("GetUser", "user1").Return(user, nil) + + defer api.AssertExpectations(t) + + context, err := p.GetThreadContext("post1", "channel1", 50) + assert.NoError(t, err) + assert.NotNil(t, context) + assert.Contains(t, context.Content, "file(s) attached") +} diff --git a/server/websocket_client_test.go b/server/websocket_client_test.go new file mode 100644 index 0000000..7f62776 --- /dev/null +++ b/server/websocket_client_test.go @@ -0,0 +1,381 @@ +package main + +import ( + "encoding/json" + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestNewWebSocketClient(t *testing.T) { + p := setupPlugin() + baseURL := "http://localhost:3001" + + ws := NewWebSocketClient(baseURL, p) + + assert.NotNil(t, ws) + assert.Equal(t, baseURL, ws.baseURL) + assert.Equal(t, p, ws.plugin) + assert.NotNil(t, ws.subscriptions) + assert.NotNil(t, ws.stopChan) + assert.Equal(t, 0, len(ws.subscriptions)) + assert.False(t, ws.reconnecting) + assert.Nil(t, ws.conn) +} + +func TestWebSocketClient_Subscribe(t *testing.T) { + p := setupPlugin() + ws := NewWebSocketClient("http://localhost:3001", p) + + sessionID := "session_123" + channelID := "channel_456" + + ws.Subscribe(sessionID, channelID) + + ws.mu.RLock() + defer ws.mu.RUnlock() + + assert.Equal(t, channelID, ws.subscriptions[sessionID]) + assert.Equal(t, 1, len(ws.subscriptions)) +} + +func TestWebSocketClient_Unsubscribe(t *testing.T) { + p := setupPlugin() + ws := NewWebSocketClient("http://localhost:3001", p) + + sessionID := "session_123" + channelID := "channel_456" + + // Subscribe first + ws.Subscribe(sessionID, channelID) + assert.Equal(t, 1, len(ws.subscriptions)) + + // Then unsubscribe + ws.Unsubscribe(sessionID) + + ws.mu.RLock() + defer ws.mu.RUnlock() + + assert.Equal(t, 0, len(ws.subscriptions)) +} + +func TestWebSocketClient_MultipleSubscriptions(t *testing.T) { + p := setupPlugin() + ws := NewWebSocketClient("http://localhost:3001", p) + + subscriptions := map[string]string{ + "session_1": "channel_1", + "session_2": "channel_2", + "session_3": "channel_3", + } + + for sessionID, channelID := range subscriptions { + ws.Subscribe(sessionID, channelID) + } + + ws.mu.RLock() + defer ws.mu.RUnlock() + + assert.Equal(t, 3, len(ws.subscriptions)) + for sessionID, channelID := range subscriptions { + assert.Equal(t, channelID, ws.subscriptions[sessionID]) + } +} + +func TestWebSocketClient_ProcessMessage_NotSubscribed(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + ws := NewWebSocketClient("http://localhost:3001", p) + + msg := &WebSocketMessage{ + Type: "output", + SessionID: "unsubscribed_session", + Data: json.RawMessage(`{"output": "test output"}`), + } + + // Should not call any API methods for unsubscribed session + ws.processMessage(msg) +} + +func TestWebSocketClient_HandleOutput(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + ws := NewWebSocketClient("http://localhost:3001", p) + + channelID := "channel_123" + sessionID := "session_123" + ws.Subscribe(sessionID, channelID) + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == channelID && + post.UserId == p.botUserID && + post.Message == "test output" + })).Return(&model.Post{}, nil) + + msg := &WebSocketMessage{ + Type: "output", + SessionID: sessionID, + Data: json.RawMessage(`{"output": "test output"}`), + } + + ws.processMessage(msg) +} + +func TestWebSocketClient_HandleError(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + ws := NewWebSocketClient("http://localhost:3001", p) + + channelID := "channel_123" + sessionID := "session_123" + ws.Subscribe(sessionID, channelID) + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == channelID && + post.UserId == p.botUserID && + post.Message == "âš ī¸ Error: test error" + })).Return(&model.Post{}, nil) + + msg := &WebSocketMessage{ + Type: "error", + SessionID: sessionID, + Data: json.RawMessage(`{"error": "test error"}`), + } + + ws.processMessage(msg) +} + +func TestWebSocketClient_HandleStatus_Stopped(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + setupKVMocks(api) + ws := NewWebSocketClient("http://localhost:3001", p) + + channelID := "channel_123" + sessionID := "session_123" + ws.Subscribe(sessionID, channelID) + p.SaveSession(channelID, &ChannelSession{ + SessionID: sessionID, + UserID: "user_id", + }) + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == channelID && + post.UserId == p.botUserID && + post.Message == "âšī¸ Claude Code session stopped." + })).Return(&model.Post{}, nil) + + // Mock session deletion + api.On("LogWarn", mock.Anything, mock.Anything, mock.Anything).Maybe().Return() + + msg := &WebSocketMessage{ + Type: "status", + SessionID: sessionID, + Data: json.RawMessage(`{"status": "stopped"}`), + } + + ws.processMessage(msg) + + // Verify unsubscribed + ws.mu.RLock() + subscriptionCount := len(ws.subscriptions) + ws.mu.RUnlock() + assert.Equal(t, 0, subscriptionCount) +} + +func TestWebSocketClient_HandleFileChange(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + ws := NewWebSocketClient("http://localhost:3001", p) + + channelID := "channel_123" + sessionID := "session_123" + ws.Subscribe(sessionID, channelID) + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == channelID && + post.UserId == p.botUserID && + post.Message == "📝 File modified: `test.go`" + })).Return(&model.Post{}, nil) + + msg := &WebSocketMessage{ + Type: "file_change", + SessionID: sessionID, + Data: json.RawMessage(`{"path": "test.go", "action": "modified"}`), + } + + ws.processMessage(msg) +} + +func TestWebSocketClient_HandleOutput_InvalidData(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + ws := NewWebSocketClient("http://localhost:3001", p) + + channelID := "channel_123" + sessionID := "session_123" + ws.Subscribe(sessionID, channelID) + + api.On("LogError", mock.Anything, mock.Anything, mock.Anything).Return() + + msg := &WebSocketMessage{ + Type: "output", + SessionID: sessionID, + Data: json.RawMessage(`invalid json`), + } + + ws.processMessage(msg) +} + +func TestWebSocketClient_Close(t *testing.T) { + p := setupPlugin() + ws := NewWebSocketClient("http://localhost:3001", p) + + err := ws.Close() + + assert.NoError(t, err) + assert.Nil(t, ws.conn) +} + +func TestPostBotMessage_Success(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + channelID := "channel_123" + message := "test message" + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == channelID && + post.UserId == p.botUserID && + post.Message == message + })).Return(&model.Post{}, nil) + + p.postBotMessage(channelID, message) +} + +func TestPostBotMessage_Error(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + + channelID := "channel_123" + message := "test message" + + api.On("CreatePost", mock.Anything).Return(nil, model.NewAppError("CreatePost", "error", nil, "", 500)) + api.On("LogError", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + + p.postBotMessage(channelID, message) +} + +func TestWebSocketMessage_JSONSerialization(t *testing.T) { + msg := &WebSocketMessage{ + Type: "output", + SessionID: "session_123", + Data: json.RawMessage(`{"output": "test"}`), + Timestamp: 1234567890, + } + + data, err := json.Marshal(msg) + assert.NoError(t, err) + + var decoded WebSocketMessage + err = json.Unmarshal(data, &decoded) + assert.NoError(t, err) + + assert.Equal(t, msg.Type, decoded.Type) + assert.Equal(t, msg.SessionID, decoded.SessionID) + assert.Equal(t, msg.Timestamp, decoded.Timestamp) +} + +func TestSubscribeMessage_JSONSerialization(t *testing.T) { + msg := &SubscribeMessage{ + Type: "subscribe", + SessionID: "session_123", + } + + data, err := json.Marshal(msg) + assert.NoError(t, err) + + var decoded SubscribeMessage + err = json.Unmarshal(data, &decoded) + assert.NoError(t, err) + + assert.Equal(t, msg.Type, decoded.Type) + assert.Equal(t, msg.SessionID, decoded.SessionID) +} + +func TestWebSocketClient_HandleStatus_WithMessage(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + ws := NewWebSocketClient("http://localhost:3001", p) + + channelID := "channel_123" + sessionID := "session_123" + ws.Subscribe(sessionID, channelID) + + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.ChannelId == channelID && + post.UserId == p.botUserID && + post.Message == "Running tests..." + })).Return(&model.Post{}, nil) + + msg := &WebSocketMessage{ + Type: "status", + SessionID: sessionID, + Data: json.RawMessage(`{"status": "running", "message": "Running tests..."}`), + } + + ws.processMessage(msg) +} + +func TestWebSocketClient_ProcessMessage_UnknownType(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + p := setupPlugin() + p.SetAPI(api) + ws := NewWebSocketClient("http://localhost:3001", p) + + channelID := "channel_123" + sessionID := "session_123" + ws.Subscribe(sessionID, channelID) + + api.On("LogDebug", "Unknown message type", "type", "unknown_type").Return() + + msg := &WebSocketMessage{ + Type: "unknown_type", + SessionID: sessionID, + Data: json.RawMessage(`{}`), + } + + ws.processMessage(msg) +}