Skip to content

Commit

Permalink
Added timeout to asyncssh.connect() call
Browse files Browse the repository at this point in the history
  • Loading branch information
LanderOtto committed Jan 20, 2025
1 parent 6e21f11 commit 50e2a4f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
5 changes: 5 additions & 0 deletions streamflow/deployment/connector/schemas/ssh.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@
"description": "Perform a strict validation of the host SSH keys (and return exception if key is not recognized as valid)",
"default": true
},
"connectTimeout": {
"type": "integer",
"description": "Time (in seconds) to wait for establish the connection. When an attempt fails, the time to wait is increased.",
"default": 30
},
"dataTransferConnection": {
"oneOf": [
{
Expand Down
24 changes: 18 additions & 6 deletions streamflow/deployment/connector/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def _get_connection(
port=port,
tunnel=await self._get_connection(config.tunnel),
username=config.username,
connect_timeout=config.connect_timeout * (self.connection_attempts + 1),
)

def _get_param_from_file(self, file_path: str):
Expand Down Expand Up @@ -114,13 +115,14 @@ async def get_connection(self) -> asyncssh.SSHClientConnection:
self._connecting = True
try:
self._ssh_connection = await self._get_connection(self._config)
except (ConnectionError, asyncssh.Error) as e:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Connection to {self._config.hostname} failed: {e}."
)
except (ConnectionError, asyncssh.Error, asyncio.TimeoutError) as err:
await self.close()
raise
if isinstance(err, asyncio.TimeoutError):
raise asyncio.TimeoutError(
f"The SSH connection attempt to {self.get_hostname()} took too long."
)
else:
raise
finally:
self._connect_event.set()
else:
Expand Down Expand Up @@ -213,6 +215,7 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess:
ConnectionError,
ConnectionLost,
DisconnectError,
asyncio.TimeoutError,
) as exc:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
Expand Down Expand Up @@ -336,6 +339,7 @@ def __init__(
self,
check_host_key: bool,
client_keys: MutableSequence[str],
connect_timeout: int,
hostname: str,
password_file: str | None,
ssh_key_passphrase_file: str | None,
Expand All @@ -344,6 +348,7 @@ def __init__(
):
self.check_host_key: bool = check_host_key
self.client_keys: MutableSequence[str] = client_keys
self.connect_timeout: int = connect_timeout
self.hostname: str = hostname
self.password_file: str | None = password_file
self.ssh_key_passphrase_file: str | None = ssh_key_passphrase_file
Expand All @@ -359,6 +364,7 @@ def __init__(
nodes: MutableSequence[Any],
username: str | None = None,
checkHostKey: bool = True,
connectTimeout: int = 30,
dataTransferConnection: str | MutableMapping[str, Any] | None = None,
file: str | None = None,
maxConcurrentSessions: int = 10,
Expand Down Expand Up @@ -399,6 +405,7 @@ def __init__(
template_map=services_map,
)
self.checkHostKey: bool = checkHostKey
self.connect_timeout: int = connectTimeout
self.passwordFile: str | None = passwordFile
self.maxConcurrentSessions: int = maxConcurrentSessions
self.maxConnections: int = maxConnections
Expand Down Expand Up @@ -543,6 +550,11 @@ def _get_config(self, node: str | MutableMapping[str, Any]):
if "tunnel" in node
else self.tunnel if hasattr(self, "tunnel") else None
),
connect_timeout=(
node["connect_timeout"]
if "connect_timeout" in node
else self.connect_timeout
),
)

def _get_ssh_client_process(
Expand Down

0 comments on commit 50e2a4f

Please sign in to comment.