Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/dstack/_internal/server/routers/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ async def create_or_update_secret(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
):
_, project = user_project
user, project = user_project
return CustomORJSONResponse(
await secrets_services.create_or_update_secret(
session=session,
project=project,
name=body.name,
value=body.value,
user=user,
)
)

Expand All @@ -76,9 +77,10 @@ async def delete_secrets(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
):
_, project = user_project
user, project = user_project
await secrets_services.delete_secrets(
session=session,
project=project,
names=body.secrets_names,
user=user,
)
127 changes: 84 additions & 43 deletions src/dstack/_internal/server/services/secrets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import re
import uuid
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Dict, List, Optional

import sqlalchemy.exc
from sqlalchemy import delete, select, update
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.core.errors import (
Expand All @@ -11,11 +14,10 @@
ServerClientError,
)
from dstack._internal.core.models.secrets import Secret
from dstack._internal.server.models import DecryptedString, ProjectModel, SecretModel
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)

from dstack._internal.server.db import get_db
from dstack._internal.server.models import DecryptedString, ProjectModel, SecretModel, UserModel
from dstack._internal.server.services import events
from dstack._internal.server.services.locking import get_locker

_SECRET_NAME_REGEX = "^[A-Za-z0-9-_]{1,200}$"
_SECRET_VALUE_MAX_LENGTH = 5000
Expand Down Expand Up @@ -57,6 +59,7 @@ async def create_or_update_secret(
project: ProjectModel,
name: str,
value: str,
user: UserModel,
) -> Secret:
_validate_secret(name=name, value=value)
try:
Expand All @@ -65,13 +68,15 @@ async def create_or_update_secret(
project=project,
name=name,
value=value,
user=user,
)
except ResourceExistsError:
secret_model = await update_secret(
session=session,
project=project,
name=name,
value=value,
user=user,
)
return secret_model_to_secret(secret_model, include_value=True)

Expand All @@ -80,26 +85,24 @@ async def delete_secrets(
session: AsyncSession,
project: ProjectModel,
names: List[str],
user: UserModel,
):
existing_secrets_query = await session.execute(
select(SecretModel).where(
SecretModel.project_id == project.id,
SecretModel.name.in_(names),
)
)
existing_names = [s.name for s in existing_secrets_query.scalars().all()]
missing_names = set(names) - set(existing_names)
if missing_names:
raise ResourceNotExistsError(f"Secrets not found: {', '.join(missing_names)}")

await session.execute(
delete(SecretModel).where(
SecretModel.project_id == project.id,
SecretModel.name.in_(names),
)
)
await session.commit()
logger.info("Deleted secrets %s in project %s", names, project.name)
async with get_project_secret_models_by_name_for_update(
session=session, project=project, names=names
) as secret_models:
existing_names = [s.name for s in secret_models]
missing_names = set(names) - set(existing_names)
if missing_names:
raise ResourceNotExistsError(f"Secrets not found: {', '.join(missing_names)}")
for secret_model in secret_models:
await session.delete(secret_model)
events.emit(
session,
"Secret deleted",
actor=events.UserActor.from_user(user),
targets=[events.Target.from_model(secret_model)],
)
await session.commit()


def secret_model_to_secret(secret_model: SecretModel, include_value: bool = False) -> Secret:
Expand Down Expand Up @@ -142,20 +145,60 @@ async def get_project_secret_model_by_name(
return res.scalar_one_or_none()


@asynccontextmanager
async def get_project_secret_models_by_name_for_update(
session: AsyncSession, project: ProjectModel, names: list[str]
) -> AsyncGenerator[list[SecretModel], None]:
"""
Fetch secrets from the database and lock them for update.

**NOTE**: commit changes to the database before exiting from this context manager,
so that in-memory locks are only released after commit.
"""
filters = [
SecretModel.project_id == project.id,
SecretModel.name.in_(names),
]
res = await session.execute(select(SecretModel.id).where(*filters))
secret_ids = res.scalars().all()
if not secret_ids:
yield []
else:
async with get_locker(get_db().dialect_name).lock_ctx(
SecretModel.__tablename__, sorted(secret_ids)
):
# Refetch after lock
res = await session.execute(
select(SecretModel)
.where(SecretModel.id.in_(secret_ids), *filters)
.with_for_update(key_share=True)
.order_by(SecretModel.id) # take locks in order
)
yield list(res.scalars().all())


async def create_secret(
session: AsyncSession,
project: ProjectModel,
name: str,
value: str,
user: UserModel,
) -> SecretModel:
secret_model = SecretModel(
id=uuid.uuid4(),
project_id=project.id,
name=name,
value=DecryptedString(plaintext=value),
)
try:
async with session.begin_nested():
session.add(secret_model)
events.emit(
session,
"Secret created",
actor=events.UserActor.from_user(user),
targets=[events.Target.from_model(secret_model)],
)
except sqlalchemy.exc.IntegrityError:
raise ResourceExistsError()
await session.commit()
Expand All @@ -167,25 +210,23 @@ async def update_secret(
project: ProjectModel,
name: str,
value: str,
user: UserModel,
) -> SecretModel:
await session.execute(
update(SecretModel)
.where(
SecretModel.project_id == project.id,
SecretModel.name == name,
)
.values(
value=DecryptedString(plaintext=value),
)
)
await session.commit()
secret_model = await get_project_secret_model_by_name(
session=session,
project=project,
name=name,
)
if secret_model is None:
raise ResourceNotExistsError()
async with get_project_secret_models_by_name_for_update(
session=session, project=project, names=[name]
) as secret_models:
if not secret_models:
raise ResourceNotExistsError()
secret_model = secret_models[0]
if secret_model.value.get_plaintext_or_error() != value:
secret_model.value = DecryptedString(plaintext=value)
events.emit(
session,
"Secret updated",
actor=events.UserActor.from_user(user),
targets=[events.Target.from_model(secret_model)],
)
await session.commit()
return secret_model


Expand Down
32 changes: 32 additions & 0 deletions src/tests/_internal/server/routers/test_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
create_secret,
create_user,
get_auth_headers,
list_events,
)


Expand Down Expand Up @@ -145,6 +146,9 @@ async def test_creates_secret(self, test_db, session: AsyncSession, client: Asyn
res = await session.execute(select(SecretModel))
secret_model = res.scalar()
assert secret_model is not None
events = await list_events(session)
assert len(events) == 1
assert events[0].message == "Secret created"

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
Expand All @@ -165,6 +169,29 @@ async def test_updates_secret(self, test_db, session: AsyncSession, client: Asyn
assert response.status_code == 200
await session.refresh(secret)
assert secret.value.get_plaintext_or_error() == "new_value"
events = await list_events(session)
assert len(events) == 1
assert events[0].message == "Secret updated"

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_no_event_if_value_unchanged(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.ADMIN
)
await create_secret(session=session, project=project, name="secret1", value="value")
response = await client.post(
f"/api/project/{project.name}/secrets/create_or_update",
headers=get_auth_headers(user.token),
json={"name": "secret1", "value": "value"},
)
assert response.status_code == 200
events = await list_events(session)
assert len(events) == 0

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
Expand Down Expand Up @@ -253,6 +280,11 @@ async def test_deletes_secrets(self, test_db, session: AsyncSession, client: Asy
assert len(secrets) == 1
assert secrets[0].name == "secret2"

# Verify event was emitted
events = await list_events(session)
assert len(events) == 1
assert events[0].message == "Secret deleted"

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_delete_nonexistent_secret_raises_error(
Expand Down