From 5ea693ec6ec486652642fff8a294d20285ebf8e1 Mon Sep 17 00:00:00 2001 From: CamDavidsonPilon Date: Wed, 12 Jun 2024 20:18:20 -0400 Subject: [PATCH] adding a when action --- .../actions/leader/experiment_profile.py | 115 ++++++++++--- .../experiment_profiles/profile_struct.py | 11 +- .../tests/test_execute_experiment_profile.py | 156 ++++++++++++++++++ 3 files changed, 259 insertions(+), 23 deletions(-) diff --git a/pioreactor/actions/leader/experiment_profile.py b/pioreactor/actions/leader/experiment_profile.py index 3281e01e..a4ec5c87 100644 --- a/pioreactor/actions/leader/experiment_profile.py +++ b/pioreactor/actions/leader/experiment_profile.py @@ -16,23 +16,25 @@ from pioreactor.experiment_profiles import profile_struct as struct from pioreactor.logging import create_logger from pioreactor.logging import CustomLogger -from pioreactor.pubsub import publish +from pioreactor.pubsub import Client from pioreactor.pubsub import put_into_leader from pioreactor.utils import ClusterJobManager from pioreactor.utils import managed_lifecycle from pioreactor.utils.timing import current_utc_timestamp from pioreactor.whoami import get_assigned_experiment_name from pioreactor.whoami import get_unit_name +from pioreactor.whoami import is_testing_env bool_expression = str | bool -def wrap_in_try_except(func, logger: CustomLogger) -> Callable: +def wrap_in_try_except(func, logger: CustomLogger, silent=False) -> Callable: def inner_function(*args, **kwargs) -> None: try: func(*args, **kwargs) except Exception as e: - logger.warning(f"Error in action: {e}") + if not silent: + logger.warning(f"Error in action: {e}") return inner_function @@ -142,12 +144,14 @@ def get_simple_priority(action: struct.Action): return 3 case struct.Update(): return 4 + case struct.When(): + return 5 case struct.Repeat(): return 6 case struct.Log(): return 10 case _: - raise ValueError(f"Not a valid action: {action}") + raise ValueError(f"Not a defined action: {action}") def wrapped_execute_action( @@ -156,6 +160,7 @@ def wrapped_execute_action( job_name: str, logger: CustomLogger, schedule: scheduler, + client: Client, action: struct.Action, dry_run: bool = False, ) -> Callable[..., None]: @@ -165,27 +170,28 @@ def wrapped_execute_action( match action: case struct.Start(_, if_, options, args): - return start_job(unit, experiment, job_name, options, args, dry_run, if_, logger) + return start_job(unit, experiment, client, job_name, options, args, dry_run, if_, logger) case struct.Pause(_, if_): - return pause_job(unit, experiment, job_name, dry_run, if_, logger) + return pause_job(unit, experiment, client, job_name, dry_run, if_, logger) case struct.Resume(_, if_): - return resume_job(unit, experiment, job_name, dry_run, if_, logger) + return resume_job(unit, experiment, client, job_name, dry_run, if_, logger) case struct.Stop(_, if_): - return stop_job(unit, experiment, job_name, dry_run, if_, logger) + return stop_job(unit, experiment, client, job_name, dry_run, if_, logger) case struct.Update(_, if_, options): - return update_job(unit, experiment, job_name, options, dry_run, if_, logger) + return update_job(unit, experiment, client, job_name, options, dry_run, if_, logger) case struct.Log(_, options, if_): - return log(unit, experiment, job_name, options, dry_run, if_, logger) + return log(unit, experiment, client, job_name, options, dry_run, if_, logger) case struct.Repeat(_, if_, repeat_every_hours, while_, max_hours, actions): return repeat( unit, experiment, + client, job_name, dry_run, if_, @@ -198,6 +204,21 @@ def wrapped_execute_action( schedule, ) + case struct.When(_, if_, condition, actions): + return when( + unit, + experiment, + client, + job_name, + dry_run, + if_, + condition, + logger, + action, + actions, + schedule, + ) + case _: raise ValueError(f"Not a valid action: {action}") @@ -215,21 +236,67 @@ def common_wrapped_execute_action( job_name: str, logger: CustomLogger, schedule: scheduler, + client: Client, action: struct.Action, dry_run: bool = False, ) -> Callable[..., None]: actions_to_execute = [] for worker in get_active_workers_in_experiment(experiment): actions_to_execute.append( - wrapped_execute_action(worker, experiment, job_name, logger, schedule, action, dry_run) + wrapped_execute_action(worker, experiment, job_name, logger, schedule, client, action, dry_run) ) return chain_functions(*actions_to_execute) +def when( + unit: str, + experiment: str, + client: Client, + job_name: str, + dry_run: bool, + if_: Optional[bool_expression], + condition: bool_expression, + logger: CustomLogger, + when_action: struct.When, + actions: list[struct.Action], + schedule: scheduler, +) -> Callable[..., None]: + def _callable() -> None: + # first check if the Pioreactor is still part of the experiment. + if (get_assigned_experiment_name(unit) != experiment) and not is_testing_env(): + return + + if (if_ is None) or evaluate_bool_expression(if_, unit): + if evaluate_bool_expression(condition, unit): + for action in actions: + schedule.enter( + delay=hours_to_seconds(action.hours_elapsed), + priority=get_simple_priority(action), + action=wrapped_execute_action( + unit, experiment, job_name, logger, schedule, client, action, dry_run + ), + ) + + else: + schedule.enter( + delay=10, # check every 10 seconds?? + priority=get_simple_priority(when_action), + action=wrapped_execute_action( + unit, experiment, job_name, logger, schedule, client, when_action, dry_run + ), + ) + + else: + logger.debug(f"Action's `if` condition, `{if_}`, evaluated False. Skipping action.") + + return wrap_in_try_except(_callable, logger, silent=False) + + def repeat( unit: str, experiment: str, + client: Client, job_name: str, dry_run: bool, if_: Optional[bool_expression], @@ -238,9 +305,9 @@ def repeat( while_: Optional[bool_expression], repeat_every_hours: float, max_hours: Optional[float], - actions: list[struct.ActionWithoutRepeat], + actions: list[struct.BasicAction], schedule: scheduler, -): +) -> Callable[..., None]: def _callable() -> None: # first check if the Pioreactor is still part of the experiment. if get_assigned_experiment_name(unit) != experiment: @@ -261,7 +328,7 @@ def _callable() -> None: delay=hours_to_seconds(action.hours_elapsed), priority=get_simple_priority(action), action=wrapped_execute_action( - unit, experiment, job_name, logger, schedule, action, dry_run + unit, experiment, job_name, logger, schedule, client, action, dry_run ), ) @@ -276,7 +343,7 @@ def _callable() -> None: delay=hours_to_seconds(repeat_every_hours), priority=get_simple_priority(repeat_action), action=wrapped_execute_action( - unit, experiment, job_name, logger, schedule, repeat_action, dry_run + unit, experiment, job_name, logger, schedule, client, repeat_action, dry_run ), ) else: @@ -293,6 +360,7 @@ def _callable() -> None: def log( unit: str, experiment: str, + client: Client, job_name: str, options: struct._LogOptions, dry_run: bool, @@ -315,6 +383,7 @@ def _callable() -> None: def start_job( unit: str, experiment: str, + client: Client, job_name: str, options: dict, args: list, @@ -331,7 +400,7 @@ def _callable() -> None: if dry_run: logger.info(f"Dry-run: Starting {job_name} on {unit} with options {options} and args {args}.") else: - publish( + client.publish( f"pioreactor/{unit}/{experiment}/run/{job_name}", encode( { @@ -349,6 +418,7 @@ def _callable() -> None: def pause_job( unit: str, experiment: str, + client: Client, job_name: str, dry_run: bool, if_: Optional[str | bool], @@ -363,7 +433,7 @@ def _callable() -> None: if dry_run: logger.info(f"Dry-run: Pausing {job_name} on {unit}.") else: - publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "sleeping") + client.publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "sleeping") else: logger.debug(f"Action's `if` condition, `{if_}`, evaluated False. Skipping action.") @@ -373,6 +443,7 @@ def _callable() -> None: def resume_job( unit: str, experiment: str, + client: Client, job_name: str, dry_run: bool, if_: Optional[str | bool], @@ -386,7 +457,7 @@ def _callable() -> None: if dry_run: logger.info(f"Dry-run: Resuming {job_name} on {unit}.") else: - publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "ready") + client.publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "ready") else: logger.debug(f"Action's `if` condition, `{if_}`, evaluated False. Skipping action.") @@ -396,6 +467,7 @@ def _callable() -> None: def stop_job( unit: str, experiment: str, + client: Client, job_name: str, dry_run: bool, if_: Optional[str | bool], @@ -409,7 +481,7 @@ def _callable() -> None: if dry_run: logger.info(f"Dry-run: Stopping {job_name} on {unit}.") else: - publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "disconnected") + client.publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "disconnected") else: logger.debug(f"Action's `if` condition, `{if_}`, evaluated False. Skipping action.") @@ -419,6 +491,7 @@ def _callable() -> None: def update_job( unit: str, experiment: str, + client: Client, job_name: str, options: dict, dry_run: bool, @@ -436,7 +509,7 @@ def _callable() -> None: else: for setting, value in evaluate_options(options, unit).items(): - publish(f"pioreactor/{unit}/{experiment}/{job_name}/{setting}/set", value) + client.publish(f"pioreactor/{unit}/{experiment}/{job_name}/{setting}/set", value) else: logger.debug(f"Action's `if` condition, `{if_}`, evaluated False. Skipping action.") @@ -622,6 +695,7 @@ def execute_experiment_profile(profile_filename: str, experiment: str, dry_run: job_name, logger, sched, + state.mqtt_client, action, dry_run, ), @@ -645,6 +719,7 @@ def execute_experiment_profile(profile_filename: str, experiment: str, dry_run: job_name, logger, sched, + state.mqtt_client, action, dry_run, ), diff --git a/pioreactor/experiment_profiles/profile_struct.py b/pioreactor/experiment_profiles/profile_struct.py index 50a3656a..457076a0 100644 --- a/pioreactor/experiment_profiles/profile_struct.py +++ b/pioreactor/experiment_profiles/profile_struct.py @@ -68,19 +68,24 @@ class Resume(_Action): pass +class When(_Action): + condition: str = "" + actions: list[Action] = [] + + class Repeat(_Action): repeat_every_hours: float = 1.0 while_: t.Optional[str | bool] = field(name="while", default=None) max_hours: t.Optional[float] = None - actions: list[ActionWithoutRepeat] = [] + actions: list[BasicAction] = [] _completed_loops: int = 0 def __str__(self) -> str: return f"{self.__class__.__name__}({self.hours_elapsed=:.5f}, {self.repeat_every_hours=}, {self.max_hours=}, {self.while_=})" -Action = t.Union[Log, Start, Pause, Stop, Update, Resume, Repeat] -ActionWithoutRepeat = t.Union[Log, Start, Pause, Stop, Update, Resume] +BasicAction = Log | Start | Pause | Stop | Update | Resume +Action = BasicAction | Repeat | When ####### diff --git a/pioreactor/tests/test_execute_experiment_profile.py b/pioreactor/tests/test_execute_experiment_profile.py index d40fc85d..79c63232 100644 --- a/pioreactor/tests/test_execute_experiment_profile.py +++ b/pioreactor/tests/test_execute_experiment_profile.py @@ -4,11 +4,13 @@ from unittest.mock import patch import pytest +from msgspec.json import encode from msgspec.yaml import decode from pioreactor.actions.leader.experiment_profile import _verify_experiment_profile from pioreactor.actions.leader.experiment_profile import execute_experiment_profile from pioreactor.actions.leader.experiment_profile import hours_to_seconds +from pioreactor.background_jobs.stirring import start_stirring from pioreactor.experiment_profiles.profile_struct import _LogOptions from pioreactor.experiment_profiles.profile_struct import CommonBlock from pioreactor.experiment_profiles.profile_struct import Job @@ -20,9 +22,12 @@ from pioreactor.experiment_profiles.profile_struct import Start from pioreactor.experiment_profiles.profile_struct import Stop from pioreactor.experiment_profiles.profile_struct import Update +from pioreactor.experiment_profiles.profile_struct import When from pioreactor.pubsub import collect_all_logs_of_level from pioreactor.pubsub import publish from pioreactor.pubsub import subscribe_and_callback +from pioreactor.structs import ODReading +from pioreactor.utils.timing import current_utc_datetime # First test the hours_to_seconds function @@ -524,6 +529,157 @@ def collection_actions(msg): ) +@patch("pioreactor.actions.leader.experiment_profile._load_experiment_profile") +def test_execute_experiment_profile_when_action(mock__load_experiment_profile) -> None: + experiment = "_testing_experiment" + action = When( + hours_elapsed=0.0005, + condition="${{unit1:od_reading:od1.od > 2.0}}", + actions=[ + Start(hours_elapsed=0, options={"target_rpm": 500}), + Update(hours_elapsed=0.001, options={"target_rpm": 600}), + ], + ) + + profile = Profile( + experiment_profile_name="test_when_action_profile", + plugins=[], + pioreactors={ + "unit1": PioreactorSpecificBlock( + jobs={"stirring": Job(actions=[action])}, + ) + }, + metadata=Metadata(author="test_author"), + ) + + mock__load_experiment_profile.return_value = profile + + actions = [] + + def collect_actions(msg): + actions.append(msg.topic) + + subscribe_and_callback( + collect_actions, + [f"pioreactor/unit1/{experiment}/#"], + allow_retained=False, + ) + + # Simulate OD value + publish( + f"pioreactor/unit1/{experiment}/od_reading/od1", + encode(ODReading(od=2.5, angle="90", timestamp=current_utc_datetime(), channel="1")), + retain=True, + ) + + execute_experiment_profile("profile.yaml", experiment) + + assert actions == [ + f"pioreactor/unit1/{experiment}/od_reading/od1", + f"pioreactor/unit1/{experiment}/run/stirring", + f"pioreactor/unit1/{experiment}/stirring/target_rpm/set", + ] + + +@patch("pioreactor.actions.leader.experiment_profile._load_experiment_profile") +def test_execute_experiment_profile_when_action_with_if(mock__load_experiment_profile) -> None: + experiment = "_testing_experiment" + action = When( + hours_elapsed=0.0005, + if_="1 == 1", + condition="${{unit1:od_reading:od1.od > 2.0}}", + actions=[ + Start(hours_elapsed=0, options={"target_rpm": 500}), + Update(hours_elapsed=0.001, options={"target_rpm": 600}), + ], + ) + + profile = Profile( + experiment_profile_name="test_when_action_with_if_profile", + plugins=[], + pioreactors={ + "unit1": PioreactorSpecificBlock( + jobs={"stirring": Job(actions=[action])}, + ) + }, + metadata=Metadata(author="test_author"), + ) + + mock__load_experiment_profile.return_value = profile + + actions = [] + + def collect_actions(msg): + actions.append(msg.topic) + + subscribe_and_callback( + collect_actions, + [f"pioreactor/unit1/{experiment}/#"], + allow_retained=False, + ) + + # Simulate OD value + publish( + f"pioreactor/unit1/{experiment}/od_reading/od1", + encode(ODReading(od=2.5, angle="90", timestamp=current_utc_datetime(), channel="1")), + retain=True, + ) + + execute_experiment_profile("profile.yaml", experiment) + + assert actions == [ + f"pioreactor/unit1/{experiment}/od_reading/od1", + f"pioreactor/unit1/{experiment}/run/stirring", + f"pioreactor/unit1/{experiment}/stirring/target_rpm/set", + ] + + +@patch("pioreactor.actions.leader.experiment_profile._load_experiment_profile") +def test_execute_experiment_profile_when_action_condition_eventually_met( + mock__load_experiment_profile, +) -> None: + experiment = "_testing_experiment" + + when = When( + hours_elapsed=0.00, + condition="${{unit1:stirring:target_rpm > 800}}", + actions=[ + Update(hours_elapsed=0, options={"target_rpm": 200}), + ], + ) + update = Update(hours_elapsed=0.002, options={"target_rpm": 1000}) + + profile = Profile( + experiment_profile_name="test_when_action_condition_not_met_profile", + plugins=[], + pioreactors={ + "unit1": PioreactorSpecificBlock( + jobs={"stirring": Job(actions=[when, update])}, + ) + }, + metadata=Metadata(author="test_author"), + ) + + mock__load_experiment_profile.return_value = profile + + actions = [] + + def collect_actions(msg): + if msg.payload: + actions.append(float(msg.payload.decode())) + + subscribe_and_callback( + collect_actions, + [f"pioreactor/unit1/{experiment}/stirring/target_rpm"], + allow_retained=False, + ) + + with start_stirring(target_rpm=500, unit="unit1", experiment=experiment, use_rpm=True): + execute_experiment_profile("profile.yaml", experiment) + + assert actions == [500, 1000, 200] + + def test_profiles_in_github_repo() -> None: from pioreactor.mureq import get