"""HydroFlows Method class.
A method is where the actual work of a rule happens.
It should have a name, inputs, and outputs, and optionally params.
All HydroFlow methods should inherit from this class and implement specific
validators and a run method.
"""
import inspect
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, ClassVar, Dict, Generator, List, Optional, Tuple
from hydroflows.utils.parsers import has_wildcards
from hydroflows.workflow.method_entrypoints import METHODS
from hydroflows.workflow.method_parameters import Parameters
from hydroflows.workflow.wildcards import resolve_wildcards
__all__ = ["Method"]
logger = logging.getLogger(__name__)
[docs]
class Method(ABC):
"""Base method for all methods.
The method class defines the structure of a method in a HydroFlow workflow.
It should have a name, input, output and params, and implement a _run and __init__ method.
"""
# name of the method, should be replaced in subclass
name: ClassVar[str] = "abstract_method"
# Define the method kwargs for testing
_test_kwargs = {}
@abstractmethod
def __init__(self) -> None:
# NOTE: the parameter fields are specific to each method and should
# be initialized in the method __init__ method.
self.input: Parameters = Parameters()
# NOTE: wildcards on outputs should be defined on file parent, not the file name itself!
self.output: Parameters = Parameters()
self.params: Parameters = Parameters()
## ABSTRACT METHODS
@abstractmethod
def _run(self) -> None:
"""Implement the rule logic here.
This method is called when executing the rule.
"""
# NOTE: this should be implemented in the specific rule
# it can use input, output and params, e.g. self.input.file1
# raise NotImplementedError
pass
## INPUT/OUTPUT/PARAMS PROPERTIES
@property
def input(self) -> Parameters:
"""Return the input parameters of the method."""
if not hasattr(self, "_input"):
raise ValueError("Input parameters not set")
return self._input
@input.setter
def input(self, value: Parameters) -> None:
"""Set the input parameters of the method."""
if not isinstance(value, Parameters):
raise ValueError("Input should be a Parameters instance")
self._input = value
@property
def output(self) -> Parameters:
"""Return the output parameters of the method."""
if not hasattr(self, "_output"):
raise ValueError("Output parameters not set")
return self._output
@output.setter
def output(self, value: Parameters) -> None:
"""Set the output parameters of the method."""
if not isinstance(value, Parameters):
raise ValueError("Output should be a Parameters instance")
self._output = value
@property
def params(self) -> Parameters:
"""Return the additional parameters of the method."""
if not hasattr(self, "_params"):
return Parameters()
return self._params
@params.setter
def params(self, value: Parameters) -> None:
"""Set the additional parameters of the method."""
if not isinstance(value, Parameters):
raise ValueError("Params should be a Parameters instance")
self._params = value
## MAGIC METHODS
def __repr__(self) -> str:
parameters = "\n ".join([f"{k}=({getattr(self, k)})" for k in self.dict])
method_type = self.__class__.__bases__[0].__name__
expand = ""
if isinstance(self, ExpandMethod):
wcs = " ".join([f"{k}={v}" for k, v in self.expand_wildcards.items()])
expand = f" expand_wildcards=({wcs})\n"
return f"{method_type}(\n name={self.name}\n {parameters}\n{expand})"
def __eq__(self, other: "Method") -> bool:
return (
self.__class__ == other.__class__
and self.name == other.name
and self.to_dict() == other.to_dict()
)
## SERIALIZATION METHODS
[docs]
def to_kwargs(
self,
mode="json",
exclude_defaults=True,
posix_path=False,
return_refs=False,
**kwargs,
) -> Dict[str, Any]:
"""Return flattened keyword-arguments which result in the same method parametrization."""
kwargs = dict(
mode=mode, posix_path=posix_path, return_refs=return_refs, **kwargs
)
# get all input, output which are in the __init__ signature
par = set(inspect.signature(self.__init__).parameters.keys())
# Add extra optional (not defined) vars from input
par.update(par ^ set(self.input.to_dict().keys()))
par = list(par)
# Get the arguments
in_kw = self.input.to_dict(filter_keys=par, **kwargs)
out_kw = self.output.to_dict(filter_keys=par, **kwargs)
# get all non-default params
params_kw = self.params.to_dict(exclude_defaults=exclude_defaults, **kwargs)
return {**in_kw, **out_kw, **params_kw}
def _kwargs_to_key_mapping(self) -> Dict[str, str]:
"""Return kwarg-key to input/output/params-key mapping."""
# check if key is in input, output or params
mapping = {}
for key in self.to_kwargs().keys():
for c in ["input", "output", "params"]:
if key in self.dict.get(c, {}):
mapping[key] = f"{c}.{key}"
break
if key not in mapping:
raise ValueError(f"Key {key} not found in {self.name}.")
return mapping
@property
def dict(self) -> Dict[str, Dict]:
"""Return a dictionary representation of the method input, output and params."""
if not hasattr(self, "._dict") or not self._dict:
self._dict = self.to_dict()
return self._dict
[docs]
def to_dict(self, **kwargs) -> Dict:
"""Return a serialized dictionary representation of the method input, output and params."""
_kwargs = dict(exclude_defaults=True, round_trip=True, mode="json")
dump_kwargs = {**_kwargs, **kwargs}
out_dict = {
"input": self.input.to_dict(**dump_kwargs),
"output": self.output.to_dict(**dump_kwargs),
"params": {},
}
if hasattr(self, "_params"): # params are optional
out_dict["params"] = self.params.to_dict(**dump_kwargs)
return out_dict
## SERIALIZATION METHODS
[docs]
@classmethod
def from_kwargs(cls, method_name: Optional[str] = None, **kwargs) -> "Method":
"""Create a new method instance from the method `name` and its initialization arguments."""
# if called from the parent class, get the subclass by name
if cls.name == "abstract_method":
if method_name is None:
raise ValueError("Cannot initiate from Method without a method name")
cls = cls._get_subclass(method_name)
return cls(**kwargs)
@classmethod
def _get_subclasses(cls) -> Generator[type["Method"], None, None]:
"""Get all imported subclasses of the Method class."""
for subclass in cls.__subclasses__():
yield from subclass._get_subclasses()
yield subclass
@classmethod
def _get_subclass(cls, name: str) -> type["Method"]:
"""Get a subclass by name."""
name = name.lower()
for subclass in cls._get_subclasses():
if subclass.name.lower() == name or subclass.__name__.lower() == name:
return subclass
# if not found, try to import the module using entry points
return METHODS.load(name)
## TESTING METHODS (we keep these here such that external implementations can use them)
def _test_roundtrip(self) -> None:
"""Test if the method can be serialized and deserialized."""
# parse all values to strings to test serialization
kw = self.to_kwargs(exclude_defaults=False, posix_path=True)
kw = {k: str(v) for k, v in kw.items()}
m = self.from_kwargs(self.name, **kw)
if m.dict != self.dict:
raise ValueError("Method serialization failed")
def _test_unique_keys(self) -> None:
"""Check if the method input, output and params keys are unique."""
inputs = list(self.input.model_fields.keys())
outputs = list(self.output.model_fields.keys())
params = list(self.params.model_fields.keys())
ukeys = set(inputs + outputs + params)
if "method_name" in ukeys:
raise ValueError("Key 'method_name' is reserved for the method name")
nkeys = len(inputs) + len(outputs) + len(params)
# check for unique keys
if len(ukeys) != nkeys:
raise ValueError("Keys of input, output and params should all be unique")
[docs]
def test_method(self) -> None:
"""Run all tests on the method."""
errors = []
try:
self._test_unique_keys()
except Exception as e:
errors.append(e)
try:
self._test_roundtrip()
except Exception as e:
errors.append(e)
if errors:
msg = "\n".join([str(e) for e in errors])
raise ValueError(f"Method tests failed:\n{msg}")
## RUN METHODS
[docs]
def dryrun(
self,
input_files: List[Path],
missing_file_error: bool = False,
touch_output: bool = False,
) -> List[Path]:
"""Run method with dummy outputs.
Parameters
----------
input_files : List[Path]
List of input paths to check, by default None.
missing_file_error : bool, optional
Raise an error if a missing file is encountered, by default False.
Returns
-------
List[Path]
List of output paths.
"""
for key, value in self.input.model_dump().items():
if isinstance(value, Path):
if value not in input_files and not value.is_file():
msg = f"Input file {self.name}.input.{key} not found: {value}"
if missing_file_error:
raise FileNotFoundError(msg)
else:
logger.warning(msg)
# return output paths
out_paths = [value for _, value in self._output_paths]
if touch_output:
for path in out_paths:
logger.debug(f"Touch file at {path.as_posix()}")
path.parent.mkdir(parents=True, exist_ok=True)
path.touch()
return out_paths
[docs]
def run(self, check_output: bool = True) -> None:
"""Run the method with input/output checks.
Parameters
----------
check_output : bool, optional
Check if output files are created, by default True.
"""
# TODO warning if wildcards on input / params
self.check_input_output_paths()
self._run()
if check_output:
self.check_output_exists()
@property
def _output_paths(self) -> List[Tuple[str, Path]]:
"""Return a list of output key-path tuples."""
return [
(key, val) for key, val in self.output.to_dict(filter_types=(Path)).items()
]
[docs]
def check_output_exists(self):
"""Check if output files exist."""
for key, path in self._output_paths:
if not path.is_file():
raise FileNotFoundError(
f"Output file {self.name}.output.{key} not found: {path}"
)
[docs]
class ExpandMethod(Method, ABC):
"""Base class for methods that expand on a wildcard."""
@abstractmethod
def __init__(self) -> None:
"""Create a new expand method instance."""
# NOTE: see init of super method for requirements
# in addition call the set_expand_wildcard method
# self.set_expand_wildcard(wildcard, values)
[docs]
def set_expand_wildcard(self, wildcard: str, values: List[str]) -> None:
"""Set wildcard key and values.
Parameters
----------
wildcard : str
The wildcard key.
values : List[str]
The list of expand values for the wildcard.
"""
if not hasattr(self, "_expand_wildcards"):
self._expand_wildcards: Dict[str, List] = {}
self._expand_wildcards[wildcard] = values
@property
def expand_wildcards(self) -> Dict[str, List[str]]:
"""Return a dict with a list of expand values per wildcard key."""
if not hasattr(self, "_expand_wildcards"):
return {}
return self._expand_wildcards
@property
def output_expanded(self) -> Dict[str, Path | List[Path]]:
"""Output dict with wildcards in output path evaluated."""
if not hasattr(self, "_output_expanded"):
self._evaluate_expand_wildcards()
return self._output_expanded
def _evaluate_expand_wildcards(self) -> None:
"""Evaluate wildcards in output paths."""
self._output_expanded = {}
for key, value in self.output.to_dict(filter_types=(Path)).items():
value = resolve_wildcards(value, self.expand_wildcards)
self._output_expanded[key] = value
@property
def _output_paths(self) -> List[Tuple[str, Path]]:
"""Return a list of output key-path tuples."""
paths = []
for key, val in self.output_expanded.items():
if isinstance(val, list):
for path in val:
paths.append((key, path))
else:
paths.append((key, val))
return paths
[docs]
def get_output_for_wildcards(self, wildcards: dict[str, str]) -> dict[str, Path]:
"""Get the output paths for a specific wildcard configuration.
Parameters
----------
wildcards : dict[str, str]
The wildcard configuration.
Note that the wildcard values should be single strings.
Returns
-------
dict[str, Path]
The output paths for the wildcard configuration.
"""
# make sure all wildcard values are strings
if not all(isinstance(v, str) for v in wildcards.values()):
raise ValueError("All wildcard values should be strings.")
output = {}
for key, path in self.output.to_dict(filter_types=(Path)).items():
output[key] = resolve_wildcards(path, wildcards)
return output
[docs]
class ReduceMethod(Method):
"""Base class for methods that reduce multiple inputs to one output."""
# NOTE: for now this class merely serves to flag a a reduce method
# it may be extended in the future to include specific reduce logic or methods