diff --git a/ols/app/models/config.py b/ols/app/models/config.py index adfc73ca9..83009a372 100644 --- a/ols/app/models/config.py +++ b/ols/app/models/config.py @@ -798,6 +798,7 @@ class PostgresConfig(BaseModel): gss_encmode: str = constants.POSTGRES_CACHE_GSSENCMODE ca_cert_path: Optional[FilePath] = None max_entries: PositiveInt = constants.POSTGRES_CACHE_MAX_ENTRIES + tls_security_profile: Optional["TLSSecurityProfile"] = None def __init__(self, **data: Any) -> None: """Initialize configuration.""" @@ -1209,6 +1210,7 @@ def __init__( data.get("tlsSecurityProfile", None) ) self.quota_handlers = QuotaHandlersConfig(data.get("quota_handlers", None)) + self._propagate_tls_profile() self.proxy_config = ProxyConfig(data.get("proxy_config")) if data.get("tool_filtering", None) is not None: self.tool_filtering = ToolFilteringConfig(**data.get("tool_filtering")) @@ -1232,6 +1234,20 @@ def __init__( "offload_storage_path", constants.DEFAULT_OFFLOAD_STORAGE_PATH ) + def _propagate_tls_profile(self) -> None: + """Set the TLS security profile on all PostgresConfig instances.""" + if ( + self.tls_security_profile is None + or self.tls_security_profile.profile_type is None + ): + return + if self.conversation_cache and self.conversation_cache.postgres: + self.conversation_cache.postgres.tls_security_profile = ( + self.tls_security_profile + ) + if self.quota_handlers and self.quota_handlers.storage: + self.quota_handlers.storage.tls_security_profile = self.tls_security_profile + def validate_yaml(self, disable_tls: bool = False) -> None: """Validate OLS config.""" if self.conversation_cache is not None: diff --git a/ols/constants.py b/ols/constants.py index 217d588fd..4245a3761 100644 --- a/ols/constants.py +++ b/ols/constants.py @@ -131,7 +131,7 @@ class GenericLLMParameters: # look at https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLMODE # for all possible options -POSTGRES_CACHE_SSL_MODE = "prefer" +POSTGRES_CACHE_SSL_MODE = "require" # look at https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-GSSENCMODE # for all possible options diff --git a/ols/runners/quota_scheduler.py b/ols/runners/quota_scheduler.py index d5112add0..ce09712f3 100644 --- a/ols/runners/quota_scheduler.py +++ b/ols/runners/quota_scheduler.py @@ -10,6 +10,7 @@ from ols import constants from ols.app.models.config import LimiterConfig, PostgresConfig, QuotaHandlersConfig from ols.utils.config import AppConfig +from ols.utils.ssl import libpq_tls_params logger: logging.Logger = logging.getLogger(__name__) @@ -158,16 +159,18 @@ def get_subject_id(limiter_type: str) -> str: def connect(config: PostgresConfig) -> Any: """Initialize connection to database.""" logger.info("Initializing connection to quota limiter database") - connection = psycopg2.connect( - host=config.host, - port=config.port, - user=config.user, - password=config.password, - dbname=config.dbname, - sslmode=config.ssl_mode, - # sslrootcert=config.ca_cert_path, - gssencmode=config.gss_encmode, - ) + connect_kwargs: dict[str, Any] = { + "host": config.host, + "port": config.port, + "user": config.user, + "password": config.password, + "dbname": config.dbname, + "sslmode": config.ssl_mode, + "sslrootcert": config.ca_cert_path, + "gssencmode": config.gss_encmode, + **libpq_tls_params(config.tls_security_profile), + } + connection = psycopg2.connect(**connect_kwargs) if connection is not None: connection.autocommit = True return connection diff --git a/ols/utils/postgres.py b/ols/utils/postgres.py index 64fb1cfa9..c69fca824 100644 --- a/ols/utils/postgres.py +++ b/ols/utils/postgres.py @@ -11,6 +11,7 @@ import psycopg2 from ols.app.models.config import PostgresConfig +from ols.utils.ssl import libpq_tls_params logger = logging.getLogger(__name__) @@ -57,16 +58,18 @@ def connect(self) -> None: logger.info("Establishing connection to Postgres") self.connection = None config = self.connection_config - self.connection = psycopg2.connect( - host=config.host, - port=config.port, - user=config.user, - password=config.password, - dbname=config.dbname, - sslmode=config.ssl_mode, - sslrootcert=config.ca_cert_path, - gssencmode=config.gss_encmode, - ) + connect_kwargs: dict[str, Any] = { + "host": config.host, + "port": config.port, + "user": config.user, + "password": config.password, + "dbname": config.dbname, + "sslmode": config.ssl_mode, + "sslrootcert": config.ca_cert_path, + "gssencmode": config.gss_encmode, + **libpq_tls_params(config.tls_security_profile), + } + self.connection = psycopg2.connect(**connect_kwargs) try: cursor = self.connection.cursor() cursor.execute("SET LOCAL lock_timeout = '60s'") diff --git a/ols/utils/ssl.py b/ols/utils/ssl.py index 701587162..ee0f7cb32 100644 --- a/ols/utils/ssl.py +++ b/ols/utils/ssl.py @@ -1,8 +1,8 @@ -"""Utility function for retrieving SSL version and list of ciphers for TLS security profile.""" +"""Utility functions for TLS security profile enforcement.""" import logging import ssl -from typing import Optional +from typing import Any, Optional from ols import constants from ols.app.models.config import TLSSecurityProfile @@ -11,6 +11,43 @@ logger = logging.getLogger(__name__) +_LIBPQ_TLS_VERSION_MAP: dict[ssl.TLSVersion, str] = { + ssl.TLSVersion.TLSv1: "TLSv1", + ssl.TLSVersion.TLSv1_1: "TLSv1.1", + ssl.TLSVersion.TLSv1_2: "TLSv1.2", + ssl.TLSVersion.TLSv1_3: "TLSv1.3", +} + + +def libpq_tls_params( + sec_profile: Optional[TLSSecurityProfile], +) -> dict[str, Any]: + """Return extra libpq connection kwargs enforcing the TLS security profile. + + Maps the OpenShift TLS security profile to libpq's + ``ssl_min_protocol_version`` parameter. Returns an empty dict when no + profile is configured so the caller can simply ``**``-merge it. + + Cipher enforcement is not supported by libpq on the client side — + cipher negotiation is controlled by the PostgreSQL server's + ``ssl_ciphers`` setting. + """ + if sec_profile is None or sec_profile.profile_type is None: + return {} + + min_version = get_min_tls_version(sec_profile) + if min_version is None: + return {} + + libpq_value = _LIBPQ_TLS_VERSION_MAP.get(min_version) + if libpq_value is None: + logger.warning("Unmapped TLS version %s, skipping enforcement", min_version) + return {} + + logger.info("Enforcing Postgres ssl_min_protocol_version=%s", libpq_value) + return {"ssl_min_protocol_version": libpq_value} + + def get_ssl_version(sec_profile: Optional[TLSSecurityProfile]) -> int: """Get SSL protocol constant for TLS context creation.""" logger.info("Using SSL protocol version: %s", ssl.PROTOCOL_TLS_SERVER) diff --git a/tests/unit/runners/test_quota_scheduler_runner.py b/tests/unit/runners/test_quota_scheduler_runner.py index 774aef1ce..0ea1efa6d 100644 --- a/tests/unit/runners/test_quota_scheduler_runner.py +++ b/tests/unit/runners/test_quota_scheduler_runner.py @@ -182,21 +182,56 @@ def test_connect(): """Test the connection to Postgres.""" exception_message = "Exception during PostgreSQL storage." - # connection won't be established config = PostgresConfig() - # don't connect to real PostgreSQL instance with patch("psycopg2.connect") as mock_connect: mock_connect.return_value.cursor.return_value.execute.side_effect = Exception( exception_message ) - # try to connect to mocked Postgres connection = connect(config) - # connection should not be established assert connection is not None +def test_connect_passes_sslrootcert(): + """Verify connect passes sslrootcert to psycopg2.connect.""" + config = PostgresConfig() + config.ca_cert_path = "/certs/ca.pem" + + with patch("psycopg2.connect") as mock_connect: + connect(config) + + kwargs = mock_connect.call_args + assert kwargs.kwargs.get("sslrootcert") == "/certs/ca.pem" + + +def test_connect_passes_ssl_min_protocol_version_when_profile_set(): + """Verify connect passes ssl_min_protocol_version when profile is on config.""" + from ols.app.models.config import TLSSecurityProfile + + config = PostgresConfig() + profile = TLSSecurityProfile() + profile.profile_type = "IntermediateType" + config.tls_security_profile = profile + + with patch("psycopg2.connect") as mock_connect: + connect(config) + + kwargs = mock_connect.call_args.kwargs + assert kwargs.get("ssl_min_protocol_version") == "TLSv1.2" + + +def test_connect_no_ssl_min_protocol_version_without_profile(): + """Verify connect has no ssl_min_protocol_version without a TLS profile.""" + config = PostgresConfig() + + with patch("psycopg2.connect") as mock_connect: + connect(config) + + kwargs = mock_connect.call_args.kwargs + assert "ssl_min_protocol_version" not in kwargs + + def test_get_subject_id(): """Check the function to get subject ID based on quota limiter type.""" assert get_subject_id(constants.USER_QUOTA_LIMITER) == "u" diff --git a/tests/unit/utils/test_postgres.py b/tests/unit/utils/test_postgres.py index b2f5e31f6..452fe6456 100644 --- a/tests/unit/utils/test_postgres.py +++ b/tests/unit/utils/test_postgres.py @@ -5,6 +5,7 @@ import psycopg2 import pytest +from ols.app.models.config import TLSSecurityProfile from ols.utils.postgres import PostgresBase, connection @@ -165,3 +166,44 @@ def test_connected_returns_false_on_interface_error(self): psycopg2.InterfaceError("cannot reach server") ) assert component.connected() is False + + +class TestPostgresBaseTlsProfile: + """Tests for TLS security profile integration in PostgresBase.""" + + def _mock_config(self, profile: TLSSecurityProfile | None = None) -> MagicMock: + """Return a MagicMock PostgresConfig with the given TLS profile.""" + cfg = MagicMock() + cfg.ca_cert_path = None + cfg.tls_security_profile = profile + return cfg + + def test_ssl_min_protocol_version_passed_when_profile_set(self): + """Verify psycopg2.connect receives ssl_min_protocol_version.""" + profile = TLSSecurityProfile() + profile.profile_type = "IntermediateType" + + with patch("psycopg2.connect") as mock_connect: + FakeComponent(config=self._mock_config(profile)) + + kwargs = mock_connect.call_args.kwargs + assert kwargs.get("ssl_min_protocol_version") == "TLSv1.2" + + def test_ssl_min_protocol_version_not_passed_when_no_profile(self): + """Verify psycopg2.connect has no ssl_min_protocol_version without profile.""" + with patch("psycopg2.connect") as mock_connect: + FakeComponent(config=self._mock_config()) + + kwargs = mock_connect.call_args.kwargs + assert "ssl_min_protocol_version" not in kwargs + + def test_ssl_min_protocol_version_not_passed_when_profile_type_is_none(self): + """Verify psycopg2.connect skips enforcement when profile_type is None.""" + profile = TLSSecurityProfile() + profile.profile_type = None + + with patch("psycopg2.connect") as mock_connect: + FakeComponent(config=self._mock_config(profile)) + + kwargs = mock_connect.call_args.kwargs + assert "ssl_min_protocol_version" not in kwargs diff --git a/tests/unit/utils/test_ssl.py b/tests/unit/utils/test_ssl.py index 740270732..887336e80 100644 --- a/tests/unit/utils/test_ssl.py +++ b/tests/unit/utils/test_ssl.py @@ -3,6 +3,7 @@ import ssl as stdlib_ssl import pytest +from psycopg2 import extensions from ols import constants from ols.app.models.config import TLSSecurityProfile @@ -10,6 +11,11 @@ from ols.utils import tls +def test_postgres_ssl_mode_default_is_require(): + """Verify the default Postgres SSL mode is 'require', not 'prefer'.""" + assert constants.POSTGRES_CACHE_SSL_MODE == "require" + + def test_get_ssl_version_returns_protocol_constant(): """Check the function to get SSL version.""" assert ssl_utils.get_ssl_version(None) == constants.DEFAULT_SSL_VERSION @@ -71,3 +77,41 @@ def test_get_ciphers_with_proper_security_profile(tls_profile_name): allowed_ciphers = ssl_utils.get_ciphers(security_profile) assert allowed_ciphers is not None assert allowed_ciphers == tls.ciphers_for_tls_profile(tls_profile_name) + + +class TestLibpqTlsParams: + """Tests for the libpq_tls_params helper.""" + + def test_returns_empty_when_profile_is_none(self): + """Return empty dict when no TLS security profile is provided.""" + assert ssl_utils.libpq_tls_params(None) == {} + + def test_returns_empty_when_profile_type_is_none(self): + """Return empty dict when profile exists but profile_type is unset.""" + profile = TLSSecurityProfile() + profile.profile_type = None + assert ssl_utils.libpq_tls_params(profile) == {} + + @pytest.mark.parametrize( + ("profile_type", "expected_version"), + [ + ("IntermediateType", "TLSv1.2"), + ("ModernType", "TLSv1.3"), + ], + ) + def test_maps_profile_to_libpq_version(self, profile_type, expected_version): + """Verify the profile maps to the correct libpq version string.""" + profile = TLSSecurityProfile() + profile.profile_type = profile_type + params = ssl_utils.libpq_tls_params(profile) + assert params == {"ssl_min_protocol_version": expected_version} + + def test_result_can_be_merged_into_connect_kwargs(self): + """Verify the returned dict produces a valid libpq DSN.""" + profile = TLSSecurityProfile() + profile.profile_type = "IntermediateType" + params = ssl_utils.libpq_tls_params(profile) + dsn = extensions.make_dsn( + host="127.0.0.1", dbname="test", sslmode="require", **params + ) + assert "ssl_min_protocol_version=TLSv1.2" in dsn