Source code for tardis.utilities.executors.sshexecutor

from typing import Optional, NamedTuple
from ...configuration.utilities import enable_yaml_load
from ...exceptions.tardisexceptions import TardisAuthError
from ...exceptions.executorexceptions import CommandExecutionFailure, ExecutorFailure
from ...interfaces.executor import Executor
from ..attributedict import AttributeDict
from cobald.daemon.plugins import yaml_tag

import asyncio
import asyncssh
import logging
import pyotp
from asyncssh.auth import KbdIntPrompts, KbdIntResponse
from asyncssh.client import SSHClient
from asyncssh.misc import MaybeAwait

from asyncstdlib import (
    ExitStack as AsyncExitStack,
    contextmanager as asynccontextmanager,
)

from functools import partial

logger = logging.getLogger("cobald.runtime.tardis.utilities.executors.sshexecutor")


[docs]async def probe_max_session(connection: asyncssh.SSHClientConnection): """ Probe the sshd `MaxSessions`, i.e. the multiplexing limit per connection """ sessions = 0 # It does not actually matter what kind of session we open here, but: # - it should stay open without a separate task to manage it # - it should reliably and promptly clean up when done probing # `create_process` is a bit heavy but does all that. async with AsyncExitStack() as aes: try: while True: await aes.enter_context(await connection.create_process()) sessions += 1 except asyncssh.ChannelOpenError: pass return sessions
[docs]class MFASSHClient(SSHClient): def __init__(self, *args, mfa_config, **kwargs): super().__init__(*args, **kwargs) self._mfa_responses = {} for entry in mfa_config: self._mfa_responses[entry["prompt"].strip()] = pyotp.TOTP(entry["totp"])
[docs] async def kbdint_auth_requested(self) -> MaybeAwait[Optional[str]]: """ Keyboard-interactive authentication has been requested This method should return a string containing a comma-separated list of submethods that the server should use for keyboard-interactive authentication. An empty string can be returned to let the server pick the type of keyboard-interactive authentication to perform. """ return ""
[docs] async def kbdint_challenge_received( self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts ) -> MaybeAwait[Optional[KbdIntResponse]]: """ A keyboard-interactive auth challenge has been received This method is called when the server sends a keyboard-interactive authentication challenge. The return value should be a list of strings of the same length as the number of prompts provided if the challenge can be answered, or `None` to indicate that some other form of authentication should be attempted. """ # prompts is of type Sequence[Tuple[str, bool]] try: return [self._mfa_responses[prompt[0].strip()].now() for prompt in prompts] except KeyError as ke: msg = f"Keyboard interactive authentication failed: Unexpected Prompt {ke}" logger.error(msg) raise TardisAuthError(msg) from ke
[docs]class ConnectionState(NamedTuple): """State associated with an active SSH connection""" #: the SSH connection itself connection: asyncssh.SSHClientConnection #: bound on concurrent sessions over the connection bound: asyncio.Semaphore
[docs]@enable_yaml_load("!SSHExecutor") @yaml_tag(eager=True) class SSHExecutor(Executor): """ Execute shell commands via an SSH connection This class provides several convenience features over a raw SSH connection: - Establishing a connection includes retries for temporary unavailability - An established connection is multiplexed for concurrent commands - Executing commands are used as feedback on the connection state - On connection failure both connection and commands are automatically retried Notably, these features work in accord: Once a single command fails due to a broken connection, multiplexing means all commands are queued until the connection is reestablished. Retrying failed commands efficiently waits for the single connection to be retried. :param on_disconnect_retry: Whether to retry commands if the connection is lost """ def __init__(self, *, on_disconnect_retry: "int | bool" = 3, **parameters): self.on_disconnect_retry = int(on_disconnect_retry) self._parameters = parameters # enable Multi-factor Authentication if required if mfa_config := self._parameters.pop("mfa_config", None): self._parameters["client_factory"] = partial( MFASSHClient, mfa_config=mfa_config ) # the current SSH connection unless it must be (re-)established self._connection_state: "ConnectionState | None" = None self._lock = None async def _establish_connection(self): for retry in range(9): try: return await asyncssh.connect(**self._parameters) except ( ConnectionResetError, asyncssh.DisconnectError, asyncssh.ConnectionLost, BrokenPipeError, ): await asyncio.sleep(2**retry) return await asyncssh.connect(**self._parameters) def _handle_broken_ssh_connection( self, ssh_connection: asyncssh.SSHClientConnection, chained_exception: "Exception | None" = None, ): # clear broken connection to get it replaced # by a new connection during next command if ( self._connection_state is not None and ssh_connection is self._connection_state.connection ): self._connection_state = None raise ExecutorFailure( description="SSH connection lost", executor=self, ) from chained_exception @property @asynccontextmanager async def bounded_connection(self): """ Get the current connection with a single reserved session slot This is a context manager that guards the current :py:class:`~asyncssh.SSHClientConnection` so that only `MaxSessions` commands run at once. """ if self._connection_state is None: async with self.lock: # check that connection has not been initialized in a different task while self._connection_state is None: connection = await self._establish_connection() max_session = await probe_max_session(connection) self._connection_state = ConnectionState( connection, asyncio.Semaphore(value=max_session) ) assert self._connection_state is not None session, bound = self._connection_state async with bound: yield session @property def lock(self): """Lock protecting the connection""" # Create lock once tardis event loop is running. # To avoid got Future <Future pending> attached to a different loop exception if self._lock is None: self._lock = asyncio.Lock() return self._lock
[docs] async def run_command(self, command: str, stdin_input: "str | None" = None): try: return await self._run_command_once(command, stdin_input) except ExecutorFailure: for _ in range(self.on_disconnect_retry): try: return await self._run_command_once(command, stdin_input) except ExecutorFailure: pass raise
async def _run_command_once(self, command, stdin_input=None): async with self.bounded_connection as ssh_connection: try: response = await ssh_connection.run( command, check=True, input=stdin_input ) except asyncssh.ProcessError as pe: raise CommandExecutionFailure( message=f"Run command {command} via SSHExecutor failed", exit_code=pe.exit_status, stdin=stdin_input, stdout=pe.stdout, stderr=pe.stderr, ) from pe except asyncssh.ChannelOpenError as coe: self._handle_broken_ssh_connection( ssh_connection, chained_exception=coe ) else: # In case asyncssh loses the connection while running a command, the # connection loss seems to be silently ignored, however the # exit_status is None in that case. if response.exit_status is None: self._handle_broken_ssh_connection(ssh_connection) return AttributeDict( stdout=response.stdout, stderr=response.stderr, exit_code=response.exit_status, )
[docs]@enable_yaml_load("!DupingSSHExecutor") @yaml_tag(eager=True) class DupingSSHExecutor(SSHExecutor): def __init__(self, *, wrapper="/bin/bash", **parameters): self._wrapper_script = wrapper super().__init__(**parameters)
[docs] async def run_command(self, command, stdin_input=None): stdin_input = f"{command}\n{stdin_input}\n" if stdin_input else f"{command}\n" return await super().run_command(self._wrapper_script, stdin_input=stdin_input)