diff --git a/acceptance/cmd/workspace/apps/output.txt b/acceptance/cmd/workspace/apps/output.txt index e1524affa4..a1645dd38d 100644 --- a/acceptance/cmd/workspace/apps/output.txt +++ b/acceptance/cmd/workspace/apps/output.txt @@ -75,6 +75,7 @@ Global Flags: -o, --output type output type: text or json (default text) -p, --profile string ~/.databrickscfg profile -t, --target string bundle target to use (if applicable) + --var strings set values for variables defined in bundle config. Example: --var="key=value" Exit code: 1 diff --git a/cmd/apps/apps.go b/cmd/apps/apps.go new file mode 100644 index 0000000000..999161a55b --- /dev/null +++ b/cmd/apps/apps.go @@ -0,0 +1,19 @@ +package apps + +import "github.com/spf13/cobra" + +// ManagementGroupID contains auto-generated CLI commands for Apps API, +// that are separate from main CLI commands defined in Commands. +const ManagementGroupID = "management" + +// Commands returns the list of custom app commands to be added +// to the auto-generated apps command group. +func Commands() []*cobra.Command { + return []*cobra.Command{ + newInitCmd(), + newDevRemoteCmd(), + newLogsCommand(), + newRunLocal(), + newValidateCmd(), + } +} diff --git a/cmd/apps/deploy_bundle.go b/cmd/apps/deploy_bundle.go new file mode 100644 index 0000000000..a3b7b3b71e --- /dev/null +++ b/cmd/apps/deploy_bundle.go @@ -0,0 +1,213 @@ +package apps + +import ( + "context" + "errors" + "fmt" + "os" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/resources" + "github.com/databricks/cli/bundle/run" + "github.com/databricks/cli/cmd/bundle/utils" + "github.com/databricks/cli/libs/apps/validation" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go/service/apps" + "github.com/spf13/cobra" +) + +// ErrorWrapper is a function type for wrapping deployment errors. +type ErrorWrapper func(cmd *cobra.Command, appName string, err error) error + +// isBundleDirectory checks if the current directory contains a databricks.yml file. +func isBundleDirectory() bool { + _, err := os.Stat("databricks.yml") + return err == nil +} + +// BundleDeployOverrideWithWrapper creates a deploy override function that uses +// the provided error wrapper for API fallback errors. +func BundleDeployOverrideWithWrapper(wrapError ErrorWrapper) func(*cobra.Command, *apps.CreateAppDeploymentRequest) { + return func(deployCmd *cobra.Command, deployReq *apps.CreateAppDeploymentRequest) { + var ( + force bool + skipValidation bool + ) + + deployCmd.Flags().BoolVar(&force, "force", false, "Force-override Git branch validation") + deployCmd.Flags().BoolVar(&skipValidation, "skip-validation", false, "Skip project validation (build, typecheck, lint)") + + // Update the command usage to reflect that APP_NAME is optional when in bundle mode + deployCmd.Use = "deploy [APP_NAME]" + + // Override Args to allow 0 or 1 arguments (bundle mode vs API mode) + deployCmd.Args = func(cmd *cobra.Command, args []string) error { + // In bundle mode, no arguments needed + if isBundleDirectory() { + if len(args) > 0 { + return errors.New("APP_NAME argument is not allowed when deploying from a bundle directory") + } + return nil + } + // In API mode, exactly 1 argument required + if len(args) != 1 { + return fmt.Errorf("accepts 1 arg(s), received %d", len(args)) + } + return nil + } + + originalRunE := deployCmd.RunE + deployCmd.RunE = func(cmd *cobra.Command, args []string) error { + // If we're in a bundle directory, use the enhanced deploy flow + if isBundleDirectory() { + return runBundleDeploy(cmd, force, skipValidation) + } + + // Otherwise, fall back to the original API deploy command + err := originalRunE(cmd, args) + return wrapError(cmd, deployReq.AppName, err) + } + + // Update the help text to explain the dual behavior + deployCmd.Long = `Create an app deployment. + +When run from a directory containing a databricks.yml bundle configuration, +this command runs an enhanced deployment pipeline: +1. Validates the project (build, typecheck, lint for Node.js projects) +2. Deploys the bundle to the workspace +3. Runs the app + +When run from a non-bundle directory, creates an app deployment using the API. + +Arguments: + APP_NAME: The name of the app (required only when not in a bundle directory). + +Examples: + # Deploy from a bundle directory (no app name required) + databricks apps deploy + + # Deploy a specific app using the API + databricks apps deploy my-app + + # Deploy from bundle with validation skip + databricks apps deploy --skip-validation + + # Force deploy (override git branch validation) + databricks apps deploy --force` + } +} + +// runBundleDeploy executes the enhanced deployment flow for bundle directories. +func runBundleDeploy(cmd *cobra.Command, force, skipValidation bool) error { + ctx := cmd.Context() + + // Get current working directory for validation + workDir, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get working directory: %w", err) + } + + // Step 1: Validate project (unless skipped) + if !skipValidation { + validator := validation.GetProjectValidator(workDir) + if validator != nil { + result, err := validator.Validate(ctx, workDir) + if err != nil { + return fmt.Errorf("validation error: %w", err) + } + + if !result.Success { + // Show error details + if result.Details != nil { + cmdio.LogString(ctx, result.Details.Error()) + } + return errors.New("validation failed - fix errors before deploying") + } + cmdio.LogString(ctx, "āœ… "+result.Message) + } else { + log.Debugf(ctx, "No validator found for project type, skipping validation") + } + } + + // Step 2: Deploy bundle + cmdio.LogString(ctx, "Deploying bundle...") + b, err := utils.ProcessBundle(cmd, utils.ProcessOptions{ + InitFunc: func(b *bundle.Bundle) { + b.Config.Bundle.Force = force + }, + // Context is already initialized by the workspace command's PreRunE + SkipInitContext: true, + AlwaysPull: true, + FastValidate: true, + Build: true, + Deploy: true, + }) + if err != nil { + return fmt.Errorf("deploy failed: %w", err) + } + log.Infof(ctx, "Deploy completed") + + // Step 3: Detect and run app + appKey, err := detectBundleApp(b) + if err != nil { + return err + } + + log.Infof(ctx, "Running app: %s", appKey) + if err := runBundleApp(ctx, b, appKey); err != nil { + cmdio.LogString(ctx, "āœ” Deployment succeeded, but failed to start app") + appName := b.Config.Resources.Apps[appKey].Name + return fmt.Errorf("failed to run app: %w. Run `databricks apps logs %s` to view logs", err, appName) + } + + cmdio.LogString(ctx, "āœ” Deployment complete!") + return nil +} + +// detectBundleApp finds the single app in the bundle configuration. +func detectBundleApp(b *bundle.Bundle) (string, error) { + bundleApps := b.Config.Resources.Apps + + if len(bundleApps) == 0 { + return "", errors.New("no apps found in bundle configuration") + } + + if len(bundleApps) > 1 { + return "", errors.New("multiple apps found in bundle, cannot auto-detect") + } + + for key := range bundleApps { + return key, nil + } + + return "", errors.New("unexpected error detecting app") +} + +// runBundleApp runs the specified app using the runner interface. +func runBundleApp(ctx context.Context, b *bundle.Bundle, appKey string) error { + ref, err := resources.Lookup(b, appKey, run.IsRunnable) + if err != nil { + return fmt.Errorf("failed to lookup app: %w", err) + } + + runner, err := run.ToRunner(b, ref) + if err != nil { + return fmt.Errorf("failed to create runner: %w", err) + } + + output, err := runner.Run(ctx, &run.Options{}) + if err != nil { + return fmt.Errorf("failed to run app: %w", err) + } + + if output != nil { + resultString, err := output.String() + if err != nil { + return err + } + log.Infof(ctx, "App output: %s", resultString) + } + + return nil +} diff --git a/cmd/apps/dev.go b/cmd/apps/dev.go new file mode 100644 index 0000000000..c9dcb5c305 --- /dev/null +++ b/cmd/apps/dev.go @@ -0,0 +1,245 @@ +package apps + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "net/url" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strconv" + "strings" + "syscall" + "time" + + "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/apps/prompt" + "github.com/databricks/cli/libs/apps/vite" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/spf13/cobra" +) + +const ( + vitePort = 5173 + viteReadyCheckInterval = 100 * time.Millisecond + viteReadyMaxAttempts = 50 +) + +func isViteReady(port int) bool { + conn, err := net.DialTimeout("tcp", "localhost:"+strconv.Itoa(port), viteReadyCheckInterval) + if err != nil { + return false + } + conn.Close() + return true +} + +// detectAppNameFromBundle tries to extract the app name from a databricks.yml bundle config. +// Returns the app name if found, or empty string if no bundle or no apps found. +func detectAppNameFromBundle() string { + const bundleFile = "databricks.yml" + + // Check if databricks.yml exists + if _, err := os.Stat(bundleFile); os.IsNotExist(err) { + return "" + } + + // Get current working directory + cwd, err := os.Getwd() + if err != nil { + return "" + } + + // Load the bundle configuration directly + rootConfig, diags := config.Load(filepath.Join(cwd, bundleFile)) + if diags.HasError() { + return "" + } + + // Check for apps in the bundle + bundleApps := rootConfig.Resources.Apps + if len(bundleApps) == 0 { + return "" + } + + // If there's exactly one app, return its name + if len(bundleApps) == 1 { + for _, app := range bundleApps { + return app.Name + } + } + + // Multiple apps - can't auto-detect + return "" +} + +func startViteDevServer(ctx context.Context, appURL string, port int) (*exec.Cmd, chan error, error) { + // Pass script through stdin, and pass arguments in order + viteCmd := exec.Command("node", "-", appURL, strconv.Itoa(port)) + viteCmd.Stdin = bytes.NewReader(vite.ServerScript) + viteCmd.Stdout = os.Stdout + viteCmd.Stderr = os.Stderr + + err := viteCmd.Start() + if err != nil { + return nil, nil, fmt.Errorf("failed to start Vite server: %w", err) + } + + cmdio.LogString(ctx, fmt.Sprintf("šŸš€ Starting Vite development server on port %d...", port)) + + viteErr := make(chan error, 1) + go func() { + if err := viteCmd.Wait(); err != nil { + viteErr <- fmt.Errorf("vite server exited with error: %w", err) + } else { + viteErr <- errors.New("vite server exited unexpectedly") + } + }() + + for range viteReadyMaxAttempts { + select { + case err := <-viteErr: + return nil, nil, err + default: + if isViteReady(port) { + return viteCmd, viteErr, nil + } + time.Sleep(viteReadyCheckInterval) + } + } + + _ = viteCmd.Process.Kill() + return nil, nil, errors.New("timeout waiting for Vite server to be ready") +} + +func newDevRemoteCmd() *cobra.Command { + var ( + appName string + clientPath string + port int + ) + + cmd := &cobra.Command{ + Use: "dev-remote", + Short: "Run AppKit app locally with WebSocket bridge to remote server", + Long: `Run AppKit app locally with WebSocket bridge to remote server. + +Starts a local Vite development server and establishes a WebSocket bridge +to the remote Databricks app for development with hot module replacement. + +Examples: + # Interactive mode - select app from picker + databricks apps dev-remote + + # Start development server for a specific app + databricks apps dev-remote --name my-app + + # Use a custom client path + databricks apps dev-remote --name my-app --client-path ./frontend + + # Use a custom port + databricks apps dev-remote --name my-app --port 3000`, + Args: root.NoArgs, + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + // Validate client path early (before any network calls) + if _, err := os.Stat(clientPath); os.IsNotExist(err) { + return fmt.Errorf("client directory not found: %s", clientPath) + } + + // Check if port is already in use + if isViteReady(port) { + return fmt.Errorf("port %d is already in use; try --port ", port) + } + + w := cmdctx.WorkspaceClient(ctx) + + // Resolve app name with priority: flag > bundle config > prompt + if appName == "" { + // Try to detect from bundle config + appName = detectAppNameFromBundle() + if appName != "" { + cmdio.LogString(ctx, fmt.Sprintf("Using app '%s' from bundle configuration", appName)) + } + } + + if appName == "" { + // Fall back to interactive prompt + selected, err := prompt.PromptForAppSelection(ctx, "Select an app to connect to") + if err != nil { + return err + } + appName = selected + } + + bridge := vite.NewBridge(ctx, w, appName, port) + + // Validate app exists and get domain before starting Vite + var appDomain *url.URL + err := prompt.RunWithSpinnerCtx(ctx, "Connecting to app...", func() error { + var domainErr error + appDomain, domainErr = bridge.GetAppDomain() + return domainErr + }) + if err != nil { + if strings.Contains(err.Error(), "does not exist") || strings.Contains(err.Error(), "is deleted") { + return fmt.Errorf("application '%s' has not been deployed yet. Run `databricks apps deploy` to deploy and then try again", appName) + } + return fmt.Errorf("failed to get app domain: %w", err) + } + + viteCmd, viteErr, err := startViteDevServer(ctx, appDomain.String(), port) + if err != nil { + return err + } + + done := make(chan error, 1) + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + go func() { + done <- bridge.Start() + }() + + select { + case err := <-viteErr: + bridge.Stop() + <-done + return err + case err := <-done: + cmdio.LogString(ctx, "Bridge stopped") + if viteCmd.Process != nil { + _ = viteCmd.Process.Signal(os.Interrupt) + <-viteErr + } + return err + case <-sigChan: + cmdio.LogString(ctx, "\nšŸ›‘ Shutting down...") + bridge.Stop() + <-done + if viteCmd.Process != nil { + if err := viteCmd.Process.Signal(os.Interrupt); err != nil { + cmdio.LogString(ctx, fmt.Sprintf("Failed to interrupt Vite: %v", err)) + _ = viteCmd.Process.Kill() + } + <-viteErr + } + return nil + } + }, + } + + cmd.Flags().StringVar(&appName, "name", "", "Name of the app to connect to (prompts if not provided)") + cmd.Flags().StringVar(&clientPath, "client-path", "./client", "Path to the Vite client directory") + cmd.Flags().IntVar(&port, "port", vitePort, "Port to run the Vite server on") + + return cmd +} diff --git a/cmd/workspace/apps/dev_test.go b/cmd/apps/dev_test.go similarity index 89% rename from cmd/workspace/apps/dev_test.go rename to cmd/apps/dev_test.go index 2133d0440f..8aa0224798 100644 --- a/cmd/workspace/apps/dev_test.go +++ b/cmd/apps/dev_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/databricks/cli/libs/apps/vite" "github.com/databricks/cli/libs/cmdio" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -46,10 +47,10 @@ func TestIsViteReady(t *testing.T) { func TestViteServerScriptContent(t *testing.T) { // Verify the embedded script is not empty - assert.NotEmpty(t, viteServerScript) + assert.NotEmpty(t, vite.ServerScript) // Verify it's a JavaScript file with expected content - assert.Contains(t, string(viteServerScript), "startViteServer") + assert.Contains(t, string(vite.ServerScript), "startViteServer") } func TestStartViteDevServerNoNode(t *testing.T) { @@ -81,9 +82,9 @@ func TestStartViteDevServerNoNode(t *testing.T) { } func TestViteServerScriptEmbedded(t *testing.T) { - assert.NotEmpty(t, viteServerScript) + assert.NotEmpty(t, vite.ServerScript) - scriptContent := string(viteServerScript) + scriptContent := string(vite.ServerScript) assert.Contains(t, scriptContent, "startViteServer") assert.Contains(t, scriptContent, "createServer") assert.Contains(t, scriptContent, "queriesHMRPlugin") diff --git a/cmd/apps/init.go b/cmd/apps/init.go new file mode 100644 index 0000000000..c0a73e595e --- /dev/null +++ b/cmd/apps/init.go @@ -0,0 +1,1036 @@ +package apps + +import ( + "bytes" + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "text/template" + + "github.com/charmbracelet/huh" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/apps/features" + "github.com/databricks/cli/libs/apps/prompt" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" + "github.com/spf13/cobra" +) + +const ( + templatePathEnvVar = "DATABRICKS_APPKIT_TEMPLATE_PATH" + defaultTemplateURL = "https://github.com/databricks/appkit/tree/main/template" +) + +func newInitCmd() *cobra.Command { + var ( + templatePath string + branch string + name string + warehouseID string + description string + outputDir string + featuresFlag []string + deploy bool + run string + ) + + cmd := &cobra.Command{ + Use: "init", + Short: "Initialize a new AppKit application from a template", + Long: `Initialize a new AppKit application from a template. + +When run without arguments, uses the default AppKit template and an interactive prompt +guides you through the setup. When run with --name, runs in non-interactive mode +(all required flags must be provided). + +Examples: + # Interactive mode with default template (recommended) + databricks apps init + + # Non-interactive with flags + databricks apps init --name my-app + + # With analytics feature (requires --warehouse-id) + databricks apps init --name my-app --features=analytics --warehouse-id=abc123 + + # Create, deploy, and run with dev-remote + databricks apps init --name my-app --deploy --run=dev-remote + + # With a custom template from a local path + databricks apps init --template /path/to/template --name my-app + + # With a GitHub URL + databricks apps init --template https://github.com/user/repo --name my-app + +Feature dependencies: + Some features require additional flags: + - analytics: requires --warehouse-id (SQL Warehouse ID) + +Environment variables: + DATABRICKS_APPKIT_TEMPLATE_PATH Override the default template source`, + Args: cobra.NoArgs, + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + return runCreate(ctx, createOptions{ + templatePath: templatePath, + branch: branch, + name: name, + warehouseID: warehouseID, + description: description, + outputDir: outputDir, + features: featuresFlag, + deploy: deploy, + run: run, + }) + }, + } + + cmd.Flags().StringVar(&templatePath, "template", "", "Template path (local directory or GitHub URL)") + cmd.Flags().StringVar(&branch, "branch", "", "Git branch or tag (for GitHub templates)") + cmd.Flags().StringVar(&name, "name", "", "Project name (prompts if not provided)") + cmd.Flags().StringVar(&warehouseID, "warehouse-id", "", "SQL warehouse ID") + cmd.Flags().StringVar(&description, "description", "", "App description") + cmd.Flags().StringVar(&outputDir, "output-dir", "", "Directory to write the project to") + cmd.Flags().StringSliceVar(&featuresFlag, "features", nil, "Features to enable (comma-separated). Available: "+strings.Join(features.GetFeatureIDs(), ", ")) + cmd.Flags().BoolVar(&deploy, "deploy", false, "Deploy the app after creation") + cmd.Flags().StringVar(&run, "run", "", "Run the app after creation (none, dev, dev-remote)") + + return cmd +} + +type createOptions struct { + templatePath string + branch string + name string + warehouseID string + description string + outputDir string + features []string + deploy bool + run string +} + +// templateVars holds the variables for template substitution. +type templateVars struct { + ProjectName string + SQLWarehouseID string + AppDescription string + Profile string + WorkspaceHost string + PluginImport string + PluginUsage string + // Feature resource fragments (aggregated from selected features) + BundleVariables string + BundleResources string + TargetVariables string + AppEnv string + DotEnv string + DotEnvExample string +} + +// featureFragments holds aggregated content from feature resource files. +type featureFragments struct { + BundleVariables string + BundleResources string + TargetVariables string + AppEnv string + DotEnv string + DotEnvExample string +} + +// parseDeployAndRunFlags parses the deploy and run flag values into typed values. +func parseDeployAndRunFlags(deploy bool, run string) (bool, prompt.RunMode, error) { + var runMode prompt.RunMode + switch run { + case "dev": + runMode = prompt.RunModeDev + case "dev-remote": + runMode = prompt.RunModeDevRemote + case "", "none": + runMode = prompt.RunModeNone + default: + return false, prompt.RunModeNone, fmt.Errorf("invalid --run value: %q (must be none, dev, or dev-remote)", run) + } + return deploy, runMode, nil +} + +// promptForFeaturesAndDeps prompts for features and their dependencies. +// Used when the template uses the feature-fragment system. +func promptForFeaturesAndDeps(ctx context.Context, preSelectedFeatures []string) (*prompt.CreateProjectConfig, error) { + config := &prompt.CreateProjectConfig{ + Dependencies: make(map[string]string), + Features: preSelectedFeatures, + } + theme := prompt.AppkitTheme() + + // Step 1: Feature selection (skip if features already provided via flag) + if len(config.Features) == 0 && len(features.AvailableFeatures) > 0 { + options := make([]huh.Option[string], 0, len(features.AvailableFeatures)) + for _, f := range features.AvailableFeatures { + label := f.Name + " - " + f.Description + options = append(options, huh.NewOption(label, f.ID)) + } + + err := huh.NewMultiSelect[string](). + Title("Select features"). + Description("space to toggle, enter to confirm"). + Options(options...). + Value(&config.Features). + Height(8). + WithTheme(theme). + Run() + if err != nil { + return nil, err + } + if len(config.Features) == 0 { + prompt.PrintAnswered("Features", "None") + } else { + prompt.PrintAnswered("Features", fmt.Sprintf("%d selected", len(config.Features))) + } + } + + // Step 2: Prompt for feature dependencies + deps := features.CollectDependencies(config.Features) + for _, dep := range deps { + // Special handling for SQL warehouse - show picker instead of text input + if dep.ID == "sql_warehouse_id" { + warehouseID, err := prompt.PromptForWarehouse(ctx) + if err != nil { + return nil, err + } + config.Dependencies[dep.ID] = warehouseID + continue + } + + var value string + description := dep.Description + if !dep.Required { + description += " (optional)" + } + + input := huh.NewInput(). + Title(dep.Title). + Description(description). + Placeholder(dep.Placeholder). + Value(&value) + + if dep.Required { + input = input.Validate(func(s string) error { + if s == "" { + return errors.New("this field is required") + } + return nil + }) + } + + if err := input.WithTheme(theme).Run(); err != nil { + return nil, err + } + prompt.PrintAnswered(dep.Title, value) + config.Dependencies[dep.ID] = value + } + + // Step 3: Description + config.Description = prompt.DefaultAppDescription + err := huh.NewInput(). + Title("Description"). + Placeholder(prompt.DefaultAppDescription). + Value(&config.Description). + WithTheme(theme). + Run() + if err != nil { + return nil, err + } + if config.Description == "" { + config.Description = prompt.DefaultAppDescription + } + prompt.PrintAnswered("Description", config.Description) + + // Step 4: Deploy and run options + config.Deploy, config.RunMode, err = prompt.PromptForDeployAndRun() + if err != nil { + return nil, err + } + + return config, nil +} + +// loadFeatureFragments reads and aggregates resource fragments for selected features. +// templateDir is the path to the template directory (containing the "features" subdirectory). +func loadFeatureFragments(templateDir string, featureIDs []string, vars templateVars) (*featureFragments, error) { + featuresDir := filepath.Join(templateDir, "features") + + resourceFiles := features.CollectResourceFiles(featureIDs) + if len(resourceFiles) == 0 { + return &featureFragments{}, nil + } + + var bundleVarsList, bundleResList, targetVarsList, appEnvList, dotEnvList, dotEnvExampleList []string + + for _, rf := range resourceFiles { + if rf.BundleVariables != "" { + content, err := readAndSubstitute(filepath.Join(featuresDir, rf.BundleVariables), vars) + if err != nil { + return nil, fmt.Errorf("read bundle variables: %w", err) + } + bundleVarsList = append(bundleVarsList, content) + } + if rf.BundleResources != "" { + content, err := readAndSubstitute(filepath.Join(featuresDir, rf.BundleResources), vars) + if err != nil { + return nil, fmt.Errorf("read bundle resources: %w", err) + } + bundleResList = append(bundleResList, content) + } + if rf.TargetVariables != "" { + content, err := readAndSubstitute(filepath.Join(featuresDir, rf.TargetVariables), vars) + if err != nil { + return nil, fmt.Errorf("read target variables: %w", err) + } + targetVarsList = append(targetVarsList, content) + } + if rf.AppEnv != "" { + content, err := readAndSubstitute(filepath.Join(featuresDir, rf.AppEnv), vars) + if err != nil { + return nil, fmt.Errorf("read app env: %w", err) + } + appEnvList = append(appEnvList, content) + } + if rf.DotEnv != "" { + content, err := readAndSubstitute(filepath.Join(featuresDir, rf.DotEnv), vars) + if err != nil { + return nil, fmt.Errorf("read dotenv: %w", err) + } + dotEnvList = append(dotEnvList, content) + } + if rf.DotEnvExample != "" { + content, err := readAndSubstitute(filepath.Join(featuresDir, rf.DotEnvExample), vars) + if err != nil { + return nil, fmt.Errorf("read dotenv example: %w", err) + } + dotEnvExampleList = append(dotEnvExampleList, content) + } + } + + // Join fragments (they already have proper indentation from the fragment files) + return &featureFragments{ + BundleVariables: strings.TrimSuffix(strings.Join(bundleVarsList, ""), "\n"), + BundleResources: strings.TrimSuffix(strings.Join(bundleResList, ""), "\n"), + TargetVariables: strings.TrimSuffix(strings.Join(targetVarsList, ""), "\n"), + AppEnv: strings.TrimSuffix(strings.Join(appEnvList, ""), "\n"), + DotEnv: strings.TrimSuffix(strings.Join(dotEnvList, ""), "\n"), + DotEnvExample: strings.TrimSuffix(strings.Join(dotEnvExampleList, ""), "\n"), + }, nil +} + +// readAndSubstitute reads a file and applies variable substitution. +func readAndSubstitute(path string, vars templateVars) (string, error) { + content, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return "", nil // Fragment file doesn't exist, skip it + } + return "", err + } + return substituteVars(string(content), vars), nil +} + +// parseGitHubURL extracts the repository URL, subdirectory, and branch from a GitHub URL. +// Input: https://github.com/user/repo/tree/main/templates/starter +// Output: repoURL="https://github.com/user/repo", subdir="templates/starter", branch="main" +func parseGitHubURL(url string) (repoURL, subdir, branch string) { + // Remove trailing slash + url = strings.TrimSuffix(url, "/") + + // Check for /tree/branch/path pattern + if idx := strings.Index(url, "/tree/"); idx != -1 { + repoURL = url[:idx] + rest := url[idx+6:] // Skip "/tree/" + + // Split into branch and path + parts := strings.SplitN(rest, "/", 2) + branch = parts[0] + if len(parts) > 1 { + subdir = parts[1] + } + return repoURL, subdir, branch + } + + // No /tree/ pattern, just a repo URL + return url, "", "" +} + +// cloneRepo clones a git repository to a temporary directory. +func cloneRepo(ctx context.Context, repoURL, branch string) (string, error) { + tempDir, err := os.MkdirTemp("", "appkit-template-*") + if err != nil { + return "", fmt.Errorf("create temp dir: %w", err) + } + + args := []string{"clone", "--depth", "1"} + if branch != "" { + args = append(args, "--branch", branch) + } + args = append(args, repoURL, tempDir) + + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Stdout = nil + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + os.RemoveAll(tempDir) + if stderr.Len() > 0 { + return "", fmt.Errorf("git clone failed: %s: %w", strings.TrimSpace(stderr.String()), err) + } + return "", fmt.Errorf("git clone failed: %w", err) + } + + return tempDir, nil +} + +// resolveTemplate resolves a template path, handling both local paths and GitHub URLs. +// Returns the local path to use, a cleanup function (for temp dirs), and any error. +func resolveTemplate(ctx context.Context, templatePath, branch string) (localPath string, cleanup func(), err error) { + // Case 1: Local path - return as-is + if !strings.HasPrefix(templatePath, "https://") { + return templatePath, nil, nil + } + + // Case 2: GitHub URL - parse and clone + repoURL, subdir, urlBranch := parseGitHubURL(templatePath) + if branch == "" { + branch = urlBranch // Use branch from URL if not overridden by flag + } + + // Clone to temp dir with spinner + var tempDir string + err = prompt.RunWithSpinnerCtx(ctx, "Cloning template...", func() error { + var cloneErr error + tempDir, cloneErr = cloneRepo(ctx, repoURL, branch) + return cloneErr + }) + if err != nil { + return "", nil, err + } + + cleanup = func() { os.RemoveAll(tempDir) } + + // Return path to subdirectory if specified + if subdir != "" { + return filepath.Join(tempDir, subdir), cleanup, nil + } + return tempDir, cleanup, nil +} + +func runCreate(ctx context.Context, opts createOptions) error { + var selectedFeatures []string + var dependencies map[string]string + var shouldDeploy bool + var runMode prompt.RunMode + isInteractive := cmdio.IsPromptSupported(ctx) + + // Use features from flags if provided + if len(opts.features) > 0 { + selectedFeatures = opts.features + } + + // Resolve template path (supports local paths and GitHub URLs) + templateSrc := opts.templatePath + if templateSrc == "" { + templateSrc = os.Getenv(templatePathEnvVar) + } + if templateSrc == "" { + // Use default template from GitHub + templateSrc = defaultTemplateURL + } + + // Step 1: Get project name first (needed before we can check destination) + // Determine output directory for validation + destDir := opts.name + if opts.outputDir != "" { + destDir = filepath.Join(opts.outputDir, opts.name) + } + + if opts.name == "" { + if !isInteractive { + return errors.New("--name is required in non-interactive mode") + } + // Prompt includes validation for name format AND directory existence + name, err := prompt.PromptForProjectName(opts.outputDir) + if err != nil { + return err + } + opts.name = name + // Update destDir with the actual name + destDir = opts.name + if opts.outputDir != "" { + destDir = filepath.Join(opts.outputDir, opts.name) + } + } else { + // Non-interactive mode: validate name and directory existence + if err := prompt.ValidateProjectName(opts.name); err != nil { + return err + } + if _, err := os.Stat(destDir); err == nil { + return fmt.Errorf("directory %s already exists", destDir) + } + } + + // Step 2: Resolve template (handles GitHub URLs by cloning) + resolvedPath, cleanup, err := resolveTemplate(ctx, templateSrc, opts.branch) + if err != nil { + return err + } + if cleanup != nil { + defer cleanup() + } + + // Check for generic subdirectory first (default for multi-template repos) + templateDir := filepath.Join(resolvedPath, "generic") + if _, err := os.Stat(templateDir); os.IsNotExist(err) { + // Fall back to the provided path directly + templateDir = resolvedPath + if _, err := os.Stat(templateDir); os.IsNotExist(err) { + return fmt.Errorf("template not found at %s (also checked %s/generic)", resolvedPath, resolvedPath) + } + } + + // Step 3: Determine template type and gather configuration + usesFeatureFragments := features.HasFeaturesDirectory(templateDir) + + if usesFeatureFragments { + // Feature-fragment template: prompt for features and their dependencies + if isInteractive && len(selectedFeatures) == 0 { + // Need to prompt for features (but we already have the name) + config, err := promptForFeaturesAndDeps(ctx, selectedFeatures) + if err != nil { + return err + } + selectedFeatures = config.Features + dependencies = config.Dependencies + if config.Description != "" { + opts.description = config.Description + } + shouldDeploy = config.Deploy + runMode = config.RunMode + + // Get warehouse from dependencies if provided + if wh, ok := dependencies["sql_warehouse_id"]; ok && wh != "" { + opts.warehouseID = wh + } + } else { + // Non-interactive or features provided via flag + flagValues := map[string]string{ + "warehouse-id": opts.warehouseID, + } + if len(selectedFeatures) > 0 { + if err := features.ValidateFeatureDependencies(selectedFeatures, flagValues); err != nil { + return err + } + } + dependencies = make(map[string]string) + if opts.warehouseID != "" { + dependencies["sql_warehouse_id"] = opts.warehouseID + } + var err error + shouldDeploy, runMode, err = parseDeployAndRunFlags(opts.deploy, opts.run) + if err != nil { + return err + } + } + + // Validate feature IDs + if err := features.ValidateFeatureIDs(selectedFeatures); err != nil { + return err + } + } else { + // Pre-assembled template: detect plugins and prompt for their dependencies + detectedPlugins, err := features.DetectPluginsFromServer(templateDir) + if err != nil { + return fmt.Errorf("failed to detect plugins: %w", err) + } + + log.Debugf(ctx, "Detected plugins: %v", detectedPlugins) + + // Map detected plugins to feature IDs for ApplyFeatures + selectedFeatures = features.MapPluginsToFeatures(detectedPlugins) + log.Debugf(ctx, "Mapped to features: %v", selectedFeatures) + + pluginDeps := features.GetPluginDependencies(detectedPlugins) + + log.Debugf(ctx, "Plugin dependencies: %d", len(pluginDeps)) + + if isInteractive && len(pluginDeps) > 0 { + // Prompt for plugin dependencies + dependencies, err = prompt.PromptForPluginDependencies(ctx, pluginDeps) + if err != nil { + return err + } + if wh, ok := dependencies["sql_warehouse_id"]; ok && wh != "" { + opts.warehouseID = wh + } + } else { + // Non-interactive: check flags + dependencies = make(map[string]string) + if opts.warehouseID != "" { + dependencies["sql_warehouse_id"] = opts.warehouseID + } + + // Validate required dependencies are provided + for _, dep := range pluginDeps { + if dep.Required { + if _, ok := dependencies[dep.ID]; !ok { + return fmt.Errorf("missing required flag --%s for detected plugin", dep.FlagName) + } + } + } + } + + // Prompt for description and post-creation actions + if isInteractive { + if opts.description == "" { + opts.description = prompt.DefaultAppDescription + } + var deployVal bool + var runVal prompt.RunMode + deployVal, runVal, err = prompt.PromptForDeployAndRun() + if err != nil { + return err + } + shouldDeploy = deployVal + runMode = runVal + } else { + var err error + shouldDeploy, runMode, err = parseDeployAndRunFlags(opts.deploy, opts.run) + if err != nil { + return err + } + } + } + + // Track whether we started creating the project for cleanup on failure + var projectCreated bool + var runErr error + defer func() { + if runErr != nil && projectCreated { + // Clean up partially created project on failure + os.RemoveAll(destDir) + } + }() + + // Set description default + if opts.description == "" { + opts.description = prompt.DefaultAppDescription + } + + // Get workspace host and profile from context + workspaceHost := "" + profile := "" + if w := cmdctx.WorkspaceClient(ctx); w != nil && w.Config != nil { + workspaceHost = w.Config.Host + profile = w.Config.Profile + } + + // Build plugin imports and usages from selected features + pluginImport, pluginUsage := features.BuildPluginStrings(selectedFeatures) + + // Template variables (initial, without feature fragments) + vars := templateVars{ + ProjectName: opts.name, + SQLWarehouseID: opts.warehouseID, + AppDescription: opts.description, + Profile: profile, + WorkspaceHost: workspaceHost, + PluginImport: pluginImport, + PluginUsage: pluginUsage, + } + + // Load feature resource fragments + fragments, err := loadFeatureFragments(templateDir, selectedFeatures, vars) + if err != nil { + return fmt.Errorf("load feature fragments: %w", err) + } + vars.BundleVariables = fragments.BundleVariables + vars.BundleResources = fragments.BundleResources + vars.TargetVariables = fragments.TargetVariables + vars.AppEnv = fragments.AppEnv + vars.DotEnv = fragments.DotEnv + vars.DotEnvExample = fragments.DotEnvExample + + // Copy template with variable substitution + var fileCount int + runErr = prompt.RunWithSpinnerCtx(ctx, "Creating project...", func() error { + var copyErr error + fileCount, copyErr = copyTemplate(templateDir, destDir, vars) + return copyErr + }) + if runErr != nil { + return runErr + } + projectCreated = true // From here on, cleanup on failure + + // Get absolute path + absOutputDir, err := filepath.Abs(destDir) + if err != nil { + absOutputDir = destDir + } + + // Apply features (adds selected features, removes unselected feature files) + runErr = prompt.RunWithSpinnerCtx(ctx, "Configuring features...", func() error { + return features.ApplyFeatures(absOutputDir, selectedFeatures) + }) + if runErr != nil { + return runErr + } + + // Run npm install + runErr = runNpmInstall(ctx, absOutputDir) + if runErr != nil { + return runErr + } + + // Run npm run setup + runErr = runNpmSetup(ctx, absOutputDir) + if runErr != nil { + return runErr + } + + // Show next steps only if user didn't choose to deploy or run + showNextSteps := !shouldDeploy && runMode == prompt.RunModeNone + prompt.PrintSuccess(opts.name, absOutputDir, fileCount, showNextSteps) + + // Execute post-creation actions (deploy and/or run) + if shouldDeploy || runMode != prompt.RunModeNone { + // Change to project directory for subsequent commands + if err := os.Chdir(absOutputDir); err != nil { + return fmt.Errorf("failed to change to project directory: %w", err) + } + } + + if shouldDeploy { + cmdio.LogString(ctx, "") + cmdio.LogString(ctx, "Deploying app...") + if err := runPostCreateDeploy(ctx); err != nil { + cmdio.LogString(ctx, fmt.Sprintf("⚠ Deploy failed: %v", err)) + cmdio.LogString(ctx, " You can deploy manually with: databricks apps deploy") + } + } + + if runMode != prompt.RunModeNone { + cmdio.LogString(ctx, "") + if err := runPostCreateDev(ctx, runMode); err != nil { + return err + } + } + + return nil +} + +// runPostCreateDeploy runs the deploy command in the current directory. +func runPostCreateDeploy(ctx context.Context) error { + // Use os.Args[0] to get the path to the current executable + executable := os.Args[0] + cmd := exec.CommandContext(ctx, executable, "apps", "deploy") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + return cmd.Run() +} + +// runPostCreateDev runs the dev or dev-remote command in the current directory. +func runPostCreateDev(ctx context.Context, mode prompt.RunMode) error { + switch mode { + case prompt.RunModeDev: + cmdio.LogString(ctx, "Starting development server (npm run dev)...") + cmd := exec.CommandContext(ctx, "npm", "run", "dev") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + return cmd.Run() + case prompt.RunModeDevRemote: + cmdio.LogString(ctx, "Starting remote development server...") + // Use os.Args[0] to get the path to the current executable + executable := os.Args[0] + cmd := exec.CommandContext(ctx, executable, "apps", "dev-remote") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + return cmd.Run() + default: + return nil + } +} + +// runNpmInstall runs npm install in the project directory. +func runNpmInstall(ctx context.Context, projectDir string) error { + // Check if npm is available + if _, err := exec.LookPath("npm"); err != nil { + cmdio.LogString(ctx, "⚠ npm not found. Please install Node.js and run 'npm install' manually.") + return nil + } + + return prompt.RunWithSpinnerCtx(ctx, "Installing dependencies...", func() error { + cmd := exec.CommandContext(ctx, "npm", "install") + cmd.Dir = projectDir + cmd.Stdout = nil // Suppress output + cmd.Stderr = nil + return cmd.Run() + }) +} + +// runNpmSetup runs npx appkit-setup in the project directory. +func runNpmSetup(ctx context.Context, projectDir string) error { + // Check if npx is available + if _, err := exec.LookPath("npx"); err != nil { + return nil + } + + return prompt.RunWithSpinnerCtx(ctx, "Running setup...", func() error { + cmd := exec.CommandContext(ctx, "npx", "appkit-setup", "--write") + cmd.Dir = projectDir + cmd.Stdout = nil // Suppress output + cmd.Stderr = nil + return cmd.Run() + }) +} + +// renameFiles maps source file names to destination names (for files that can't use special chars). +var renameFiles = map[string]string{ + "_gitignore": ".gitignore", + "_env": ".env", + "_env.local": ".env.local", + "_npmrc": ".npmrc", + "_prettierrc": ".prettierrc", + "_eslintrc": ".eslintrc", +} + +// copyTemplate copies the template directory to dest, substituting variables. +func copyTemplate(src, dest string, vars templateVars) (int, error) { + fileCount := 0 + + // Find the project_name placeholder directory + srcProjectDir := "" + entries, err := os.ReadDir(src) + if err != nil { + return 0, err + } + for _, e := range entries { + if e.IsDir() && strings.Contains(e.Name(), "{{.project_name}}") { + srcProjectDir = filepath.Join(src, e.Name()) + break + } + } + + // If no {{.project_name}} dir found, copy src directly + if srcProjectDir == "" { + srcProjectDir = src + } + + log.Debugf(context.Background(), "Copying template from: %s", srcProjectDir) + + // Files and directories to skip + skipFiles := map[string]bool{ + "CLAUDE.md": true, + "AGENTS.md": true, + "databricks_template_schema.json": true, + } + skipDirs := map[string]bool{ + "docs": true, + "features": true, // Feature fragments are processed separately, not copied + } + + err = filepath.Walk(srcProjectDir, func(srcPath string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + baseName := filepath.Base(srcPath) + + // Skip certain files + if skipFiles[baseName] { + log.Debugf(context.Background(), "Skipping file: %s", baseName) + return nil + } + + // Skip certain directories + if info.IsDir() && skipDirs[baseName] { + log.Debugf(context.Background(), "Skipping directory: %s", baseName) + return filepath.SkipDir + } + + // Calculate relative path from source project dir + relPath, err := filepath.Rel(srcProjectDir, srcPath) + if err != nil { + return err + } + + // Substitute variables in path + relPath = substituteVars(relPath, vars) + + // Handle .tmpl extension - strip it + relPath = strings.TrimSuffix(relPath, ".tmpl") + + // Apply file renames (e.g., _gitignore -> .gitignore) + fileName := filepath.Base(relPath) + if newName, ok := renameFiles[fileName]; ok { + relPath = filepath.Join(filepath.Dir(relPath), newName) + } + + destPath := filepath.Join(dest, relPath) + + if info.IsDir() { + log.Debugf(context.Background(), "Creating directory: %s", relPath) + return os.MkdirAll(destPath, info.Mode()) + } + + log.Debugf(context.Background(), "Copying file: %s", relPath) + + // Read file content + content, err := os.ReadFile(srcPath) + if err != nil { + return err + } + + // Handle special files + switch filepath.Base(srcPath) { + case "package.json": + content, err = processPackageJSON(content, vars) + if err != nil { + return fmt.Errorf("process package.json: %w", err) + } + default: + // Use Go template engine for .tmpl files (handles conditionals) + if strings.HasSuffix(srcPath, ".tmpl") { + content, err = executeTemplate(srcPath, content, vars) + if err != nil { + return fmt.Errorf("process template %s: %w", srcPath, err) + } + } else if isTextFile(srcPath) { + // Simple substitution for other text files + content = []byte(substituteVars(string(content), vars)) + } + } + + // Create parent directory + if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + return err + } + + // Write file + if err := os.WriteFile(destPath, content, info.Mode()); err != nil { + return err + } + + fileCount++ + return nil + }) + if err != nil { + log.Debugf(context.Background(), "Error during template copy: %v", err) + } + log.Debugf(context.Background(), "Copied %d files", fileCount) + + return fileCount, err +} + +// processPackageJSON updates the package.json with project-specific values. +func processPackageJSON(content []byte, vars templateVars) ([]byte, error) { + // Just do string substitution to preserve key order and formatting + return []byte(substituteVars(string(content), vars)), nil +} + +// substituteVars replaces template variables in a string. +func substituteVars(s string, vars templateVars) string { + s = strings.ReplaceAll(s, "{{.project_name}}", vars.ProjectName) + s = strings.ReplaceAll(s, "{{.sql_warehouse_id}}", vars.SQLWarehouseID) + s = strings.ReplaceAll(s, "{{.app_description}}", vars.AppDescription) + s = strings.ReplaceAll(s, "{{.profile}}", vars.Profile) + s = strings.ReplaceAll(s, "{{workspace_host}}", vars.WorkspaceHost) + + // Handle plugin placeholders + if vars.PluginImport != "" { + s = strings.ReplaceAll(s, "{{.plugin_import}}", vars.PluginImport) + s = strings.ReplaceAll(s, "{{.plugin_usage}}", vars.PluginUsage) + } else { + // No plugins selected - clean up the template + // Remove ", {{.plugin_import}}" from import line + s = strings.ReplaceAll(s, ", {{.plugin_import}} ", " ") + s = strings.ReplaceAll(s, ", {{.plugin_import}}", "") + // Remove the plugin_usage line entirely + s = strings.ReplaceAll(s, " {{.plugin_usage}},\n", "") + s = strings.ReplaceAll(s, " {{.plugin_usage}},", "") + } + + return s +} + +// executeTemplate processes a .tmpl file using Go's text/template engine. +func executeTemplate(path string, content []byte, vars templateVars) ([]byte, error) { + tmpl, err := template.New(filepath.Base(path)). + Funcs(template.FuncMap{ + "workspace_host": func() string { return vars.WorkspaceHost }, + }). + Parse(string(content)) + if err != nil { + return nil, fmt.Errorf("parse template: %w", err) + } + + // Use a map to match template variable names exactly (snake_case) + data := map[string]string{ + "project_name": vars.ProjectName, + "sql_warehouse_id": vars.SQLWarehouseID, + "app_description": vars.AppDescription, + "profile": vars.Profile, + "workspace_host": vars.WorkspaceHost, + "plugin_import": vars.PluginImport, + "plugin_usage": vars.PluginUsage, + "bundle_variables": vars.BundleVariables, + "bundle_resources": vars.BundleResources, + "target_variables": vars.TargetVariables, + "app_env": vars.AppEnv, + "dotenv": vars.DotEnv, + "dotenv_example": vars.DotEnvExample, + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return nil, fmt.Errorf("execute template: %w", err) + } + + return buf.Bytes(), nil +} + +// textExtensions contains file extensions that should be treated as text files. +var textExtensions = map[string]bool{ + ".ts": true, ".tsx": true, ".js": true, ".jsx": true, + ".json": true, ".yaml": true, ".yml": true, + ".md": true, ".txt": true, ".html": true, ".css": true, + ".scss": true, ".less": true, ".sql": true, + ".sh": true, ".bash": true, ".zsh": true, + ".py": true, ".go": true, ".rs": true, + ".toml": true, ".ini": true, ".cfg": true, + ".env": true, ".gitignore": true, ".npmrc": true, + ".prettierrc": true, ".eslintrc": true, +} + +// textBaseNames contains file names (without extension) that should be treated as text files. +var textBaseNames = map[string]bool{ + "Makefile": true, "Dockerfile": true, "LICENSE": true, + "README": true, ".gitignore": true, ".env": true, + ".nvmrc": true, ".node-version": true, + "_gitignore": true, "_env": true, "_npmrc": true, +} + +// isTextFile checks if a file is likely a text file based on extension. +func isTextFile(path string) bool { + ext := strings.ToLower(filepath.Ext(path)) + if textExtensions[ext] { + return true + } + return textBaseNames[filepath.Base(path)] +} diff --git a/cmd/apps/init_test.go b/cmd/apps/init_test.go new file mode 100644 index 0000000000..6e889e4f4e --- /dev/null +++ b/cmd/apps/init_test.go @@ -0,0 +1,304 @@ +package apps + +import ( + "testing" + + "github.com/databricks/cli/libs/apps/prompt" + "github.com/stretchr/testify/assert" +) + +func TestParseGitHubURL(t *testing.T) { + tests := []struct { + name string + url string + wantRepoURL string + wantSubdir string + wantBranch string + }{ + { + name: "simple repo URL", + url: "https://github.com/user/repo", + wantRepoURL: "https://github.com/user/repo", + wantSubdir: "", + wantBranch: "", + }, + { + name: "repo URL with trailing slash", + url: "https://github.com/user/repo/", + wantRepoURL: "https://github.com/user/repo", + wantSubdir: "", + wantBranch: "", + }, + { + name: "repo with branch", + url: "https://github.com/user/repo/tree/main", + wantRepoURL: "https://github.com/user/repo", + wantSubdir: "", + wantBranch: "main", + }, + { + name: "repo with branch and subdir", + url: "https://github.com/user/repo/tree/main/templates/starter", + wantRepoURL: "https://github.com/user/repo", + wantSubdir: "templates/starter", + wantBranch: "main", + }, + { + name: "repo with branch and deep subdir", + url: "https://github.com/databricks/cli/tree/v0.1.0/libs/template/templates/default-python", + wantRepoURL: "https://github.com/databricks/cli", + wantSubdir: "libs/template/templates/default-python", + wantBranch: "v0.1.0", + }, + { + name: "repo with feature branch", + url: "https://github.com/user/repo/tree/feature/my-feature", + wantRepoURL: "https://github.com/user/repo", + wantSubdir: "my-feature", + wantBranch: "feature", + }, + { + name: "repo URL with trailing slash and tree", + url: "https://github.com/user/repo/tree/main/", + wantRepoURL: "https://github.com/user/repo", + wantSubdir: "", + wantBranch: "main", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotRepoURL, gotSubdir, gotBranch := parseGitHubURL(tt.url) + assert.Equal(t, tt.wantRepoURL, gotRepoURL, "repoURL mismatch") + assert.Equal(t, tt.wantSubdir, gotSubdir, "subdir mismatch") + assert.Equal(t, tt.wantBranch, gotBranch, "branch mismatch") + }) + } +} + +func TestIsTextFile(t *testing.T) { + tests := []struct { + path string + expected bool + }{ + // Text files by extension + {"file.ts", true}, + {"file.tsx", true}, + {"file.js", true}, + {"file.jsx", true}, + {"file.json", true}, + {"file.yaml", true}, + {"file.yml", true}, + {"file.md", true}, + {"file.txt", true}, + {"file.html", true}, + {"file.css", true}, + {"file.scss", true}, + {"file.sql", true}, + {"file.sh", true}, + {"file.py", true}, + {"file.go", true}, + {"file.toml", true}, + {"file.env", true}, + + // Text files by name + {"Makefile", true}, + {"Dockerfile", true}, + {"LICENSE", true}, + {"README", true}, + {".gitignore", true}, + {".env", true}, + {"_gitignore", true}, + {"_env", true}, + + // Binary files (should return false) + {"file.png", false}, + {"file.jpg", false}, + {"file.gif", false}, + {"file.pdf", false}, + {"file.exe", false}, + {"file.bin", false}, + {"file.zip", false}, + {"randomfile", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + result := isTextFile(tt.path) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestSubstituteVars(t *testing.T) { + vars := templateVars{ + ProjectName: "my-app", + SQLWarehouseID: "warehouse123", + AppDescription: "My awesome app", + Profile: "default", + WorkspaceHost: "https://dbc-123.cloud.databricks.com", + PluginImport: "analytics", + PluginUsage: "analytics()", + } + + tests := []struct { + name string + input string + expected string + }{ + { + name: "project name substitution", + input: "name: {{.project_name}}", + expected: "name: my-app", + }, + { + name: "warehouse id substitution", + input: "warehouse: {{.sql_warehouse_id}}", + expected: "warehouse: warehouse123", + }, + { + name: "description substitution", + input: "description: {{.app_description}}", + expected: "description: My awesome app", + }, + { + name: "profile substitution", + input: "profile: {{.profile}}", + expected: "profile: default", + }, + { + name: "workspace host substitution", + input: "host: {{workspace_host}}", + expected: "host: https://dbc-123.cloud.databricks.com", + }, + { + name: "plugin import substitution", + input: "import { {{.plugin_import}} } from 'appkit'", + expected: "import { analytics } from 'appkit'", + }, + { + name: "plugin usage substitution", + input: "plugins: [{{.plugin_usage}}]", + expected: "plugins: [analytics()]", + }, + { + name: "multiple substitutions", + input: "{{.project_name}} - {{.app_description}}", + expected: "my-app - My awesome app", + }, + { + name: "no substitutions needed", + input: "plain text without variables", + expected: "plain text without variables", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := substituteVars(tt.input, vars) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestSubstituteVarsNoPlugins(t *testing.T) { + // Test plugin cleanup when no plugins are selected + vars := templateVars{ + ProjectName: "my-app", + SQLWarehouseID: "", + AppDescription: "My app", + Profile: "", + WorkspaceHost: "", + PluginImport: "", // No plugins + PluginUsage: "", + } + + tests := []struct { + name string + input string + expected string + }{ + { + name: "removes plugin import with comma", + input: "import { core, {{.plugin_import}} } from 'appkit'", + expected: "import { core } from 'appkit'", + }, + { + name: "removes plugin usage line", + input: "plugins: [\n {{.plugin_usage}},\n]", + expected: "plugins: [\n]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := substituteVars(tt.input, vars) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParseDeployAndRunFlags(t *testing.T) { + tests := []struct { + name string + deploy bool + run string + wantDeploy bool + wantRunMode prompt.RunMode + wantErr bool + }{ + { + name: "deploy true, run none", + deploy: true, + run: "none", + wantDeploy: true, + wantRunMode: prompt.RunModeNone, + wantErr: false, + }, + { + name: "deploy true, run dev", + deploy: true, + run: "dev", + wantDeploy: true, + wantRunMode: prompt.RunModeDev, + wantErr: false, + }, + { + name: "deploy false, run dev-remote", + deploy: false, + run: "dev-remote", + wantDeploy: false, + wantRunMode: prompt.RunModeDevRemote, + wantErr: false, + }, + { + name: "empty run value", + deploy: false, + run: "", + wantDeploy: false, + wantRunMode: prompt.RunModeNone, + wantErr: false, + }, + { + name: "invalid run value", + deploy: true, + run: "invalid", + wantDeploy: false, + wantRunMode: prompt.RunModeNone, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + deploy, runMode, err := parseDeployAndRunFlags(tt.deploy, tt.run) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.wantDeploy, deploy) + assert.Equal(t, tt.wantRunMode, runMode) + }) + } +} diff --git a/cmd/workspace/apps/logs.go b/cmd/apps/logs.go similarity index 100% rename from cmd/workspace/apps/logs.go rename to cmd/apps/logs.go diff --git a/cmd/workspace/apps/logs_test.go b/cmd/apps/logs_test.go similarity index 100% rename from cmd/workspace/apps/logs_test.go rename to cmd/apps/logs_test.go diff --git a/cmd/workspace/apps/run_local.go b/cmd/apps/run_local.go similarity index 98% rename from cmd/workspace/apps/run_local.go rename to cmd/apps/run_local.go index 1d68f84b7e..bb6219609f 100644 --- a/cmd/workspace/apps/run_local.go +++ b/cmd/apps/run_local.go @@ -234,9 +234,3 @@ func newRunLocal() *cobra.Command { return cmd } - -func init() { - cmdOverrides = append(cmdOverrides, func(cmd *cobra.Command) { - cmd.AddCommand(newRunLocal()) - }) -} diff --git a/cmd/apps/validate.go b/cmd/apps/validate.go new file mode 100644 index 0000000000..ab37913a27 --- /dev/null +++ b/cmd/apps/validate.go @@ -0,0 +1,72 @@ +package apps + +import ( + "errors" + "fmt" + "os" + + "github.com/databricks/cli/libs/apps/validation" + "github.com/databricks/cli/libs/cmdio" + "github.com/spf13/cobra" +) + +func newValidateCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "validate", + Short: "Validate a Databricks App project", + Long: `Validate a Databricks App project by running build, typecheck, and lint checks. + +This command detects the project type and runs the appropriate validation: +- Node.js projects (package.json): runs npm install, build, typecheck, and lint + +Examples: + # Validate the current directory + databricks apps validate + + # Validate a specific directory + databricks apps validate --path ./my-app`, + RunE: func(cmd *cobra.Command, args []string) error { + return runValidate(cmd) + }, + } + + cmd.Flags().String("path", "", "Path to the project directory (defaults to current directory)") + + return cmd +} + +func runValidate(cmd *cobra.Command) error { + ctx := cmd.Context() + + // Get project path + projectPath, _ := cmd.Flags().GetString("path") + if projectPath == "" { + var err error + projectPath, err = os.Getwd() + if err != nil { + return fmt.Errorf("failed to get working directory: %w", err) + } + } + + // Get validator for project type + validator := validation.GetProjectValidator(projectPath) + if validator == nil { + return errors.New("no supported project type detected (looking for package.json)") + } + + // Run validation + result, err := validator.Validate(ctx, projectPath) + if err != nil { + return fmt.Errorf("validation error: %w", err) + } + + if !result.Success { + if result.Details != nil { + cmdio.LogString(ctx, result.Details.Error()) + } + return errors.New("validation failed") + } + + cmdio.LogString(ctx, "āœ… "+result.Message) + return nil +} diff --git a/cmd/workspace/apps/dev.go b/cmd/workspace/apps/dev.go deleted file mode 100644 index 5cad8f8879..0000000000 --- a/cmd/workspace/apps/dev.go +++ /dev/null @@ -1,172 +0,0 @@ -package apps - -import ( - "bytes" - "context" - _ "embed" - "errors" - "fmt" - "net" - "os" - "os/exec" - "os/signal" - "strconv" - "syscall" - "time" - - "github.com/databricks/cli/cmd/root" - "github.com/databricks/cli/libs/cmdctx" - "github.com/databricks/cli/libs/cmdio" - "github.com/spf13/cobra" -) - -//go:embed vite-server.js -var viteServerScript []byte - -const ( - vitePort = 5173 - viteReadyCheckInterval = 100 * time.Millisecond - viteReadyMaxAttempts = 50 -) - -func isViteReady(port int) bool { - conn, err := net.DialTimeout("tcp", "localhost:"+strconv.Itoa(port), viteReadyCheckInterval) - if err != nil { - return false - } - conn.Close() - return true -} - -func startViteDevServer(ctx context.Context, appURL string, port int) (*exec.Cmd, chan error, error) { - // Pass script through stdin, and pass arguments in order - viteCmd := exec.Command("node", "-", appURL, strconv.Itoa(port)) - viteCmd.Stdin = bytes.NewReader(viteServerScript) - viteCmd.Stdout = os.Stdout - viteCmd.Stderr = os.Stderr - - err := viteCmd.Start() - if err != nil { - return nil, nil, fmt.Errorf("failed to start Vite server: %w", err) - } - - cmdio.LogString(ctx, fmt.Sprintf("šŸš€ Starting Vite development server on port %d...", port)) - - viteErr := make(chan error, 1) - go func() { - if err := viteCmd.Wait(); err != nil { - viteErr <- fmt.Errorf("vite server exited with error: %w", err) - } else { - viteErr <- errors.New("vite server exited unexpectedly") - } - }() - - for range viteReadyMaxAttempts { - select { - case err := <-viteErr: - return nil, nil, err - default: - if isViteReady(port) { - return viteCmd, viteErr, nil - } - time.Sleep(viteReadyCheckInterval) - } - } - - _ = viteCmd.Process.Kill() - return nil, nil, errors.New("timeout waiting for Vite server to be ready") -} - -func newRunDevCommand() *cobra.Command { - var ( - appName string - clientPath string - port int - ) - - cmd := &cobra.Command{} - - cmd.Use = "dev-remote" - cmd.Hidden = true - cmd.Short = `Run Databricks app locally with WebSocket bridge to remote server.` - cmd.Long = `Run Databricks app locally with WebSocket bridge to remote server. - - Starts a local development server and establishes a WebSocket bridge - to the remote Databricks app for development. - ` - - cmd.PreRunE = root.MustWorkspaceClient - - cmd.Flags().StringVar(&appName, "app-name", "", "Name of the app to connect to (required)") - cmd.Flags().StringVar(&clientPath, "client-path", "./client", "Path to the Vite client directory") - cmd.Flags().IntVar(&port, "port", vitePort, "Port to run the Vite server on") - - cmd.RunE = func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - w := cmdctx.WorkspaceClient(ctx) - - if appName == "" { - return errors.New("app name is required (use --app-name)") - } - - if _, err := os.Stat(clientPath); os.IsNotExist(err) { - return fmt.Errorf("client directory not found: %s", clientPath) - } - - bridge := NewViteBridge(ctx, w, appName, port) - - appDomain, err := bridge.GetAppDomain() - if err != nil { - return fmt.Errorf("failed to get app domain: %w", err) - } - - viteCmd, viteErr, err := startViteDevServer(ctx, appDomain.String(), port) - if err != nil { - return err - } - - done := make(chan error, 1) - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - - go func() { - done <- bridge.Start() - }() - - select { - case err := <-viteErr: - bridge.Stop() - <-done - return err - case err := <-done: - cmdio.LogString(ctx, "Bridge stopped") - if viteCmd.Process != nil { - _ = viteCmd.Process.Signal(os.Interrupt) - <-viteErr - } - return err - case <-sigChan: - cmdio.LogString(ctx, "\nšŸ›‘ Shutting down...") - bridge.Stop() - <-done - if viteCmd.Process != nil { - if err := viteCmd.Process.Signal(os.Interrupt); err != nil { - cmdio.LogString(ctx, fmt.Sprintf("Failed to interrupt Vite: %v", err)) - _ = viteCmd.Process.Kill() - } - <-viteErr - } - return nil - } - } - - cmd.ValidArgsFunction = cobra.NoFileCompletions - - return cmd -} - -func init() { - cmdOverrides = append(cmdOverrides, func(cmd *cobra.Command) { - cmd.AddCommand(newRunDevCommand()) - }) -} diff --git a/cmd/workspace/apps/overrides.go b/cmd/workspace/apps/overrides.go index a1e35da903..f678cc3045 100644 --- a/cmd/workspace/apps/overrides.go +++ b/cmd/workspace/apps/overrides.go @@ -1,6 +1,9 @@ package apps import ( + "slices" + + appsCli "github.com/databricks/cli/cmd/apps" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go/service/apps" "github.com/spf13/cobra" @@ -23,6 +26,9 @@ func listDeploymentsOverride(listDeploymentsCmd *cobra.Command, listDeploymentsR } func createOverride(createCmd *cobra.Command, createReq *apps.CreateAppRequest) { + createCmd.Short = `Create an app in your workspace.` + createCmd.Long = `Create an app in your workspace.` + originalRunE := createCmd.RunE createCmd.RunE = func(cmd *cobra.Command, args []string) error { err := originalRunE(cmd, args) @@ -30,14 +36,6 @@ func createOverride(createCmd *cobra.Command, createReq *apps.CreateAppRequest) } } -func deployOverride(deployCmd *cobra.Command, deployReq *apps.CreateAppDeploymentRequest) { - originalRunE := deployCmd.RunE - deployCmd.RunE = func(cmd *cobra.Command, args []string) error { - err := originalRunE(cmd, args) - return wrapDeploymentError(cmd, deployReq.AppName, err) - } -} - func createUpdateOverride(createUpdateCmd *cobra.Command, createUpdateReq *apps.AsyncUpdateAppRequest) { originalRunE := createUpdateCmd.RunE createUpdateCmd.RunE = func(cmd *cobra.Command, args []string) error { @@ -56,12 +54,42 @@ func startOverride(startCmd *cobra.Command, startReq *apps.StartAppRequest) { func init() { cmdOverrides = append(cmdOverrides, func(cmd *cobra.Command) { - cmd.AddCommand(newLogsCommand()) + // Commands that should NOT go into the management group + // (either they are main commands or have special grouping) + nonManagementCommands := []string{ + // 'deploy' is overloaded as API and bundle command + "deploy", + // permission commands are assigned into "permission" group in cmd/cmd.go + "get-permission-levels", + "get-permissions", + "set-permissions", + "update-permissions", + } + + // Put auto-generated API commands into 'management' group + for _, subCmd := range cmd.Commands() { + if slices.Contains(nonManagementCommands, subCmd.Name()) { + continue + } + if subCmd.GroupID == "" { + subCmd.GroupID = appsCli.ManagementGroupID + } + } + + // Add custom commands from cmd/apps/ + for _, appsCmd := range appsCli.Commands() { + cmd.AddCommand(appsCmd) + } + + // Add --var flag support for bundle operations + cmd.PersistentFlags().StringSlice("var", []string{}, `set values for variables defined in bundle config. Example: --var="key=value"`) }) + + // Register command overrides listOverrides = append(listOverrides, listOverride) listDeploymentsOverrides = append(listDeploymentsOverrides, listDeploymentsOverride) createOverrides = append(createOverrides, createOverride) - deployOverrides = append(deployOverrides, deployOverride) + deployOverrides = append(deployOverrides, appsCli.BundleDeployOverrideWithWrapper(wrapDeploymentError)) createUpdateOverrides = append(createUpdateOverrides, createUpdateOverride) startOverrides = append(startOverrides, startOverride) } diff --git a/go.mod b/go.mod index dd1a073f73..3a4c15f796 100644 --- a/go.mod +++ b/go.mod @@ -43,15 +43,32 @@ require ( // Dependencies for experimental MCP commands require github.com/google/jsonschema-go v0.4.2 // MIT +require ( + github.com/charmbracelet/huh v0.8.0 + github.com/charmbracelet/lipgloss v1.1.0 +) + require ( cloud.google.com/go/auth v0.16.5 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.8.4 // indirect github.com/ProtonMail/go-crypto v1.1.6 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect + github.com/atotto/clipboard v0.1.4 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/catppuccin/go v0.3.0 // indirect + github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 // indirect + github.com/charmbracelet/bubbletea v1.3.6 // indirect + github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect + github.com/charmbracelet/x/ansi v0.9.3 // indirect + github.com/charmbracelet/x/cellbuf v0.0.13 // indirect + github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect + github.com/charmbracelet/x/term v0.2.1 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/cloudflare/circl v1.6.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -62,9 +79,18 @@ require ( github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.7 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/stretchr/objx v0.5.2 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/zclconf/go-cty v1.16.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect diff --git a/go.sum b/go.sum index 8be5ce6c28..397fcc914b 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= +github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0= github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= @@ -16,8 +18,44 @@ github.com/ProtonMail/go-crypto v1.1.6 h1:ZcV+Ropw6Qn0AX9brlQLAUXfqLBc7Bl+f/DmNx github.com/ProtonMail/go-crypto v1.1.6/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY= +github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E= github.com/briandowns/spinner v1.23.1 h1:t5fDPmScwUjozhDj4FA46p5acZWIPXYE30qW2Ptu650= github.com/briandowns/spinner v1.23.1/go.mod h1:LaZeM4wm2Ywy6vO571mvhQNRcWfRUnXOs0RcKV0wYKM= +github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY= +github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc= +github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 h1:JFgG/xnwFfbezlUnFMJy0nusZvytYysV4SCS2cYbvws= +github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7/go.mod h1:ISC1gtLcVilLOf23wvTfoQuYbW2q0JevFxPfUzZ9Ybw= +github.com/charmbracelet/bubbletea v1.3.6 h1:VkHIxPJQeDt0aFJIsVxw8BQdh/F/L2KKZGsK6et5taU= +github.com/charmbracelet/bubbletea v1.3.6/go.mod h1:oQD9VCRQFF8KplacJLo28/jofOI2ToOfGYeFgBBxHOc= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/charmbracelet/huh v0.8.0 h1:Xz/Pm2h64cXQZn/Jvele4J3r7DDiqFCNIVteYukxDvY= +github.com/charmbracelet/huh v0.8.0/go.mod h1:5YVc+SlZ1IhQALxRPpkGwwEKftN/+OlJlnJYlDRFqN4= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.9.3 h1:BXt5DHS/MKF+LjuK4huWrC6NCvHtexww7dMayh6GXd0= +github.com/charmbracelet/x/ansi v0.9.3/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= +github.com/charmbracelet/x/cellbuf v0.0.13 h1:/KBBKHuVRbq1lYx5BzEHBAFBP8VcQzJejZ/IA3iR28k= +github.com/charmbracelet/x/cellbuf v0.0.13/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/conpty v0.1.0 h1:4zc8KaIcbiL4mghEON8D72agYtSeIgq8FSThSPQIb+U= +github.com/charmbracelet/x/conpty v0.1.0/go.mod h1:rMFsDJoDwVmiYM10aD4bH2XiRgwI7NYJtQgl5yskjEQ= +github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA= +github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= +github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko3AQ4gK1MTS/de7F5hPGx6/k1u0w4TeYmBFwzYVP4= +github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ= +github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= +github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= +github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY= +github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo= +github.com/charmbracelet/x/xpty v0.1.2 h1:Pqmu4TEJ8KeA9uSkISKMU3f+C1F6OGBn8ABuGlqCbtI= +github.com/charmbracelet/x/xpty v0.1.2/go.mod h1:XK2Z0id5rtLWcpeNiMYBccNNBrP2IJnzHI0Lq13Xzq4= github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8= @@ -27,6 +65,8 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s= github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= github.com/databricks/databricks-sdk-go v0.96.0 h1:tpR3GSwkM3Vd6P9KfYEXAJiKZ1KLJ2T2+J3tF8jxlEk= @@ -34,8 +74,12 @@ github.com/databricks/databricks-sdk-go v0.96.0/go.mod h1:hWoHnHbNLjPKiTm5K/7bcI github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= @@ -101,6 +145,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= +github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= @@ -111,6 +157,18 @@ github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOA github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= +github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/nwidger/jsoncolor v0.3.2 h1:rVJJlwAWDJShnbTYOQ5RM7yTA20INyKXlJ/fg4JMhHQ= github.com/nwidger/jsoncolor v0.3.2/go.mod h1:Cs34umxLbJvgBMnVNVqhji9BhoT/N/KinHqZptQ7cf4= github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= @@ -121,6 +179,9 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quasilyte/go-ruleguard/dsl v0.3.22 h1:wd8zkOhSNr+I+8Qeciml08ivDt1pSXe60+5DqOpCjPE= github.com/quasilyte/go-ruleguard/dsl v0.3.22/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -143,6 +204,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/zclconf/go-cty v1.16.4 h1:QGXaag7/7dCzb+odlGrgr+YmYZFaOCMW6DEpS+UD1eE= github.com/zclconf/go-cty v1.16.4/go.mod h1:VvMs5i0vgZdhYawQNq5kePSpLAoz8u1xvZgrPIxfnZE= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= @@ -176,6 +239,7 @@ golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/libs/apps/features/features.go b/libs/apps/features/features.go new file mode 100644 index 0000000000..64a4ce3949 --- /dev/null +++ b/libs/apps/features/features.go @@ -0,0 +1,329 @@ +package features + +import ( + "fmt" + "os" + "path/filepath" + "regexp" + "strings" +) + +// FeatureDependency defines a prompt/input required by a feature. +type FeatureDependency struct { + ID string // e.g., "sql_warehouse_id" + FlagName string // CLI flag name, e.g., "warehouse-id" (maps to --warehouse-id) + Title string // e.g., "SQL Warehouse ID" + Description string // e.g., "Required for executing SQL queries" + Placeholder string + Required bool +} + +// FeatureResourceFiles defines paths to YAML fragment files for a feature's resources. +// Paths are relative to the template's features directory (e.g., "analytics/bundle_variables.yml"). +type FeatureResourceFiles struct { + BundleVariables string // Variables section for databricks.yml + BundleResources string // Resources section for databricks.yml (app resources) + TargetVariables string // Dev target variables section for databricks.yml + AppEnv string // Environment variables for app.yaml + DotEnv string // Environment variables for .env (development) + DotEnvExample string // Environment variables for .env.example +} + +// Feature represents an optional feature that can be added to an AppKit project. +type Feature struct { + ID string + Name string + Description string + PluginImport string + PluginUsage string + Dependencies []FeatureDependency + ResourceFiles FeatureResourceFiles +} + +// AvailableFeatures lists all features that can be selected when creating a project. +var AvailableFeatures = []Feature{ + { + ID: "analytics", + Name: "Analytics", + Description: "SQL analytics with charts and dashboards", + PluginImport: "analytics", + PluginUsage: "analytics()", + Dependencies: []FeatureDependency{ + { + ID: "sql_warehouse_id", + FlagName: "warehouse-id", + Title: "SQL Warehouse ID", + Description: "required for SQL queries", + Required: true, + }, + }, + ResourceFiles: FeatureResourceFiles{ + BundleVariables: "analytics/bundle_variables.yml", + BundleResources: "analytics/bundle_resources.yml", + TargetVariables: "analytics/target_variables.yml", + AppEnv: "analytics/app_env.yml", + DotEnv: "analytics/dotenv.yml", + DotEnvExample: "analytics/dotenv_example.yml", + }, + }, +} + +var featureByID = func() map[string]Feature { + m := make(map[string]Feature, len(AvailableFeatures)) + for _, f := range AvailableFeatures { + m[f.ID] = f + } + return m +}() + +// featureByPluginImport maps plugin import names to features. +var featureByPluginImport = func() map[string]Feature { + m := make(map[string]Feature, len(AvailableFeatures)) + for _, f := range AvailableFeatures { + if f.PluginImport != "" { + m[f.PluginImport] = f + } + } + return m +}() + +// pluginPattern matches plugin function calls dynamically built from AvailableFeatures. +// Matches patterns like: analytics(), genie(), oauth(), etc. +var pluginPattern = func() *regexp.Regexp { + var plugins []string + for _, f := range AvailableFeatures { + if f.PluginImport != "" { + plugins = append(plugins, regexp.QuoteMeta(f.PluginImport)) + } + } + if len(plugins) == 0 { + // Fallback pattern that matches nothing + return regexp.MustCompile(`$^`) + } + // Build pattern: \b(plugin1|plugin2|plugin3)\s*\( + pattern := `\b(` + strings.Join(plugins, "|") + `)\s*\(` + return regexp.MustCompile(pattern) +}() + +// serverFilePaths lists common locations for the server entry file. +var serverFilePaths = []string{ + "src/server/index.ts", + "src/server/index.tsx", + "src/server.ts", + "server/index.ts", + "server/server.ts", + "server.ts", +} + +// TODO: We should come to an agreement if we want to do it like this, +// or maybe we should have an appkit.json manifest file in each project. +func DetectPluginsFromServer(templateDir string) ([]string, error) { + var content []byte + + for _, p := range serverFilePaths { + fullPath := filepath.Join(templateDir, p) + data, err := os.ReadFile(fullPath) + if err == nil { + content = data + break + } + } + + if content == nil { + return nil, nil // No server file found + } + + matches := pluginPattern.FindAllStringSubmatch(string(content), -1) + seen := make(map[string]bool) + var plugins []string + + for _, m := range matches { + plugin := m[1] + if !seen[plugin] { + seen[plugin] = true + plugins = append(plugins, plugin) + } + } + + return plugins, nil +} + +// GetPluginDependencies returns all dependencies required by the given plugin names. +func GetPluginDependencies(pluginNames []string) []FeatureDependency { + seen := make(map[string]bool) + var deps []FeatureDependency + + for _, plugin := range pluginNames { + feature, ok := featureByPluginImport[plugin] + if !ok { + continue + } + for _, dep := range feature.Dependencies { + if !seen[dep.ID] { + seen[dep.ID] = true + deps = append(deps, dep) + } + } + } + + return deps +} + +// MapPluginsToFeatures maps plugin import names to feature IDs. +// This is used to convert detected plugins (e.g., "analytics") to feature IDs +// so that ApplyFeatures can properly retain feature-specific files. +func MapPluginsToFeatures(pluginNames []string) []string { + seen := make(map[string]bool) + var featureIDs []string + + for _, plugin := range pluginNames { + feature, ok := featureByPluginImport[plugin] + if ok && !seen[feature.ID] { + seen[feature.ID] = true + featureIDs = append(featureIDs, feature.ID) + } + } + + return featureIDs +} + +// HasFeaturesDirectory checks if the template uses the feature-fragment system. +func HasFeaturesDirectory(templateDir string) bool { + featuresDir := filepath.Join(templateDir, "features") + info, err := os.Stat(featuresDir) + return err == nil && info.IsDir() +} + +// ValidateFeatureIDs checks that all provided feature IDs are valid. +// Returns an error if any feature ID is unknown. +func ValidateFeatureIDs(featureIDs []string) error { + for _, id := range featureIDs { + if _, ok := featureByID[id]; !ok { + return fmt.Errorf("unknown feature: %q; available: %s", id, strings.Join(GetFeatureIDs(), ", ")) + } + } + return nil +} + +// ValidateFeatureDependencies checks that all required dependencies for the given features +// are provided in the flagValues map. Returns an error listing missing required flags. +func ValidateFeatureDependencies(featureIDs []string, flagValues map[string]string) error { + deps := CollectDependencies(featureIDs) + var missing []string + + for _, dep := range deps { + if !dep.Required { + continue + } + value, ok := flagValues[dep.FlagName] + if !ok || value == "" { + missing = append(missing, "--"+dep.FlagName) + } + } + + if len(missing) > 0 { + return fmt.Errorf("missing required flags for selected features: %s", strings.Join(missing, ", ")) + } + return nil +} + +// GetFeatureIDs returns a list of all available feature IDs for help text. +func GetFeatureIDs() []string { + ids := make([]string, len(AvailableFeatures)) + for i, f := range AvailableFeatures { + ids[i] = f.ID + } + return ids +} + +// BuildPluginStrings builds the plugin import and usage strings from selected feature IDs. +// Returns comma-separated imports and newline-separated usages. +func BuildPluginStrings(featureIDs []string) (pluginImport, pluginUsage string) { + if len(featureIDs) == 0 { + return "", "" + } + + var imports []string + var usages []string + + for _, id := range featureIDs { + feature, ok := featureByID[id] + if !ok || feature.PluginImport == "" { + continue + } + imports = append(imports, feature.PluginImport) + usages = append(usages, feature.PluginUsage) + } + + if len(imports) == 0 { + return "", "" + } + + // Join imports with comma (e.g., "analytics, trpc") + pluginImport = strings.Join(imports, ", ") + + // Join usages with newline and proper indentation + pluginUsage = strings.Join(usages, ",\n ") + + return pluginImport, pluginUsage +} + +// ApplyFeatures applies any post-copy modifications for selected features. +// This removes feature-specific directories if the feature is not selected. +func ApplyFeatures(projectDir string, featureIDs []string) error { + selectedSet := make(map[string]bool) + for _, id := range featureIDs { + selectedSet[id] = true + } + + // Remove analytics-specific files if analytics is not selected + if !selectedSet["analytics"] { + queriesDir := filepath.Join(projectDir, "config", "queries") + if err := os.RemoveAll(queriesDir); err != nil && !os.IsNotExist(err) { + return err + } + } + + return nil +} + +// CollectDependencies returns all unique dependencies required by the selected features. +func CollectDependencies(featureIDs []string) []FeatureDependency { + seen := make(map[string]bool) + var deps []FeatureDependency + + for _, id := range featureIDs { + feature, ok := featureByID[id] + if !ok { + continue + } + for _, dep := range feature.Dependencies { + if !seen[dep.ID] { + seen[dep.ID] = true + deps = append(deps, dep) + } + } + } + + return deps +} + +// CollectResourceFiles returns all resource file paths for the selected features. +func CollectResourceFiles(featureIDs []string) []FeatureResourceFiles { + var resources []FeatureResourceFiles + for _, id := range featureIDs { + feature, ok := featureByID[id] + if !ok { + continue + } + // Only include if at least one resource file is defined + rf := feature.ResourceFiles + if rf.BundleVariables != "" || rf.BundleResources != "" || + rf.TargetVariables != "" || rf.AppEnv != "" || + rf.DotEnv != "" || rf.DotEnvExample != "" { + resources = append(resources, rf) + } + } + + return resources +} diff --git a/libs/apps/features/features_test.go b/libs/apps/features/features_test.go new file mode 100644 index 0000000000..dfd2bb2f84 --- /dev/null +++ b/libs/apps/features/features_test.go @@ -0,0 +1,453 @@ +package features + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateFeatureIDs(t *testing.T) { + tests := []struct { + name string + featureIDs []string + expectError bool + errorMsg string + }{ + { + name: "valid feature - analytics", + featureIDs: []string{"analytics"}, + expectError: false, + }, + { + name: "empty feature list", + featureIDs: []string{}, + expectError: false, + }, + { + name: "nil feature list", + featureIDs: nil, + expectError: false, + }, + { + name: "unknown feature", + featureIDs: []string{"unknown-feature"}, + expectError: true, + errorMsg: "unknown feature", + }, + { + name: "mix of valid and invalid", + featureIDs: []string{"analytics", "invalid"}, + expectError: true, + errorMsg: "unknown feature", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateFeatureIDs(tt.featureIDs) + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateFeatureDependencies(t *testing.T) { + tests := []struct { + name string + featureIDs []string + flagValues map[string]string + expectError bool + errorMsg string + }{ + { + name: "analytics with warehouse provided", + featureIDs: []string{"analytics"}, + flagValues: map[string]string{"warehouse-id": "abc123"}, + expectError: false, + }, + { + name: "analytics without warehouse", + featureIDs: []string{"analytics"}, + flagValues: map[string]string{}, + expectError: true, + errorMsg: "--warehouse-id", + }, + { + name: "analytics with empty warehouse", + featureIDs: []string{"analytics"}, + flagValues: map[string]string{"warehouse-id": ""}, + expectError: true, + errorMsg: "--warehouse-id", + }, + { + name: "no features - no dependencies needed", + featureIDs: []string{}, + flagValues: map[string]string{}, + expectError: false, + }, + { + name: "unknown feature - gracefully ignored", + featureIDs: []string{"unknown"}, + flagValues: map[string]string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateFeatureDependencies(tt.featureIDs, tt.flagValues) + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestGetFeatureIDs(t *testing.T) { + ids := GetFeatureIDs() + + assert.NotEmpty(t, ids) + assert.Contains(t, ids, "analytics") +} + +func TestBuildPluginStrings(t *testing.T) { + tests := []struct { + name string + featureIDs []string + expectedImport string + expectedUsage string + }{ + { + name: "no features", + featureIDs: []string{}, + expectedImport: "", + expectedUsage: "", + }, + { + name: "nil features", + featureIDs: nil, + expectedImport: "", + expectedUsage: "", + }, + { + name: "analytics feature", + featureIDs: []string{"analytics"}, + expectedImport: "analytics", + expectedUsage: "analytics()", + }, + { + name: "unknown feature - ignored", + featureIDs: []string{"unknown"}, + expectedImport: "", + expectedUsage: "", + }, + { + name: "mix of known and unknown", + featureIDs: []string{"analytics", "unknown"}, + expectedImport: "analytics", + expectedUsage: "analytics()", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + importStr, usageStr := BuildPluginStrings(tt.featureIDs) + assert.Equal(t, tt.expectedImport, importStr) + assert.Equal(t, tt.expectedUsage, usageStr) + }) + } +} + +func TestCollectDependencies(t *testing.T) { + tests := []struct { + name string + featureIDs []string + expectedDeps int + expectedIDs []string + }{ + { + name: "no features", + featureIDs: []string{}, + expectedDeps: 0, + expectedIDs: nil, + }, + { + name: "analytics feature", + featureIDs: []string{"analytics"}, + expectedDeps: 1, + expectedIDs: []string{"sql_warehouse_id"}, + }, + { + name: "unknown feature", + featureIDs: []string{"unknown"}, + expectedDeps: 0, + expectedIDs: nil, + }, + { + name: "duplicate features - deduped deps", + featureIDs: []string{"analytics", "analytics"}, + expectedDeps: 1, + expectedIDs: []string{"sql_warehouse_id"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + deps := CollectDependencies(tt.featureIDs) + assert.Len(t, deps, tt.expectedDeps) + + if tt.expectedIDs != nil { + for i, expectedID := range tt.expectedIDs { + assert.Equal(t, expectedID, deps[i].ID) + } + } + }) + } +} + +func TestCollectResourceFiles(t *testing.T) { + tests := []struct { + name string + featureIDs []string + expectedResources int + }{ + { + name: "no features", + featureIDs: []string{}, + expectedResources: 0, + }, + { + name: "analytics feature", + featureIDs: []string{"analytics"}, + expectedResources: 1, + }, + { + name: "unknown feature", + featureIDs: []string{"unknown"}, + expectedResources: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resources := CollectResourceFiles(tt.featureIDs) + assert.Len(t, resources, tt.expectedResources) + + if tt.expectedResources > 0 && tt.featureIDs[0] == "analytics" { + assert.NotEmpty(t, resources[0].BundleVariables) + assert.NotEmpty(t, resources[0].BundleResources) + } + }) + } +} + +func TestDetectPluginsFromServer(t *testing.T) { + tests := []struct { + name string + serverContent string + expectedPlugins []string + }{ + { + name: "analytics plugin", + serverContent: `import { createApp, server, analytics } from '@databricks/appkit'; +createApp({ + plugins: [ + server(), + analytics(), + ], +}).catch(console.error);`, + expectedPlugins: []string{"analytics"}, + }, + { + name: "analytics with other plugins not in AvailableFeatures", + serverContent: `import { createApp, server, analytics, genie } from '@databricks/appkit'; +createApp({ + plugins: [ + server(), + analytics(), + genie(), + ], +}).catch(console.error);`, + expectedPlugins: []string{"analytics"}, // Only analytics is detected since genie is not in AvailableFeatures + }, + { + name: "no recognized plugins", + serverContent: `import { createApp, server } from '@databricks/appkit';`, + expectedPlugins: nil, + }, + { + name: "plugin not in AvailableFeatures", + serverContent: `createApp({ + plugins: [oauth()], +});`, + expectedPlugins: nil, // oauth is not in AvailableFeatures, so not detected + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temp dir with server file + tempDir := t.TempDir() + serverDir := tempDir + "/src/server" + require.NoError(t, os.MkdirAll(serverDir, 0o755)) + require.NoError(t, os.WriteFile(serverDir+"/index.ts", []byte(tt.serverContent), 0o644)) + + plugins, err := DetectPluginsFromServer(tempDir) + require.NoError(t, err) + assert.Equal(t, tt.expectedPlugins, plugins) + }) + } +} + +func TestDetectPluginsFromServerAlternatePath(t *testing.T) { + // Test server/server.ts path (common in some templates) + tempDir := t.TempDir() + serverDir := tempDir + "/server" + require.NoError(t, os.MkdirAll(serverDir, 0o755)) + + serverContent := `import { createApp, server, analytics } from '@databricks/appkit'; +createApp({ + plugins: [ + server(), + analytics(), + ], +}).catch(console.error);` + + require.NoError(t, os.WriteFile(serverDir+"/server.ts", []byte(serverContent), 0o644)) + + plugins, err := DetectPluginsFromServer(tempDir) + require.NoError(t, err) + assert.Equal(t, []string{"analytics"}, plugins) +} + +func TestDetectPluginsFromServerNoFile(t *testing.T) { + tempDir := t.TempDir() + plugins, err := DetectPluginsFromServer(tempDir) + require.NoError(t, err) + assert.Nil(t, plugins) +} + +func TestGetPluginDependencies(t *testing.T) { + tests := []struct { + name string + pluginNames []string + expectedDeps []string + }{ + { + name: "analytics plugin", + pluginNames: []string{"analytics"}, + expectedDeps: []string{"sql_warehouse_id"}, + }, + { + name: "unknown plugin", + pluginNames: []string{"server"}, + expectedDeps: nil, + }, + { + name: "empty plugins", + pluginNames: []string{}, + expectedDeps: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + deps := GetPluginDependencies(tt.pluginNames) + if tt.expectedDeps == nil { + assert.Empty(t, deps) + } else { + assert.Len(t, deps, len(tt.expectedDeps)) + for i, dep := range deps { + assert.Equal(t, tt.expectedDeps[i], dep.ID) + } + } + }) + } +} + +func TestHasFeaturesDirectory(t *testing.T) { + // Test with features directory + tempDir := t.TempDir() + require.NoError(t, os.MkdirAll(tempDir+"/features", 0o755)) + assert.True(t, HasFeaturesDirectory(tempDir)) + + // Test without features directory + tempDir2 := t.TempDir() + assert.False(t, HasFeaturesDirectory(tempDir2)) +} + +func TestMapPluginsToFeatures(t *testing.T) { + tests := []struct { + name string + pluginNames []string + expectedFeatures []string + }{ + { + name: "analytics plugin maps to analytics feature", + pluginNames: []string{"analytics"}, + expectedFeatures: []string{"analytics"}, + }, + { + name: "unknown plugin", + pluginNames: []string{"server", "unknown"}, + expectedFeatures: nil, + }, + { + name: "empty plugins", + pluginNames: []string{}, + expectedFeatures: nil, + }, + { + name: "duplicate plugins", + pluginNames: []string{"analytics", "analytics"}, + expectedFeatures: []string{"analytics"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + features := MapPluginsToFeatures(tt.pluginNames) + if tt.expectedFeatures == nil { + assert.Empty(t, features) + } else { + assert.Equal(t, tt.expectedFeatures, features) + } + }) + } +} + +func TestPluginPatternGeneration(t *testing.T) { + // Test that the plugin pattern is dynamically generated from AvailableFeatures + // This ensures new features with PluginImport are automatically detected + + // Get all plugin imports from AvailableFeatures + var expectedPlugins []string + for _, f := range AvailableFeatures { + if f.PluginImport != "" { + expectedPlugins = append(expectedPlugins, f.PluginImport) + } + } + + // Test that each plugin is matched by the pattern + for _, plugin := range expectedPlugins { + testCode := fmt.Sprintf("plugins: [%s()]", plugin) + matches := pluginPattern.FindAllStringSubmatch(testCode, -1) + assert.NotEmpty(t, matches, "Pattern should match plugin: %s", plugin) + assert.Equal(t, plugin, matches[0][1], "Captured group should be plugin name: %s", plugin) + } + + // Test that non-plugin function calls are not matched + testCode := "const x = someOtherFunction()" + matches := pluginPattern.FindAllStringSubmatch(testCode, -1) + assert.Empty(t, matches, "Pattern should not match non-plugin functions") +} diff --git a/libs/apps/prompt/prompt.go b/libs/apps/prompt/prompt.go new file mode 100644 index 0000000000..8ff2b19a61 --- /dev/null +++ b/libs/apps/prompt/prompt.go @@ -0,0 +1,601 @@ +package prompt + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "regexp" + "strconv" + "time" + + "github.com/briandowns/spinner" + "github.com/charmbracelet/huh" + "github.com/charmbracelet/lipgloss" + "github.com/databricks/cli/libs/apps/features" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/listing" + "github.com/databricks/databricks-sdk-go/service/apps" + "github.com/databricks/databricks-sdk-go/service/sql" +) + +// DefaultAppDescription is the default description for new apps. +const DefaultAppDescription = "A Databricks App powered by AppKit" + +// AppkitTheme returns a custom theme for appkit prompts. +func AppkitTheme() *huh.Theme { + t := huh.ThemeBase() + + // Databricks brand colors + red := lipgloss.Color("#BD2B26") + gray := lipgloss.Color("#71717A") // Mid-tone gray, readable on light and dark + yellow := lipgloss.Color("#FFAB00") + + t.Focused.Title = t.Focused.Title.Foreground(red).Bold(true) + t.Focused.Description = t.Focused.Description.Foreground(gray) + t.Focused.SelectedOption = t.Focused.SelectedOption.Foreground(yellow) + t.Focused.TextInput.Placeholder = t.Focused.TextInput.Placeholder.Foreground(gray) + + return t +} + +// Styles for printing answered prompts. +var ( + answeredTitleStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("#71717A")) + answeredValueStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("#FFAB00")). + Bold(true) +) + +// PrintAnswered prints a completed prompt answer to keep history visible. +func PrintAnswered(title, value string) { + fmt.Printf("%s %s\n", answeredTitleStyle.Render(title+":"), answeredValueStyle.Render(value)) +} + +// printAnswered is an alias for internal use. +func printAnswered(title, value string) { + PrintAnswered(title, value) +} + +// RunMode specifies how to run the app after creation. +type RunMode string + +const ( + RunModeNone RunMode = "none" + RunModeDev RunMode = "dev" + RunModeDevRemote RunMode = "dev-remote" +) + +// CreateProjectConfig holds the configuration gathered from the interactive prompt. +type CreateProjectConfig struct { + ProjectName string + Description string + Features []string + Dependencies map[string]string // e.g., {"sql_warehouse_id": "abc123"} + Deploy bool // Whether to deploy the app after creation + RunMode RunMode // How to run the app after creation +} + +// App name constraints. +const ( + MaxAppNameLength = 30 + DevTargetPrefix = "dev-" +) + +// projectNamePattern is the compiled regex for validating project names. +// Pre-compiled for efficiency since validation is called on every keystroke. +var projectNamePattern = regexp.MustCompile(`^[a-z][a-z0-9-]*$`) + +// ValidateProjectName validates the project name for length and pattern constraints. +// It checks that the name plus the "dev-" prefix doesn't exceed 30 characters, +// and that the name follows the pattern: starts with a letter, contains only +// lowercase letters, numbers, or hyphens. +func ValidateProjectName(s string) error { + if s == "" { + return errors.New("project name is required") + } + + // Check length constraint (dev- prefix + name <= 30) + totalLength := len(DevTargetPrefix) + len(s) + if totalLength > MaxAppNameLength { + maxAllowed := MaxAppNameLength - len(DevTargetPrefix) + return fmt.Errorf("name too long (max %d chars)", maxAllowed) + } + + // Check pattern + if !projectNamePattern.MatchString(s) { + return errors.New("must start with a letter, use only lowercase letters, numbers, or hyphens") + } + + return nil +} + +// PrintHeader prints the AppKit header banner. +func PrintHeader() { + headerStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#BD2B26")). + Bold(true) + + subtitleStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#71717A")) + + fmt.Println() + fmt.Println(headerStyle.Render("ā—† Create a new Databricks AppKit project")) + fmt.Println(subtitleStyle.Render(" Full-stack TypeScript • React • Tailwind CSS")) + fmt.Println() +} + +// PromptForProjectName prompts only for project name. +// Used as the first step before resolving templates. +// outputDir is used to check if the destination directory already exists. +func PromptForProjectName(outputDir string) (string, error) { + PrintHeader() + theme := AppkitTheme() + + var name string + err := huh.NewInput(). + Title("Project name"). + Description("lowercase letters, numbers, hyphens (max 26 chars)"). + Placeholder("my-app"). + Value(&name). + Validate(func(s string) error { + if err := ValidateProjectName(s); err != nil { + return err + } + destDir := s + if outputDir != "" { + destDir = filepath.Join(outputDir, s) + } + if _, err := os.Stat(destDir); err == nil { + return fmt.Errorf("directory %s already exists", destDir) + } + return nil + }). + WithTheme(theme). + Run() + if err != nil { + return "", err + } + + printAnswered("Project name", name) + return name, nil +} + +// PromptForPluginDependencies prompts for dependencies required by detected plugins. +// Returns a map of dependency ID to value. +func PromptForPluginDependencies(ctx context.Context, deps []features.FeatureDependency) (map[string]string, error) { + theme := AppkitTheme() + result := make(map[string]string) + + for _, dep := range deps { + // Special handling for SQL warehouse - show picker instead of text input + if dep.ID == "sql_warehouse_id" { + warehouseID, err := PromptForWarehouse(ctx) + if err != nil { + return nil, err + } + result[dep.ID] = warehouseID + continue + } + + var value string + description := dep.Description + if !dep.Required { + description += " (optional)" + } + + input := huh.NewInput(). + Title(dep.Title). + Description(description). + Placeholder(dep.Placeholder). + Value(&value) + + if dep.Required { + input = input.Validate(func(s string) error { + if s == "" { + return errors.New("this field is required") + } + return nil + }) + } + + if err := input.WithTheme(theme).Run(); err != nil { + return nil, err + } + printAnswered(dep.Title, value) + result[dep.ID] = value + } + + return result, nil +} + +// PromptForDeployAndRun prompts for post-creation deploy and run options. +func PromptForDeployAndRun() (deploy bool, runMode RunMode, err error) { + theme := AppkitTheme() + + // Deploy after creation? + err = huh.NewConfirm(). + Title("Deploy after creation?"). + Description("Run 'databricks apps deploy' after setup"). + Value(&deploy). + WithTheme(theme). + Run() + if err != nil { + return false, RunModeNone, err + } + if deploy { + printAnswered("Deploy after creation", "Yes") + } else { + printAnswered("Deploy after creation", "No") + } + + // Run the app? + runModeStr := string(RunModeNone) + err = huh.NewSelect[string](). + Title("Run the app after creation?"). + Description("Choose how to start the development server"). + Options( + huh.NewOption("No, I'll run it later", string(RunModeNone)), + huh.NewOption("Yes, run locally (npm run dev)", string(RunModeDev)), + huh.NewOption("Yes, run with remote bridge (dev-remote)", string(RunModeDevRemote)), + ). + Value(&runModeStr). + WithTheme(theme). + Run() + if err != nil { + return false, RunModeNone, err + } + + runModeLabels := map[string]string{ + string(RunModeNone): "No", + string(RunModeDev): "Yes (local)", + string(RunModeDevRemote): "Yes (remote)", + } + printAnswered("Run after creation", runModeLabels[runModeStr]) + + return deploy, RunMode(runModeStr), nil +} + +// PromptForProjectConfig shows an interactive form to gather project configuration. +// Flow: name -> features -> feature dependencies -> description -> deploy/run. +// If preSelectedFeatures is provided, the feature selection prompt is skipped. +func PromptForProjectConfig(ctx context.Context, preSelectedFeatures []string) (*CreateProjectConfig, error) { + config := &CreateProjectConfig{ + Dependencies: make(map[string]string), + Features: preSelectedFeatures, + } + theme := AppkitTheme() + + PrintHeader() + + // Step 1: Project name + err := huh.NewInput(). + Title("Project name"). + Description("lowercase letters, numbers, hyphens (max 26 chars)"). + Placeholder("my-app"). + Value(&config.ProjectName). + Validate(ValidateProjectName). + WithTheme(theme). + Run() + if err != nil { + return nil, err + } + printAnswered("Project name", config.ProjectName) + + // Step 2: Feature selection (skip if features already provided via flag) + if len(config.Features) == 0 && len(features.AvailableFeatures) > 0 { + options := make([]huh.Option[string], 0, len(features.AvailableFeatures)) + for _, f := range features.AvailableFeatures { + label := f.Name + " - " + f.Description + options = append(options, huh.NewOption(label, f.ID)) + } + + err = huh.NewMultiSelect[string](). + Title("Select features"). + Description("space to toggle, enter to confirm"). + Options(options...). + Value(&config.Features). + Height(8). + WithTheme(theme). + Run() + if err != nil { + return nil, err + } + if len(config.Features) == 0 { + printAnswered("Features", "None") + } else { + printAnswered("Features", fmt.Sprintf("%d selected", len(config.Features))) + } + } + + // Step 3: Prompt for feature dependencies + deps := features.CollectDependencies(config.Features) + for _, dep := range deps { + // Special handling for SQL warehouse - show picker instead of text input + if dep.ID == "sql_warehouse_id" { + warehouseID, err := PromptForWarehouse(ctx) + if err != nil { + return nil, err + } + config.Dependencies[dep.ID] = warehouseID + continue + } + + var value string + description := dep.Description + if !dep.Required { + description += " (optional)" + } + + input := huh.NewInput(). + Title(dep.Title). + Description(description). + Placeholder(dep.Placeholder). + Value(&value) + + if dep.Required { + input = input.Validate(func(s string) error { + if s == "" { + return errors.New("this field is required") + } + return nil + }) + } + + if err := input.WithTheme(theme).Run(); err != nil { + return nil, err + } + printAnswered(dep.Title, value) + config.Dependencies[dep.ID] = value + } + + // Step 4: Description + config.Description = DefaultAppDescription + err = huh.NewInput(). + Title("Description"). + Placeholder(DefaultAppDescription). + Value(&config.Description). + WithTheme(theme). + Run() + if err != nil { + return nil, err + } + if config.Description == "" { + config.Description = DefaultAppDescription + } + printAnswered("Description", config.Description) + + // Step 5: Deploy after creation? + err = huh.NewConfirm(). + Title("Deploy after creation?"). + Description("Run 'databricks apps deploy' after setup"). + Value(&config.Deploy). + WithTheme(theme). + Run() + if err != nil { + return nil, err + } + if config.Deploy { + printAnswered("Deploy after creation", "Yes") + } else { + printAnswered("Deploy after creation", "No") + } + + // Step 6: Run the app? + runModeStr := string(RunModeNone) + err = huh.NewSelect[string](). + Title("Run the app after creation?"). + Description("Choose how to start the development server"). + Options( + huh.NewOption("No, I'll run it later", string(RunModeNone)), + huh.NewOption("Yes, run locally (npm run dev)", string(RunModeDev)), + huh.NewOption("Yes, run with remote bridge (dev-remote)", string(RunModeDevRemote)), + ). + Value(&runModeStr). + WithTheme(theme). + Run() + if err != nil { + return nil, err + } + config.RunMode = RunMode(runModeStr) + + runModeLabels := map[string]string{ + string(RunModeNone): "No", + string(RunModeDev): "Yes (local)", + string(RunModeDevRemote): "Yes (remote)", + } + printAnswered("Run after creation", runModeLabels[runModeStr]) + + return config, nil +} + +// ListSQLWarehouses fetches all SQL warehouses the user has access to. +func ListSQLWarehouses(ctx context.Context) ([]sql.EndpointInfo, error) { + w := cmdctx.WorkspaceClient(ctx) + if w == nil { + return nil, errors.New("no workspace client available") + } + + iter := w.Warehouses.List(ctx, sql.ListWarehousesRequest{}) + return listing.ToSlice(ctx, iter) +} + +// PromptForWarehouse shows a picker to select a SQL warehouse. +func PromptForWarehouse(ctx context.Context) (string, error) { + var warehouses []sql.EndpointInfo + err := RunWithSpinnerCtx(ctx, "Fetching SQL warehouses...", func() error { + var fetchErr error + warehouses, fetchErr = ListSQLWarehouses(ctx) + return fetchErr + }) + if err != nil { + return "", fmt.Errorf("failed to fetch SQL warehouses: %w", err) + } + + if len(warehouses) == 0 { + return "", errors.New("no SQL warehouses found. Create one in your workspace first") + } + + theme := AppkitTheme() + + // Build options with warehouse name and state + options := make([]huh.Option[string], 0, len(warehouses)) + warehouseNames := make(map[string]string) // id -> name for printing + for _, wh := range warehouses { + state := string(wh.State) + label := fmt.Sprintf("%s (%s)", wh.Name, state) + options = append(options, huh.NewOption(label, wh.Id)) + warehouseNames[wh.Id] = wh.Name + } + + var selected string + err = huh.NewSelect[string](). + Title("Select SQL Warehouse"). + Description(fmt.Sprintf("%d warehouses available — type to filter", len(warehouses))). + Options(options...). + Value(&selected). + Filtering(true). + Height(8). + WithTheme(theme). + Run() + if err != nil { + return "", err + } + + printAnswered("SQL Warehouse", warehouseNames[selected]) + return selected, nil +} + +// RunWithSpinnerCtx runs a function while showing a spinner with the given title. +// The spinner stops and the function returns early if the context is cancelled. +// Panics in the action are recovered and returned as errors. +func RunWithSpinnerCtx(ctx context.Context, title string, action func() error) error { + s := spinner.New( + spinner.CharSets[14], + 80*time.Millisecond, + spinner.WithColor("yellow"), // Databricks brand color + spinner.WithSuffix(" "+title), + ) + s.Start() + + done := make(chan error, 1) + go func() { + defer func() { + if r := recover(); r != nil { + done <- fmt.Errorf("action panicked: %v", r) + } + }() + done <- action() + }() + + select { + case err := <-done: + s.Stop() + return err + case <-ctx.Done(): + s.Stop() + // Wait for action goroutine to complete to avoid orphaned goroutines. + // For exec.CommandContext, the process is killed when context is cancelled. + <-done + return ctx.Err() + } +} + +// ListAllApps fetches all apps the user has access to from the workspace. +func ListAllApps(ctx context.Context) ([]apps.App, error) { + w := cmdctx.WorkspaceClient(ctx) + if w == nil { + return nil, errors.New("no workspace client available") + } + + iter := w.Apps.List(ctx, apps.ListAppsRequest{}) + return listing.ToSlice(ctx, iter) +} + +// PromptForAppSelection shows a picker to select an existing app. +// Returns the selected app name or error if cancelled/no apps found. +func PromptForAppSelection(ctx context.Context, title string) (string, error) { + if !cmdio.IsPromptSupported(ctx) { + return "", errors.New("--name is required in non-interactive mode") + } + + // Fetch all apps the user has access to + var existingApps []apps.App + err := RunWithSpinnerCtx(ctx, "Fetching apps...", func() error { + var fetchErr error + existingApps, fetchErr = ListAllApps(ctx) + return fetchErr + }) + if err != nil { + return "", fmt.Errorf("failed to fetch apps: %w", err) + } + + if len(existingApps) == 0 { + return "", errors.New("no apps found. Create one first with 'databricks apps create '") + } + + theme := AppkitTheme() + + // Build options + options := make([]huh.Option[string], 0, len(existingApps)) + for _, app := range existingApps { + label := app.Name + if app.Description != "" { + desc := app.Description + if len(desc) > 40 { + desc = desc[:37] + "..." + } + label += " — " + desc + } + options = append(options, huh.NewOption(label, app.Name)) + } + + var selected string + err = huh.NewSelect[string](). + Title(title). + Description(fmt.Sprintf("%d apps found — type to filter", len(existingApps))). + Options(options...). + Value(&selected). + Filtering(true). + Height(8). + WithTheme(theme). + Run() + if err != nil { + return "", err + } + + printAnswered("App", selected) + return selected, nil +} + +// PrintSuccess prints a success message after project creation. +// If showNextSteps is true, also prints the "Next steps" section. +func PrintSuccess(projectName, outputDir string, fileCount int, showNextSteps bool) { + successStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#FFAB00")). // Databricks yellow + Bold(true) + + dimStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#71717A")) // Mid-tone gray + + codeStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#FF3621")) // Databricks orange + + fmt.Println() + fmt.Println(successStyle.Render("āœ” Project created successfully!")) + fmt.Println() + fmt.Println(dimStyle.Render(" Location: " + outputDir)) + fmt.Println(dimStyle.Render(" Files: " + strconv.Itoa(fileCount))) + + if showNextSteps { + fmt.Println() + fmt.Println(dimStyle.Render(" Next steps:")) + fmt.Println() + fmt.Println(codeStyle.Render(" cd " + projectName)) + fmt.Println(codeStyle.Render(" npm run dev")) + } + fmt.Println() +} diff --git a/libs/apps/prompt/prompt_test.go b/libs/apps/prompt/prompt_test.go new file mode 100644 index 0000000000..b48153fbdb --- /dev/null +++ b/libs/apps/prompt/prompt_test.go @@ -0,0 +1,187 @@ +package prompt + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateProjectName(t *testing.T) { + tests := []struct { + name string + projectName string + expectError bool + errorMsg string + }{ + { + name: "valid simple name", + projectName: "my-app", + expectError: false, + }, + { + name: "valid name with numbers", + projectName: "app123", + expectError: false, + }, + { + name: "valid name with hyphens", + projectName: "my-cool-app", + expectError: false, + }, + { + name: "empty name", + projectName: "", + expectError: true, + errorMsg: "required", + }, + { + name: "name too long", + projectName: "this-is-a-very-long-app-name-that-exceeds", + expectError: true, + errorMsg: "too long", + }, + { + name: "name at max length (26 chars)", + projectName: "abcdefghijklmnopqrstuvwxyz", + expectError: false, + }, + { + name: "name starts with number", + projectName: "123app", + expectError: true, + errorMsg: "must start with a letter", + }, + { + name: "name starts with hyphen", + projectName: "-myapp", + expectError: true, + errorMsg: "must start with a letter", + }, + { + name: "name with uppercase", + projectName: "MyApp", + expectError: true, + errorMsg: "lowercase", + }, + { + name: "name with underscore", + projectName: "my_app", + expectError: true, + errorMsg: "lowercase letters, numbers, or hyphens", + }, + { + name: "name with spaces", + projectName: "my app", + expectError: true, + errorMsg: "lowercase letters, numbers, or hyphens", + }, + { + name: "name with special characters", + projectName: "my@app!", + expectError: true, + errorMsg: "lowercase letters, numbers, or hyphens", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateProjectName(tt.projectName) + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestRunWithSpinnerCtx(t *testing.T) { + t.Run("successful action", func(t *testing.T) { + ctx := context.Background() + executed := false + + err := RunWithSpinnerCtx(ctx, "Testing...", func() error { + executed = true + return nil + }) + + assert.NoError(t, err) + assert.True(t, executed) + }) + + t.Run("action returns error", func(t *testing.T) { + ctx := context.Background() + expectedErr := errors.New("action failed") + + err := RunWithSpinnerCtx(ctx, "Testing...", func() error { + return expectedErr + }) + + assert.Equal(t, expectedErr, err) + }) + + t.Run("context cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + actionStarted := make(chan struct{}) + actionDone := make(chan struct{}) + + go func() { + _ = RunWithSpinnerCtx(ctx, "Testing...", func() error { + close(actionStarted) + time.Sleep(100 * time.Millisecond) + close(actionDone) + return nil + }) + }() + + // Wait for action to start + <-actionStarted + // Cancel context + cancel() + // Wait for action to complete (spinner should wait) + <-actionDone + }) + + t.Run("action panics - recovered", func(t *testing.T) { + ctx := context.Background() + + err := RunWithSpinnerCtx(ctx, "Testing...", func() error { + panic("test panic") + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "action panicked") + assert.Contains(t, err.Error(), "test panic") + }) +} + +func TestRunModeConstants(t *testing.T) { + assert.Equal(t, RunModeNone, RunMode("none")) + assert.Equal(t, RunModeDev, RunMode("dev")) + assert.Equal(t, RunModeDevRemote, RunMode("dev-remote")) +} + +func TestMaxAppNameLength(t *testing.T) { + // Verify the constant is set correctly + assert.Equal(t, 30, MaxAppNameLength) + assert.Equal(t, "dev-", DevTargetPrefix) + + // Max allowed name length should be 30 - 4 ("dev-") = 26 + maxAllowed := MaxAppNameLength - len(DevTargetPrefix) + assert.Equal(t, 26, maxAllowed) + + // Test at boundary + validName := "abcdefghijklmnopqrstuvwxyz" // 26 chars + assert.Len(t, validName, 26) + assert.NoError(t, ValidateProjectName(validName)) + + // Test over boundary + invalidName := "abcdefghijklmnopqrstuvwxyz1" // 27 chars + assert.Len(t, invalidName, 27) + assert.Error(t, ValidateProjectName(invalidName)) +} diff --git a/libs/apps/validation/nodejs.go b/libs/apps/validation/nodejs.go new file mode 100644 index 0000000000..6ac9c929f2 --- /dev/null +++ b/libs/apps/validation/nodejs.go @@ -0,0 +1,158 @@ +package validation + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/briandowns/spinner" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" +) + +// ValidationNodeJs implements validation for Node.js-based projects. +type ValidationNodeJs struct{} + +type validationStep struct { + name string + command string + errorPrefix string + displayName string + skipIf func(workDir string) bool // Optional: skip step if this returns true +} + +func (v *ValidationNodeJs) Validate(ctx context.Context, workDir string) (*ValidateResult, error) { + log.Infof(ctx, "Starting Node.js validation: build + typecheck") + startTime := time.Now() + + cmdio.LogString(ctx, "Validating project...") + + // TODO: these steps could be changed to npx appkit [command] instead if we can determine its an appkit project. + steps := []validationStep{ + { + name: "install", + command: "npm install", + errorPrefix: "Failed to install dependencies", + displayName: "Installing dependencies", + skipIf: hasNodeModules, + }, + { + name: "generate", + command: "npm run typegen --if-present", + errorPrefix: "Failed to run npm typegen", + displayName: "Generating types", + }, + { + name: "ast-grep-lint", + command: "npm run lint:ast-grep --if-present", + errorPrefix: "AST-grep lint found violations", + displayName: "Running AST-grep lint", + }, + { + name: "typecheck", + command: "npm run typecheck --if-present", + errorPrefix: "Failed to run client typecheck", + displayName: "Type checking", + }, + { + name: "build", + command: "npm run build --if-present", + errorPrefix: "Failed to run npm build", + displayName: "Building", + }, + } + + for _, step := range steps { + // Check if step should be skipped + if step.skipIf != nil && step.skipIf(workDir) { + log.Debugf(ctx, "skipping %s (condition met)", step.name) + cmdio.LogString(ctx, "ā­ļø Skipped "+step.displayName) + continue + } + + log.Debugf(ctx, "running %s...", step.name) + + // Run step with spinner + stepStart := time.Now() + var stepErr *ValidationDetail + + s := spinner.New( + spinner.CharSets[14], + 80*time.Millisecond, + spinner.WithColor("yellow"), + spinner.WithSuffix(" "+step.displayName+"..."), + ) + s.Start() + + stepErr = runValidationCommand(ctx, workDir, step.command) + + s.Stop() + stepDuration := time.Since(stepStart) + + if stepErr != nil { + log.Errorf(ctx, "%s failed (duration: %.1fs)", step.name, stepDuration.Seconds()) + cmdio.LogString(ctx, fmt.Sprintf("āŒ %s failed (%.1fs)", step.displayName, stepDuration.Seconds())) + return &ValidateResult{ + Success: false, + Message: step.errorPrefix, + Details: stepErr, + }, nil + } + + log.Debugf(ctx, "āœ“ %s passed: duration=%.1fs", step.name, stepDuration.Seconds()) + cmdio.LogString(ctx, fmt.Sprintf("āœ… %s (%.1fs)", step.displayName, stepDuration.Seconds())) + } + + totalDuration := time.Since(startTime) + log.Infof(ctx, "āœ“ all validation checks passed: total_duration=%.1fs", totalDuration.Seconds()) + + return &ValidateResult{ + Success: true, + Message: fmt.Sprintf("All validation checks passed (%.1fs)", totalDuration.Seconds()), + }, nil +} + +// hasNodeModules returns true if node_modules directory exists in the workDir. +func hasNodeModules(workDir string) bool { + nodeModules := filepath.Join(workDir, "node_modules") + info, err := os.Stat(nodeModules) + return err == nil && info.IsDir() +} + +// runValidationCommand executes a shell command in the specified directory. +func runValidationCommand(ctx context.Context, workDir, command string) *ValidationDetail { + cmd := exec.CommandContext(ctx, "sh", "-c", command) + cmd.Dir = workDir + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + exitCode := 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } else { + return &ValidationDetail{ + ExitCode: -1, + Stdout: stdout.String(), + Stderr: fmt.Sprintf("Failed to execute command: %v\nStderr: %s", err, stderr.String()), + } + } + } + + if exitCode != 0 { + return &ValidationDetail{ + ExitCode: exitCode, + Stdout: stdout.String(), + Stderr: stderr.String(), + } + } + + return nil +} diff --git a/libs/apps/validation/validation.go b/libs/apps/validation/validation.go new file mode 100644 index 0000000000..804b725e02 --- /dev/null +++ b/libs/apps/validation/validation.go @@ -0,0 +1,69 @@ +package validation + +import ( + "context" + "fmt" + "os" + "path/filepath" +) + +// ValidationDetail contains detailed output from a failed validation. +type ValidationDetail struct { + ExitCode int `json:"exit_code"` + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` +} + +func (vd *ValidationDetail) Error() string { + return fmt.Sprintf("validation failed (exit code %d)\nStdout:\n%s\nStderr:\n%s", + vd.ExitCode, vd.Stdout, vd.Stderr) +} + +// ValidateResult contains the outcome of a validation operation. +type ValidateResult struct { + Success bool `json:"success"` + Message string `json:"message"` + Details *ValidationDetail `json:"details,omitempty"` + ProgressLog []string `json:"progress_log,omitempty"` +} + +func (vr *ValidateResult) String() string { + var result string + + if len(vr.ProgressLog) > 0 { + result = "Validation Progress:\n" + for _, entry := range vr.ProgressLog { + result += entry + "\n" + } + result += "\n" + } + + if vr.Success { + result += "āœ… " + vr.Message + } else { + result += "āŒ " + vr.Message + if vr.Details != nil { + result += fmt.Sprintf("\n\nExit code: %d\n\nStdout:\n%s\n\nStderr:\n%s", + vr.Details.ExitCode, vr.Details.Stdout, vr.Details.Stderr) + } + } + + return result +} + +// Validation defines the interface for project validation strategies. +type Validation interface { + Validate(ctx context.Context, workDir string) (*ValidateResult, error) +} + +// GetProjectValidator returns the appropriate validator based on project type. +// Returns nil if no validator is applicable. +func GetProjectValidator(workDir string) Validation { + // Check for Node.js project (package.json exists) + packageJSON := filepath.Join(workDir, "package.json") + if _, err := os.Stat(packageJSON); err == nil { + return &ValidationNodeJs{} + } + // TODO: Extend this with other project types as needed (e.g. python, etc.) + return nil +} diff --git a/cmd/workspace/apps/vite_bridge.go b/libs/apps/vite/bridge.go similarity index 80% rename from cmd/workspace/apps/vite_bridge.go rename to libs/apps/vite/bridge.go index 288d6b46d8..35abea5d30 100644 --- a/cmd/workspace/apps/vite_bridge.go +++ b/libs/apps/vite/bridge.go @@ -1,8 +1,9 @@ -package apps +package vite import ( "bufio" "context" + _ "embed" "encoding/json" "errors" "fmt" @@ -23,6 +24,9 @@ import ( "golang.org/x/sync/errgroup" ) +//go:embed server.js +var ServerScript []byte + const ( localViteURL = "http://localhost:%d" localViteHMRURL = "ws://localhost:%d/dev-hmr" @@ -38,12 +42,17 @@ const ( httpIdleConnTimeout = 90 * time.Second // Bridge operation timeouts - bridgeFetchTimeout = 30 * time.Second - bridgeConnTimeout = 60 * time.Second - bridgeTunnelReadyTimeout = 30 * time.Second + BridgeFetchTimeout = 30 * time.Second + BridgeConnTimeout = 60 * time.Second + BridgeTunnelReadyTimeout = 30 * time.Second + + // Retry configuration + tunnelConnectMaxRetries = 10 + tunnelConnectInitialBackoff = 2 * time.Second + tunnelConnectMaxBackoff = 30 * time.Second ) -type ViteBridgeMessage struct { +type BridgeMessage struct { Type string `json:"type"` TunnelID string `json:"tunnelId,omitempty"` Path string `json:"path,omitempty"` @@ -65,7 +74,7 @@ type prioritizedMessage struct { priority int // 0 = high (HMR), 1 = normal (fetch) } -type ViteBridge struct { +type Bridge struct { ctx context.Context w *databricks.WorkspaceClient appName string @@ -76,11 +85,13 @@ type ViteBridge struct { stopChan chan struct{} stopOnce sync.Once httpClient *http.Client - connectionRequests chan *ViteBridgeMessage + connectionRequests chan *BridgeMessage port int + keepaliveDone chan struct{} // Signals keepalive goroutine to stop on reconnect + keepaliveMu sync.Mutex // Protects keepaliveDone } -func NewViteBridge(ctx context.Context, w *databricks.WorkspaceClient, appName string, port int) *ViteBridge { +func NewBridge(ctx context.Context, w *databricks.WorkspaceClient, appName string, port int) *Bridge { // Configure HTTP client optimized for local high-volume requests transport := &http.Transport{ MaxIdleConns: 100, @@ -90,7 +101,7 @@ func NewViteBridge(ctx context.Context, w *databricks.WorkspaceClient, appName s DisableCompression: false, } - return &ViteBridge{ + return &Bridge{ ctx: ctx, w: w, appName: appName, @@ -100,12 +111,12 @@ func NewViteBridge(ctx context.Context, w *databricks.WorkspaceClient, appName s }, stopChan: make(chan struct{}), tunnelWriteChan: make(chan prioritizedMessage, 100), // Buffered channel for async writes - connectionRequests: make(chan *ViteBridgeMessage, 10), + connectionRequests: make(chan *BridgeMessage, 10), port: port, } } -func (vb *ViteBridge) getAuthHeaders(wsURL string) (http.Header, error) { +func (vb *Bridge) getAuthHeaders(wsURL string) (http.Header, error) { req, err := http.NewRequestWithContext(vb.ctx, "GET", wsURL, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) @@ -119,7 +130,7 @@ func (vb *ViteBridge) getAuthHeaders(wsURL string) (http.Header, error) { return req.Header, nil } -func (vb *ViteBridge) GetAppDomain() (*url.URL, error) { +func (vb *Bridge) GetAppDomain() (*url.URL, error) { app, err := vb.w.Apps.Get(vb.ctx, apps.GetAppRequest{ Name: vb.appName, }) @@ -134,7 +145,7 @@ func (vb *ViteBridge) GetAppDomain() (*url.URL, error) { return url.Parse(app.Url) } -func (vb *ViteBridge) connectToTunnel(appDomain *url.URL) error { +func (vb *Bridge) connectToTunnel(appDomain *url.URL) error { wsURL := fmt.Sprintf("wss://%s/dev-tunnel", appDomain.Host) headers, err := vb.getAuthHeaders(wsURL) @@ -189,13 +200,71 @@ func (vb *ViteBridge) connectToTunnel(appDomain *url.URL) error { vb.tunnelConn = conn - // Start keepalive ping goroutine - go vb.tunnelKeepalive() + // Start keepalive ping goroutine (stop existing one first if any) + vb.keepaliveMu.Lock() + if vb.keepaliveDone != nil { + close(vb.keepaliveDone) + } + vb.keepaliveDone = make(chan struct{}) + keepaliveDone := vb.keepaliveDone + vb.keepaliveMu.Unlock() + + go vb.tunnelKeepalive(keepaliveDone) return nil } -func (vb *ViteBridge) connectToViteHMR() error { +// ConnectToTunnelWithRetry attempts to connect to the tunnel with exponential backoff. +// This handles cases where the app isn't fully ready yet (e.g., right after deployment). +func (vb *Bridge) ConnectToTunnelWithRetry(appDomain *url.URL) error { + var lastErr error + backoff := tunnelConnectInitialBackoff + + for attempt := 1; attempt <= tunnelConnectMaxRetries; attempt++ { + err := vb.connectToTunnel(appDomain) + if err == nil { + if attempt > 1 { + cmdio.LogString(vb.ctx, "āœ… Connected to tunnel successfully!") + } + return nil + } + + lastErr = err + + // Check if context is cancelled + select { + case <-vb.ctx.Done(): + return vb.ctx.Err() + default: + } + + // Don't retry on the last attempt + if attempt == tunnelConnectMaxRetries { + break + } + + // Log retry attempt + cmdio.LogString(vb.ctx, fmt.Sprintf("ā³ Connection attempt %d/%d failed, retrying in %v...", attempt, tunnelConnectMaxRetries, backoff)) + log.Debugf(vb.ctx, "[vite_bridge] Connection error: %v", err) + + // Wait before retrying + select { + case <-time.After(backoff): + case <-vb.ctx.Done(): + return vb.ctx.Err() + } + + // Exponential backoff with cap + backoff = time.Duration(float64(backoff) * 1.5) + if backoff > tunnelConnectMaxBackoff { + backoff = tunnelConnectMaxBackoff + } + } + + return fmt.Errorf("failed to connect after %d attempts: %w", tunnelConnectMaxRetries, lastErr) +} + +func (vb *Bridge) connectToViteHMR() error { dialer := websocket.Dialer{ Subprotocols: []string{viteHMRProtocol}, } @@ -216,14 +285,17 @@ func (vb *ViteBridge) connectToViteHMR() error { return nil } -// tunnelKeepalive sends periodic pings to keep the connection alive -// Remote servers often have 30-60s idle timeouts -func (vb *ViteBridge) tunnelKeepalive() { +// tunnelKeepalive sends periodic pings to keep the connection alive. +// Remote servers often have 30-60s idle timeouts. +// The done channel is used to stop this goroutine on reconnect. +func (vb *Bridge) tunnelKeepalive(done <-chan struct{}) { ticker := time.NewTicker(wsKeepaliveInterval) defer ticker.Stop() for { select { + case <-done: + return case <-vb.stopChan: return case <-ticker.C: @@ -244,7 +316,7 @@ func (vb *ViteBridge) tunnelKeepalive() { // tunnelWriter handles all writes to the tunnel websocket in a single goroutine // This eliminates mutex contention and ensures ordered delivery -func (vb *ViteBridge) tunnelWriter(ctx context.Context) error { +func (vb *Bridge) tunnelWriter(ctx context.Context) error { for { select { case <-ctx.Done(): @@ -260,7 +332,7 @@ func (vb *ViteBridge) tunnelWriter(ctx context.Context) error { } } -func (vb *ViteBridge) handleTunnelMessages(ctx context.Context) error { +func (vb *Bridge) handleTunnelMessages(ctx context.Context) error { for { select { case <-ctx.Done(): @@ -273,15 +345,14 @@ func (vb *ViteBridge) handleTunnelMessages(ctx context.Context) error { _, message, err := vb.tunnelConn.ReadMessage() if err != nil { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure) { - log.Infof(vb.ctx, "[vite_bridge] Tunnel closed, reconnecting...") - time.Sleep(time.Second) + cmdio.LogString(vb.ctx, "šŸ”„ Tunnel closed, reconnecting...") appDomain, err := vb.GetAppDomain() if err != nil { return fmt.Errorf("failed to get app domain for reconnection: %w", err) } - if err := vb.connectToTunnel(appDomain); err != nil { + if err := vb.ConnectToTunnelWithRetry(appDomain); err != nil { return fmt.Errorf("failed to reconnect to tunnel: %w", err) } continue @@ -292,7 +363,7 @@ func (vb *ViteBridge) handleTunnelMessages(ctx context.Context) error { // Debug: Log raw message log.Debugf(vb.ctx, "[vite_bridge] Raw message: %s", string(message)) - var msg ViteBridgeMessage + var msg BridgeMessage if err := json.Unmarshal(message, &msg); err != nil { log.Errorf(vb.ctx, "[vite_bridge] Failed to parse message: %v", err) continue @@ -307,7 +378,7 @@ func (vb *ViteBridge) handleTunnelMessages(ctx context.Context) error { } } -func (vb *ViteBridge) handleMessage(msg *ViteBridgeMessage) error { +func (vb *Bridge) handleMessage(msg *BridgeMessage) error { switch msg.Type { case "tunnel:ready": vb.tunnelID = msg.TunnelID @@ -319,7 +390,7 @@ func (vb *ViteBridge) handleMessage(msg *ViteBridgeMessage) error { return nil case "fetch": - go func(fetchMsg ViteBridgeMessage) { + go func(fetchMsg BridgeMessage) { if err := vb.handleFetchRequest(&fetchMsg); err != nil { log.Errorf(vb.ctx, "[vite_bridge] Error handling fetch request for %s: %v", fetchMsg.Path, err) } @@ -328,7 +399,7 @@ func (vb *ViteBridge) handleMessage(msg *ViteBridgeMessage) error { case "file:read": // Handle file read requests in parallel like fetch requests - go func(fileReadMsg ViteBridgeMessage) { + go func(fileReadMsg BridgeMessage) { if err := vb.handleFileReadRequest(&fileReadMsg); err != nil { log.Errorf(vb.ctx, "[vite_bridge] Error handling file read request for %s: %v", fileReadMsg.Path, err) } @@ -344,7 +415,7 @@ func (vb *ViteBridge) handleMessage(msg *ViteBridgeMessage) error { } } -func (vb *ViteBridge) handleConnectionRequest(msg *ViteBridgeMessage) error { +func (vb *Bridge) handleConnectionRequest(msg *BridgeMessage) error { cmdio.LogString(vb.ctx, "") cmdio.LogString(vb.ctx, "šŸ”” Connection Request") cmdio.LogString(vb.ctx, " User: "+msg.Viewer) @@ -370,13 +441,13 @@ func (vb *ViteBridge) handleConnectionRequest(msg *ViteBridgeMessage) error { approved = strings.ToLower(strings.TrimSpace(input)) == "y" case err := <-errChan: return fmt.Errorf("failed to read user input: %w", err) - case <-time.After(bridgeConnTimeout): + case <-time.After(BridgeConnTimeout): // Default to denying after timeout cmdio.LogString(vb.ctx, "ā±ļø Timeout waiting for response, denying connection") approved = false } - response := ViteBridgeMessage{ + response := BridgeMessage{ Type: "connection:response", RequestID: msg.RequestID, Viewer: msg.Viewer, @@ -408,7 +479,7 @@ func (vb *ViteBridge) handleConnectionRequest(msg *ViteBridgeMessage) error { return nil } -func (vb *ViteBridge) handleFetchRequest(msg *ViteBridgeMessage) error { +func (vb *Bridge) handleFetchRequest(msg *BridgeMessage) error { targetURL := fmt.Sprintf(localViteURL, vb.port) + msg.Path log.Debugf(vb.ctx, "[vite_bridge] Fetch request: %s %s", msg.Method, msg.Path) @@ -437,7 +508,7 @@ func (vb *ViteBridge) handleFetchRequest(msg *ViteBridgeMessage) error { } } - metadataResponse := ViteBridgeMessage{ + metadataResponse := BridgeMessage{ Type: "fetch:response:meta", Path: msg.Path, Status: resp.StatusCode, @@ -456,7 +527,7 @@ func (vb *ViteBridge) handleFetchRequest(msg *ViteBridgeMessage) error { data: responseData, priority: 1, // Normal priority }: - case <-time.After(bridgeFetchTimeout): + case <-time.After(BridgeFetchTimeout): return errors.New("timeout sending fetch metadata") } @@ -467,7 +538,7 @@ func (vb *ViteBridge) handleFetchRequest(msg *ViteBridgeMessage) error { data: body, priority: 1, // Normal priority }: - case <-time.After(bridgeFetchTimeout): + case <-time.After(BridgeFetchTimeout): return errors.New("timeout sending fetch body") } } @@ -480,17 +551,17 @@ const ( allowedExtension = ".sql" ) -func (vb *ViteBridge) handleFileReadRequest(msg *ViteBridgeMessage) error { +func (vb *Bridge) handleFileReadRequest(msg *BridgeMessage) error { log.Debugf(vb.ctx, "[vite_bridge] File read request: %s", msg.Path) - if err := validateFilePath(msg.Path); err != nil { + if err := ValidateFilePath(msg.Path); err != nil { log.Warnf(vb.ctx, "[vite_bridge] File validation failed for %s: %v", msg.Path, err) return vb.sendFileReadError(msg.RequestID, fmt.Sprintf("Invalid file path: %v", err)) } content, err := os.ReadFile(msg.Path) - response := ViteBridgeMessage{ + response := BridgeMessage{ Type: "file:read:response", RequestID: msg.RequestID, } @@ -521,7 +592,7 @@ func (vb *ViteBridge) handleFileReadRequest(msg *ViteBridgeMessage) error { return nil } -func validateFilePath(requestedPath string) error { +func ValidateFilePath(requestedPath string) error { // Clean the path to resolve any ../ or ./ components cleanPath := filepath.Clean(requestedPath) @@ -561,8 +632,8 @@ func validateFilePath(requestedPath string) error { } // Helper to send error response -func (vb *ViteBridge) sendFileReadError(requestID, errorMsg string) error { - response := ViteBridgeMessage{ +func (vb *Bridge) sendFileReadError(requestID, errorMsg string) error { + response := BridgeMessage{ Type: "file:read:response", RequestID: requestID, Error: errorMsg, @@ -586,10 +657,10 @@ func (vb *ViteBridge) sendFileReadError(requestID, errorMsg string) error { return nil } -func (vb *ViteBridge) handleHMRMessage(msg *ViteBridgeMessage) error { +func (vb *Bridge) handleHMRMessage(msg *BridgeMessage) error { log.Debugf(vb.ctx, "[vite_bridge] HMR message received: %s", msg.Body) - response := ViteBridgeMessage{ + response := BridgeMessage{ Type: "hmr:client", Body: msg.Body, } @@ -613,7 +684,7 @@ func (vb *ViteBridge) handleHMRMessage(msg *ViteBridgeMessage) error { return nil } -func (vb *ViteBridge) handleViteHMRMessages(ctx context.Context) error { +func (vb *Bridge) handleViteHMRMessages(ctx context.Context) error { for { select { case <-ctx.Done(): @@ -636,7 +707,7 @@ func (vb *ViteBridge) handleViteHMRMessages(ctx context.Context) error { return err } - response := ViteBridgeMessage{ + response := BridgeMessage{ Type: "hmr:message", Body: string(message), } @@ -659,13 +730,14 @@ func (vb *ViteBridge) handleViteHMRMessages(ctx context.Context) error { } } -func (vb *ViteBridge) Start() error { +func (vb *Bridge) Start() error { appDomain, err := vb.GetAppDomain() if err != nil { return fmt.Errorf("failed to get app domain: %w", err) } - if err := vb.connectToTunnel(appDomain); err != nil { + // Use retry logic for initial connection (app may not be ready yet) + if err := vb.ConnectToTunnelWithRetry(appDomain); err != nil { return err } @@ -678,7 +750,7 @@ func (vb *ViteBridge) Start() error { return } - var msg ViteBridgeMessage + var msg BridgeMessage if err := json.Unmarshal(message, &msg); err != nil { continue } @@ -697,7 +769,7 @@ func (vb *ViteBridge) Start() error { if err != nil { return fmt.Errorf("failed waiting for tunnel ready: %w", err) } - case <-time.After(bridgeTunnelReadyTimeout): + case <-time.After(BridgeTunnelReadyTimeout): return errors.New("timeout waiting for tunnel ready") } @@ -753,7 +825,7 @@ func (vb *ViteBridge) Start() error { return g.Wait() } -func (vb *ViteBridge) Stop() { +func (vb *Bridge) Stop() { vb.stopOnce.Do(func() { close(vb.stopChan) diff --git a/cmd/workspace/apps/vite_bridge_test.go b/libs/apps/vite/bridge_test.go similarity index 89% rename from cmd/workspace/apps/vite_bridge_test.go rename to libs/apps/vite/bridge_test.go index 8d5f5c3f8d..60f0aecf00 100644 --- a/cmd/workspace/apps/vite_bridge_test.go +++ b/libs/apps/vite/bridge_test.go @@ -1,4 +1,4 @@ -package apps +package vite import ( "context" @@ -83,7 +83,7 @@ func TestValidateFilePath(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validateFilePath(tt.path) + err := ValidateFilePath(tt.path) if tt.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) @@ -94,21 +94,21 @@ func TestValidateFilePath(t *testing.T) { } } -func TestViteBridgeMessageSerialization(t *testing.T) { +func TestBridgeMessageSerialization(t *testing.T) { tests := []struct { name string - msg ViteBridgeMessage + msg BridgeMessage }{ { name: "tunnel ready message", - msg: ViteBridgeMessage{ + msg: BridgeMessage{ Type: "tunnel:ready", TunnelID: "test-tunnel-123", }, }, { name: "fetch request message", - msg: ViteBridgeMessage{ + msg: BridgeMessage{ Type: "fetch", Path: "/src/components/ui/card.tsx", Method: "GET", @@ -117,7 +117,7 @@ func TestViteBridgeMessageSerialization(t *testing.T) { }, { name: "connection request message", - msg: ViteBridgeMessage{ + msg: BridgeMessage{ Type: "connection:request", Viewer: "user@example.com", RequestID: "req-456", @@ -125,7 +125,7 @@ func TestViteBridgeMessageSerialization(t *testing.T) { }, { name: "fetch response with headers", - msg: ViteBridgeMessage{ + msg: BridgeMessage{ Type: "fetch:response:meta", Status: 200, Headers: map[string]any{ @@ -141,7 +141,7 @@ func TestViteBridgeMessageSerialization(t *testing.T) { data, err := json.Marshal(tt.msg) require.NoError(t, err) - var decoded ViteBridgeMessage + var decoded BridgeMessage err = json.Unmarshal(data, &decoded) require.NoError(t, err) @@ -154,21 +154,21 @@ func TestViteBridgeMessageSerialization(t *testing.T) { } } -func TestViteBridgeHandleMessage(t *testing.T) { +func TestBridgeHandleMessage(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) w := &databricks.WorkspaceClient{} - vb := NewViteBridge(ctx, w, "test-app", 5173) + vb := NewBridge(ctx, w, "test-app", 5173) tests := []struct { name string - msg *ViteBridgeMessage + msg *BridgeMessage expectError bool }{ { name: "tunnel ready message", - msg: &ViteBridgeMessage{ + msg: &BridgeMessage{ Type: "tunnel:ready", TunnelID: "tunnel-123", }, @@ -176,7 +176,7 @@ func TestViteBridgeHandleMessage(t *testing.T) { }, { name: "unknown message type", - msg: &ViteBridgeMessage{ + msg: &BridgeMessage{ Type: "unknown:type", }, expectError: false, @@ -199,7 +199,7 @@ func TestViteBridgeHandleMessage(t *testing.T) { } } -func TestViteBridgeHandleFileReadRequest(t *testing.T) { +func TestBridgeHandleFileReadRequest(t *testing.T) { // Create a temporary directory structure tmpDir := t.TempDir() oldWd, err := os.Getwd() @@ -250,12 +250,12 @@ func TestViteBridgeHandleFileReadRequest(t *testing.T) { defer resp.Body.Close() defer conn.Close() - vb := NewViteBridge(ctx, w, "test-app", 5173) + vb := NewBridge(ctx, w, "test-app", 5173) vb.tunnelConn = conn go func() { _ = vb.tunnelWriter(ctx) }() - msg := &ViteBridgeMessage{ + msg := &BridgeMessage{ Type: "file:read", Path: "config/queries/test_query.sql", RequestID: "req-123", @@ -268,7 +268,7 @@ func TestViteBridgeHandleFileReadRequest(t *testing.T) { time.Sleep(100 * time.Millisecond) // Parse the response - var response ViteBridgeMessage + var response BridgeMessage err = json.Unmarshal(lastMessage, &response) require.NoError(t, err) @@ -307,12 +307,12 @@ func TestViteBridgeHandleFileReadRequest(t *testing.T) { defer resp.Body.Close() defer conn.Close() - vb := NewViteBridge(ctx, w, "test-app", 5173) + vb := NewBridge(ctx, w, "test-app", 5173) vb.tunnelConn = conn go func() { _ = vb.tunnelWriter(ctx) }() - msg := &ViteBridgeMessage{ + msg := &BridgeMessage{ Type: "file:read", Path: "config/queries/nonexistent.sql", RequestID: "req-456", @@ -324,7 +324,7 @@ func TestViteBridgeHandleFileReadRequest(t *testing.T) { // Give the message time to be sent time.Sleep(100 * time.Millisecond) - var response ViteBridgeMessage + var response BridgeMessage err = json.Unmarshal(lastMessage, &response) require.NoError(t, err) @@ -334,11 +334,11 @@ func TestViteBridgeHandleFileReadRequest(t *testing.T) { }) } -func TestViteBridgeStop(t *testing.T) { +func TestBridgeStop(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) w := &databricks.WorkspaceClient{} - vb := NewViteBridge(ctx, w, "test-app", 5173) + vb := NewBridge(ctx, w, "test-app", 5173) // Call Stop multiple times to ensure it's idempotent vb.Stop() @@ -354,12 +354,12 @@ func TestViteBridgeStop(t *testing.T) { } } -func TestNewViteBridge(t *testing.T) { +func TestNewBridge(t *testing.T) { ctx := context.Background() w := &databricks.WorkspaceClient{} appName := "test-app" - vb := NewViteBridge(ctx, w, appName, 5173) + vb := NewBridge(ctx, w, appName, 5173) assert.NotNil(t, vb) assert.Equal(t, appName, vb.appName) diff --git a/cmd/workspace/apps/vite-server.js b/libs/apps/vite/server.js similarity index 100% rename from cmd/workspace/apps/vite-server.js rename to libs/apps/vite/server.js