Source code for batter.exec.slurm

"""Execution backend that submits steps to Slurm via ``sbatch``."""

from __future__ import annotations

import os
import shlex
import subprocess
from pathlib import Path
from typing import Any, Dict, Optional

from loguru import logger

from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem

from .base import ExecBackend, Resources

__all__ = ["SlurmBackend"]


_SBATCH_TPL = """#!/usr/bin/env bash
#SBATCH --job-name={job_name}
#SBATCH --output={log_dir}/{job_name}.%j.out
#SBATCH --error={log_dir}/{job_name}.%j.err
{partition}{account}{time}{cpus}{gpus}{mem}{extras}

set -euo pipefail

echo "[batter] node: $(hostname)"
echo "[batter] start: $(date -Is)"

# user environment (optional)
{env_block}

# step payload
{payload}

echo "[batter] done: $(date -Is)"
"""


def _fmt(flag: str, value: Optional[str | int]) -> str:
    """Render a ``#SBATCH`` line when ``value`` is provided."""
    return f"#SBATCH --{flag}={value}\n" if value not in (None, "", 0) else ""


[docs] class SlurmBackend(ExecBackend): """Slurm backend that materializes lightweight job scripts.""" name: str = "slurm"
[docs] def run(self, step: Step, system: SimSystem, params: Dict[str, Any]) -> ExecResult: """Submit ``step`` to Slurm. Parameters ---------- step : Step Pipeline step metadata. system : SimSystem Simulation system whose ``root`` directory stores scripts and logs. params : dict Backend-specific options. Recognised keys include ``resources``, ``env`` (exported variables), and ``payload`` (shell snippet). Returns ------- ExecResult Artifacts referencing the generated script and log paths together with the submitted job identifier (if available). """ root = system.root sbatch_dir = root / "sbatch" log_dir = root / "logs" sbatch_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True) # resources rdict = params.get("resources", {}) or {} res = Resources( time=rdict.get("time"), cpus=rdict.get("cpus"), gpus=rdict.get("gpus"), mem=rdict.get("mem"), partition=rdict.get("partition"), account=rdict.get("account"), extra=rdict.get("extra", {}), ) sbatch_text = _SBATCH_TPL.format( job_name=step.name, log_dir=log_dir.as_posix(), partition=_fmt("partition", res.partition), account=_fmt("account", res.account), time=_fmt("time", res.time), cpus=_fmt("cpus-per-task", res.cpus), gpus=_fmt("gpus", res.gpus), mem=_fmt("mem", res.mem), extras="".join(_fmt(k, v) for k, v in (res.extra or {}).items()), env_block=self._format_env(params.get("env", {})), payload=self._payload(step, system, params), ) script_path = sbatch_dir / f"{step.name}.sh" script_path.write_text(sbatch_text) script_path.chmod(0o755) job_id = self._submit(script_path) logger.info("SLURM: submitted step {!r} as job {}", step.name, job_id or "<unknown>") return ExecResult( job_ids=[job_id] if job_id else [], artifacts={ "script": script_path, "stdout": log_dir / f"{step.name}.{job_id}.out" if job_id else None, "stderr": log_dir / f"{step.name}.{job_id}.err" if job_id else None, }, )
@staticmethod def _format_env(env: Dict[str, str]) -> str: """Create export statements for user-provided environment variables.""" if not env: return ":\n" lines = [f'export {k}={shlex.quote(str(v))}' for k, v in env.items()] return "\n".join(lines) @staticmethod def _payload(step: Step, system: SimSystem, params: Dict[str, Any]) -> str: """Return the shell snippet that drives the step.""" payload = params.get("payload") if isinstance(payload, str) and payload.strip(): return payload return f'echo "[batter] no payload for step {step.name}; system root: {system.root}"' @staticmethod def _submit(script_path: Path) -> Optional[str]: """Submit a script via ``utils.slurm_job`` or ``sbatch``. Parameters ---------- script_path : pathlib.Path Script to submit. Returns ------- str or None Parsed Slurm job identifier when submission succeeds. """ try: from batter.utils.slurm_job import submit_job # type: ignore job_id = submit_job(script_path) # expected to return string/int id return str(job_id) except Exception: pass try: res = subprocess.run(["sbatch", str(script_path)], check=True, capture_output=True, text=True) # typical output: "Submitted batch job 123456" out = res.stdout.strip() for tok in out.split(): if tok.isdigit(): return tok return None except Exception as e: logger.error("Failed to submit sbatch: {}", e) return None