Source code for hydroflows.workflow.rule

"""HydroFlows Rule class.

This class is responsible for:
- detecting and validating wildcards in the method.
- creating method instances based on the wildcards.
- parsing input, and output paths of the rule (i.e. for all method instances).
- running the rule (i.e. running all method instances).
"""

import logging
import weakref
from itertools import product
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

from tqdm.contrib.concurrent import thread_map

from hydroflows.utils.parsers import get_wildcards
from hydroflows.utils.path_utils import cwd
from hydroflows.workflow.method import ExpandMethod, Method, ReduceMethod
from hydroflows.workflow.method_parameters import Parameters
from hydroflows.workflow.wildcards import resolve_wildcards

if TYPE_CHECKING:
    from hydroflows.workflow.workflow import Workflow

__all__ = ["Rule"]

logger = logging.getLogger(__name__)


[docs] class Rule: """Rule class. A rule is a wrapper around a method to be run in the context of a workflow. The rule is responsible for detecting wildcards and evaluating them based on the workflow wildcards. It creates method instances based on the wildcards and evaluates all input and output paths of the rule. The rule can be run and dryrun. There is one common rule class to rule all methods. """ def __init__( self, method: Method, workflow: "Workflow", rule_id: Optional[str] = None, ) -> None: """Create a rule instance. Parameters ---------- method : Method The method instance to run. workflow : Workflow The workflow instance to which the rule belongs. rule_id : str, optional The rule id, by default None (method name). """ # set the method self.method: Method = method # set rule id which defaults to method name if rule_id is None: rule_id = method.name self.rule_id: str = str(rule_id) # add weak reference to workflow to avoid circular references self._workflow_ref = weakref.ref(workflow) # placeholders self._wildcard_fields: Dict[str, List] = {} # wildcard - fieldname dictionary self._wildcards: Dict[str, List] = {} # repeat, expand, reduce wildcards self._loop_depth: int = 0 # loop depth of the rule (based on repeat wildcards) self._method_instances: List[Method] = [] # list of method instances self._input: Dict[str, list[Path]] = {} # input paths for all method instances self._output: Dict[ str, list[Path] ] = {} # output paths for all method instances self._output_refs: Dict[str, str] = {} # output path references # add expand wildcards to workflow wildcards if isinstance(self.method, ExpandMethod): for wc, val in self.method.expand_wildcards.items(): self.workflow.wildcards.set(wc, val) # detect and validate wildcards self._detect_wildcards() self._validate_wildcards() # get method instances and in- and output paths self._set_method_instances() self._set_input_output() # add references to other rule outputs and config self._create_references_for_method_inputs() def __repr__(self) -> str: """Return the representation of the rule.""" repr_dict = { "id": self.rule_id, "method": self.method.name, "runs": self.n_runs, } for key, values in self.wildcards.items(): if values: repr_dict[key] = values repr_str = ", ".join([f"{k}={v}" for k, v in repr_dict.items()]) return f"Rule({repr_str})" @property def workflow(self) -> "Workflow": """Return the workflow.""" return self._workflow_ref() @property def n_runs(self) -> int: """Return the number of required method runs.""" return len(self._method_instances) @property def wildcards(self) -> Dict[str, List]: """Return the wildcards of the rule per wildcard type. Wildcards are saved for three types, based on whether these "expand" (1:n), "reduce" (n:1) and "repeat" (n:n) the method. """ return self._wildcards @property def wildcard_fields(self) -> Dict[str, List]: """Return a wildcard - fieldname dictionary. Per wildcard it contains all input, output and params field names which have the wildcard. """ return self._wildcard_fields @property def method_instances(self) -> List[Method]: """Return a list of all method instances.""" return self._method_instances @property def input(self) -> Dict[str, list[Path]]: """Return the input paths of the rule per field.""" return self._input @property def output(self) -> Dict[str, list[Path]]: """Return the output paths of the rule per field.""" return self._output @property def _all_wildcard_fields(self) -> List[str]: """Return all input, output, and params fields with wildcards.""" return list(set(sum(self.wildcard_fields.values(), []))) @property def _all_wildcards(self) -> List[str]: """Return all wildcards.""" return list(set(sum(self.wildcards.values(), []))) ## SERIALIZATION METHODS
[docs] def to_dict(self) -> Dict: """Return the rule as a dictionary.""" out = { "method": self.method.name, "kwargs": self.method.to_kwargs(return_refs=True, posix_path=True), } if self.rule_id != self.method.name: out["rule_id"] = self.rule_id return out
## WILDCARD METHODS def _detect_wildcards(self) -> None: """Detect wildcards based on known workflow wildcard names.""" # check for wildcards in input and output known_wildcards = self.workflow.wildcards.names wildcards: Dict[str, List] = {"input": [], "output": [], "params": []} wildcard_fields: Dict[str, List] = {} for sec in wildcards.keys(): for field, value in getattr(self.method, sec): # skip if value is not a string or path if not isinstance(value, (str, Path)): continue val_wildcards = get_wildcards(value) # loop over wildcards that are known and in the value for wc in set(val_wildcards) & set(known_wildcards): if wc not in wildcards[sec]: wildcards[sec].append(wc) if wc not in wildcard_fields: wildcard_fields[wc] = [] wildcard_fields[wc].append(field) # loop over wildcards that are not known for wc in set(val_wildcards) - set(known_wildcards): # raise warning if wildcard is not known logger.warning(f"Wildcard {wc} not found in workflow wildcards.") # organize wildcards in expand, reduce and repeat wc_in = set(wildcards["input"] + wildcards["params"]) wc_out = set(wildcards["output"]) wildcards_dict = { "repeat": list(wc_in & wc_out), "expand": list(wc_out - wc_in), "reduce": list(wc_in - wc_out), } # set the wildcard properties self._wildcards = wildcards_dict self._wildcard_fields = wildcard_fields self._loop_depth = len(self.wildcards["repeat"]) def _validate_wildcards(self) -> None: """Validate wildcards based on method type.""" msg = "" if isinstance(self.method, ExpandMethod) and not self.wildcards["expand"]: msg = f"ExpandMethod {self.method.name} requires a new expand wildcard on output (Rule {self.rule_id})." elif isinstance(self.method, ReduceMethod) and not self.wildcards["reduce"]: msg = f"ReduceMethod {self.method.name} requires a reduce wildcard on input only (Rule {self.rule_id})." elif self.wildcards["expand"] and not isinstance(self.method, ExpandMethod): wcs = self.wildcards["expand"] msg = f"Wildcard(s) {wcs} missing on input or method {self.method.name} should be an ExpandMethod (Rule {self.rule_id})." elif self.wildcards["reduce"] and not isinstance(self.method, ReduceMethod): wcs = self.wildcards["reduce"] msg = f"Wildcard(s) {wcs} missing on output or method {self.method.name} should be a ReduceMethod (Rule {self.rule_id})." if msg: raise ValueError(msg) def _create_method_instance(self, wildcards: Dict[str, str | list[str]]) -> Method: """Return a new method instance with wildcards replaced by values. Parameters ---------- wildcards : Dict[str, str | list[str]] The wildcards to replace in the method instance. For repeat wildcards, the value should be a single string. For reduce wildcards, the value should be a list of strings. Expand wildcards are only on the output and are set in the method. """ # repeat kwargs should always be a single value; for wc in self.wildcards["repeat"]: if not isinstance(wildcards[wc], str): raise ValueError({f"Repeat wildcard '{wc}' should be a string."}) # reduce should be lists; for wc in self.wildcards["reduce"]: if not isinstance(wildcards[wc], list): raise ValueError(f"Reduce wildcard '{wc}' should be a list.") # expand wildcards should not be in instance wildcards -> only inputs for wc in self.wildcards["expand"]: if wc in wildcards: raise ValueError(f"Expand wildcard '{wc}' should not be in wildcards.") # get kwargs from method kwargs = self.method.to_kwargs() # get input fields over which the method should reduce reduce_fields = [] for wc in self.wildcards["reduce"]: reduce_fields.extend(self.wildcard_fields[wc]) reduce_fields = list(set(reduce_fields)) # keep unique values if reduce_fields: # make sure all values are a list # then take the product of the lists wc_list = [ val if isinstance(val, list) else [val] for val in wildcards.values() ] wildcards_reduce: List[Dict] = [ dict(zip(wildcards.keys(), wc)) for wc in list(product(*wc_list)) ] for key in kwargs: if key in reduce_fields: # reduce method -> turn values into lists # wildcards = {wc: [v1, v2, ...], ...} kwargs[key] = [ resolve_wildcards(kwargs[key], d) for d in wildcards_reduce ] elif key in self._all_wildcard_fields: # repeat method # wildcards = {wc: v, ...} kwargs[key] = resolve_wildcards(kwargs[key], wildcards) method = self.method.from_kwargs(**kwargs) return method @property def _wildcard_product(self) -> List[Dict[str, str]]: """Return the values of wildcards per method instance.""" # only repeat if there are wildcards on the output wildcards = self.wildcards["repeat"] wc_values = [self.workflow.wildcards.get(wc) for wc in wildcards] # drop None from list of values; this occurs when the workflow is not fully initialized yet wc_values = [v for v in wc_values if v is not None] wc_tuples: List[Tuple] = list(product(*wc_values)) wc_product: List[Dict] = [ dict(zip(wildcards, list(wc_val))) for wc_val in wc_tuples ] # add reduce wildcards for wc in self.wildcards["reduce"]: wc_val = self.workflow.wildcards.get(wc) wc_product = [{**wc_dict, wc: wc_val} for wc_dict in wc_product] return wc_product @property def _output_path_refs(self) -> Dict[str, str]: """Retrieve output path references of rule method. Returns ------- Dict[str, str] Dictionary containing the output path as the key and the reference as the value """ if not self._output_refs: for key, value in self.method.output: if isinstance(value, Path): value = value.as_posix() self._output_refs[value] = f"$rules.{self.rule_id}.output.{key}" return self._output_refs def _create_references_for_method_inputs(self): """Create references for method inputs based on output paths of previous rules.""" # chain all output_path_refs of previous rules together output_path_refs = {} for rule in self.workflow.rules: output_path_refs.update(rule._output_path_refs) # Check on duplicate output values for key, value in self.method.output: if not isinstance(value, Path): continue value = value.as_posix() if value in output_path_refs: duplicate_field = output_path_refs[value].replace("$rules.", "") raise ValueError( "All output file paths must be unique. " f"{self.rule_id}.output.{key} ({value}) is already an output of {duplicate_field}" ) for key, value in self.method.input: # Skip if key is already present in input refs if key in self.method.input._refs or value is None: continue if isinstance(value, Path): value = value.as_posix() if value in list(output_path_refs.keys()): self.method.input._refs.update({key: output_path_refs[value]}) def _set_method_instances(self): """Set a list with all instances of the method based on the wildcards.""" self._method_instances = [] for wildcard_dict in self._wildcard_product: method = self._create_method_instance(wildcard_dict) self._method_instances.append(method) def _set_input_output(self): """Set the input and output paths dicts of the rule.""" parameters = {"input": {}, "output": {}} for method in self._method_instances: for name in parameters: if name == "output" and isinstance(method, ExpandMethod): inout_dict = method.output_expanded else: obj: Parameters = getattr(method, name) inout_dict = {key: getattr(obj, key) for key in obj.all_fields} for key, value in inout_dict.items(): if key not in parameters[name]: parameters[name][key] = [] if name in ["input", "output"]: if not isinstance(value, list): value = [value] if not isinstance(value[0], Path): continue # Removes duplicates # Using set() does not preserve insertion order, this does and also filters uniques for val in value: if val not in parameters[name][key]: parameters[name][key].append(val) elif value not in parameters[name][key]: parameters[name][key].append(value) self._input = parameters["input"] self._output = parameters["output"] ## RUN METHODS
[docs] def run(self, max_workers=1) -> None: """Run the rule. Parameters ---------- max_workers : int, optional The maximum number of workers to use, by default 1 """ nruns = self.n_runs # set working directory to workflow root with cwd(self.workflow.root): if nruns == 1 or max_workers == 1: for i, method in enumerate(self._method_instances): msg = f"Running {self.rule_id} {i + 1}/{nruns}" logger.info(msg) method.run() else: tqdm_kwargs = {} if max_workers is not None: tqdm_kwargs.update(max_workers=max_workers) thread_map( lambda method: method.run(), self._method_instances, **tqdm_kwargs, )
[docs] def dryrun( self, input_files: Optional[List[Path]] = None, missing_file_error: bool = False, ) -> List[Path]: """Dryrun the rule. Parameters ---------- input_files : List[Path], optional The input files to use for the dryrun, by default None missing_file_error : bool, optional Whether to raise an error if a file is missing, by default False Returns ------- List[Path] The output files of the dryrun. """ nruns = self.n_runs input_files = input_files or [] output_files = [] # set working directory to workflow root with cwd(self.workflow.root): for i, method in enumerate(self._method_instances): msg = f"Running {self.rule_id} {i + 1}/{nruns}" logger.debug(msg) output_files_i = method.dryrun( missing_file_error=missing_file_error, input_files=input_files ) output_files.extend(output_files_i) return output_files