diff --git a/superset-frontend/cypress-base/cypress/e2e/database/modal.test.ts b/superset-frontend/cypress-base/cypress/e2e/database/modal.test.ts index 340fb7392c54..44a02d051bbe 100644 --- a/superset-frontend/cypress-base/cypress/e2e/database/modal.test.ts +++ b/superset-frontend/cypress-base/cypress/e2e/database/modal.test.ts @@ -63,56 +63,57 @@ describe('Add database', () => { it('show error alerts on dynamic form for bad host', () => { cy.get('.preferred > :nth-child(1)').click(); - cy.get('input[name="host"]').type('badhost', { force: true }); - cy.get('input[name="port"]').type('5432', { force: true }); - cy.get('input[name="username"]').type('testusername', { force: true }); - cy.get('input[name="database"]').type('testdb', { force: true }); - cy.get('input[name="password"]').type('testpass', { force: true }); - - cy.get('body').click(0, 0); - + cy.get('input[name="host"]').type('badhost', { force: true }).blur(); + cy.wait('@validateParams', { timeout: 30000 }); + cy.get('input[name="port"]').type('5432', { force: true }).blur(); + cy.wait('@validateParams', { timeout: 30000 }); + cy.get('input[name="username"]') + .type('testusername', { force: true }) + .blur(); + cy.wait('@validateParams', { timeout: 30000 }); + cy.get('input[name="database"]').type('testdb', { force: true }).blur(); + cy.wait('@validateParams', { timeout: 30000 }); + cy.get('input[name="password"]').type('testpass', { force: true }).blur(); cy.wait('@validateParams', { timeout: 30000 }); - cy.getBySel('btn-submit-connection').should('not.be.disabled'); + cy.getBySel('btn-submit-connection', { timeout: 60000 }).should( + 'not.be.disabled', + ); cy.getBySel('btn-submit-connection').click({ force: true }); - cy.wait('@validateParams', { timeout: 30000 }).then(() => { - cy.wait('@createDb', { timeout: 60000 }).then(() => { - cy.contains( - '.ant-form-item-explain-error', - "The hostname provided can't be resolved", - ).should('exist'); - }); + cy.wait('@createDb', { timeout: 60000 }).then(() => { + cy.contains( + '.ant-form-item-explain-error', + "The hostname provided can't be resolved", + ).should('exist'); }); }); it('show error alerts on dynamic form for bad port', () => { cy.get('.preferred > :nth-child(1)').click(); - cy.get('input[name="host"]').type('localhost', { force: true }); - cy.get('body').click(0, 0); + cy.get('input[name="host"]').type('localhost', { force: true }).blur(); cy.wait('@validateParams', { timeout: 30000 }); - - cy.get('input[name="port"]').type('5430', { force: true }); - cy.get('input[name="database"]').type('testdb', { force: true }); - cy.get('input[name="username"]').type('testusername', { force: true }); - + cy.get('input[name="port"]').type('5430', { force: true }).blur(); + cy.wait('@validateParams', { timeout: 30000 }); + cy.get('input[name="database"]').type('testdb', { force: true }).blur(); + cy.wait('@validateParams', { timeout: 30000 }); + cy.get('input[name="username"]') + .type('testusername', { force: true }) + .blur(); + cy.wait('@validateParams', { timeout: 30000 }); + cy.get('input[name="password"]').type('testpass', { force: true }).blur(); cy.wait('@validateParams', { timeout: 30000 }); - cy.get('input[name="password"]').type('testpass', { force: true }); - cy.wait('@validateParams'); - - cy.getBySel('btn-submit-connection').should('not.be.disabled'); + cy.getBySel('btn-submit-connection', { timeout: 60000 }).should( + 'not.be.disabled', + ); cy.getBySel('btn-submit-connection').click({ force: true }); - cy.wait('@validateParams', { timeout: 30000 }).then(() => { - cy.get('body').click(0, 0); - cy.getBySel('btn-submit-connection').click({ force: true }); - cy.wait('@createDb', { timeout: 60000 }).then(() => { - cy.contains( - '.ant-form-item-explain-error', - 'The port is closed', - ).should('exist'); - }); + + cy.wait('@createDb', { timeout: 60000 }).then(() => { + cy.contains('.ant-form-item-explain-error', 'The port is closed').should( + 'exist', + ); }); }); }); diff --git a/superset-frontend/packages/superset-ui-core/src/components/Form/LabeledErrorBoundInput.tsx b/superset-frontend/packages/superset-ui-core/src/components/Form/LabeledErrorBoundInput.tsx index 882fabf8b414..2095cca2f03d 100644 --- a/superset-frontend/packages/superset-ui-core/src/components/Form/LabeledErrorBoundInput.tsx +++ b/superset-frontend/packages/superset-ui-core/src/components/Form/LabeledErrorBoundInput.tsx @@ -79,7 +79,7 @@ export const LabeledErrorBoundInput = ({ isValidating ? 'validating' : hasError ? 'error' : 'success' } help={errorMessage || helpText} - hasFeedback={!!hasError} + hasFeedback={isValidating || !!hasError} > {visibilityToggle || props.name === 'password' ? ( ( @@ -250,6 +251,7 @@ export const accessTokenField = ({ id="access_token" name="access_token" required={required} + isValidating={isValidating} visibilityToggle={!isEditMode} value={db?.parameters?.access_token} validationMethods={{ onBlur: getValidation }} diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx index 87456f342442..146aae23fe2d 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx @@ -33,6 +33,7 @@ export const TableCatalog = ({ getValidation, validationErrors, db, + isValidating, }: FieldPropTypes) => { const tableCatalog = db?.catalog || []; const catalogError = validationErrors || {}; @@ -51,6 +52,7 @@ export const TableCatalog = ({ ( `${theme.sizeUnit}px 0 ${theme.sizeUnit * 2}px`}; -`; +interface SSHTunnelFormProps { + db: DatabaseObject | null; + onSSHTunnelParametersChange: CustomEventHandlerType; + setSSHTunnelLoginMethod: (method: AuthType) => void; + isValidating?: boolean; + validationErrors?: JsonObject | null; + getValidation: () => void; +} const SSHTunnelForm = ({ db, onSSHTunnelParametersChange, setSSHTunnelLoginMethod, -}: { - db: DatabaseObject | null; - onSSHTunnelParametersChange: FieldPropTypes['changeMethods']['onSSHTunnelParametersChange']; - setSSHTunnelLoginMethod: (method: AuthType) => void; -}) => { + isValidating = false, + validationErrors, + getValidation, +}: SSHTunnelFormProps) => { const [usePassword, setUsePassword] = useState(AuthType.Password); + const sshErrors = validationErrors?.ssh_tunnel || {}; return (
- - {t('SSH Host')} - - - - {t('SSH Port')} - - @@ -100,15 +112,17 @@ const SSHTunnelForm = ({ - - {t('Username')} - - @@ -148,16 +162,20 @@ const SSHTunnelForm = ({ - - {t('SSH Password')} - - + iconRender={(visible: boolean) => visible ? ( @@ -182,30 +200,47 @@ const SSHTunnelForm = ({ {t('Private Key')} - + + + - - {t('Private Key Password')} - - + iconRender={(visible: boolean) => visible ? ( diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx index 455c4740a260..2a9b259288cc 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx @@ -1212,26 +1212,40 @@ describe('DatabaseModal', () => { 'ssh-tunnel-server_address-input', ); expect(SSHTunnelServerAddressInput).toHaveValue(''); - userEvent.type(SSHTunnelServerAddressInput, 'localhost'); - expect(SSHTunnelServerAddressInput).toHaveValue('localhost'); + fireEvent.change(SSHTunnelServerAddressInput, { + target: { value: 'localhost' }, + }); + await waitFor(() => + expect(SSHTunnelServerAddressInput).toHaveValue('localhost'), + ); const SSHTunnelServerPortInput = screen.getByTestId( 'ssh-tunnel-server_port-input', ); expect(SSHTunnelServerPortInput).toHaveValue(null); - userEvent.type(SSHTunnelServerPortInput, '22'); - expect(SSHTunnelServerPortInput).toHaveValue(22); + fireEvent.change(SSHTunnelServerPortInput, { + target: { value: '22' }, + }); + await waitFor(() => expect(SSHTunnelServerPortInput).toHaveValue(22)); const SSHTunnelUsernameInput = screen.getByTestId( 'ssh-tunnel-username-input', ); expect(SSHTunnelUsernameInput).toHaveValue(''); - userEvent.type(SSHTunnelUsernameInput, 'test'); - expect(SSHTunnelUsernameInput).toHaveValue('test'); + fireEvent.change(SSHTunnelUsernameInput, { + target: { value: 'test' }, + }); + await waitFor(() => + expect(SSHTunnelUsernameInput).toHaveValue('test'), + ); const SSHTunnelPasswordInput = screen.getByTestId( 'ssh-tunnel-password-input', ); expect(SSHTunnelPasswordInput).toHaveValue(''); - userEvent.type(SSHTunnelPasswordInput, 'pass'); - expect(SSHTunnelPasswordInput).toHaveValue('pass'); + fireEvent.change(SSHTunnelPasswordInput, { + target: { value: 'pass' }, + }); + await waitFor(() => + expect(SSHTunnelPasswordInput).toHaveValue('pass'), + ); }); test('properly interacts with SSH Tunnel form textboxes', async () => { @@ -1250,26 +1264,40 @@ describe('DatabaseModal', () => { 'ssh-tunnel-server_address-input', ); expect(SSHTunnelServerAddressInput).toHaveValue(''); - userEvent.type(SSHTunnelServerAddressInput, 'localhost'); - expect(SSHTunnelServerAddressInput).toHaveValue('localhost'); + fireEvent.change(SSHTunnelServerAddressInput, { + target: { value: 'localhost' }, + }); + await waitFor(() => + expect(SSHTunnelServerAddressInput).toHaveValue('localhost'), + ); const SSHTunnelServerPortInput = screen.getByTestId( 'ssh-tunnel-server_port-input', ); expect(SSHTunnelServerPortInput).toHaveValue(null); - userEvent.type(SSHTunnelServerPortInput, '22'); - expect(SSHTunnelServerPortInput).toHaveValue(22); + fireEvent.change(SSHTunnelServerPortInput, { + target: { value: '22' }, + }); + await waitFor(() => expect(SSHTunnelServerPortInput).toHaveValue(22)); const SSHTunnelUsernameInput = screen.getByTestId( 'ssh-tunnel-username-input', ); expect(SSHTunnelUsernameInput).toHaveValue(''); - userEvent.type(SSHTunnelUsernameInput, 'test'); - expect(SSHTunnelUsernameInput).toHaveValue('test'); + fireEvent.change(SSHTunnelUsernameInput, { + target: { value: 'test' }, + }); + await waitFor(() => + expect(SSHTunnelUsernameInput).toHaveValue('test'), + ); const SSHTunnelPasswordInput = screen.getByTestId( 'ssh-tunnel-password-input', ); expect(SSHTunnelPasswordInput).toHaveValue(''); - userEvent.type(SSHTunnelPasswordInput, 'pass'); - expect(SSHTunnelPasswordInput).toHaveValue('pass'); + fireEvent.change(SSHTunnelPasswordInput, { + target: { value: 'pass' }, + }); + await waitFor(() => + expect(SSHTunnelPasswordInput).toHaveValue('pass'), + ); }); test('if the SSH Tunneling toggle is not true, no inputs are displayed', async () => { @@ -1364,7 +1392,10 @@ describe('DatabaseModal', () => { }), ); - const textboxes = screen.getAllByRole('textbox'); + // Wait for step 2 to render + expect(await screen.findByText(/step 2 of 3/i)).toBeInTheDocument(); + + const textboxes = await screen.findAllByRole('textbox'); const hostField = textboxes[0]; const portField = screen.getByRole('spinbutton'); const databaseNameField = textboxes[1]; @@ -1380,15 +1411,20 @@ describe('DatabaseModal', () => { expect(connectButton).toBeDisabled(); - userEvent.type(hostField, 'localhost'); - userEvent.type(portField, '5432'); - userEvent.type(databaseNameField, 'postgres'); - userEvent.type(usernameField, 'testdb'); - userEvent.type(passwordField, 'demoPassword'); + fireEvent.change(hostField, { target: { value: 'localhost' } }); + fireEvent.blur(hostField); + fireEvent.change(portField, { target: { value: '5432' } }); + fireEvent.blur(portField); + fireEvent.change(databaseNameField, { target: { value: 'postgres' } }); + fireEvent.blur(databaseNameField); + fireEvent.change(usernameField, { target: { value: 'testdb' } }); + fireEvent.blur(usernameField); + fireEvent.change(passwordField, { target: { value: 'demoPassword' } }); + fireEvent.blur(passwordField); await waitFor(() => expect(connectButton).toBeEnabled()); - expect(await screen.findByDisplayValue(/5432/i)).toBeInTheDocument(); + await waitFor(() => expect(portField).toHaveValue(5432)); expect(hostField).toHaveValue('localhost'); expect(portField).toHaveValue(5432); expect(databaseNameField).toHaveValue('postgres'); @@ -1397,10 +1433,48 @@ describe('DatabaseModal', () => { expect(connectButton).toBeEnabled(); userEvent.click(connectButton); + // Verify that validation was called during the form interaction + // Note: With the optimized validation, redundant calls on the same db state are skipped + await waitFor(() => { + expect( + fetchMock.callHistory.calls(VALIDATE_PARAMS_ENDPOINT).length, + ).toBeGreaterThan(0); + }); + }); + + test('does not fire redundant validation on blur when db has not changed', async () => { + setup(); + + userEvent.click( + await screen.findByRole('button', { + name: /postgresql/i, + }), + ); + + expect(await screen.findByText(/step 2 of 3/i)).toBeInTheDocument(); + + const textboxes = await screen.findAllByRole('textbox'); + const hostField = textboxes[0]; + + // Type a value and blur - should trigger validation + fireEvent.change(hostField, { target: { value: 'localhost' } }); + fireEvent.blur(hostField); + + await waitFor(() => { + expect( + fetchMock.callHistory.calls(VALIDATE_PARAMS_ENDPOINT).length, + ).toEqual(1); + }); + + // Blur again without changing the value - should NOT trigger another validation + fireEvent.focus(hostField); + fireEvent.blur(hostField); + + // Wait a tick to ensure no additional calls are made await waitFor(() => { expect( fetchMock.callHistory.calls(VALIDATE_PARAMS_ENDPOINT).length, - ).toEqual(5); + ).toEqual(1); }); }); }); diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.tsx index 0919941938ad..bd9a26084af1 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/index.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/index.tsx @@ -617,6 +617,7 @@ const DatabaseModal: FunctionComponent = ({ hasValidated, setHasValidated, ] = useDatabaseValidation(); + const lastValidatedDbSnapshotRef = useRef(null); const [hasConnectedDb, setHasConnectedDb] = useState(false); const [showCTAbtns, setShowCTAbtns] = useState(false); const [dbName, setDbName] = useState(''); @@ -724,6 +725,7 @@ const DatabaseModal: FunctionComponent = ({ const handleClearValidationErrors = useCallback(() => { setValidationErrors(null); setHasValidated(false); + lastValidatedDbSnapshotRef.current = null; clearError(); }, [setValidationErrors, setHasValidated, clearError]); @@ -800,6 +802,16 @@ const DatabaseModal: FunctionComponent = ({ [onChange], ); + const handleTextChange = useCallback( + ({ target }: { target: HTMLInputElement }) => { + onChange(ActionType.TextChange, { + name: target.name, + value: target.value, + }); + }, + [onChange], + ); + const handleChangeWithValidation = useCallback( ( actionType: ActionType, @@ -811,6 +823,21 @@ const DatabaseModal: FunctionComponent = ({ [onChange, handleClearValidationErrors], ); + const getBlurValidation = useCallback(async () => { + const currentDbSnapshot = JSON.stringify(db); + if (currentDbSnapshot === lastValidatedDbSnapshotRef.current) { + return []; + } + const result = await getValidation(db); + // Only cache after a request that produced a usable response. ``null`` + // signals an unexpected/network failure, in which case we leave the + // snapshot untouched so the next blur retries. + if (result !== null) { + lastValidatedDbSnapshotRef.current = currentDbSnapshot; + } + return result; + }, [db, getValidation]); + const onClose = () => { setDB({ type: ActionType.Reset }); setHasConnectedDb(false); @@ -1796,7 +1823,6 @@ const DatabaseModal: FunctionComponent = ({ name: target.name, value: target.value, }); - handleClearValidationErrors(); }} setSSHTunnelLoginMethod={(method: AuthType) => setDB({ @@ -1804,6 +1830,9 @@ const DatabaseModal: FunctionComponent = ({ payload: { login_method: method }, }) } + isValidating={isValidating} + validationErrors={validationErrors} + getValidation={getBlurValidation} /> ); @@ -1872,13 +1901,8 @@ const DatabaseModal: FunctionComponent = ({ }); }} onParametersChange={handleParametersChange} - onChange={({ target }: { target: HTMLInputElement }) => - handleChangeWithValidation(ActionType.TextChange, { - name: target.name, - value: target.value, - }) - } - getValidation={() => getValidation(db)} + onChange={handleTextChange} + getValidation={getBlurValidation} validationErrors={validationErrors} getPlaceholder={getPlaceholder} clearValidationErrors={handleClearValidationErrors} diff --git a/superset-frontend/src/views/CRUD/hooks.ts b/superset-frontend/src/views/CRUD/hooks.ts index fa69b5840707..4cb9e9f9c974 100644 --- a/superset-frontend/src/views/CRUD/hooks.ts +++ b/superset-frontend/src/views/CRUD/hooks.ts @@ -819,16 +819,13 @@ export function useDatabaseValidation() { ); const [isValidating, setIsValidating] = useState(false); const [hasValidated, setHasValidated] = useState(false); + const latestRequestIdRef = useRef(0); const getValidation = useCallback( async (database: Partial | null, onCreate = false) => { - if (database?.parameters?.ssh) { - setValidationErrors(null); - setIsValidating(false); - setHasValidated(true); - return Promise.resolve([]); - } - + const requestId = latestRequestIdRef.current + 1; + latestRequestIdRef.current = requestId; + const isLatest = () => latestRequestIdRef.current === requestId; setIsValidating(true); try { @@ -837,6 +834,7 @@ export function useDatabaseValidation() { body: JSON.stringify(transformDB(database)), headers: { 'Content-Type': 'application/json' }, }); + if (!isLatest()) return []; setValidationErrors(null); setIsValidating(false); setHasValidated(true); @@ -866,6 +864,19 @@ export function useDatabaseValidation() { return acc; } + if (extra?.ssh_tunnel) { + acc.ssh_tunnel = { + ...acc.ssh_tunnel, + ...Object.fromEntries( + (extra.missing ?? []).map((field: string) => [ + field, + 'This is a required field', + ]), + ), + }; + return acc; + } + if (extra?.invalid) { extra.invalid.forEach((field: string) => { acc[field] = message; @@ -885,6 +896,7 @@ export function useDatabaseValidation() { return acc; }, {}); + if (!isLatest()) return parsedErrors; setValidationErrors(parsedErrors); setIsValidating(false); setHasValidated(true); @@ -893,9 +905,11 @@ export function useDatabaseValidation() { } console.error('Unexpected error during validation:', error); - setIsValidating(false); - setHasValidated(true); - return {}; + if (isLatest()) { + setIsValidating(false); + setHasValidated(true); + } + return null; } }, [setValidationErrors], diff --git a/superset/commands/database/validate.py b/superset/commands/database/validate.py index 4d9952ad89a1..a37bb3cf53c6 100644 --- a/superset/commands/database/validate.py +++ b/superset/commands/database/validate.py @@ -19,6 +19,7 @@ from flask_babel import gettext as __ +from superset import is_feature_enabled from superset.commands.base import BaseCommand from superset.commands.database.exceptions import ( DatabaseOfflineError, @@ -26,6 +27,10 @@ InvalidEngineError, InvalidParametersError, ) +from superset.commands.database.ssh_tunnel.exceptions import ( + SSHTunnelDatabasePortError, + SSHTunnelingNotEnabledError, +) from superset.daos.database import DatabaseDAO from superset.databases.utils import make_url_safe from superset.db_engine_specs import get_engine_spec @@ -42,14 +47,23 @@ def __init__(self, properties: dict[str, Any]): self._properties = properties.copy() self._model: Optional[Database] = None - def run(self) -> None: + def run(self) -> None: # noqa: C901 self.validate() engine = self._properties["engine"] driver = self._properties.get("driver") if engine in BYPASS_VALIDATION_ENGINES: - # Skip engines that are only validated onCreate + # Skip engines that are only validated onCreate, but still surface + # database_name uniqueness and SSH tunnel field errors so the + # progressive validation flow stays consistent across engines. + errors: list[SupersetError] = [] + if database_name_error := self._get_database_name_error(): + errors.append(database_name_error) + errors.extend(self._get_ssh_tunnel_errors()) + if errors: + event_logger.log_with_context(action="validation_error", engine=engine) + raise InvalidParametersError(errors) return engine_spec = get_engine_spec(engine, driver) @@ -65,8 +79,17 @@ def run(self) -> None: ), ) - # perform initial validation + # perform initial validation (host, port, database, username) errors = engine_spec.validate_parameters(self._properties) # type: ignore + + # Collect database_name errors along with parameter errors + if database_name_error := self._get_database_name_error(): + errors.append(database_name_error) + + # Collect SSH tunnel errors + ssh_tunnel_errors = self._get_ssh_tunnel_errors() + errors.extend(ssh_tunnel_errors) + if errors: event_logger.log_with_context(action="validation_error", engine=engine) raise InvalidParametersError(errors) @@ -138,6 +161,107 @@ def run(self) -> None: ), ) - def validate(self) -> None: + def _load_model(self) -> None: + """Load the existing database model if updating.""" if (database_id := self._properties.get("id")) is not None: self._model = DatabaseDAO.find_by_id(database_id) + + def _get_database_name_error(self) -> Optional[SupersetError]: + """Check for duplicate database name and return error if found.""" + database_id = self._properties.get("id") + + if database_name := self._properties.get("database_name"): + is_unique = ( + DatabaseDAO.validate_update_uniqueness(database_id, database_name) + if database_id is not None + else DatabaseDAO.validate_uniqueness(database_name) + ) + if not is_unique: + return SupersetError( + message=__("A database with the same name already exists."), + error_type=SupersetErrorType.INVALID_PAYLOAD_SCHEMA_ERROR, + level=ErrorLevel.ERROR, + extra={"invalid": ["database_name"]}, + ) + return None + + def validate(self) -> None: + """Load the model and validate SSH tunnel if enabled.""" + self._load_model() + self._validate_ssh_tunnel() + + def _validate_ssh_tunnel(self) -> None: + """Validate SSH tunnel configuration if enabled.""" + ssh_tunnel = self._properties.get("ssh_tunnel") or {} + parameters = self._properties.get("parameters") or {} + # SSH can be turned on via the dedicated tunnel payload OR the + # ``parameters.ssh`` flag the UI sets while the user is filling the + # form. Both paths must enforce the feature flag and the database + # port requirement. + ssh_enabled = bool(ssh_tunnel) or bool(parameters.get("ssh")) + if not ssh_enabled: + return + if not is_feature_enabled("SSH_TUNNELING"): + raise SSHTunnelingNotEnabledError() + if not parameters.get("port"): + raise SSHTunnelDatabasePortError() + + def _get_ssh_tunnel_errors(self) -> list[SupersetError]: + """Validate SSH tunnel fields and return list of errors.""" + errors: list[SupersetError] = [] + ssh_tunnel = self._properties.get("ssh_tunnel") or {} + parameters = self._properties.get("parameters", {}) + + # Check if SSH is enabled via parameters.ssh flag + ssh_enabled = parameters.get("ssh", False) + + # Only validate SSH tunnel if SSH is enabled or ssh_tunnel is provided + if not ssh_enabled and not ssh_tunnel: + return errors + + # Required fields + required_fields = ["server_address", "server_port", "username"] + missing = [f for f in required_fields if not ssh_tunnel.get(f)] + + if missing: + errors.append( + SupersetError( + message=__( + "One or more parameters are missing: %(missing)s", + missing=", ".join(missing), + ), + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.WARNING, + extra={"missing": missing, "ssh_tunnel": True}, + ) + ) + + # Either password or private_key is required + has_password = bool(ssh_tunnel.get("password")) + has_private_key = bool(ssh_tunnel.get("private_key")) + + if not has_password and not has_private_key: + errors.append( + SupersetError( + message=__("Must provide credentials for the SSH Tunnel"), + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.WARNING, + extra={"missing": ["password"], "ssh_tunnel": True}, + ) + ) + + # If private_key is provided, private_key_password is required + if has_private_key and not ssh_tunnel.get("private_key_password"): + errors.append( + SupersetError( + message=__( + "One or more parameters are missing: %(missing)s", + missing="private_key_password", + ), + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.WARNING, + extra={"missing": ["private_key_password"], "ssh_tunnel": True}, + ) + ) + + return errors diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index da0dd3cf3403..e283af302dc4 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -443,6 +443,24 @@ class Meta: # pylint: disable=too-few-public-methods required=True, metadata={"description": configuration_method_description}, ) + ssh_tunnel = fields.Nested("DatabaseSSHTunnelValidation", allow_none=True) + + +class DatabaseSSHTunnelValidation(Schema): + """SSH Tunnel schema for validation. + + Allows partial data without strict authentication requirements. + """ + + id = fields.Integer( + allow_none=True, metadata={"description": "SSH Tunnel ID (for updates)"} + ) + server_address = fields.String(allow_none=True) + server_port = fields.Integer(allow_none=True) + username = fields.String(allow_none=True) + password = fields.String(required=False, allow_none=True) + private_key = fields.String(required=False, allow_none=True) + private_key_password = fields.String(required=False, allow_none=True) class DatabaseSSHTunnel(Schema): diff --git a/tests/unit_tests/commands/databases/validate_test.py b/tests/unit_tests/commands/databases/validate_test.py index 96f613315fe5..53ea51a634d0 100644 --- a/tests/unit_tests/commands/databases/validate_test.py +++ b/tests/unit_tests/commands/databases/validate_test.py @@ -23,6 +23,10 @@ DatabaseTestConnectionFailedError, InvalidParametersError, ) +from superset.commands.database.ssh_tunnel.exceptions import ( + SSHTunnelDatabasePortError, + SSHTunnelingNotEnabledError, +) from superset.commands.database.validate import ValidateDatabaseParametersCommand from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -205,3 +209,358 @@ def test_command_with_oauth2_not_configured(mocker: MockerFixture) -> None: extra={"engine_name": "gsheets"}, ) ] + + +def test_command_duplicate_database_name(mocker: MockerFixture) -> None: + """ + Validation surfaces a duplicate-name error for a new database with a + name already in use. + """ + DatabaseDAO = mocker.patch( # noqa: N806 + "superset.commands.database.validate.DatabaseDAO" + ) + DatabaseDAO.validate_uniqueness.return_value = False + mocker.patch( + "superset.commands.database.validate.get_engine_spec", + return_value=mocker.MagicMock( + validate_parameters=mocker.MagicMock(return_value=[]), + ), + ) + + properties = { + "engine": "postgresql", + "database_name": "duplicate", + "parameters": { + "host": "localhost", + "port": 5432, + "username": "u", + "database": "d", + }, + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(InvalidParametersError) as excinfo: + command.run() + assert any( + err.error_type == SupersetErrorType.INVALID_PAYLOAD_SCHEMA_ERROR + and err.extra is not None + and err.extra.get("invalid") == ["database_name"] + for err in excinfo.value.errors + ) + + +def test_command_duplicate_database_name_on_update(mocker: MockerFixture) -> None: + """ + Validation uses ``validate_update_uniqueness`` when an ``id`` is provided. + """ + DatabaseDAO = mocker.patch( # noqa: N806 + "superset.commands.database.validate.DatabaseDAO" + ) + DatabaseDAO.find_by_id.return_value = mocker.MagicMock() + DatabaseDAO.validate_update_uniqueness.return_value = False + mocker.patch( + "superset.commands.database.validate.get_engine_spec", + return_value=mocker.MagicMock( + validate_parameters=mocker.MagicMock(return_value=[]), + ), + ) + + properties = { + "id": 1, + "engine": "postgresql", + "database_name": "existing", + "parameters": { + "host": "localhost", + "port": 5432, + "username": "u", + "database": "d", + }, + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(InvalidParametersError): + command.run() + DatabaseDAO.validate_update_uniqueness.assert_called_once_with(1, "existing") + + +def test_command_duplicate_database_name_bypass_engine( + mocker: MockerFixture, +) -> None: + """ + Bypass engines (e.g. ``bigquery``) still validate database name uniqueness. + """ + DatabaseDAO = mocker.patch( # noqa: N806 + "superset.commands.database.validate.DatabaseDAO" + ) + DatabaseDAO.validate_uniqueness.return_value = False + + properties = { + "engine": "bigquery", + "database_name": "duplicate", + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(InvalidParametersError) as excinfo: + command.run() + assert excinfo.value.errors[0].error_type == ( + SupersetErrorType.INVALID_PAYLOAD_SCHEMA_ERROR + ) + + +def test_validate_ssh_tunnel_feature_disabled(mocker: MockerFixture) -> None: + """ + Enabling SSH tunnel without the feature flag raises an error. + """ + mocker.patch( + "superset.commands.database.validate.is_feature_enabled", + return_value=False, + ) + + properties = { + "engine": "postgresql", + "ssh_tunnel": {"server_address": "localhost"}, + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(SSHTunnelingNotEnabledError): + command.run() + + +def test_validate_ssh_tunnel_missing_db_port(mocker: MockerFixture) -> None: + """ + SSH tunneling requires a database port. + """ + mocker.patch( + "superset.commands.database.validate.is_feature_enabled", + return_value=True, + ) + + properties = { + "engine": "postgresql", + "ssh_tunnel": {"server_address": "localhost"}, + "parameters": {"host": "localhost"}, + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(SSHTunnelDatabasePortError): + command.run() + + +def test_get_ssh_tunnel_errors_missing_required_fields( + mocker: MockerFixture, +) -> None: + """ + SSH tunnel collects missing required fields (server_address, server_port, + username) and missing credentials. + """ + mocker.patch( + "superset.commands.database.validate.is_feature_enabled", + return_value=True, + ) + mocker.patch( + "superset.commands.database.validate.get_engine_spec", + return_value=mocker.MagicMock( + validate_parameters=mocker.MagicMock(return_value=[]), + ), + ) + + properties = { + "engine": "postgresql", + "parameters": { + "host": "localhost", + "port": 5432, + "username": "u", + "database": "d", + }, + "ssh_tunnel": {"server_address": "ssh.example.com"}, + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(InvalidParametersError) as excinfo: + command.run() + + assert any( + err.extra is not None + and err.extra.get("ssh_tunnel") is True + and err.extra.get("missing") == ["server_port", "username"] + for err in excinfo.value.errors + ) + assert any( + err.extra is not None + and err.extra.get("ssh_tunnel") is True + and err.extra.get("missing") == ["password"] + for err in excinfo.value.errors + ) + + +def test_get_ssh_tunnel_errors_private_key_without_password( + mocker: MockerFixture, +) -> None: + """ + Providing a private_key without private_key_password raises a missing + parameters error. + """ + mocker.patch( + "superset.commands.database.validate.is_feature_enabled", + return_value=True, + ) + mocker.patch( + "superset.commands.database.validate.get_engine_spec", + return_value=mocker.MagicMock( + validate_parameters=mocker.MagicMock(return_value=[]), + ), + ) + + properties = { + "engine": "postgresql", + "parameters": { + "host": "localhost", + "port": 5432, + "username": "u", + "database": "d", + }, + "ssh_tunnel": { + "server_address": "ssh.example.com", + "server_port": 22, + "username": "ssh-user", + "private_key": "----- KEY -----", + }, + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(InvalidParametersError) as excinfo: + command.run() + + assert any( + err.extra is not None + and err.extra.get("ssh_tunnel") is True + and err.extra.get("missing") == ["private_key_password"] + for err in excinfo.value.errors + ) + + +def test_get_ssh_tunnel_errors_skipped_when_not_enabled( + mocker: MockerFixture, +) -> None: + """ + SSH tunnel validation is a no-op when ssh is not enabled and no tunnel + is provided. + """ + DatabaseDAO = mocker.patch( # noqa: N806 + "superset.commands.database.validate.DatabaseDAO" + ) + DatabaseDAO.validate_uniqueness.return_value = True + + database = mocker.MagicMock() + with database.get_sqla_engine() as engine: + engine.dialect.do_ping.return_value = True + DatabaseDAO.build_db_for_connection_test.return_value = database + + mocker.patch( + "superset.commands.database.validate.get_engine_spec", + return_value=mocker.MagicMock( + validate_parameters=mocker.MagicMock(return_value=[]), + build_sqlalchemy_uri=mocker.MagicMock(return_value="postgresql://"), + unmask_encrypted_extra=mocker.MagicMock(return_value="{}"), + ), + ) + + properties = { + "engine": "postgresql", + "database_name": "ok", + "parameters": { + "host": "localhost", + "port": 5432, + "username": "u", + "database": "d", + }, + } + command = ValidateDatabaseParametersCommand(properties) + command.run() + + +def test_bypass_engine_surfaces_ssh_tunnel_errors(mocker: MockerFixture) -> None: + """ + Bypass engines also surface SSH tunnel field errors so the progressive + validation flow stays consistent across engine types. + """ + mocker.patch( + "superset.commands.database.validate.is_feature_enabled", + return_value=True, + ) + DatabaseDAO = mocker.patch( # noqa: N806 + "superset.commands.database.validate.DatabaseDAO" + ) + DatabaseDAO.validate_uniqueness.return_value = True + + properties = { + "engine": "snowflake", + "database_name": "ok", + "parameters": {"port": 443}, + "ssh_tunnel": {"server_address": "ssh.example.com"}, + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(InvalidParametersError) as excinfo: + command.run() + assert any( + err.extra is not None and err.extra.get("ssh_tunnel") is True + for err in excinfo.value.errors + ) + + +def test_validate_ssh_tunnel_feature_disabled_via_parameters_ssh( + mocker: MockerFixture, +) -> None: + """ + The SSH feature-flag guard fires when the UI marks ``parameters.ssh`` + even if ``ssh_tunnel`` itself is empty during progressive validation. + """ + mocker.patch( + "superset.commands.database.validate.is_feature_enabled", + return_value=False, + ) + + properties = { + "engine": "postgresql", + "parameters": {"host": "localhost", "port": 5432, "ssh": True}, + "ssh_tunnel": {}, + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(SSHTunnelingNotEnabledError): + command.run() + + +def test_ssh_tunnel_missing_message_is_interpolated( + mocker: MockerFixture, +) -> None: + """ + The translated ``parameters are missing`` message is interpolated with + the actual missing fields rather than the raw ``%(missing)s`` token. + """ + mocker.patch( + "superset.commands.database.validate.is_feature_enabled", + return_value=True, + ) + mocker.patch( + "superset.commands.database.validate.get_engine_spec", + return_value=mocker.MagicMock( + validate_parameters=mocker.MagicMock(return_value=[]), + ), + ) + + properties = { + "engine": "postgresql", + "parameters": { + "host": "localhost", + "port": 5432, + "username": "u", + "database": "d", + }, + "ssh_tunnel": {"server_address": "ssh.example.com"}, + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(InvalidParametersError) as excinfo: + command.run() + missing_field_messages = [ + err.message + for err in excinfo.value.errors + if err.extra is not None + and err.extra.get("missing") + and err.extra.get("ssh_tunnel") # noqa: E501 + ] + assert missing_field_messages + assert all("%(missing)s" not in msg for msg in missing_field_messages) + assert any("server_port" in msg for msg in missing_field_messages)