Source code for poiesis.repository.mongo

"""MongoDB adaptor for NoSQL database operations."""

import logging
import os
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from typing import Any

import motor.motor_asyncio
from bson.objectid import ObjectId

from poiesis.api.exceptions import DBException
from poiesis.api.tes.models import TesExecutorLog, TesState, TesTask, TesTaskLog
from poiesis.constants import get_poiesis_constants
from poiesis.core.adaptors.kubernetes.kubernetes import KubernetesAdapter
from poiesis.core.constants import get_poiesis_core_constants
from poiesis.core.services.models import PodPhase
from poiesis.repository.schemas import TaskSchema

logger = logging.getLogger(__name__)

poiesis_constants = get_poiesis_constants()
poiesis_core_constants = get_poiesis_core_constants()


[docs] class MongoDBClient: """Simple MongoDB client using Motor for async operations.""" def __init__(self) -> None: """Initialize MongoDB client with connection pooling. Args: connection_string: MongoDB connection URI database: Default database name max_pool_size: Maximum number of connections in the pool """ self.connection_string = os.getenv( poiesis_core_constants.K8s.MONGODB_URI_SECRET_KEY ) logger.debug(f"MongoDB connection string: {self.connection_string}") self.database = poiesis_constants.Database.MongoDB.DATABASE self.max_pool_size = poiesis_constants.Database.MongoDB.MAX_POOL_SIZE self.client: motor.motor_asyncio.AsyncIOMotorClient = ( motor.motor_asyncio.AsyncIOMotorClient( self.connection_string, maxPoolSize=self.max_pool_size ) ) self.db = self.client[self.database] self.kubernetes_client = KubernetesAdapter()
[docs] async def get_task(self, task_id: str) -> TaskSchema: """Get a task from the database. Args: task_id: ID of the task """ task = await self.db[ poiesis_constants.Database.MongoDB.TASK_COLLECTION ].find_one({"task_id": task_id}) if task is None: raise DBException(f"Task with ID {task_id} not found") return TaskSchema(**task)
[docs] async def insert_task(self, task: TaskSchema) -> str: """Insert a single document into the specified collection. Args: task: Task to insert Returns: The inserted document ID as a string Raises: DBException: If the document is not valid or the insert operation fails """ try: result = await self.db[ poiesis_constants.Database.MongoDB.TASK_COLLECTION ].insert_one(task.model_dump()) return str(result.inserted_id) except Exception as e: logger.error( "Error inserting document into collection " f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {str(e)}" ) raise DBException( "Error inserting document into collection " f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {e}", ) from e
[docs] async def update_task_state(self, task_id: str, state: TesState) -> None: """Update the state of a task in the database. This would be called by jobs in case of task state change or failure. Args: task_id: ID of the task state: State of the task Raises: DBException: If the update operation fails """ try: task = await self.get_task(task_id) if task.data.state != state: task.data.state = state await self.db[ poiesis_constants.Database.MongoDB.TASK_COLLECTION ].update_one( {"task_id": task_id}, { "$set": { "state": state.value, "updated_at": datetime.now(UTC), "data.state": state.value, } }, ) except Exception as e: logger.error( "Error updating document in collection" f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {str(e)}" ) raise DBException( "Error updating document in collection" f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {e}", ) from e
[docs] async def add_task_log(self, task_id: str) -> None: """Add a log for a task in the database. Args: task_id: ID of the task Note: This is because the spec defines that in case of task failure and retry, another log will be added to the task. """ _log = TesTaskLog(logs=[], outputs=[]) try: task = await self.get_task(task_id) task.data.logs = task.data.logs or [] task.data.logs.append(_log) await self.db[ poiesis_constants.Database.MongoDB.TASK_COLLECTION ].update_one( {"task_id": task_id}, { "$set": { "data.logs": [log.model_dump() for log in task.data.logs], } }, ) except Exception as e: logger.error( "Error adding task log in collection" f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {str(e)}" ) raise DBException( "Error adding task log in collection" f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {e}", ) from e
[docs] async def add_tes_task_log_end_time(self, task_id: str) -> None: """Add the end time of a task in the database. Args: task_id: ID of the task """ try: task = await self.get_task(task_id) assert task.data.logs is not None task.data.logs[-1].end_time = datetime.now(UTC).strftime( "%Y-%m-%dT%H:%M:%S%z" ) await self.db[ poiesis_constants.Database.MongoDB.TASK_COLLECTION ].update_one( {"task_id": task_id}, { "$set": { # TODO: check if this can be optimized with # data.logs[-1].end_time "data.logs": [log.model_dump() for log in task.data.logs], } }, ) except Exception as e: logger.error( "Error adding task log in collection" f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {str(e)}" ) raise DBException( "Error adding task log in collection" f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {e}", ) from e
[docs] async def add_tes_task_system_logs( self, task_id: str, system_logs: list[str] | None = None ) -> None: """Add system logs for a task in the database. Args: task_id: ID of the task system_logs: System logs to add, custom logs apart from the pod logs. """ try: task = await self.get_task(task_id) assert task.data.logs is not None # Define job prefixes to look for job_prefixes = [ f"{poiesis_core_constants.K8s.TEXAM_PREFIX}-{task_id}", f"{poiesis_core_constants.K8s.TOF_PREFIX}-{task_id}", f"{poiesis_core_constants.K8s.TIF_PREFIX}-{task_id}", ] system_logs = system_logs or [] # Collect logs from all related pods for prefix in job_prefixes: try: pods = await self.kubernetes_client.get_pods( label_selector=f"job-name={prefix}" ) for pod in pods: assert pod.metadata is not None assert pod.metadata.name is not None pod_logs = await self.kubernetes_client.get_pod_log( pod.metadata.name ) if pod_logs: assert pod.metadata is not None system_logs.append( f"Logs from {pod.metadata.name}: {pod_logs}" ) except Exception as e: system_logs.append(f"Error getting logs for {prefix}: {str(e)}") # Add system logs to the task if task.data.logs: task.data.logs[-1].system_logs = system_logs # Update the task in the database await self.db[ poiesis_constants.Database.MongoDB.TASK_COLLECTION ].update_one( {"task_id": task_id}, { "$set": { "data.logs": [log.model_dump() for log in task.data.logs], } }, ) except Exception as e: logger.error( "Error adding system logs in collection " f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {str(e)}" ) raise DBException( "Error adding system logs in collection " f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {e}", ) from e
[docs] async def add_task_executor_log(self, task_id: str) -> None: """Append a log for a task in the database. Each executor has a log. Args: task_id: ID of the task Note: This assumes that the executors are called sequentially, hence we will just append to the last log. """ # XXX: We initialize the log with exit code 0 _log = TesExecutorLog(exit_code=0) try: task = await self.get_task(task_id) # This shouldn't be needed as add_task_log should have been called task.data.logs = task.data.logs or [] # Last logs is the current task log, hence we pick the last one task.data.logs[-1].logs.append(_log) await self.db[ poiesis_constants.Database.MongoDB.TASK_COLLECTION ].update_one( {"task_id": task_id}, { "$set": { "data.logs": [log.model_dump() for log in task.data.logs], } }, ) except Exception as e: logger.error( "Error upserting task log in collection" f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {str(e)}" ) raise DBException( "Error upserting task log in collection" f"{poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {e}", ) from e
[docs] async def update_executor_log( self, job_name: str, pod_phase: str, stdout: str, stderr: str | None = None, ) -> None: """Update the executor log in the database. Get the index of the executor from executor name and then updates the idx log of executor of the last log of the task. Note: If the pods fails to start, we can't call the get_pod_log method. If the stdout and stderr are provided, we use them instead of the pod log, else try to call the get_pod_log method, if that fails, we use empty strings. Args: job_name: Name of the job pod_phase: Phase of the pod stdout: Standard output of the pod stderr: Standard error of the pod """ try: # Note: The executor name is of the form <te_prefix>-<UUID>-<idx>. pod_name_without_prefix = job_name.split( f"{poiesis_core_constants.K8s.TE_PREFIX}-" )[-1] parts = pod_name_without_prefix.split("-") idx = int(parts[-1]) # UUID has 6 parts, hence we take the first 5 task_id = "-".join(parts[:5]) task = await self.get_task(task_id) assert task.data.logs is not None exec_log = task.data.logs[-1].logs[idx] exec_log.end_time = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S%z") exec_log.stderr = stderr or "" exec_log.stdout = stdout exec_log.exit_code = 0 if pod_phase == PodPhase.SUCCEEDED.value else 1 await self.db[ poiesis_constants.Database.MongoDB.TASK_COLLECTION ].update_one( {"task_id": task_id}, { "$set": { "data.logs": [log.model_dump() for log in task.data.logs], } }, ) except Exception as e: logger.error( "Error updating executor log in collection" f" {poiesis_constants.Database.MongoDB.TASK_COLLECTION}: {str(e)}" )
[docs] async def list_tasks( self, query: dict[str, Any], page_size: int | None = None, next_page_token: str | None = None, ) -> tuple[list[TesTask], str | None]: """List tasks from the database with pagination. Args: query: Query to filter tasks page_size: Number of tasks to return next_page_token: Token for returning the next page of results Returns: tuple[list[TesTask], Optional[str]]: list of tasks matching the query, and next page token """ db_query = query.copy() # If there's a next page token, filter documents after that ID if next_page_token: try: db_query["_id"] = {"$gt": ObjectId(next_page_token)} except Exception as e: raise ValueError("Invalid next_page_token") from e cursor = ( self.db[poiesis_constants.Database.MongoDB.TASK_COLLECTION] .find(db_query) .sort("_id", 1) ) if page_size is not None: cursor = cursor.limit(page_size + 1) docs = await cursor.to_list(None) tasks = [TesTask(**doc["data"]) for doc in docs[:page_size]] next_token = None if page_size is not None and len(docs) > page_size: next_token = str(docs[page_size]["_id"]) return tasks, next_token
[docs] @asynccontextmanager async def connection(self): """Async context manager for explicit connection handling. Yields: AsyncIOMotorDatabase instance """ try: yield self.db finally: # Motor handles connection pooling internally pass
[docs] async def close(self): """Close all connections in the pool.""" self.client.close()