diff --git a/aztk/client.py b/aztk/client.py index 342370f1..b3a96c15 100644 --- a/aztk/client.py +++ b/aztk/client.py @@ -229,7 +229,7 @@ def __delete_user_on_pool(self, username, pool_id, nodes): concurrent.futures.wait(futures) - def __cluster_run(self, cluster_id, command, internal, container_name=None): + def __cluster_run(self, cluster_id, command, internal, container_name=None, timeout=None): pool, nodes = self.__get_pool_details(cluster_id) nodes = [node for node in nodes] if internal: @@ -242,14 +242,15 @@ def __cluster_run(self, cluster_id, command, internal, container_name=None): 'aztk', cluster_nodes, ssh_key=ssh_key.exportKey().decode('utf-8'), - container_name=container_name)) + container_name=container_name, + timeout=timeout)) return output except OSError as exc: raise exc finally: self.__delete_user_on_pool('aztk', pool.id, nodes) - def __cluster_copy(self, cluster_id, source_path, destination_path, container_name=None, internal=False, get=False): + def __cluster_copy(self, cluster_id, source_path, destination_path, container_name=None, internal=False, get=False, timeout=None): pool, nodes = self.__get_pool_details(cluster_id) nodes = [node for node in nodes] if internal: @@ -265,7 +266,8 @@ def __cluster_copy(self, cluster_id, source_path, destination_path, container_na source_path=source_path, destination_path=destination_path, ssh_key=ssh_key.exportKey().decode('utf-8'), - get=get)) + get=get, + timeout=timeout)) return output except (OSError, batch_error.BatchErrorException) as exc: raise exc diff --git a/aztk/spark/client.py b/aztk/spark/client.py index b5a60a9c..0d47b7c7 100644 --- a/aztk/spark/client.py +++ b/aztk/spark/client.py @@ -161,23 +161,39 @@ def get_application_status(self, cluster_id: str, app_name: str): except batch_error.BatchErrorException as e: raise error.AztkError(helpers.format_batch_exception(e)) - def cluster_run(self, cluster_id: str, command: str, host=False, internal: bool = False): + def cluster_run(self, cluster_id: str, command: str, host=False, internal: bool = False, timeout=None): try: - return self.__cluster_run(cluster_id, command, internal, container_name='spark' if not host else None) + return self.__cluster_run(cluster_id, + command, + internal, + container_name='spark' if not host else None, + timeout=timeout) except batch_error.BatchErrorException as e: raise error.AztkError(helpers.format_batch_exception(e)) - def cluster_copy(self, cluster_id: str, source_path: str, destination_path: str, host: bool = False, internal: bool = False): + def cluster_copy(self, cluster_id: str, source_path: str, destination_path: str, host: bool = False, internal: bool = False, timeout=None): try: container_name = None if host else 'spark' - return self.__cluster_copy(cluster_id, source_path, destination_path, container_name=container_name, get=False, internal=internal) + return self.__cluster_copy(cluster_id, + source_path, + destination_path, + container_name=container_name, + get=False, + internal=internal, + timeout=timeout) except batch_error.BatchErrorException as e: raise error.AztkError(helpers.format_batch_exception(e)) - def cluster_download(self, cluster_id: str, source_path: str, destination_path: str, host: bool = False, internal: bool = False): + def cluster_download(self, cluster_id: str, source_path: str, destination_path: str, host: bool = False, internal: bool = False, timeout=None): try: container_name = None if host else 'spark' - return self.__cluster_copy(cluster_id, source_path, destination_path, container_name=container_name, get=True, internal=internal) + return self.__cluster_copy(cluster_id, + source_path, + destination_path, + container_name=container_name, + get=True, + internal=internal, + timeout=timeout) except batch_error.BatchErrorException as e: raise error.AztkError(helpers.format_batch_exception(e)) diff --git a/aztk/utils/ssh.py b/aztk/utils/ssh.py index 733a7ea0..6ad3fc2f 100644 --- a/aztk/utils/ssh.py +++ b/aztk/utils/ssh.py @@ -5,10 +5,12 @@ import io import os import select +import socket import socketserver as SocketServer import sys from concurrent.futures import ThreadPoolExecutor +from aztk.error import AztkError from . import helpers @@ -16,11 +18,11 @@ def connect(hostname, port=22, username=None, password=None, - pkey=None): + pkey=None, + timeout=None): import paramiko client = paramiko.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) if pkey: @@ -28,19 +30,20 @@ def connect(hostname, else: ssh_key = None - client.connect( - hostname, - port=port, - username=username, - password=password, - pkey=ssh_key - ) + timeout = timeout or 20 + try: + client.connect(hostname, port=port, username=username, password=password, pkey=ssh_key, timeout=timeout) + except socket.timeout: + raise AztkError("Connection timed out to: {}".format(hostname)) return client -def node_exec_command(node_id, command, username, hostname, port, ssh_key=None, password=None, container_name=None): - client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key) +def node_exec_command(node_id, command, username, hostname, port, ssh_key=None, password=None, container_name=None, timeout=None): + try: + client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key, timeout=timeout) + except AztkError as e: + return (node_id, e) if container_name: cmd = 'sudo docker exec 2>&1 -t {0} /bin/bash -c \'set -e; set -o pipefail; {1}; wait\''.format(container_name, command) else: @@ -51,7 +54,7 @@ def node_exec_command(node_id, command, username, hostname, port, ssh_key=None, return (node_id, output) -async def clus_exec_command(command, username, nodes, ports=None, ssh_key=None, password=None, container_name=None): +async def clus_exec_command(command, username, nodes, ports=None, ssh_key=None, password=None, container_name=None, timeout=None): return await asyncio.gather( *[asyncio.get_event_loop().run_in_executor(ThreadPoolExecutor(), node_exec_command, @@ -62,12 +65,16 @@ async def clus_exec_command(command, username, nodes, ports=None, ssh_key=None, node_rls.port, ssh_key, password, - container_name) for node, node_rls in nodes] + container_name, + timeout) for node, node_rls in nodes] ) -def copy_from_node(node_id, source_path, destination_path, username, hostname, port, ssh_key=None, password=None, container_name=None): - client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key) +def copy_from_node(node_id, source_path, destination_path, username, hostname, port, ssh_key=None, password=None, container_name=None, timeout=None): + try: + client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key, timeout=timeout) + except AztkError as e: + return (node_id, False, e) sftp_client = client.open_sftp() try: destination_path = os.path.join(os.path.dirname(destination_path), node_id, os.path.basename(destination_path)) @@ -82,8 +89,11 @@ def copy_from_node(node_id, source_path, destination_path, username, hostname, p client.close() -def node_copy(node_id, source_path, destination_path, username, hostname, port, ssh_key=None, password=None, container_name=None): - client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key) +def node_copy(node_id, source_path, destination_path, username, hostname, port, ssh_key=None, password=None, container_name=None, timeout=None): + try: + client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key, timeout=timeout) + except AztkError as e: + return (node_id, False, e) sftp_client = client.open_sftp() try: if container_name: @@ -108,7 +118,7 @@ def node_copy(node_id, source_path, destination_path, username, hostname, port, #TODO: progress bar -async def clus_copy(username, nodes, source_path, destination_path, ssh_key=None, password=None, container_name=None, get=False): +async def clus_copy(username, nodes, source_path, destination_path, ssh_key=None, password=None, container_name=None, get=False, timeout=None): return await asyncio.gather( *[asyncio.get_event_loop().run_in_executor(ThreadPoolExecutor(), copy_from_node if get else node_copy, @@ -120,5 +130,6 @@ async def clus_copy(username, nodes, source_path, destination_path, ssh_key=None node_rls.port, ssh_key, password, - container_name) for node, node_rls in nodes] + container_name, + timeout) for node, node_rls in nodes] ) diff --git a/aztk_cli/spark/endpoints/cluster/cluster_run.py b/aztk_cli/spark/endpoints/cluster/cluster_run.py index 5567b0fc..0157ae43 100644 --- a/aztk_cli/spark/endpoints/cluster/cluster_run.py +++ b/aztk_cli/spark/endpoints/cluster/cluster_run.py @@ -28,5 +28,8 @@ def print_execute_result(node_id, result): print("-" * (len(node_id) + 6)) print("| ", node_id, " |") print("-" * (len(node_id) + 6)) - for line in result: - print(line) + if isinstance(result, Exception): + print(result + "\n") + else: + for line in result: + print(line)