Skip to content

Script

The hera.workflows.script module provides the Script class.

See https://argoproj.github.io/argo-workflows/workflow-concepts/#script for more on scripts in Argo Workflows.

Script

A Script in Argo Workflows acts as a wrapper around a Container, where you can specify Python code to run through source.

In Hera, you should aim to use the script decorator, rather than the Script class directly. You will need to refer to the Script class for the kwargs that the decorator can take, but your IDE should give you code completion and type hints.

Source code in src/hera/workflows/script.py
@dataclass(kw_only=True)
class Script(
    EnvIOMixin,
    CallableTemplateMixin,
    ContainerMixin,
    TemplateMixin,
    ResourceMixin,
    VolumeMountMixin,
):
    """A Script in Argo Workflows acts as a wrapper around a Container, where you can specify Python code to run through `source`.

    In Hera, you should aim to use the script decorator, rather than the Script class directly.
    You will need to refer to the Script class for the kwargs that the decorator can take, but your IDE should give you code completion and type hints.
    """

    container_name: Optional[str] = None
    args: Optional[List[str]] = None
    command: Optional[List[str]] = None
    lifecycle: Optional[Lifecycle] = None
    security_context: Optional[SecurityContext] = None
    source: Optional[Union[Callable, str]] = None
    working_dir: Optional[str] = None
    add_cwd_to_sys_path: Optional[bool] = None
    constructor: str | ScriptConstructor | None = None

    def __post_init__(self):
        """Perform post init validation."""
        super().__post_init__()

        self.constructor = self._set_constructor(self.constructor)

        self.command = self.command or global_config.script_command

        if self.add_cwd_to_sys_path is None:
            self.add_cwd_to_sys_path = True

        assert isinstance(self.constructor, ScriptConstructor)

        self.constructor.transform_values(self)

    @staticmethod
    def _set_constructor(constructor: str | ScriptConstructor | None):
        if constructor is None:
            # TODO: In the future we can insert
            # detection code here to determine
            # the best constructor to use.
            constructor = InlineScriptConstructor()
        if isinstance(constructor, ScriptConstructor):
            return constructor
        assert isinstance(constructor, str)
        if constructor.lower() == "inline":
            return InlineScriptConstructor()
        elif constructor.lower() == "runner":
            return RunnerScriptConstructor()
        raise ValueError(f"Unknown constructor {constructor}")

    def _build_template(self) -> _ModelTemplate:
        assert isinstance(self.constructor, ScriptConstructor)

        return self.constructor.transform_template_post_build(
            self,
            _ModelTemplate(
                active_deadline_seconds=IntOrString(root=self.active_deadline_seconds)
                if self.active_deadline_seconds
                else None,
                affinity=self.affinity,
                archive_location=self.archive_location,
                automount_service_account_token=self.automount_service_account_token,
                daemon=self.daemon,
                executor=self.executor,
                fail_fast=self.fail_fast,
                host_aliases=self.host_aliases,
                init_containers=self._build_init_containers(),
                inputs=self._build_inputs(),
                memoize=self.memoize,
                metadata=self._build_metadata(),
                metrics=self._build_metrics(),
                name=self.name,
                node_selector=self.node_selector,
                outputs=self._build_outputs(),
                parallelism=self.parallelism,
                plugin=self.plugin,
                pod_spec_patch=self.pod_spec_patch,
                priority_class_name=self.priority_class_name,
                retry_strategy=self._build_retry_strategy(),
                scheduler_name=self.scheduler_name,
                script=self._build_script(),
                security_context=self.pod_security_context,
                service_account_name=self.service_account_name,
                sidecars=self._build_sidecars(),
                synchronization=self.synchronization,
                timeout=self.timeout,
                tolerations=self.tolerations,
                volumes=self._build_volumes(),
            ),
        )

    def _build_script(self) -> _ModelScriptTemplate:
        assert isinstance(self.constructor, ScriptConstructor)
        if _output_annotations_used(cast(Callable, self.source)) and isinstance(
            self.constructor, RunnerScriptConstructor
        ):
            if not self.constructor.outputs_directory:
                self.constructor.outputs_directory = self.constructor.DEFAULT_HERA_OUTPUTS_DIRECTORY
            if self.constructor.volume_for_outputs is not None:
                if self.constructor.volume_for_outputs.mount_path is None:
                    self.constructor.volume_for_outputs.mount_path = self.constructor.outputs_directory
                self._create_hera_outputs_volume(self.constructor.volume_for_outputs)
        assert self.image
        return self.constructor.transform_script_template_post_build(
            self,
            _ModelScriptTemplate(
                args=self.args,
                command=self.command,
                env=self._build_env(),
                env_from=self._build_env_from(),
                image=self.image,
                # `image_pull_policy` in script wants a string not an `ImagePullPolicy` object
                image_pull_policy=self._build_image_pull_policy(),
                lifecycle=self.lifecycle,
                liveness_probe=self.liveness_probe,
                name=self.container_name,
                ports=self.ports,
                readiness_probe=self.readiness_probe,
                resize_policy=self.resize_policy,
                resources=self._build_resources(),
                restart_policy=self.restart_policy,
                security_context=self.security_context,
                source=self.constructor.generate_source(self),
                startup_probe=self.startup_probe,
                stdin=self.stdin,
                stdin_once=self.stdin_once,
                termination_message_path=self.termination_message_path,
                termination_message_policy=self.termination_message_policy,
                tty=self.tty,
                volume_devices=self.volume_devices,
                volume_mounts=self._build_volume_mounts(),
                working_dir=self.working_dir,
            ),
        )

    def _build_inputs(self) -> Optional[ModelInputs]:
        inputs = super()._build_inputs()
        func_parameters: List[Parameter] = []
        func_artifacts: List[Artifact] = []
        if callable(self.source):
            func_parameters, func_artifacts = _get_inputs_from_callable(self.source)

        return cast(Optional[ModelInputs], self._aggregate_callable_io(inputs, func_parameters, func_artifacts, False))

    def _build_outputs(self) -> Optional[ModelOutputs]:
        outputs = super()._build_outputs()

        if not callable(self.source):
            return outputs

        outputs_directory = None
        if isinstance(self.constructor, RunnerScriptConstructor):
            outputs_directory = self.constructor.outputs_directory or self.constructor.DEFAULT_HERA_OUTPUTS_DIRECTORY

        out_parameters, out_artifacts = _get_outputs_from_return_annotation(self.source, outputs_directory)
        func_parameters, func_artifacts = _get_outputs_from_parameter_annotations(self.source, outputs_directory)
        func_parameters.extend(out_parameters)
        func_artifacts.extend(out_artifacts)

        return cast(
            Optional[ModelOutputs], self._aggregate_callable_io(outputs, func_parameters, func_artifacts, True)
        )

    def _aggregate_callable_io(
        self,
        current_io: Optional[Union[ModelInputs, ModelOutputs]],
        func_parameters: List[Parameter],
        func_artifacts: List[Artifact],
        output: bool,
    ) -> Union[ModelOutputs, ModelInputs, None]:
        """Aggregate the Inputs/Outputs with parameters and artifacts extracted from a callable."""
        if not func_parameters and not func_artifacts:
            return current_io
        if current_io is None:
            if output:
                return ModelOutputs(
                    parameters=[p.as_output() for p in func_parameters] or None,
                    artifacts=[a._build_artifact() for a in func_artifacts] or None,
                )

            return ModelInputs(
                parameters=[p.as_input() for p in func_parameters] or None,
                artifacts=[a._build_artifact() for a in func_artifacts] or None,
            )

        seen_params = {p.name for p in current_io.parameters or []}
        seen_artifacts = {a.name for a in current_io.artifacts or []}

        for param in func_parameters:
            if param.name not in seen_params and param.name not in seen_artifacts:
                if current_io.parameters is None:
                    current_io.parameters = []
                if output:
                    current_io.parameters.append(param.as_output())
                else:
                    current_io.parameters.append(param.as_input())

        for artifact in func_artifacts:
            if artifact.name not in seen_artifacts:
                if current_io.artifacts is None:
                    current_io.artifacts = []
                current_io.artifacts.append(artifact._build_artifact())

        return current_io

    def _create_hera_outputs_volume(self, volume: _BaseVolume) -> None:
        """Add given volume to the script template for the automatic saving of the hera outputs."""
        assert isinstance(self.constructor, RunnerScriptConstructor)

        if self.volumes is None:
            self.volumes = []
        elif isinstance(self.volumes, Sequence):
            self.volumes = list(self.volumes)
        elif not isinstance(self.volumes, list):
            self.volumes = [self.volumes]

        if volume not in self.volumes:
            self.volumes.append(volume)

active_deadline_seconds

active_deadline_seconds: Optional[int | str] = None

add_cwd_to_sys_path

add_cwd_to_sys_path: Optional[bool] = None

affinity

affinity: Optional[Affinity] = None

annotations

annotations: Optional[Dict[str, str]] = None

archive_location

archive_location: Optional[ArtifactLocation] = None

args

args: Optional[List[str]] = None

automount_service_account_token

automount_service_account_token: Optional[bool] = None

command

command: Optional[List[str]] = None

constructor

constructor: str | ScriptConstructor | None = None

container_name

container_name: Optional[str] = None

daemon

daemon: Optional[bool] = None

env

env: EnvT = None

env_from

env_from: EnvFromT = None

executor

executor: Optional[ExecutorConfig] = None

fail_fast

fail_fast: Optional[bool] = None

host_aliases

host_aliases: Optional[List[HostAlias]] = None

image

image: Optional[str] = None

image_pull_policy

image_pull_policy: Optional[Union[str, ImagePullPolicy]] = (
    None
)

init_containers

init_containers: Optional[
    List[Union[UserContainer, UserContainer]]
] = None

inputs

inputs: InputsT = None

labels

labels: Optional[Dict[str, str]] = None

lifecycle

lifecycle: Optional[Lifecycle] = None

liveness_probe

liveness_probe: Optional[Probe] = None

memoize

memoize: Optional[Memoize] = None

metrics

metrics: Optional[MetricsT] = None

name

name: Optional[str] = None

node_selector

node_selector: Optional[Dict[str, str]] = None

outputs

outputs: OutputsT = None

parallelism

parallelism: Optional[int] = None

plugin

plugin: Optional[Plugin] = None

pod_security_context

pod_security_context: Optional[PodSecurityContext] = None

pod_spec_patch

pod_spec_patch: Optional[str] = None

ports

ports: Optional[List[ContainerPort]] = None

priority_class_name

priority_class_name: Optional[str] = None

readiness_probe

readiness_probe: Optional[Probe] = None

resize_policy

resize_policy: Optional[List[ContainerResizePolicy]] = None

resources

resources: Optional[
    Union[ResourceRequirements, Resources]
] = None

restart_policy

restart_policy: Optional[str] = None

retry_strategy

retry_strategy: Optional[
    Union[RetryStrategy, RetryStrategy]
] = None

scheduler_name

scheduler_name: Optional[str] = None

security_context

security_context: Optional[SecurityContext] = None

service_account_name

service_account_name: Optional[str] = None

sidecars

sidecars: Optional[
    OneOrMany[UserContainer | UserContainer]
] = None

source

source: Optional[Union[Callable, str]] = None

startup_probe

startup_probe: Optional[Probe] = None

stdin

stdin: Optional[bool] = None

stdin_once

stdin_once: Optional[bool] = None

synchronization

synchronization: Optional[Synchronization] = None

termination_message_path

termination_message_path: Optional[str] = None

termination_message_policy

termination_message_policy: Optional[str] = None

timeout

timeout: Optional[str] = None

tolerations

tolerations: Optional[List[Toleration]] = None

tty

tty: Optional[bool] = None

volume_devices

volume_devices: Optional[List[VolumeDevice]] = None

volume_mounts

volume_mounts: Optional[List[VolumeMount]] = None

volumes

volumes: Optional[VolumesT] = None

working_dir

working_dir: Optional[str] = None

get_artifact

get_artifact(name: str) -> Artifact

Finds and returns the artifact with the supplied name.

Note that this method will raise an error if the artifact is not found.

Parameters:

Name Type Description Default
name str

name of the input artifact to find and return.

required

Returns:

Name Type Description
Artifact Artifact

the artifact with the supplied name.

Raises:

Type Description
KeyError

if the artifact is not found.

Source code in src/hera/workflows/_mixins.py
def get_artifact(self, name: str) -> Artifact:
    """Finds and returns the artifact with the supplied name.

    Note that this method will raise an error if the artifact is not found.

    Args:
        name: name of the input artifact to find and return.

    Returns:
        Artifact: the artifact with the supplied name.

    Raises:
        KeyError: if the artifact is not found.
    """
    inputs = self._build_inputs()
    if inputs is None:
        raise KeyError(f"No inputs set. Artifact {name} not found.")
    if inputs.artifacts is None:
        raise KeyError(f"No artifacts set. Artifact {name} not found.")
    for artifact in inputs.artifacts:
        if artifact.name == name:
            return Artifact(name=name, from_=f"{{{{inputs.artifacts.{artifact.name}}}}}")
    raise KeyError(f"Artifact {name} not found.")

get_parameter

get_parameter(name: str) -> Parameter

Finds and returns the parameter with the supplied name.

Note that this method will raise an error if the parameter is not found.

Parameters:

Name Type Description Default
name str

name of the input parameter to find and return.

required

Returns:

Name Type Description
Parameter Parameter

the parameter with the supplied name.

Raises:

Type Description
KeyError

if the parameter is not found.

Source code in src/hera/workflows/_mixins.py
def get_parameter(self, name: str) -> Parameter:
    """Finds and returns the parameter with the supplied name.

    Note that this method will raise an error if the parameter is not found.

    Args:
        name: name of the input parameter to find and return.

    Returns:
        Parameter: the parameter with the supplied name.

    Raises:
        KeyError: if the parameter is not found.
    """
    inputs = self._build_inputs()
    if inputs is None:
        raise KeyError(f"No inputs set. Parameter {name} not found.")
    if inputs.parameters is None:
        raise KeyError(f"No parameters set. Parameter {name} not found.")
    for p in inputs.parameters:
        if p.name == name:
            param = Parameter.from_model(p)
            param.value = f"{{{{inputs.parameters.{param.name}}}}}"
            return param
    raise KeyError(f"Parameter {name} not found.")

ScriptConstructor

A ScriptConstructor is responsible for generating the source code for a Script given a python callable.

This allows users to customize the behaviour of the template that hera generates when a python callable is passed to the Script class.

In order to use your custom ScriptConstructor implementation, you can set it as the Script.constructor field.

Source code in src/hera/workflows/script.py
class ScriptConstructor(BaseMixin):
    """A ScriptConstructor is responsible for generating the source code for a Script given a python callable.

    This allows users to customize the behaviour of the template that hera generates when a python callable is
    passed to the Script class.

    In order to use your custom ScriptConstructor implementation, you can set it as the Script.constructor field.
    """

    @abstractmethod
    def generate_source(self, instance: "Script") -> str:
        """A function that can inspect the Script instance and generate the source field."""
        raise NotImplementedError

    def transform_values(self, script: "Script") -> None:
        """A function that will be invoked by the __post_init__ of the Script class."""
        return None

    def transform_script_template_post_build(
        self, instance: "Script", script: _ModelScriptTemplate
    ) -> _ModelScriptTemplate:
        """A hook to transform the generated script template."""
        return script

    def transform_template_post_build(self, instance: "Script", template: _ModelTemplate) -> _ModelTemplate:
        """A hook to transform the generated template."""
        return template

generate_source

generate_source(instance: Script) -> str

A function that can inspect the Script instance and generate the source field.

Source code in src/hera/workflows/script.py
@abstractmethod
def generate_source(self, instance: "Script") -> str:
    """A function that can inspect the Script instance and generate the source field."""
    raise NotImplementedError

transform_script_template_post_build

transform_script_template_post_build(
    instance: Script, script: ScriptTemplate
) -> ScriptTemplate

A hook to transform the generated script template.

Source code in src/hera/workflows/script.py
def transform_script_template_post_build(
    self, instance: "Script", script: _ModelScriptTemplate
) -> _ModelScriptTemplate:
    """A hook to transform the generated script template."""
    return script

transform_template_post_build

transform_template_post_build(
    instance: Script, template: Template
) -> Template

A hook to transform the generated template.

Source code in src/hera/workflows/script.py
def transform_template_post_build(self, instance: "Script", template: _ModelTemplate) -> _ModelTemplate:
    """A hook to transform the generated template."""
    return template

transform_values

transform_values(script: Script) -> None

A function that will be invoked by the post_init of the Script class.

Source code in src/hera/workflows/script.py
def transform_values(self, script: "Script") -> None:
    """A function that will be invoked by the __post_init__ of the Script class."""
    return None

InlineScriptConstructor

InlineScriptConstructor is a script constructor that submits a script as a source to Argo.

This script constructor is focused on taking a Python script/function “as is” for remote execution. The constructor processes the script to infer what parameters it needs to deserialize so the script can execute. The submitted script will contain prefixes such as new imports, e.g. import os, import json, etc. and will contain the necessary json.loads calls to deserialize the parameters so they are usable by the script just like a normal Python script/function.

Source code in src/hera/workflows/script.py
@dataclass(kw_only=True)
class InlineScriptConstructor(ScriptConstructor):
    """`InlineScriptConstructor` is a script constructor that submits a script as a `source` to Argo.

    This script constructor is focused on taking a Python script/function "as is" for remote execution. The
    constructor processes the script to infer what parameters it needs to deserialize so the script can execute.
    The submitted script will contain prefixes such as new imports, e.g. `import os`, `import json`, etc. and
    will contain the necessary `json.loads` calls to deserialize the parameters so they are usable by the script just
    like a normal Python script/function.
    """

    add_cwd_to_sys_path: Optional[bool] = None

    @staticmethod
    def _roundtrip(source):
        tree = ast.parse(source)
        if hasattr(ast, "unparse"):
            return ast.unparse(tree)
        return ast.unparse(tree)

    def _get_param_script_portion(self, instance: Script) -> str:
        """Constructs and returns a script that loads the parameters of the specified arguments.

        Since Argo passes parameters through `{{input.parameters.name}}` it can be very cumbersome for users to
        manage that. This creates a script that automatically imports json and loads/adds code to interpret
        each independent argument into the script.

        Returns:
        -------
        str
            The string representation of the script to load.
        """
        inputs = instance._build_inputs()
        assert inputs
        extract = "import json\n"
        for param in sorted(inputs.parameters or [], key=lambda x: x.name):
            # Hera does not know what the content of the `InputFrom` is, coming from another task. In some cases
            # non-JSON encoded strings are returned, which fail the loads, but they can be used as plain strings
            # which is why this captures that in an except. This is only used for `InputFrom` cases as the extra
            # payload of the script is not necessary when regular input is set on the task via `func_params`
            if param.value_from is None:
                extract += f"""try: {param.name} = json.loads(r'''{{{{inputs.parameters.{param.name}}}}}''')\n"""
                extract += f"""except: {param.name} = r'''{{{{inputs.parameters.{param.name}}}}}'''\n"""
        return textwrap.dedent(extract) if extract != "import json\n" else ""

    def generate_source(self, instance: Script) -> str:
        """Assembles and returns a script representation of the given function.

        This also assembles any extra script material prefixed to the string source.
        The script is expected to be a callable function the client is interested in submitting
        for execution on Argo and the `script_extra` material represents the parameter loading part obtained, likely,
        through `get_param_script_portion`.

        Returns:
        -------
        str
            Final formatted script.
        """
        if not callable(instance.source):
            assert isinstance(instance.source, str)
            return instance.source
        args = inspect.getfullargspec(instance.source).args
        script = ""
        # Argo will save the script as a file and run it with cmd:
        # - python /argo/staging/script
        # However, this prevents the script from importing modules in its cwd,
        # since it's looking for files relative to the script path.
        # We fix this by appending the cwd path to sys:
        if instance.add_cwd_to_sys_path or self.add_cwd_to_sys_path:
            script = "import os\nimport sys\nsys.path.append(os.getcwd())\n"

        script_extra = self._get_param_script_portion(instance) if args else None
        if script_extra:
            script += copy.deepcopy(script_extra)
            script += "\n"

        # We use ast parse/unparse to get the source code of the function
        # in order to have consistent looking functions and getting rid of any comments
        # parsing issues.
        # See https://github.com/argoproj-labs/hera/issues/572
        content = self._roundtrip(textwrap.dedent(inspect.getsource(instance.source))).splitlines()
        for i, line in enumerate(content):
            if line.startswith("def") or line.startswith("async def"):
                break

        s = "\n".join(content[i + 1 :])
        script += textwrap.dedent(s)
        return textwrap.dedent(script)

add_cwd_to_sys_path

add_cwd_to_sys_path: Optional[bool] = None

generate_source

generate_source(instance: Script) -> str

Assembles and returns a script representation of the given function.

This also assembles any extra script material prefixed to the string source. The script is expected to be a callable function the client is interested in submitting for execution on Argo and the script_extra material represents the parameter loading part obtained, likely, through get_param_script_portion.

Returns:

str Final formatted script.

Source code in src/hera/workflows/script.py
def generate_source(self, instance: Script) -> str:
    """Assembles and returns a script representation of the given function.

    This also assembles any extra script material prefixed to the string source.
    The script is expected to be a callable function the client is interested in submitting
    for execution on Argo and the `script_extra` material represents the parameter loading part obtained, likely,
    through `get_param_script_portion`.

    Returns:
    -------
    str
        Final formatted script.
    """
    if not callable(instance.source):
        assert isinstance(instance.source, str)
        return instance.source
    args = inspect.getfullargspec(instance.source).args
    script = ""
    # Argo will save the script as a file and run it with cmd:
    # - python /argo/staging/script
    # However, this prevents the script from importing modules in its cwd,
    # since it's looking for files relative to the script path.
    # We fix this by appending the cwd path to sys:
    if instance.add_cwd_to_sys_path or self.add_cwd_to_sys_path:
        script = "import os\nimport sys\nsys.path.append(os.getcwd())\n"

    script_extra = self._get_param_script_portion(instance) if args else None
    if script_extra:
        script += copy.deepcopy(script_extra)
        script += "\n"

    # We use ast parse/unparse to get the source code of the function
    # in order to have consistent looking functions and getting rid of any comments
    # parsing issues.
    # See https://github.com/argoproj-labs/hera/issues/572
    content = self._roundtrip(textwrap.dedent(inspect.getsource(instance.source))).splitlines()
    for i, line in enumerate(content):
        if line.startswith("def") or line.startswith("async def"):
            break

    s = "\n".join(content[i + 1 :])
    script += textwrap.dedent(s)
    return textwrap.dedent(script)

transform_script_template_post_build

transform_script_template_post_build(
    instance: Script, script: ScriptTemplate
) -> ScriptTemplate

A hook to transform the generated script template.

Source code in src/hera/workflows/script.py
def transform_script_template_post_build(
    self, instance: "Script", script: _ModelScriptTemplate
) -> _ModelScriptTemplate:
    """A hook to transform the generated script template."""
    return script

transform_template_post_build

transform_template_post_build(
    instance: Script, template: Template
) -> Template

A hook to transform the generated template.

Source code in src/hera/workflows/script.py
def transform_template_post_build(self, instance: "Script", template: _ModelTemplate) -> _ModelTemplate:
    """A hook to transform the generated template."""
    return template

transform_values

transform_values(script: Script) -> None

A function that will be invoked by the post_init of the Script class.

Source code in src/hera/workflows/script.py
def transform_values(self, script: "Script") -> None:
    """A function that will be invoked by the __post_init__ of the Script class."""
    return None

RunnerScriptConstructor

RunnerScriptConstructor is a script constructor that runs a script in a container.

The runner script, also known as “The Hera runner”, takes a script/Python function definition, infers the path to the function (module import), assembles a path to invoke the function, and passes any specified parameters to the function. This helps users “save” on the source space required for submitting a function for remote execution on Argo. Execution within the container requires the executing container to include the file that contains the submitted script. More specifically, the container must be created in some process (e.g. CI), so that it conains the script to run remotely.

Source code in src/hera/workflows/script.py
@dataclass(kw_only=True)
class RunnerScriptConstructor(ScriptConstructor):
    """`RunnerScriptConstructor` is a script constructor that runs a script in a container.

    The runner script, also known as "The Hera runner", takes a script/Python function definition, infers the path
    to the function (module import), assembles a path to invoke the function, and passes any specified parameters
    to the function. This helps users "save" on the `source` space required for submitting a function for remote
    execution on Argo. Execution within the container *requires* the executing container to include the file that
    contains the submitted script. More specifically, the container must be created in some process (e.g. CI), so that
    it conains the script to run remotely.
    """

    outputs_directory: Optional[str] = None
    """Used for saving outputs when defined using annotations."""

    volume_for_outputs: Optional[_BaseVolume] = None
    """Volume to use if saving outputs when defined using annotations."""

    DEFAULT_HERA_OUTPUTS_DIRECTORY: str = "/tmp/hera-outputs"
    """Used as the default value for when the outputs_directory is not set"""

    pydantic_mode: Optional[Literal[1, 2]] = None
    """Used for selecting the pydantic version used for BaseModels.
    Allows for using pydantic.v1 BaseModels with pydantic v2.
    Defaults to the installed version of Pydantic."""

    def __post_init__(self):
        """Perform post init validation."""
        super().__post_init__()
        if self.pydantic_mode and self.pydantic_mode > _PYDANTIC_VERSION:
            raise ValueError("v2 pydantic mode only available for pydantic>=2")

    def transform_values(self, script: Script) -> None:
        """A function that can inspect the Script instance and generate the source field."""
        if not callable(script.source):
            return

        if script.args is not None:
            raise ValueError("Cannot specify args when callable is True")

        module = script.source.__module__

        if module == "__main__":
            from hera.workflows._runner.util import create_module_string

            module = create_module_string(Path(script.source.__globals__["__file__"]))

        script.args = [
            "-m",
            "hera.workflows.runner",
            "-e",
            f"{module}:{script.source.__name__}",
        ]

    def generate_source(self, instance: Script) -> str:
        """A function that can inspect the Script instance and generate the source field."""
        return f"{g.inputs.parameters:$}"

    def transform_script_template_post_build(
        self, instance: "Script", script: _ModelScriptTemplate
    ) -> _ModelScriptTemplate:
        """A hook to transform the generated script template."""
        script_env = []

        if self.outputs_directory:
            script_env.append(EnvVar(name="hera__outputs_directory", value=self.outputs_directory))
        if self.pydantic_mode:
            script_env.append(EnvVar(name="hera__pydantic_mode", value=str(self.pydantic_mode)))

        if script_env:
            if not script.env:
                # If user did not set any env vars themselves then we need to initialise the list
                script.env = []

            script.env.extend(script_env)

        return script

DEFAULT_HERA_OUTPUTS_DIRECTORY

DEFAULT_HERA_OUTPUTS_DIRECTORY: str = '/tmp/hera-outputs'

Used as the default value for when the outputs_directory is not set

outputs_directory

outputs_directory: Optional[str] = None

Used for saving outputs when defined using annotations.

pydantic_mode

pydantic_mode: Optional[Literal[1, 2]] = None

Used for selecting the pydantic version used for BaseModels. Allows for using pydantic.v1 BaseModels with pydantic v2. Defaults to the installed version of Pydantic.

volume_for_outputs

volume_for_outputs: Optional[_BaseVolume] = None

Volume to use if saving outputs when defined using annotations.

generate_source

generate_source(instance: Script) -> str

A function that can inspect the Script instance and generate the source field.

Source code in src/hera/workflows/script.py
def generate_source(self, instance: Script) -> str:
    """A function that can inspect the Script instance and generate the source field."""
    return f"{g.inputs.parameters:$}"

transform_script_template_post_build

transform_script_template_post_build(
    instance: Script, script: ScriptTemplate
) -> ScriptTemplate

A hook to transform the generated script template.

Source code in src/hera/workflows/script.py
def transform_script_template_post_build(
    self, instance: "Script", script: _ModelScriptTemplate
) -> _ModelScriptTemplate:
    """A hook to transform the generated script template."""
    script_env = []

    if self.outputs_directory:
        script_env.append(EnvVar(name="hera__outputs_directory", value=self.outputs_directory))
    if self.pydantic_mode:
        script_env.append(EnvVar(name="hera__pydantic_mode", value=str(self.pydantic_mode)))

    if script_env:
        if not script.env:
            # If user did not set any env vars themselves then we need to initialise the list
            script.env = []

        script.env.extend(script_env)

    return script

transform_template_post_build

transform_template_post_build(
    instance: Script, template: Template
) -> Template

A hook to transform the generated template.

Source code in src/hera/workflows/script.py
def transform_template_post_build(self, instance: "Script", template: _ModelTemplate) -> _ModelTemplate:
    """A hook to transform the generated template."""
    return template

transform_values

transform_values(script: Script) -> None

A function that can inspect the Script instance and generate the source field.

Source code in src/hera/workflows/script.py
def transform_values(self, script: Script) -> None:
    """A function that can inspect the Script instance and generate the source field."""
    if not callable(script.source):
        return

    if script.args is not None:
        raise ValueError("Cannot specify args when callable is True")

    module = script.source.__module__

    if module == "__main__":
        from hera.workflows._runner.util import create_module_string

        module = create_module_string(Path(script.source.__globals__["__file__"]))

    script.args = [
        "-m",
        "hera.workflows.runner",
        "-e",
        f"{module}:{script.source.__name__}",
    ]

script

script(**script_kwargs) -> Callable

A decorator that wraps a function into a Script object.

Using this decorator users can define a function that will be executed as a script in a container. Once the Script is returned users can use it as they generally use a Script e.g. as a callable inside a DAG or Steps. Note that invoking the function will result in the template associated with the script to be added to the workflow context, so users do not have to worry about that.

Parameters:

Name Type Description Default
**script_kwargs

Keyword arguments to be passed to the Script object.

{}

Returns:

Type Description
Callable

Function that wraps a given function into a Script.

Source code in src/hera/workflows/script.py
@_add_type_hints(Script)
def script(**script_kwargs) -> Callable:
    """A decorator that wraps a function into a Script object.

    Using this decorator users can define a function that will be executed as a script in a container. Once the
    `Script` is returned users can use it as they generally use a `Script` e.g. as a callable inside a DAG or Steps.
    Note that invoking the function will result in the template associated with the script to be added to the
    workflow context, so users do not have to worry about that.

    Args:
        **script_kwargs: Keyword arguments to be passed to the Script object.

    Returns:
        Function that wraps a given function into a `Script`.
    """

    def script_wrapper(func: Callable[FuncIns, FuncR]) -> Callable:
        """Wraps the given callable so it can be invoked as a Step or Task.

        Parameters
        ----------
        func: Callable
            Function to wrap.

        Returns:
        -------
        Callable
            Callable that represents the `Script` object `__call__` method when in a Steps or DAG context,
            otherwise returns the callable function unchanged.
        """
        # instance methods are wrapped in `staticmethod`. Hera can capture that type and extract the underlying
        # function for remote submission since it does not depend on any class or instance attributes, so it is
        # submittable
        if isinstance(func, staticmethod):
            source: Callable = func.__func__
        else:
            source = func

        if "name" in script_kwargs:
            # take the client-provided `name` if it is submitted, pop the name for otherwise there will be two
            # kwargs called `name`
            name = script_kwargs.pop("name")
        else:
            # otherwise populate the `name` from the function name
            name = source.__name__.replace("_", "-")

        s = Script(name=name, source=source, **script_kwargs)

        @wraps(func)
        def task_wrapper(*args, **kwargs) -> Union[FuncR, Step, Task, None]:
            """Invokes a `Script` object's `__call__` method using the given SubNode (Step or Task) args/kwargs."""
            if _context.active:
                return s.__call__(*args, **kwargs)
            return func(*args, **kwargs)

        # Set the wrapped function to the original function so that we can use it later
        task_wrapper.wrapped_function = func  # type: ignore
        return task_wrapper

    return script_wrapper

Comments