Skip to content

Commit

Permalink
[DPE-6296] Pyright fixes + structured_config additions + break down o…
Browse files Browse the repository at this point in the history
…f actions.py (#13)
  • Loading branch information
phvalguima authored Jan 22, 2025
1 parent 14d4daa commit 69f7982
Show file tree
Hide file tree
Showing 21 changed files with 756 additions and 591 deletions.
207 changes: 104 additions & 103 deletions poetry.lock

Large diffs are not rendered by default.

167 changes: 19 additions & 148 deletions src/benchmark/base_charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,26 @@

import logging
import subprocess
from abc import ABC, abstractmethod
from typing import Any

import ops
from charms.data_platform_libs.v0.data_models import TypedCharmBase
from charms.grafana_agent.v0.cos_agent import COSAgentProvider
from ops.charm import CharmEvents
from ops.framework import EventBase, EventSource
from ops.model import BlockedStatus

from benchmark.core.models import DPBenchmarkLifecycleState
from benchmark.core.pebble_workload_base import DPBenchmarkPebbleWorkloadBase
from benchmark.core.structured_config import BenchmarkCharmConfig
from benchmark.core.systemd_workload_base import DPBenchmarkSystemdWorkloadBase
from benchmark.core.workload_base import WorkloadBase
from benchmark.events.actions import ActionsHandler
from benchmark.events.db import DatabaseRelationHandler
from benchmark.events.peer import PeerRelationHandler
from benchmark.literals import (
COS_AGENT_RELATION,
METRICS_PORT,
PEER_RELATION,
DPBenchmarkLifecycleTransition,
DPBenchmarkMissingOptionsError,
)
from benchmark.managers.config import ConfigManager
Expand Down Expand Up @@ -70,34 +70,22 @@ def workload_build(workload_params_template: str) -> WorkloadBase:
return DPBenchmarkSystemdWorkloadBase(workload_params_template)


class DPBenchmarkCharmBase(ops.CharmBase, ABC):
class DPBenchmarkCharmBase(TypedCharmBase[BenchmarkCharmConfig]):
"""The base benchmark class."""

on = DPBenchmarkEvents() # pyright: ignore [reportGeneralTypeIssues]
on = DPBenchmarkEvents() # pyright: ignore [reportAssignmentType]

RESOURCE_DEB_NAME = "benchmark-deb"
workload_params_template = ""

config_type = BenchmarkCharmConfig

def __init__(self, *args, db_relation_name: str, workload: WorkloadBase | None = None):
super().__init__(*args)
self.framework.observe(self.on.install, self._on_install)
self.framework.observe(self.on.config_changed, self._on_config_changed)
self.framework.observe(self.on.update_status, self._on_update_status)

self.framework.observe(self.on.prepare_action, self.on_prepare_action)
self.framework.observe(self.on.run_action, self.on_run_action)
self.framework.observe(self.on.stop_action, self.on_stop_action)
self.framework.observe(self.on.cleanup_action, self.on_clean_action)

self.framework.observe(
self.on.check_upload,
self._on_check_upload,
)
self.framework.observe(
self.on.check_collect,
self._on_check_collect,
)

self.database = DatabaseRelationHandler(self, db_relation_name)
self.peers = PeerRelationHandler(self, PEER_RELATION)
self.framework.observe(self.database.on.db_config_update, self._on_config_changed)
Expand All @@ -119,8 +107,8 @@ def __init__(self, *args, db_relation_name: str, workload: WorkloadBase | None =

self.config_manager = ConfigManager(
workload=self.workload,
database=self.database.state,
peer=self.peers.peers(),
database_state=self.database.state,
peers=self.peers.peers(),
config=self.config,
labels=self.labels,
)
Expand All @@ -129,11 +117,7 @@ def __init__(self, *args, db_relation_name: str, workload: WorkloadBase | None =
self.peers.this_unit(),
self.config_manager,
)

@abstractmethod
def supported_workloads(self) -> list[str]:
"""List of supported workloads."""
...
self.actions = ActionsHandler(self)

###########################################################################
#
Expand All @@ -146,28 +130,6 @@ def _on_install(self, event: EventBase) -> None:
self.workload.install()
self.peers.state.lifecycle = DPBenchmarkLifecycleState.UNSET

def _on_check_collect(self, event: EventBase) -> None:
"""Check if the upload is finished."""
if self.config_manager.is_collecting():
# Nothing to do, upload is still in progress
event.defer()
return

if self.unit.is_leader():
self.peers.state.set(DPBenchmarkLifecycleState.UPLOADING)
# Raise we are running an upload and we will check the status later
self.on.check_upload.emit()
return
self.peers.state.set(DPBenchmarkLifecycleState.FINISHED)

def _on_check_upload(self, event: EventBase) -> None:
"""Check if the upload is finished."""
if self.config_manager.is_uploading():
# Nothing to do, upload is still in progress
event.defer()
return
self.peers.state.lifecycle = DPBenchmarkLifecycleState.FINISHED

def _on_update_status(self, event: EventBase | None = None) -> None:
"""Set status for the operator and finishes the service.
Expand All @@ -176,34 +138,20 @@ def _on_update_status(self, event: EventBase | None = None) -> None:
benchmark service and the benchmark status.
"""
try:
status = self.database.state.get()
status = self.database.state.model()
except DPBenchmarkMissingOptionsError as e:
self.unit.status = BlockedStatus(str(e))
return
if not status:
self.unit.status = BlockedStatus("No database relation available")
return

# We need to narrow the options of workload_name to the supported ones
if self.config.get("workload_name") not in self.supported_workloads():
self.unit.status = BlockedStatus(
f"Unsupported workload: {self.config.get('workload_name')}"
)
return

# Now, let's check if we need to update our lifecycle position
self._update_state()
self.update_state()
self.unit.status = self.lifecycle.status

def _on_config_changed(self, event: EventBase) -> None:
"""Config changed event."""
# We need to narrow the options of workload_name to the supported ones
if self.config.get("workload_name") not in self.supported_workloads():
self.unit.status = BlockedStatus(
f"Unsupported workload: {self.config.get('workload_name')}"
)
return

if not self.config_manager.is_prepared():
# nothing to do: set the status and leave
self._on_update_status()
Expand All @@ -228,88 +176,6 @@ def scrape_config(self) -> list[dict[str, Any]]:
}
]

###########################################################################
#
# Action and Lifecycle Handlers
#
###########################################################################

def _preflight_checks(self) -> bool:
"""Check if we have the necessary relations."""
if len(self.peers.units()) > 0 and not bool(self.peers.state.get()):
return False
try:
return bool(self.database.state.get())
except DPBenchmarkMissingOptionsError:
return False

def on_prepare_action(self, event: EventBase) -> None:
"""Process the prepare action."""
if not self._preflight_checks():
event.fail("Missing DB or S3 relations")
return

if not (state := self.lifecycle.next(DPBenchmarkLifecycleTransition.PREPARE)):
event.fail("Failed to prepare the benchmark: already done")
return

if state != DPBenchmarkLifecycleState.PREPARING:
event.fail(
"Another peer is already in prepare state. Wait or call clean action to reset."
)
return

# We process the special case of PREPARE, as explained in lifecycle.make_transition()
if not self.config_manager.prepare():
event.fail("Failed to prepare the benchmark")
return

self.lifecycle.make_transition(state)
self.unit.status = self.lifecycle.status
event.set_results({"message": "Benchmark is being prepared"})

def on_run_action(self, event: EventBase) -> None:
"""Process the run action."""
if not self._preflight_checks():
event.fail("Missing DB or S3 relations")
return

if not self._process_action_transition(DPBenchmarkLifecycleTransition.RUN):
event.fail("Failed to run the benchmark")
event.set_results({"message": "Benchmark has started"})

def on_stop_action(self, event: EventBase) -> None:
"""Process the stop action."""
if not self._preflight_checks():
event.fail("Missing DB or S3 relations")
return

if not self._process_action_transition(DPBenchmarkLifecycleTransition.STOP):
event.fail("Failed to stop the benchmark")
event.set_results({"message": "Benchmark has stopped"})

def on_clean_action(self, event: EventBase) -> None:
"""Process the clean action."""
if not self._preflight_checks():
event.fail("Missing DB or S3 relations")
return

if not self._process_action_transition(DPBenchmarkLifecycleTransition.CLEAN):
event.fail("Failed to clean the benchmark")
event.set_results({"message": "Benchmark is cleaning"})

def _process_action_transition(self, transition: DPBenchmarkLifecycleTransition) -> bool:
"""Process the action."""
# First, check if we have an update in our lifecycle state
self._update_state()

if not (state := self.lifecycle.next(transition)):
return False

self.lifecycle.make_transition(state)
self.unit.status = self.lifecycle.status
return True

###########################################################################
#
# Helpers
Expand All @@ -318,9 +184,14 @@ def _process_action_transition(self, transition: DPBenchmarkLifecycleTransition)

def _unit_ip(self) -> str:
"""Current unit ip."""
return self.model.get_binding(PEER_RELATION).network.bind_address
bind_address = None
if PEER_RELATION:
if binding := self.model.get_binding(PEER_RELATION):
bind_address = binding.network.bind_address

return str(bind_address) if bind_address else ""

def _update_state(self) -> None:
def update_state(self) -> None:
"""Update the state of the charm."""
if (next_state := self.lifecycle.next(None)) and self.lifecycle.current() != next_state:
self.lifecycle.make_transition(next_state)
32 changes: 9 additions & 23 deletions src/benchmark/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
"""

import logging
from typing import Any, Optional
from typing import Any, MutableMapping, Optional

from ops.model import Application, Relation, Unit
from overrides import override
from pydantic import BaseModel, error_wrappers, root_validator

from benchmark.literals import (
LIFECYCLE_KEY,
STOP_KEY,
DPBenchmarkLifecycleState,
DPBenchmarkMissingOptionsError,
Scope,
Expand Down Expand Up @@ -106,7 +105,6 @@ class DPBenchmarkWrapperOptionsModel(BaseModel):
workload_name: str
db_info: DPBenchmarkBaseDatabaseModel
report_interval: int
workload_profile: str
labels: str
peers: str | None = None

Expand All @@ -125,20 +123,18 @@ def __init__(
self.scope = scope

@property
def relation_data(self) -> dict[str, str]:
def relation_data(self) -> MutableMapping[str, str]:
"""Returns the relation data."""
if self.relation:
return self.relation.data[self.component]
return {}

@property
def remote_data(self) -> dict[str, str]:
def remote_data(self) -> MutableMapping[str, str]:
"""Returns the remote relation data."""
if not self.relation:
if not self.relation or self.scope != Scope.APP:
return {}
if self.scope == Scope.APP:
return self.relation.data[self.relation.app]
return self.relation.data[self.relation.unit]
return self.relation.data[self.relation.app]

def __bool__(self) -> bool:
"""Boolean evaluation based on the existence of self.relation."""
Expand Down Expand Up @@ -191,16 +187,6 @@ def lifecycle(self, status: DPBenchmarkLifecycleState | str) -> None:
else:
self.set({LIFECYCLE_KEY: status})

@property
def stop(self) -> bool:
"""Returns the value of the stop key."""
return self.relation_data.get(STOP_KEY, False)

@stop.setter
def stop(self, switch: bool) -> bool:
"""Toggles the stop key value."""
self.set({STOP_KEY: switch})


class DatabaseState(RelationState):
"""State collection for the database relation."""
Expand Down Expand Up @@ -236,7 +222,7 @@ def tls_ca(self) -> str | None:
return None
return tls_ca

def get(self) -> DPBenchmarkBaseDatabaseModel | None:
def model(self) -> DPBenchmarkBaseDatabaseModel | None:
"""Returns the value of the key."""
if not self.relation or not (endpoints := self.remote_data.get("endpoints")):
return None
Expand All @@ -248,9 +234,9 @@ def get(self) -> DPBenchmarkBaseDatabaseModel | None:
return DPBenchmarkBaseDatabaseModel(
hosts=endpoints.split(),
unix_socket=unix_socket,
username=self.data.get("username"),
password=self.data.get("password"),
db_name=self.remote_data.get(self.database_key),
username=self.data.get("username", ""),
password=self.data.get("password", ""),
db_name=self.remote_data.get(self.database_key, ""),
tls=self.tls,
tls_ca=self.tls_ca,
)
Expand Down
12 changes: 7 additions & 5 deletions src/benchmark/core/pebble_workload_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,9 @@
class DPBenchmarkPebbleTemplatePaths(WorkloadTemplatePaths):
"""Represents the benchmark service template paths."""

def __init__(self):
super().__init__()
self.svc_name = "dpe_benchmark"

@property
@override
def service(self) -> str | None:
def service(self) -> str:
"""The optional path to the service file managing the script."""
return f"/etc/systemd/system/{self.svc_name}.service"

Expand All @@ -44,6 +40,12 @@ def templates(self) -> str:
"""The path to the workload template folder."""
return os.path.join(os.environ.get("CHARM_DIR", ""), "templates")

@property
@override
def results(self) -> str:
"""The path to the results folder."""
return "/root/.benchmark/charmed_parameters/results/"

@property
@override
def service_template(self) -> str:
Expand Down
Loading

0 comments on commit 69f7982

Please sign in to comment.