Skip to content

Commit

Permalink
Added timeout to SSHConnector (#644)
Browse files Browse the repository at this point in the history
This commit adds the timeout in the `asyncssh.connect()` call.
The user chooses the time to wait using the `connectTimeout`
option. If the `retry` is enabled, the time to wait is increased at
each attempt.
  • Loading branch information
LanderOtto authored Feb 10, 2025
1 parent e86d21c commit 99c062a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ jobs:
python -m pip install -r docs/requirements.txt
- name: "Build documentation and check for consistency"
env:
CHECKSUM: "b59239241d3529a179df6158271dd00ba7a86e807a37a11ac8e078ad9c377f94"
CHECKSUM: "198f61804843130d3cae0675c67cad121d980c4648cf27c7541d87219afa3d6e"
run: |
cd docs
HASH="$(make checksum | tail -n1)"
Expand Down
5 changes: 5 additions & 0 deletions streamflow/deployment/connector/schemas/ssh.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@
"description": "Perform a strict validation of the host SSH keys (and return exception if key is not recognized as valid)",
"default": true
},
"connectTimeout": {
"type": "integer",
"description": "Max time (in seconds) to wait for establishing an SSH connection.",
"default": 30
},
"dataTransferConnection": {
"oneOf": [
{
Expand Down
38 changes: 20 additions & 18 deletions streamflow/deployment/connector/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def _get_connection(
port=port,
tunnel=await self._get_connection(config.tunnel),
username=config.username,
connect_timeout=config.connect_timeout,
)

def _get_param_from_file(self, file_path: str):
Expand Down Expand Up @@ -114,13 +115,14 @@ async def get_connection(self) -> asyncssh.SSHClientConnection:
self._connecting = True
try:
self._ssh_connection = await self._get_connection(self._config)
except (ConnectionError, asyncssh.Error) as e:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Connection to {self._config.hostname} failed: {e}."
)
except (ConnectionError, asyncssh.Error, asyncio.TimeoutError) as err:
await self.close()
raise
if isinstance(err, asyncio.TimeoutError):
raise asyncio.TimeoutError(
f"SSH connection to {self.get_hostname()} failed: connection timed out"
)
else:
raise
finally:
self._connect_event.set()
else:
Expand Down Expand Up @@ -213,6 +215,7 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess:
ConnectionError,
ConnectionLost,
DisconnectError,
asyncio.TimeoutError,
) as exc:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
Expand Down Expand Up @@ -336,6 +339,7 @@ def __init__(
self,
check_host_key: bool,
client_keys: MutableSequence[str],
connect_timeout: int,
hostname: str,
password_file: str | None,
ssh_key_passphrase_file: str | None,
Expand All @@ -344,6 +348,7 @@ def __init__(
):
self.check_host_key: bool = check_host_key
self.client_keys: MutableSequence[str] = client_keys
self.connect_timeout: int = connect_timeout
self.hostname: str = hostname
self.password_file: str | None = password_file
self.ssh_key_passphrase_file: str | None = ssh_key_passphrase_file
Expand All @@ -359,6 +364,7 @@ def __init__(
nodes: MutableSequence[Any],
username: str | None = None,
checkHostKey: bool = True,
connectTimeout: int = 30,
dataTransferConnection: str | MutableMapping[str, Any] | None = None,
file: str | None = None,
maxConcurrentSessions: int = 10,
Expand Down Expand Up @@ -399,6 +405,7 @@ def __init__(
template_map=services_map,
)
self.checkHostKey: bool = checkHostKey
self.connectTimeout: int = connectTimeout
self.passwordFile: str | None = passwordFile
self.maxConcurrentSessions: int = maxConcurrentSessions
self.maxConnections: int = maxConnections
Expand Down Expand Up @@ -522,21 +529,16 @@ def _get_config(self, node: str | MutableMapping[str, Any]):
return None
elif isinstance(node, str):
node = {"hostname": node}
ssh_key = node["sshKey"] if "sshKey" in node else self.sshKey
ssh_key = node.get("sshKey", self.sshKey)
return SSHConfig(
hostname=node["hostname"],
username=node["username"] if "username" in node else self.username,
check_host_key=(
node["checkHostKey"] if "checkHostKey" in node else self.checkHostKey
),
username=node.get("username", self.username),
check_host_key=node.get("checkHostKey", self.checkHostKey),
client_keys=[ssh_key] if ssh_key is not None else [],
password_file=(
node["passwordFile"] if "passwordFile" in node else self.passwordFile
),
ssh_key_passphrase_file=(
node["sshKeyPassphraseFile"]
if "sshKeyPassphraseFile" in node
else self.sshKeyPassphraseFile
connect_timeout=node.get("connectTimeout", self.connectTimeout),
password_file=node.get("passwordFile", self.passwordFile),
ssh_key_passphrase_file=node.get(
"sshKeyPassphraseFile", self.sshKeyPassphraseFile
),
tunnel=(
self._get_config(node["tunnel"])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ def test_schema_generation():
"""Check that the `streamflow schema` command generates a correct JSON Schema."""
assert (
hashlib.sha256(SfSchema().dump("v1.0", False).encode()).hexdigest()
== "f8e3f739678510fc34afe419b215b54d1467d84ee6433fbb0c107bc30eb1f062"
== "c60eabe4335124cb1a241496ac370667a1525b8ab5847584f8dbf3877419282e"
)
assert (
hashlib.sha256(SfSchema().dump("v1.0", True).encode()).hexdigest()
== "b91f949c055e3f5de305751540725eeba7e1a6deb1082c11bca3c6e7cfa09929"
== "d62e2dc7b71778c6aa278ee28562bbc1b0a534e286296825162ce56ecd4aeb3c"
)


Expand Down

0 comments on commit 99c062a

Please sign in to comment.