From 5a15367449072615a3b2cf35b6d3377c8a76279b Mon Sep 17 00:00:00 2001 From: Ondrej Metelka Date: Thu, 7 May 2026 09:41:17 +0200 Subject: [PATCH] OLS-2459 enforce TLS security profile on Postgres connections Change the default sslmode from "prefer" to "require" so the service never silently downgrades to cleartext even when the operator config is absent. Add a shared build_ssl_context() helper that constructs an SSLContext with minimum TLS version and cipher restrictions from the configured TLS security profile. Thread the profile through PostgresBase, the cache factory, quota limiter factory, token usage history, and the quota scheduler so every Postgres connection enforces the cluster's TLS policy when configured. Fix the quota scheduler's sslrootcert which was accidentally commented out, leaving it without CA verification. Co-authored-by: Cursor --- ols/app/models/config.py | 16 +++++++ ols/constants.py | 2 +- ols/runners/quota_scheduler.py | 23 +++++----- ols/utils/postgres.py | 23 +++++----- ols/utils/ssl.py | 41 ++++++++++++++++- .../runners/test_quota_scheduler_runner.py | 43 ++++++++++++++++-- tests/unit/utils/test_postgres.py | 42 ++++++++++++++++++ tests/unit/utils/test_ssl.py | 44 +++++++++++++++++++ 8 files changed, 207 insertions(+), 27 deletions(-) diff --git a/ols/app/models/config.py b/ols/app/models/config.py index adfc73ca9..83009a372 100644 --- a/ols/app/models/config.py +++ b/ols/app/models/config.py @@ -798,6 +798,7 @@ class PostgresConfig(BaseModel): gss_encmode: str = constants.POSTGRES_CACHE_GSSENCMODE ca_cert_path: Optional[FilePath] = None max_entries: PositiveInt = constants.POSTGRES_CACHE_MAX_ENTRIES + tls_security_profile: Optional["TLSSecurityProfile"] = None def __init__(self, **data: Any) -> None: """Initialize configuration.""" @@ -1209,6 +1210,7 @@ def __init__( data.get("tlsSecurityProfile", None) ) self.quota_handlers = QuotaHandlersConfig(data.get("quota_handlers", None)) + self._propagate_tls_profile() self.proxy_config = ProxyConfig(data.get("proxy_config")) if data.get("tool_filtering", None) is not None: self.tool_filtering = ToolFilteringConfig(**data.get("tool_filtering")) @@ -1232,6 +1234,20 @@ def __init__( "offload_storage_path", constants.DEFAULT_OFFLOAD_STORAGE_PATH ) + def _propagate_tls_profile(self) -> None: + """Set the TLS security profile on all PostgresConfig instances.""" + if ( + self.tls_security_profile is None + or self.tls_security_profile.profile_type is None + ): + return + if self.conversation_cache and self.conversation_cache.postgres: + self.conversation_cache.postgres.tls_security_profile = ( + self.tls_security_profile + ) + if self.quota_handlers and self.quota_handlers.storage: + self.quota_handlers.storage.tls_security_profile = self.tls_security_profile + def validate_yaml(self, disable_tls: bool = False) -> None: """Validate OLS config.""" if self.conversation_cache is not None: diff --git a/ols/constants.py b/ols/constants.py index 217d588fd..4245a3761 100644 --- a/ols/constants.py +++ b/ols/constants.py @@ -131,7 +131,7 @@ class GenericLLMParameters: # look at https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLMODE # for all possible options -POSTGRES_CACHE_SSL_MODE = "prefer" +POSTGRES_CACHE_SSL_MODE = "require" # look at https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-GSSENCMODE # for all possible options diff --git a/ols/runners/quota_scheduler.py b/ols/runners/quota_scheduler.py index d5112add0..ce09712f3 100644 --- a/ols/runners/quota_scheduler.py +++ b/ols/runners/quota_scheduler.py @@ -10,6 +10,7 @@ from ols import constants from ols.app.models.config import LimiterConfig, PostgresConfig, QuotaHandlersConfig from ols.utils.config import AppConfig +from ols.utils.ssl import libpq_tls_params logger: logging.Logger = logging.getLogger(__name__) @@ -158,16 +159,18 @@ def get_subject_id(limiter_type: str) -> str: def connect(config: PostgresConfig) -> Any: """Initialize connection to database.""" logger.info("Initializing connection to quota limiter database") - connection = psycopg2.connect( - host=config.host, - port=config.port, - user=config.user, - password=config.password, - dbname=config.dbname, - sslmode=config.ssl_mode, - # sslrootcert=config.ca_cert_path, - gssencmode=config.gss_encmode, - ) + connect_kwargs: dict[str, Any] = { + "host": config.host, + "port": config.port, + "user": config.user, + "password": config.password, + "dbname": config.dbname, + "sslmode": config.ssl_mode, + "sslrootcert": config.ca_cert_path, + "gssencmode": config.gss_encmode, + **libpq_tls_params(config.tls_security_profile), + } + connection = psycopg2.connect(**connect_kwargs) if connection is not None: connection.autocommit = True return connection diff --git a/ols/utils/postgres.py b/ols/utils/postgres.py index 64fb1cfa9..c69fca824 100644 --- a/ols/utils/postgres.py +++ b/ols/utils/postgres.py @@ -11,6 +11,7 @@ import psycopg2 from ols.app.models.config import PostgresConfig +from ols.utils.ssl import libpq_tls_params logger = logging.getLogger(__name__) @@ -57,16 +58,18 @@ def connect(self) -> None: logger.info("Establishing connection to Postgres") self.connection = None config = self.connection_config - self.connection = psycopg2.connect( - host=config.host, - port=config.port, - user=config.user, - password=config.password, - dbname=config.dbname, - sslmode=config.ssl_mode, - sslrootcert=config.ca_cert_path, - gssencmode=config.gss_encmode, - ) + connect_kwargs: dict[str, Any] = { + "host": config.host, + "port": config.port, + "user": config.user, + "password": config.password, + "dbname": config.dbname, + "sslmode": config.ssl_mode, + "sslrootcert": config.ca_cert_path, + "gssencmode": config.gss_encmode, + **libpq_tls_params(config.tls_security_profile), + } + self.connection = psycopg2.connect(**connect_kwargs) try: cursor = self.connection.cursor() cursor.execute("SET LOCAL lock_timeout = '60s'") diff --git a/ols/utils/ssl.py b/ols/utils/ssl.py index 701587162..ee0f7cb32 100644 --- a/ols/utils/ssl.py +++ b/ols/utils/ssl.py @@ -1,8 +1,8 @@ -"""Utility function for retrieving SSL version and list of ciphers for TLS security profile.""" +"""Utility functions for TLS security profile enforcement.""" import logging import ssl -from typing import Optional +from typing import Any, Optional from ols import constants from ols.app.models.config import TLSSecurityProfile @@ -11,6 +11,43 @@ logger = logging.getLogger(__name__) +_LIBPQ_TLS_VERSION_MAP: dict[ssl.TLSVersion, str] = { + ssl.TLSVersion.TLSv1: "TLSv1", + ssl.TLSVersion.TLSv1_1: "TLSv1.1", + ssl.TLSVersion.TLSv1_2: "TLSv1.2", + ssl.TLSVersion.TLSv1_3: "TLSv1.3", +} + + +def libpq_tls_params( + sec_profile: Optional[TLSSecurityProfile], +) -> dict[str, Any]: + """Return extra libpq connection kwargs enforcing the TLS security profile. + + Maps the OpenShift TLS security profile to libpq's + ``ssl_min_protocol_version`` parameter. Returns an empty dict when no + profile is configured so the caller can simply ``**``-merge it. + + Cipher enforcement is not supported by libpq on the client side — + cipher negotiation is controlled by the PostgreSQL server's + ``ssl_ciphers`` setting. + """ + if sec_profile is None or sec_profile.profile_type is None: + return {} + + min_version = get_min_tls_version(sec_profile) + if min_version is None: + return {} + + libpq_value = _LIBPQ_TLS_VERSION_MAP.get(min_version) + if libpq_value is None: + logger.warning("Unmapped TLS version %s, skipping enforcement", min_version) + return {} + + logger.info("Enforcing Postgres ssl_min_protocol_version=%s", libpq_value) + return {"ssl_min_protocol_version": libpq_value} + + def get_ssl_version(sec_profile: Optional[TLSSecurityProfile]) -> int: """Get SSL protocol constant for TLS context creation.""" logger.info("Using SSL protocol version: %s", ssl.PROTOCOL_TLS_SERVER) diff --git a/tests/unit/runners/test_quota_scheduler_runner.py b/tests/unit/runners/test_quota_scheduler_runner.py index 774aef1ce..0ea1efa6d 100644 --- a/tests/unit/runners/test_quota_scheduler_runner.py +++ b/tests/unit/runners/test_quota_scheduler_runner.py @@ -182,21 +182,56 @@ def test_connect(): """Test the connection to Postgres.""" exception_message = "Exception during PostgreSQL storage." - # connection won't be established config = PostgresConfig() - # don't connect to real PostgreSQL instance with patch("psycopg2.connect") as mock_connect: mock_connect.return_value.cursor.return_value.execute.side_effect = Exception( exception_message ) - # try to connect to mocked Postgres connection = connect(config) - # connection should not be established assert connection is not None +def test_connect_passes_sslrootcert(): + """Verify connect passes sslrootcert to psycopg2.connect.""" + config = PostgresConfig() + config.ca_cert_path = "/certs/ca.pem" + + with patch("psycopg2.connect") as mock_connect: + connect(config) + + kwargs = mock_connect.call_args + assert kwargs.kwargs.get("sslrootcert") == "/certs/ca.pem" + + +def test_connect_passes_ssl_min_protocol_version_when_profile_set(): + """Verify connect passes ssl_min_protocol_version when profile is on config.""" + from ols.app.models.config import TLSSecurityProfile + + config = PostgresConfig() + profile = TLSSecurityProfile() + profile.profile_type = "IntermediateType" + config.tls_security_profile = profile + + with patch("psycopg2.connect") as mock_connect: + connect(config) + + kwargs = mock_connect.call_args.kwargs + assert kwargs.get("ssl_min_protocol_version") == "TLSv1.2" + + +def test_connect_no_ssl_min_protocol_version_without_profile(): + """Verify connect has no ssl_min_protocol_version without a TLS profile.""" + config = PostgresConfig() + + with patch("psycopg2.connect") as mock_connect: + connect(config) + + kwargs = mock_connect.call_args.kwargs + assert "ssl_min_protocol_version" not in kwargs + + def test_get_subject_id(): """Check the function to get subject ID based on quota limiter type.""" assert get_subject_id(constants.USER_QUOTA_LIMITER) == "u" diff --git a/tests/unit/utils/test_postgres.py b/tests/unit/utils/test_postgres.py index b2f5e31f6..452fe6456 100644 --- a/tests/unit/utils/test_postgres.py +++ b/tests/unit/utils/test_postgres.py @@ -5,6 +5,7 @@ import psycopg2 import pytest +from ols.app.models.config import TLSSecurityProfile from ols.utils.postgres import PostgresBase, connection @@ -165,3 +166,44 @@ def test_connected_returns_false_on_interface_error(self): psycopg2.InterfaceError("cannot reach server") ) assert component.connected() is False + + +class TestPostgresBaseTlsProfile: + """Tests for TLS security profile integration in PostgresBase.""" + + def _mock_config(self, profile: TLSSecurityProfile | None = None) -> MagicMock: + """Return a MagicMock PostgresConfig with the given TLS profile.""" + cfg = MagicMock() + cfg.ca_cert_path = None + cfg.tls_security_profile = profile + return cfg + + def test_ssl_min_protocol_version_passed_when_profile_set(self): + """Verify psycopg2.connect receives ssl_min_protocol_version.""" + profile = TLSSecurityProfile() + profile.profile_type = "IntermediateType" + + with patch("psycopg2.connect") as mock_connect: + FakeComponent(config=self._mock_config(profile)) + + kwargs = mock_connect.call_args.kwargs + assert kwargs.get("ssl_min_protocol_version") == "TLSv1.2" + + def test_ssl_min_protocol_version_not_passed_when_no_profile(self): + """Verify psycopg2.connect has no ssl_min_protocol_version without profile.""" + with patch("psycopg2.connect") as mock_connect: + FakeComponent(config=self._mock_config()) + + kwargs = mock_connect.call_args.kwargs + assert "ssl_min_protocol_version" not in kwargs + + def test_ssl_min_protocol_version_not_passed_when_profile_type_is_none(self): + """Verify psycopg2.connect skips enforcement when profile_type is None.""" + profile = TLSSecurityProfile() + profile.profile_type = None + + with patch("psycopg2.connect") as mock_connect: + FakeComponent(config=self._mock_config(profile)) + + kwargs = mock_connect.call_args.kwargs + assert "ssl_min_protocol_version" not in kwargs diff --git a/tests/unit/utils/test_ssl.py b/tests/unit/utils/test_ssl.py index 740270732..887336e80 100644 --- a/tests/unit/utils/test_ssl.py +++ b/tests/unit/utils/test_ssl.py @@ -3,6 +3,7 @@ import ssl as stdlib_ssl import pytest +from psycopg2 import extensions from ols import constants from ols.app.models.config import TLSSecurityProfile @@ -10,6 +11,11 @@ from ols.utils import tls +def test_postgres_ssl_mode_default_is_require(): + """Verify the default Postgres SSL mode is 'require', not 'prefer'.""" + assert constants.POSTGRES_CACHE_SSL_MODE == "require" + + def test_get_ssl_version_returns_protocol_constant(): """Check the function to get SSL version.""" assert ssl_utils.get_ssl_version(None) == constants.DEFAULT_SSL_VERSION @@ -71,3 +77,41 @@ def test_get_ciphers_with_proper_security_profile(tls_profile_name): allowed_ciphers = ssl_utils.get_ciphers(security_profile) assert allowed_ciphers is not None assert allowed_ciphers == tls.ciphers_for_tls_profile(tls_profile_name) + + +class TestLibpqTlsParams: + """Tests for the libpq_tls_params helper.""" + + def test_returns_empty_when_profile_is_none(self): + """Return empty dict when no TLS security profile is provided.""" + assert ssl_utils.libpq_tls_params(None) == {} + + def test_returns_empty_when_profile_type_is_none(self): + """Return empty dict when profile exists but profile_type is unset.""" + profile = TLSSecurityProfile() + profile.profile_type = None + assert ssl_utils.libpq_tls_params(profile) == {} + + @pytest.mark.parametrize( + ("profile_type", "expected_version"), + [ + ("IntermediateType", "TLSv1.2"), + ("ModernType", "TLSv1.3"), + ], + ) + def test_maps_profile_to_libpq_version(self, profile_type, expected_version): + """Verify the profile maps to the correct libpq version string.""" + profile = TLSSecurityProfile() + profile.profile_type = profile_type + params = ssl_utils.libpq_tls_params(profile) + assert params == {"ssl_min_protocol_version": expected_version} + + def test_result_can_be_merged_into_connect_kwargs(self): + """Verify the returned dict produces a valid libpq DSN.""" + profile = TLSSecurityProfile() + profile.profile_type = "IntermediateType" + params = ssl_utils.libpq_tls_params(profile) + dsn = extensions.make_dsn( + host="127.0.0.1", dbname="test", sslmode="require", **params + ) + assert "ssl_min_protocol_version=TLSv1.2" in dsn