Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/critic/libs/ddb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
import os
Expand Down Expand Up @@ -78,6 +80,15 @@ def __call__(self, data: dict) -> dict:
deserialize = Deserializer()


@dataclass
class CascadeRelationship:
child_table: type['Table']
# Given the parent's partition and sort keys, return the child's partition key for querying
get_child_query_key: Callable[[Any, Any | None], Any]
# Given the child, return the child's partition and sort key for deletion
get_child_delete_keys: Callable[[BaseModel], tuple[Any, Any | None]]


class Table:
base_name: str
model: type[BaseModel]
Expand Down Expand Up @@ -220,8 +231,17 @@ def update(
raise
return True

@classmethod
def cascade_relationships(cls) -> list[CascadeRelationship]:
return []

@classmethod
def delete(cls, partition_value: Any, sort_value: Any | None = None):
for rel in cls.cascade_relationships():
child_partition_key = rel.get_child_query_key(partition_value, sort_value)
for child in rel.child_table.query(child_partition_key):
rel.child_table.delete(*rel.get_child_delete_keys(child))

get_client().delete_item(
TableName=cls.name(),
Key=cls.key(partition_value, sort_value),
Expand Down
16 changes: 14 additions & 2 deletions src/critic/libs/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from pydantic import BaseModel

from critic.libs.ddb import Table, get_client
from critic.models import UptimeMonitorModel
from critic.tables import UptimeMonitorTable
from critic.models import ProjectModel, UptimeLogModel, UptimeMonitorModel
from critic.tables import ProjectTable, UptimeLogTable, UptimeMonitorTable


def create_tables():
Expand Down Expand Up @@ -112,7 +112,19 @@ def put(cls, **kwargs) -> BaseModel:
return item


class ProjectFactory(PutMixin, ModelFactory):
__model__ = ProjectModel
__table__ = ProjectTable
__use_defaults__ = True


class UptimeMonitorFactory(PutMixin, ModelFactory):
__model__ = UptimeMonitorModel
__table__ = UptimeMonitorTable
__use_defaults__ = True


class UptimeLogFactory(PutMixin, ModelFactory):
__model__ = UptimeLogModel
__table__ = UptimeLogTable
__use_defaults__ = True
6 changes: 3 additions & 3 deletions src/critic/libs/uptime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import httpx

from critic.libs.dt import round_minute
from critic.models import MonitorState, UptimeLog, UptimeMonitorModel
from critic.models import MonitorState, UptimeLogModel, UptimeMonitorModel
from critic.tables import UptimeLogTable, UptimeMonitorTable


Expand Down Expand Up @@ -112,8 +112,8 @@ def put_log(self, state: MonitorState, status_code: int, latency: float):
"""
if self._put_log:
raise Exception('Log already put! Do not call this method more than once in one run.')
uptime_log = UptimeLog(
monitor_id=f'{self.monitor.project_id}/{self.monitor.slug}',
uptime_log = UptimeLogModel(
monitor_id=self.monitor.id,
timestamp=self.now,
status=state,
resp_code=status_code,
Expand Down
18 changes: 13 additions & 5 deletions src/critic/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MonitorState(str, Enum):
paused = 'paused'


class Project(BaseModel):
class ProjectModel(BaseModel):
id: UUID
name: str

Expand Down Expand Up @@ -50,18 +50,26 @@ def validate_next_due_at(cls, v: datetime) -> datetime:
raise ValueError('next_due_at must be no more precise than minutes')
return to_utc(v)

@property
def id(self) -> str:
return UptimeLogModel.monitor_id_from_parts(self.project_id, self.slug)

class UptimeLog(BaseModel):

class UptimeLogModel(BaseModel):
monitor_id: str = Field(
# Project ID and monitor slug, separated by a slash
# pattern = UUID / slug
pattern=r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/[a-z0-9]+(?:-[a-z0-9]+)*$'
)
timestamp: AwareDatetime
status: MonitorState
resp_code: int | None
latency_secs: float | None
resp_code: int
latency_secs: float

@staticmethod
def monitor_id_from_parts(project_id: UUID | str, slug: str) -> str:
return f'{project_id}/{slug}'


class ProjectMonitors(BaseModel):
class ProjectMonitorsModel(BaseModel):
uptime: list[UptimeMonitorModel] = Field(default_factory=list)
36 changes: 32 additions & 4 deletions src/critic/tables.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,50 @@
from datetime import datetime

from critic.libs.ddb import CONSTANT_GSI_PK, Table, deserialize, get_client, serialize
from critic.libs.ddb import (
CONSTANT_GSI_PK,
CascadeRelationship,
Table,
deserialize,
get_client,
serialize,
)

from .models import Project, UptimeLog, UptimeMonitorModel
from .models import ProjectModel, UptimeLogModel, UptimeMonitorModel


class ProjectTable(Table):
base_name = 'Project'
model = Project
model = ProjectModel
partition_key = 'id'

@classmethod
def cascade_relationships(cls) -> list[CascadeRelationship]:
return [
CascadeRelationship(
UptimeMonitorTable,
lambda pk, _sk: pk,
lambda m: (m.project_id, m.slug),
)
]


class UptimeMonitorTable(Table):
base_name = 'UptimeMonitor'
model = UptimeMonitorModel
partition_key = 'project_id'
sort_key = 'slug'

@classmethod
def cascade_relationships(cls) -> list[CascadeRelationship]:
return [
CascadeRelationship(
UptimeLogTable,
# TODO: have a universal function for this
lambda pk, sk: UptimeLogModel.monitor_id_from_parts(pk, sk),
lambda log: (log.monitor_id, log.timestamp),
)
]

@classmethod
def get_due_since(cls, timestamp: datetime) -> list[UptimeMonitorModel]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking into this a bit, it appears that DDB will only return 1 MB of data on a response, which could lead to loss of information. If there are multiple pages of data, we will need to use a 'LastEvaluatedKey' provided by DDB if it needs to create a next page for the query. Idk how much data we are expecting to be returned, so this may not be a major issue. We could probably loop over it I would imagine.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. I'm going to create a separate issue for this since the problem spans more than this PR.

response = get_client().query(
Expand All @@ -35,6 +63,6 @@ def get_due_since(cls, timestamp: datetime) -> list[UptimeMonitorModel]:

class UptimeLogTable(Table):
base_name = 'UptimeLog'
model = UptimeLog
model = UptimeLogModel
partition_key = 'monitor_id'
sort_key = 'timestamp'
80 changes: 43 additions & 37 deletions tests/critic_tests/test_libs/test_ddb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,27 @@
from botocore.exceptions import ClientError
import pytest

from critic.models import UptimeMonitorModel
from critic.tables import UptimeMonitorTable
from critic.libs.testing import ProjectFactory, UptimeLogFactory, UptimeMonitorFactory
from critic.models import ProjectModel, UptimeLogModel, UptimeMonitorModel
from critic.tables import ProjectTable, UptimeLogTable, UptimeMonitorTable


class TestDDB:
class TestTable:
@pytest.mark.integration
def test_integration(self):
# Pretend we've received data via the API
IN_DATA = {
'project_id': '6033aa47-a9f7-4d7f-b7ff-a11ba9b34474',
'slug': 'my-monitor',
'url': 'https://example.com/health',
'frequency_mins': 5,
'consecutive_fails': 0,
'next_due_at': '2025-11-10T20:35:00Z',
'timeout_secs': 30,
'assertions': {'status_code': 200, 'body_contains': 'OK'},
'failures_before_alerting': 2,
'alert_slack_channels': ['#ops'],
'alert_emails': ['alerts@example.com'],
'realert_interval_mins': 60,
}

UptimeMonitorTable.put(IN_DATA)
UptimeMonitorFactory.put(
project_id='6033aa47-a9f7-4d7f-b7ff-a11ba9b34474',
slug='my-monitor',
url='https://example.com/health',
)
out_data = UptimeMonitorTable.get('6033aa47-a9f7-4d7f-b7ff-a11ba9b34474', 'my-monitor')

# Check one of the values to make sure it's what we expect
assert str(out_data.url) == 'https://example.com/health'

@pytest.mark.parametrize('input_as_model', [True, False])
def test_unit(self, input_as_model):
# Pretend we've received data via the API
# Sometimes we may want to pass input as a dict (not a model). Make sure we handle that.
in_data = {
'project_id': '6033aa47-a9f7-4d7f-b7ff-a11ba9b34474',
'slug': 'my-monitor',
Expand Down Expand Up @@ -70,22 +59,11 @@ def test_missing_sort_key(self):
UptimeMonitorTable.get('6033aa47-a9f7-4d7f-b7ff-a11ba9b34474')

def test_query_from_monitor_table(self):
in_data = {
'project_id': '6033aa47-a9f7-4d7f-b7ff-a11ba9b34474',
'slug': 'my-monitor',
'url': 'https://example.com/health',
'frequency_mins': 5,
'consecutive_fails': 0,
'next_due_at': '2025-11-10T20:35:00Z',
'timeout_secs': 30,
'assertions': {'status_code': 200, 'body_contains': 'OK'},
'failures_before_alerting': 2,
'alert_slack_channels': ['#ops'],
'alert_emails': ['alerts@example.com'],
'realert_interval_mins': 60,
}
in_data = UptimeMonitorModel(**in_data)
UptimeMonitorTable.put(in_data)
UptimeMonitorFactory.put(
project_id='6033aa47-a9f7-4d7f-b7ff-a11ba9b34474',
slug='my-monitor',
url='https://example.com/health',
)
out_data = UptimeMonitorTable.query('6033aa47-a9f7-4d7f-b7ff-a11ba9b34474')
assert len(out_data) == 1
assert str(out_data[0].url) == 'https://example.com/health'
Expand All @@ -109,3 +87,31 @@ def test_update_error_not_conditional(self, m_get_client):
'6033aa47-a9f7-4d7f-b7ff-a11ba9b34474', 'my-monitor', {'a': 1}
)
assert excinfo.value == error

def test_cascade_delete(self):
# Happy path: deleting a project should delete all its monitors and logs
del_proj: ProjectModel = ProjectFactory.put()

del_mon: UptimeMonitorModel = UptimeMonitorFactory.put(project_id=del_proj.id)
UptimeLogFactory.put(monitor_id=del_mon.id)
UptimeLogFactory.put(monitor_id=del_mon.id)

UptimeMonitorFactory.put(project_id=del_proj.id)

# Sad path: These should all be left untouched
keep_proj: ProjectModel = ProjectFactory.put()
keep_mon: UptimeMonitorModel = UptimeMonitorFactory.put(project_id=keep_proj.id)
keep_log: UptimeLogModel = UptimeLogFactory.put(monitor_id=keep_mon.id)

# Delete the project
ProjectTable.delete(del_proj.id)

# Check the happy path (everything related to the deleted project should be gone)
assert ProjectTable.get(del_proj.id) is None
assert UptimeMonitorTable.query(del_mon.project_id) == []
assert UptimeLogTable.query(del_mon.id) == []

# Check the sad path (everything not related to the deleted project should be untouched)
assert ProjectTable.get(keep_proj.id) == keep_proj
assert UptimeMonitorTable.get(keep_mon.project_id, keep_mon.slug) == keep_mon
assert UptimeLogTable.get(keep_log.monitor_id, keep_log.timestamp) == keep_log
16 changes: 6 additions & 10 deletions tests/critic_tests/test_libs/test_uptime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from critic.libs.testing import UptimeMonitorFactory
from critic.libs.uptime import MonitorNotFoundError, UptimeCheck
from critic.models import MonitorState, UptimeLog, UptimeMonitorModel
from critic.models import MonitorState, UptimeLogModel, UptimeMonitorModel
from critic.tables import UptimeLogTable, UptimeMonitorTable


Expand Down Expand Up @@ -50,7 +50,7 @@ def test_race_condition(self, httpx_mock):
assert monitor.next_due_at == datetime(2026, 2, 10, 12, 0, 0, tzinfo=UTC)

# No log is created
logs = UptimeLogTable.query(f'{monitor.project_id}/{monitor.slug}')
logs = UptimeLogTable.query(monitor.id)
assert len(logs) == 0

def test_run_up(self, caplog, httpx_mock):
Expand All @@ -74,8 +74,7 @@ def test_run_up(self, caplog, httpx_mock):
assert response.next_due_at > time_to_check
assert response.consecutive_fails == 0

monitor_id = f'{monitor.project_id}/{monitor.slug}'
response: UptimeLog = UptimeLogTable.query(monitor_id)[-1]
response: UptimeLogModel = UptimeLogTable.query(monitor.id)[-1]

# check logging stuff
assert response.status == MonitorState.up
Expand All @@ -97,8 +96,7 @@ def test_down_with_consec_fails_above_threshold(self, httpx_mock):
assert response.state == MonitorState.down
assert response.consecutive_fails == 2

monitor_id = f'{monitor.project_id}/{monitor.slug}'
response: UptimeLog = UptimeLogTable.query(monitor_id)[-1]
response: UptimeLogModel = UptimeLogTable.query(monitor.id)[-1]
# log should have resp of 0 since there was a timeout
assert response.status == MonitorState.down
assert response.resp_code == 0
Expand All @@ -118,8 +116,7 @@ def test_down_with_consec_fails_below_threshold(self, httpx_mock):
assert response.state == MonitorState.down
assert response.consecutive_fails == 1

monitor_id = f'{monitor.project_id}/{monitor.slug}'
response: UptimeLog = UptimeLogTable.query(monitor_id)[-1]
response: UptimeLogModel = UptimeLogTable.query(monitor.id)[-1]
# log should have resp of 0 since there was a timeout
assert response.status == MonitorState.down
assert response.resp_code == 0
Expand All @@ -137,8 +134,7 @@ def test_paused(self):

response: UptimeMonitorModel = UptimeMonitorTable.get(monitor.project_id, monitor.slug)
assert response.next_due_at > time_to_check
monitor_id = f'{monitor.project_id}/{monitor.slug}'
response: UptimeLog = UptimeLogTable.query(monitor_id)
response: UptimeLogModel = UptimeLogTable.query(monitor.id)
# does not have item because no log is created since the monitor is paused
assert response == []

Expand Down
Loading