Skip to content

Commit

Permalink
adding a when action
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Jun 13, 2024
1 parent 2bd8d8f commit 5ea693e
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 23 deletions.
115 changes: 95 additions & 20 deletions pioreactor/actions/leader/experiment_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,25 @@
from pioreactor.experiment_profiles import profile_struct as struct
from pioreactor.logging import create_logger
from pioreactor.logging import CustomLogger
from pioreactor.pubsub import publish
from pioreactor.pubsub import Client
from pioreactor.pubsub import put_into_leader
from pioreactor.utils import ClusterJobManager
from pioreactor.utils import managed_lifecycle
from pioreactor.utils.timing import current_utc_timestamp
from pioreactor.whoami import get_assigned_experiment_name
from pioreactor.whoami import get_unit_name
from pioreactor.whoami import is_testing_env

bool_expression = str | bool


def wrap_in_try_except(func, logger: CustomLogger) -> Callable:
def wrap_in_try_except(func, logger: CustomLogger, silent=False) -> Callable:
def inner_function(*args, **kwargs) -> None:
try:
func(*args, **kwargs)
except Exception as e:
logger.warning(f"Error in action: {e}")
if not silent:
logger.warning(f"Error in action: {e}")

return inner_function

Expand Down Expand Up @@ -142,12 +144,14 @@ def get_simple_priority(action: struct.Action):
return 3
case struct.Update():
return 4
case struct.When():
return 5
case struct.Repeat():
return 6
case struct.Log():
return 10
case _:
raise ValueError(f"Not a valid action: {action}")
raise ValueError(f"Not a defined action: {action}")


def wrapped_execute_action(
Expand All @@ -156,6 +160,7 @@ def wrapped_execute_action(
job_name: str,
logger: CustomLogger,
schedule: scheduler,
client: Client,
action: struct.Action,
dry_run: bool = False,
) -> Callable[..., None]:
Expand All @@ -165,27 +170,28 @@ def wrapped_execute_action(

match action:
case struct.Start(_, if_, options, args):
return start_job(unit, experiment, job_name, options, args, dry_run, if_, logger)
return start_job(unit, experiment, client, job_name, options, args, dry_run, if_, logger)

case struct.Pause(_, if_):
return pause_job(unit, experiment, job_name, dry_run, if_, logger)
return pause_job(unit, experiment, client, job_name, dry_run, if_, logger)

case struct.Resume(_, if_):
return resume_job(unit, experiment, job_name, dry_run, if_, logger)
return resume_job(unit, experiment, client, job_name, dry_run, if_, logger)

case struct.Stop(_, if_):
return stop_job(unit, experiment, job_name, dry_run, if_, logger)
return stop_job(unit, experiment, client, job_name, dry_run, if_, logger)

case struct.Update(_, if_, options):
return update_job(unit, experiment, job_name, options, dry_run, if_, logger)
return update_job(unit, experiment, client, job_name, options, dry_run, if_, logger)

case struct.Log(_, options, if_):
return log(unit, experiment, job_name, options, dry_run, if_, logger)
return log(unit, experiment, client, job_name, options, dry_run, if_, logger)

case struct.Repeat(_, if_, repeat_every_hours, while_, max_hours, actions):
return repeat(
unit,
experiment,
client,
job_name,
dry_run,
if_,
Expand All @@ -198,6 +204,21 @@ def wrapped_execute_action(
schedule,
)

case struct.When(_, if_, condition, actions):
return when(
unit,
experiment,
client,
job_name,
dry_run,
if_,
condition,
logger,
action,
actions,
schedule,
)

case _:
raise ValueError(f"Not a valid action: {action}")

Expand All @@ -215,21 +236,67 @@ def common_wrapped_execute_action(
job_name: str,
logger: CustomLogger,
schedule: scheduler,
client: Client,
action: struct.Action,
dry_run: bool = False,
) -> Callable[..., None]:
actions_to_execute = []
for worker in get_active_workers_in_experiment(experiment):
actions_to_execute.append(
wrapped_execute_action(worker, experiment, job_name, logger, schedule, action, dry_run)
wrapped_execute_action(worker, experiment, job_name, logger, schedule, client, action, dry_run)
)

return chain_functions(*actions_to_execute)


def when(
unit: str,
experiment: str,
client: Client,
job_name: str,
dry_run: bool,
if_: Optional[bool_expression],
condition: bool_expression,
logger: CustomLogger,
when_action: struct.When,
actions: list[struct.Action],
schedule: scheduler,
) -> Callable[..., None]:
def _callable() -> None:
# first check if the Pioreactor is still part of the experiment.
if (get_assigned_experiment_name(unit) != experiment) and not is_testing_env():
return

if (if_ is None) or evaluate_bool_expression(if_, unit):
if evaluate_bool_expression(condition, unit):
for action in actions:
schedule.enter(
delay=hours_to_seconds(action.hours_elapsed),
priority=get_simple_priority(action),
action=wrapped_execute_action(
unit, experiment, job_name, logger, schedule, client, action, dry_run
),
)

else:
schedule.enter(
delay=10, # check every 10 seconds??
priority=get_simple_priority(when_action),
action=wrapped_execute_action(
unit, experiment, job_name, logger, schedule, client, when_action, dry_run
),
)

else:
logger.debug(f"Action's `if` condition, `{if_}`, evaluated False. Skipping action.")

return wrap_in_try_except(_callable, logger, silent=False)


def repeat(
unit: str,
experiment: str,
client: Client,
job_name: str,
dry_run: bool,
if_: Optional[bool_expression],
Expand All @@ -238,9 +305,9 @@ def repeat(
while_: Optional[bool_expression],
repeat_every_hours: float,
max_hours: Optional[float],
actions: list[struct.ActionWithoutRepeat],
actions: list[struct.BasicAction],
schedule: scheduler,
):
) -> Callable[..., None]:
def _callable() -> None:
# first check if the Pioreactor is still part of the experiment.
if get_assigned_experiment_name(unit) != experiment:
Expand All @@ -261,7 +328,7 @@ def _callable() -> None:
delay=hours_to_seconds(action.hours_elapsed),
priority=get_simple_priority(action),
action=wrapped_execute_action(
unit, experiment, job_name, logger, schedule, action, dry_run
unit, experiment, job_name, logger, schedule, client, action, dry_run
),
)

Expand All @@ -276,7 +343,7 @@ def _callable() -> None:
delay=hours_to_seconds(repeat_every_hours),
priority=get_simple_priority(repeat_action),
action=wrapped_execute_action(
unit, experiment, job_name, logger, schedule, repeat_action, dry_run
unit, experiment, job_name, logger, schedule, client, repeat_action, dry_run
),
)
else:
Expand All @@ -293,6 +360,7 @@ def _callable() -> None:
def log(
unit: str,
experiment: str,
client: Client,
job_name: str,
options: struct._LogOptions,
dry_run: bool,
Expand All @@ -315,6 +383,7 @@ def _callable() -> None:
def start_job(
unit: str,
experiment: str,
client: Client,
job_name: str,
options: dict,
args: list,
Expand All @@ -331,7 +400,7 @@ def _callable() -> None:
if dry_run:
logger.info(f"Dry-run: Starting {job_name} on {unit} with options {options} and args {args}.")
else:
publish(
client.publish(
f"pioreactor/{unit}/{experiment}/run/{job_name}",
encode(
{
Expand All @@ -349,6 +418,7 @@ def _callable() -> None:
def pause_job(
unit: str,
experiment: str,
client: Client,
job_name: str,
dry_run: bool,
if_: Optional[str | bool],
Expand All @@ -363,7 +433,7 @@ def _callable() -> None:
if dry_run:
logger.info(f"Dry-run: Pausing {job_name} on {unit}.")
else:
publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "sleeping")
client.publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "sleeping")
else:
logger.debug(f"Action's `if` condition, `{if_}`, evaluated False. Skipping action.")

Expand All @@ -373,6 +443,7 @@ def _callable() -> None:
def resume_job(
unit: str,
experiment: str,
client: Client,
job_name: str,
dry_run: bool,
if_: Optional[str | bool],
Expand All @@ -386,7 +457,7 @@ def _callable() -> None:
if dry_run:
logger.info(f"Dry-run: Resuming {job_name} on {unit}.")
else:
publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "ready")
client.publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "ready")
else:
logger.debug(f"Action's `if` condition, `{if_}`, evaluated False. Skipping action.")

Expand All @@ -396,6 +467,7 @@ def _callable() -> None:
def stop_job(
unit: str,
experiment: str,
client: Client,
job_name: str,
dry_run: bool,
if_: Optional[str | bool],
Expand All @@ -409,7 +481,7 @@ def _callable() -> None:
if dry_run:
logger.info(f"Dry-run: Stopping {job_name} on {unit}.")
else:
publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "disconnected")
client.publish(f"pioreactor/{unit}/{experiment}/{job_name}/$state/set", "disconnected")
else:
logger.debug(f"Action's `if` condition, `{if_}`, evaluated False. Skipping action.")

Expand All @@ -419,6 +491,7 @@ def _callable() -> None:
def update_job(
unit: str,
experiment: str,
client: Client,
job_name: str,
options: dict,
dry_run: bool,
Expand All @@ -436,7 +509,7 @@ def _callable() -> None:

else:
for setting, value in evaluate_options(options, unit).items():
publish(f"pioreactor/{unit}/{experiment}/{job_name}/{setting}/set", value)
client.publish(f"pioreactor/{unit}/{experiment}/{job_name}/{setting}/set", value)
else:
logger.debug(f"Action's `if` condition, `{if_}`, evaluated False. Skipping action.")

Expand Down Expand Up @@ -622,6 +695,7 @@ def execute_experiment_profile(profile_filename: str, experiment: str, dry_run:
job_name,
logger,
sched,
state.mqtt_client,
action,
dry_run,
),
Expand All @@ -645,6 +719,7 @@ def execute_experiment_profile(profile_filename: str, experiment: str, dry_run:
job_name,
logger,
sched,
state.mqtt_client,
action,
dry_run,
),
Expand Down
11 changes: 8 additions & 3 deletions pioreactor/experiment_profiles/profile_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,24 @@ class Resume(_Action):
pass


class When(_Action):
condition: str = ""
actions: list[Action] = []


class Repeat(_Action):
repeat_every_hours: float = 1.0
while_: t.Optional[str | bool] = field(name="while", default=None)
max_hours: t.Optional[float] = None
actions: list[ActionWithoutRepeat] = []
actions: list[BasicAction] = []
_completed_loops: int = 0

def __str__(self) -> str:
return f"{self.__class__.__name__}({self.hours_elapsed=:.5f}, {self.repeat_every_hours=}, {self.max_hours=}, {self.while_=})"


Action = t.Union[Log, Start, Pause, Stop, Update, Resume, Repeat]
ActionWithoutRepeat = t.Union[Log, Start, Pause, Stop, Update, Resume]
BasicAction = Log | Start | Pause | Stop | Update | Resume
Action = BasicAction | Repeat | When

#######

Expand Down
Loading

0 comments on commit 5ea693e

Please sign in to comment.