Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions ols/app/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ class PostgresConfig(BaseModel):
gss_encmode: str = constants.POSTGRES_CACHE_GSSENCMODE
ca_cert_path: Optional[FilePath] = None
max_entries: PositiveInt = constants.POSTGRES_CACHE_MAX_ENTRIES
tls_security_profile: Optional["TLSSecurityProfile"] = None

def __init__(self, **data: Any) -> None:
"""Initialize configuration."""
Expand Down Expand Up @@ -1209,6 +1210,7 @@ def __init__(
data.get("tlsSecurityProfile", None)
)
self.quota_handlers = QuotaHandlersConfig(data.get("quota_handlers", None))
self._propagate_tls_profile()
self.proxy_config = ProxyConfig(data.get("proxy_config"))
if data.get("tool_filtering", None) is not None:
self.tool_filtering = ToolFilteringConfig(**data.get("tool_filtering"))
Expand All @@ -1232,6 +1234,20 @@ def __init__(
"offload_storage_path", constants.DEFAULT_OFFLOAD_STORAGE_PATH
)

def _propagate_tls_profile(self) -> None:
"""Set the TLS security profile on all PostgresConfig instances."""
if (
self.tls_security_profile is None
or self.tls_security_profile.profile_type is None
):
return
if self.conversation_cache and self.conversation_cache.postgres:
self.conversation_cache.postgres.tls_security_profile = (
self.tls_security_profile
)
if self.quota_handlers and self.quota_handlers.storage:
self.quota_handlers.storage.tls_security_profile = self.tls_security_profile

def validate_yaml(self, disable_tls: bool = False) -> None:
"""Validate OLS config."""
if self.conversation_cache is not None:
Expand Down
2 changes: 1 addition & 1 deletion ols/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class GenericLLMParameters:

# look at https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLMODE
# for all possible options
POSTGRES_CACHE_SSL_MODE = "prefer"
POSTGRES_CACHE_SSL_MODE = "require"

# look at https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-GSSENCMODE
# for all possible options
Expand Down
23 changes: 13 additions & 10 deletions ols/runners/quota_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ols import constants
from ols.app.models.config import LimiterConfig, PostgresConfig, QuotaHandlersConfig
from ols.utils.config import AppConfig
from ols.utils.ssl import libpq_tls_params

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -158,16 +159,18 @@ def get_subject_id(limiter_type: str) -> str:
def connect(config: PostgresConfig) -> Any:
"""Initialize connection to database."""
logger.info("Initializing connection to quota limiter database")
connection = psycopg2.connect(
host=config.host,
port=config.port,
user=config.user,
password=config.password,
dbname=config.dbname,
sslmode=config.ssl_mode,
# sslrootcert=config.ca_cert_path,
gssencmode=config.gss_encmode,
)
connect_kwargs: dict[str, Any] = {
"host": config.host,
"port": config.port,
"user": config.user,
"password": config.password,
"dbname": config.dbname,
"sslmode": config.ssl_mode,
"sslrootcert": config.ca_cert_path,
"gssencmode": config.gss_encmode,
**libpq_tls_params(config.tls_security_profile),
}
connection = psycopg2.connect(**connect_kwargs)
if connection is not None:
connection.autocommit = True
return connection
Expand Down
23 changes: 13 additions & 10 deletions ols/utils/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import psycopg2

from ols.app.models.config import PostgresConfig
from ols.utils.ssl import libpq_tls_params

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,16 +58,18 @@ def connect(self) -> None:
logger.info("Establishing connection to Postgres")
self.connection = None
config = self.connection_config
self.connection = psycopg2.connect(
host=config.host,
port=config.port,
user=config.user,
password=config.password,
dbname=config.dbname,
sslmode=config.ssl_mode,
sslrootcert=config.ca_cert_path,
gssencmode=config.gss_encmode,
)
connect_kwargs: dict[str, Any] = {
"host": config.host,
"port": config.port,
"user": config.user,
"password": config.password,
"dbname": config.dbname,
"sslmode": config.ssl_mode,
"sslrootcert": config.ca_cert_path,
"gssencmode": config.gss_encmode,
**libpq_tls_params(config.tls_security_profile),
}
self.connection = psycopg2.connect(**connect_kwargs)
try:
cursor = self.connection.cursor()
cursor.execute("SET LOCAL lock_timeout = '60s'")
Expand Down
41 changes: 39 additions & 2 deletions ols/utils/ssl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Utility function for retrieving SSL version and list of ciphers for TLS security profile."""
"""Utility functions for TLS security profile enforcement."""

import logging
import ssl
from typing import Optional
from typing import Any, Optional

from ols import constants
from ols.app.models.config import TLSSecurityProfile
Expand All @@ -11,6 +11,43 @@
logger = logging.getLogger(__name__)


_LIBPQ_TLS_VERSION_MAP: dict[ssl.TLSVersion, str] = {
ssl.TLSVersion.TLSv1: "TLSv1",
ssl.TLSVersion.TLSv1_1: "TLSv1.1",
ssl.TLSVersion.TLSv1_2: "TLSv1.2",
ssl.TLSVersion.TLSv1_3: "TLSv1.3",
}


def libpq_tls_params(
sec_profile: Optional[TLSSecurityProfile],
) -> dict[str, Any]:
"""Return extra libpq connection kwargs enforcing the TLS security profile.

Maps the OpenShift TLS security profile to libpq's
``ssl_min_protocol_version`` parameter. Returns an empty dict when no
profile is configured so the caller can simply ``**``-merge it.

Cipher enforcement is not supported by libpq on the client side —
cipher negotiation is controlled by the PostgreSQL server's
``ssl_ciphers`` setting.
"""
if sec_profile is None or sec_profile.profile_type is None:
return {}

min_version = get_min_tls_version(sec_profile)
if min_version is None:
return {}

libpq_value = _LIBPQ_TLS_VERSION_MAP.get(min_version)
if libpq_value is None:
logger.warning("Unmapped TLS version %s, skipping enforcement", min_version)
return {}

logger.info("Enforcing Postgres ssl_min_protocol_version=%s", libpq_value)
return {"ssl_min_protocol_version": libpq_value}


def get_ssl_version(sec_profile: Optional[TLSSecurityProfile]) -> int:
"""Get SSL protocol constant for TLS context creation."""
logger.info("Using SSL protocol version: %s", ssl.PROTOCOL_TLS_SERVER)
Expand Down
43 changes: 39 additions & 4 deletions tests/unit/runners/test_quota_scheduler_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,56 @@ def test_connect():
"""Test the connection to Postgres."""
exception_message = "Exception during PostgreSQL storage."

# connection won't be established
config = PostgresConfig()

# don't connect to real PostgreSQL instance
with patch("psycopg2.connect") as mock_connect:
mock_connect.return_value.cursor.return_value.execute.side_effect = Exception(
exception_message
)
# try to connect to mocked Postgres
connection = connect(config)

# connection should not be established
assert connection is not None


def test_connect_passes_sslrootcert():
"""Verify connect passes sslrootcert to psycopg2.connect."""
config = PostgresConfig()
config.ca_cert_path = "/certs/ca.pem"

with patch("psycopg2.connect") as mock_connect:
connect(config)

kwargs = mock_connect.call_args
assert kwargs.kwargs.get("sslrootcert") == "/certs/ca.pem"


def test_connect_passes_ssl_min_protocol_version_when_profile_set():
"""Verify connect passes ssl_min_protocol_version when profile is on config."""
from ols.app.models.config import TLSSecurityProfile

config = PostgresConfig()
profile = TLSSecurityProfile()
profile.profile_type = "IntermediateType"
config.tls_security_profile = profile

with patch("psycopg2.connect") as mock_connect:
connect(config)

kwargs = mock_connect.call_args.kwargs
assert kwargs.get("ssl_min_protocol_version") == "TLSv1.2"


def test_connect_no_ssl_min_protocol_version_without_profile():
"""Verify connect has no ssl_min_protocol_version without a TLS profile."""
config = PostgresConfig()

with patch("psycopg2.connect") as mock_connect:
connect(config)

kwargs = mock_connect.call_args.kwargs
assert "ssl_min_protocol_version" not in kwargs


def test_get_subject_id():
"""Check the function to get subject ID based on quota limiter type."""
assert get_subject_id(constants.USER_QUOTA_LIMITER) == "u"
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/utils/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import psycopg2
import pytest

from ols.app.models.config import TLSSecurityProfile
from ols.utils.postgres import PostgresBase, connection


Expand Down Expand Up @@ -165,3 +166,44 @@ def test_connected_returns_false_on_interface_error(self):
psycopg2.InterfaceError("cannot reach server")
)
assert component.connected() is False


class TestPostgresBaseTlsProfile:
"""Tests for TLS security profile integration in PostgresBase."""

def _mock_config(self, profile: TLSSecurityProfile | None = None) -> MagicMock:
"""Return a MagicMock PostgresConfig with the given TLS profile."""
cfg = MagicMock()
cfg.ca_cert_path = None
cfg.tls_security_profile = profile
return cfg

def test_ssl_min_protocol_version_passed_when_profile_set(self):
"""Verify psycopg2.connect receives ssl_min_protocol_version."""
profile = TLSSecurityProfile()
profile.profile_type = "IntermediateType"

with patch("psycopg2.connect") as mock_connect:
FakeComponent(config=self._mock_config(profile))

kwargs = mock_connect.call_args.kwargs
assert kwargs.get("ssl_min_protocol_version") == "TLSv1.2"

def test_ssl_min_protocol_version_not_passed_when_no_profile(self):
"""Verify psycopg2.connect has no ssl_min_protocol_version without profile."""
with patch("psycopg2.connect") as mock_connect:
FakeComponent(config=self._mock_config())

kwargs = mock_connect.call_args.kwargs
assert "ssl_min_protocol_version" not in kwargs

def test_ssl_min_protocol_version_not_passed_when_profile_type_is_none(self):
"""Verify psycopg2.connect skips enforcement when profile_type is None."""
profile = TLSSecurityProfile()
profile.profile_type = None

with patch("psycopg2.connect") as mock_connect:
FakeComponent(config=self._mock_config(profile))

kwargs = mock_connect.call_args.kwargs
assert "ssl_min_protocol_version" not in kwargs
44 changes: 44 additions & 0 deletions tests/unit/utils/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import ssl as stdlib_ssl

import pytest
from psycopg2 import extensions

from ols import constants
from ols.app.models.config import TLSSecurityProfile
from ols.utils import ssl as ssl_utils
from ols.utils import tls


def test_postgres_ssl_mode_default_is_require():
"""Verify the default Postgres SSL mode is 'require', not 'prefer'."""
assert constants.POSTGRES_CACHE_SSL_MODE == "require"


def test_get_ssl_version_returns_protocol_constant():
"""Check the function to get SSL version."""
assert ssl_utils.get_ssl_version(None) == constants.DEFAULT_SSL_VERSION
Expand Down Expand Up @@ -71,3 +77,41 @@ def test_get_ciphers_with_proper_security_profile(tls_profile_name):
allowed_ciphers = ssl_utils.get_ciphers(security_profile)
assert allowed_ciphers is not None
assert allowed_ciphers == tls.ciphers_for_tls_profile(tls_profile_name)


class TestLibpqTlsParams:
"""Tests for the libpq_tls_params helper."""

def test_returns_empty_when_profile_is_none(self):
"""Return empty dict when no TLS security profile is provided."""
assert ssl_utils.libpq_tls_params(None) == {}

def test_returns_empty_when_profile_type_is_none(self):
"""Return empty dict when profile exists but profile_type is unset."""
profile = TLSSecurityProfile()
profile.profile_type = None
assert ssl_utils.libpq_tls_params(profile) == {}

@pytest.mark.parametrize(
("profile_type", "expected_version"),
[
("IntermediateType", "TLSv1.2"),
("ModernType", "TLSv1.3"),
],
)
def test_maps_profile_to_libpq_version(self, profile_type, expected_version):
"""Verify the profile maps to the correct libpq version string."""
profile = TLSSecurityProfile()
profile.profile_type = profile_type
params = ssl_utils.libpq_tls_params(profile)
assert params == {"ssl_min_protocol_version": expected_version}

def test_result_can_be_merged_into_connect_kwargs(self):
"""Verify the returned dict produces a valid libpq DSN."""
profile = TLSSecurityProfile()
profile.profile_type = "IntermediateType"
params = ssl_utils.libpq_tls_params(profile)
dsn = extensions.make_dsn(
host="127.0.0.1", dbname="test", sslmode="require", **params
)
assert "ssl_min_protocol_version=TLSv1.2" in dsn