Source code for poiesis.core.services.texam.texam

"""TExAM (Task Executor and Monitor) service."""

import asyncio
import logging
import os
import shlex
from pathlib import Path

from kubernetes import watch  # type: ignore
from kubernetes.client import (
    V1Container,
    V1EnvVar,
    V1Job,
    V1JobSpec,
    V1ObjectMeta,
    V1PersistentVolumeClaimVolumeSource,
    V1PodSpec,
    V1PodTemplateSpec,
    V1ResourceRequirements,
    V1Volume,
    V1VolumeMount,
)

from poiesis.api.tes.models import (
    TesExecutor,
    TesFileType,
    TesTask,
)
from poiesis.core.adaptors.kubernetes.kubernetes import KubernetesAdapter
from poiesis.core.adaptors.message_broker.redis_adaptor import RedisMessageBroker
from poiesis.core.constants import (
    get_executor_container_security_context,
    get_executor_pod_security_context,
    get_executor_security_volume,
    get_labels,
    get_poiesis_core_constants,
)
from poiesis.core.ports.message_broker import Message, MessageStatus
from poiesis.core.services.models import PodPhase
from poiesis.repository.mongo import MongoDBClient

core_constants = get_poiesis_core_constants()

logger = logging.getLogger(__name__)


[docs] class Texam: """TExAM service. Args: task: TesTask object. Attributes: task: TesTask object. task_id: Task ID. kubernetes_client: Kubernetes client. message_broker: Message broker. failed: Flag defining if TE failed. db: MongoDB client. _mounts_cache: Cache for volume mounts. """ def __init__( self, task: TesTask, ) -> None: """TExAM service. Args: task: TesTask object. """ self.task = task self.task_id = task.id self.kubernetes_client = KubernetesAdapter() self.message_broker = RedisMessageBroker() self.failed = False self.db = MongoDBClient() self._mounts_cache: list[V1VolumeMount] | None = None
[docs] async def execute(self) -> None: """Execute TExAM. Creates individual k8s Jobs for each executor sequentially. If any executor fails, remaining executors are marked as failed. """ for idx, executor in enumerate(self.task.executors): if self.failed: # If previous executor failed, mark remaining executors as failed executor_name = f"{core_constants.K8s.TE_PREFIX}-{self.task_id}-{idx}" if self.task_id is not None: await self.db.add_task_executor_log(self.task_id) await self.db.update_executor_log( executor_name, PodPhase.FAILED.value, stdout="", stderr=( f"Executor {idx} failed to start because previous executor" " failed." ), ) continue # Create and monitor executor sequentially success = await self.run_single_executor(executor, idx) if not success: self.failed = True # Mark remaining executors as failed for remaining_idx in range(idx + 1, len(self.task.executors)): remaining_executor_name = ( f"{core_constants.K8s.TE_PREFIX}-{self.task_id}-{remaining_idx}" ) if self.task_id is not None: await self.db.add_task_executor_log(self.task_id) await self.db.update_executor_log( remaining_executor_name, PodPhase.FAILED.value, stdout="", stderr=( f"Executor {remaining_idx} failed to start because" f" executor {idx} failed." ), ) break await self.message()
[docs] async def run_single_executor(self, executor: TesExecutor, idx: int) -> bool: """Run a single executor and monitor it to completion. Args: executor: Executor object. idx: Index of the executor. Returns: True if executor completed successfully, False otherwise. """ # Create the executor job job_created = await self.create_executor_job(executor, idx) return await self.monitor_executor_job(executor, idx) if job_created else False
def _build_command_string(self, executor: TesExecutor) -> str: """Constructs a shell command string. Get the command from the executor and construct a shell command string with proper redirection and error handling. Args: executor: Executor object. """ command_str = " ".join(shlex.quote(arg) for arg in executor.command) # Handle stdin redirection from a file if executor.stdin: command_str = f"{command_str} < {shlex.quote(executor.stdin)}" # Handle stdout and stderr redirection if executor.stdout and executor.stderr: command_str += ( f" > {shlex.quote(executor.stdout)} 2> {shlex.quote(executor.stderr)}" ) elif executor.stdout: command_str += f" > {shlex.quote(executor.stdout)}" elif executor.stderr: command_str += f" 2> {shlex.quote(executor.stderr)}" # Ignore errors if required if executor.ignore_error: command_str += " || true" return command_str
[docs] async def create_executor_job(self, executor: TesExecutor, idx: int) -> bool: """Create a k8s Job for the executor. Args: executor: Executor object. idx: Index of the executor. Returns: True if job was created successfully, False otherwise. """ executor_name = f"{core_constants.K8s.TE_PREFIX}-{self.task_id}-{idx}" job_manifest = self._create_executor_job_manifest(executor, idx) backoff_time = 1 while backoff_time < core_constants.Texam.BACKOFF_LIMIT: logger.debug( "Exponential backoff attempt: " f"{backoff_time}/{core_constants.Texam.BACKOFF_LIMIT} " f"to create job for {executor_name}." ) try: if self.task_id is None: raise ValueError("Task ID is None") logger.debug( f"Creating job for {executor_name}: {job_manifest.to_dict()}" ) await self.kubernetes_client.create_job(job_manifest) await self.db.add_task_executor_log(self.task_id) return True except Exception as e: logger.error(f"Failed to create job {executor_name}: {e}") logger.info(f"Deleting job {executor_name}") await self.kubernetes_client.delete_job(executor_name) # We don't need to delete the executor log from the DB, # since it isn't added until after the job is created. # If TExAM has launched successfully, the DB is clearly functional. logger.info(f"Retrying in {backoff_time} seconds") await asyncio.sleep(backoff_time) backoff_time = min(backoff_time * 2, core_constants.Texam.BACKOFF_LIMIT) # After all retries failed, log the failure and mark run as failed so # all executors can be marked as failed. await self.db.update_executor_log( executor_name, PodPhase.FAILED.value, stdout="", stderr="Failed to create executor job after multiple retries.", ) logger.error(f"Job {executor_name} failed to be created after all retries") return False
def _is_covered(self, path: str, mounted_set: set) -> bool: """Check if any mounted path is a prefix of this path.""" parts = Path(path).parts for i in range(1, len(parts) + 1): prefix = Path(*parts[:i]) if str(prefix) in mounted_set: return True return False def _get_mounts( self, ) -> list[V1VolumeMount]: """Get the mounts for the executor.""" if self._mounts_cache is not None: return self._mounts_cache # Volumes – mount all directly, as they will be dirs volume_mounts = set(self.task.volumes or []) # Outputs – derive parent dirs and pick minimal set (no nested ones) # as output mount wont be file, and the path will always be at least # nested 1 level output_dirs = set() for o in self.task.outputs or []: if str(o.type) == str(TesFileType.FILE): output_dirs.add(str(Path(o.path).parent)) else: output_dirs.add(o.path) # Remove subdirectories if parent is present output_mounts = set() for d in sorted(output_dirs, key=lambda x: x.count("/")): if not self._is_covered(d, output_mounts): output_mounts.add(d) # Inputs – only add if not covered by volumes or output mounts, # inputs can be a file, if the parent is root then mount as is # because we can't mount root. input_mounts = set() for inp in self.task.inputs or []: inp_path = inp.path if str(inp.type) == str(TesFileType.DIRECTORY): mount_target = inp_path else: parent_path = Path(inp_path).parent mount_target = ( inp_path if parent_path == parent_path.root else str(parent_path) ) if not self._is_covered(mount_target, volume_mounts | output_mounts): input_mounts.add( mount_target if str(inp.type) == str(TesFileType.DIRECTORY) else inp_path ) # Combine all final_mounts = volume_mounts | output_mounts | input_mounts self._mounts_cache = [ V1VolumeMount( name=core_constants.K8s.COMMON_PVC_VOLUME_NAME, mount_path=p, sub_path=p.strip("/"), ) for p in sorted(final_mounts) ] logger.debug(f"Mounts: {[m.to_dict() for m in self._mounts_cache]}") return self._mounts_cache def _create_executor_job_manifest(self, executor: TesExecutor, idx: int) -> V1Job: """Create a k8s Job for the executor. Args: executor: Executor object. idx: Index of the executor. """ executor_name = f"{core_constants.K8s.TE_PREFIX}-{self.task_id}-{idx}" _resource = ( { "cpu": str(self.task.resources.cpu_cores) if self.task.resources.cpu_cores else None, "memory": f"{self.task.resources.ram_gb}Gi" if self.task.resources.ram_gb else None, } if self.task.resources else {} ) resource = ( {k: v for k, v in _resource.items() if v is not None} if _resource else {} ) if self.task_id is None: raise ValueError(f"task_id cannot be None for executor '{executor_name}'") _parent = f"{core_constants.K8s.TEXAM_PREFIX}-{self.task_id}" return V1Job( api_version="batch/v1", kind="Job", metadata=V1ObjectMeta( name=executor_name, labels=get_labels( component=core_constants.K8s.TE_PREFIX, task_id=self.task_id, name=executor_name, parent=f"{core_constants.K8s.TEXAM_PREFIX}-{self.task_id}", ), ), spec=V1JobSpec( # Note: Backoff limit is set to 0 to fail immediately when pod fails. # This is because we want to fail all subsequent executors if any # executor fails. backoff_limit=0, ttl_seconds_after_finished=( int(core_constants.K8s.JOB_TTL) if core_constants.K8s.JOB_TTL else None ), template=V1PodTemplateSpec( metadata=V1ObjectMeta( labels=get_labels( component=core_constants.K8s.TE_PREFIX, task_id=self.task_id, name=executor_name, parent=f"{core_constants.K8s.TEXAM_PREFIX}-{self.task_id}", ) ), spec=V1PodSpec( security_context=get_executor_pod_security_context(), containers=[ V1Container( name=executor_name, image=executor.image, command=["/bin/sh", "-c"], args=[self._build_command_string(executor)], working_dir=executor.workdir, env=( [ V1EnvVar(name=key, value=value) for key, value in executor.env.items() ] if executor.env is not None else [] ), resources=V1ResourceRequirements( limits=resource, requests=resource, ), volume_mounts=self._get_mounts(), image_pull_policy=core_constants.K8s.IMAGE_PULL_POLICY, security_context=get_executor_container_security_context(), ) ], volumes=[ V1Volume( name=core_constants.K8s.COMMON_PVC_VOLUME_NAME, persistent_volume_claim=V1PersistentVolumeClaimVolumeSource( claim_name=f"{core_constants.K8s.PVC_PREFIX}-{self.task_id}" ), ), ] + get_executor_security_volume(), restart_policy=core_constants.K8s.RESTART_POLICY, ), ), ), )
[docs] async def monitor_executor_job(self, executor: TesExecutor, idx: int) -> bool: """Monitor the executor job and log details in TaskDB. Args: executor: Executor object. idx: Index of the executor. Returns: True if executor completed successfully, False otherwise. """ executor_name = f"{core_constants.K8s.TE_PREFIX}-{self.task_id}-{idx}" timeout = int( os.getenv( "MONITOR_TIMEOUT_SECONDS", core_constants.Texam.MONITOR_TIMEOUT_SECONDS ) ) try: w = watch.Watch() logger.info(f"Starting watch for job: {executor_name}") # Watch for job completion for event in w.stream( self.kubernetes_client.batch_api.list_namespaced_job, namespace=self.kubernetes_client.namespace, field_selector=f"metadata.name={executor_name}", timeout_seconds=timeout, ): if not isinstance(event, dict): continue job = event["object"] if job.metadata.name != executor_name: continue # Check job status if job.status and job.status.conditions: for condition in job.status.conditions: if condition.type == "Complete" and condition.status == "True": # Job completed successfully logs = await self._get_job_logs(executor_name) await self.db.update_executor_log( executor_name, PodPhase.SUCCEEDED.value, stdout=logs[0], stderr=logs[1], ) logger.info(f"Job {executor_name} completed successfully") w.stop() return True elif condition.type == "Failed" and condition.status == "True": # Job failed logs = await self._get_job_logs(executor_name) await self.db.update_executor_log( executor_name, PodPhase.FAILED.value, stdout=logs[0], stderr=logs[1], ) logger.error( f"Job {executor_name} failed: {condition.message}" ) w.stop() return False # If we reach here, the timeout was reached logger.error( f"Job {executor_name} monitoring timed out after {timeout} seconds" ) await self.db.update_executor_log( executor_name, PodPhase.FAILED.value, stdout="", stderr=f"Job monitoring timed out after {timeout} seconds.", ) w.stop() return False except Exception as e: logger.exception(f"Error monitoring job {executor_name}: {e}") await self.db.update_executor_log( executor_name, PodPhase.FAILED.value, stdout="", stderr=f"Error monitoring job: {str(e)}", ) return False
async def _get_job_logs(self, job_name: str) -> tuple[str, str]: """Get logs from the job's pod. Args: job_name: Name of the job. Returns: Tuple of stdout and stderr logs. """ max_retries = 3 retry_delay = 1 for attempt in range(max_retries): try: # Get pods for this job using the job-name label pods = self.kubernetes_client.core_api.list_namespaced_pod( namespace=self.kubernetes_client.namespace, label_selector=f"job-name={job_name}", ) if pods.items: # Get logs from the first pod (jobs typically create one pod) pod = pods.items[0] if pod.metadata and pod.metadata.name: pod_name = pod.metadata.name logger.debug( f"Getting logs from pod {pod_name} of job {job_name}" ) # Try to get logs, with retry for pods that aren't ready yet try: return await self.kubernetes_client.get_pod_log( pod_name ), "" except Exception as log_error: logger.warning( f"Failed to get logs from pod {pod_name}: {log_error}" ) if attempt < max_retries - 1: logger.info( f"Retrying log retrieval for pod {pod_name} " f"(attempt {attempt + 1}/{max_retries})" ) await asyncio.sleep(retry_delay) continue else: logger.error( f"Failed to get logs from pod {pod_name} after " f"{max_retries} attempts" ) return ( "", f"Failed to get logs for executor {job_name} " f"after {max_retries} attempts", ) else: logger.warning( f"Pod metadata or name is missing for job {job_name}" ) else: logger.warning( f"No pods found for job {job_name} (attempt " f"{attempt + 1}/{max_retries})" ) if attempt < max_retries - 1: await asyncio.sleep(retry_delay) continue except Exception as e: logger.warning( f"Could not get logs for job {job_name} (attempt " f"{attempt + 1}/{max_retries}): {e}" ) if attempt < max_retries - 1: await asyncio.sleep(retry_delay) continue logger.error(f"Failed to get logs for job {job_name} after all attempts") return "", f"Internal error while getting logs for executor {job_name}."
[docs] async def message(self) -> None: """Send message to TORC.""" assert self.task_id is not None if not self.failed: self.message_broker.publish( self.task_id, Message(f"TExAM job for {self.task_id} has been completed."), ) else: self.message_broker.publish( self.task_id, Message( message="TExAM job failed to run all jobs successfully.", status=MessageStatus.ERROR, ), )