Skip to content
This repository has been archived by the owner on Feb 3, 2021. It is now read-only.

Commit

Permalink
Bug: add timeout handling to cluster_run and copy (#524)
Browse files Browse the repository at this point in the history
* update cluster_run and copy to handle timeouts

* fix

* move timeout default to connect function
  • Loading branch information
jafreck committed Apr 30, 2018
1 parent 9ccc1c6 commit 47000a5
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 31 deletions.
10 changes: 6 additions & 4 deletions aztk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __delete_user_on_pool(self, username, pool_id, nodes):
concurrent.futures.wait(futures)


def __cluster_run(self, cluster_id, command, internal, container_name=None):
def __cluster_run(self, cluster_id, command, internal, container_name=None, timeout=None):
pool, nodes = self.__get_pool_details(cluster_id)
nodes = [node for node in nodes]
if internal:
Expand All @@ -242,14 +242,15 @@ def __cluster_run(self, cluster_id, command, internal, container_name=None):
'aztk',
cluster_nodes,
ssh_key=ssh_key.exportKey().decode('utf-8'),
container_name=container_name))
container_name=container_name,
timeout=timeout))
return output
except OSError as exc:
raise exc
finally:
self.__delete_user_on_pool('aztk', pool.id, nodes)

def __cluster_copy(self, cluster_id, source_path, destination_path, container_name=None, internal=False, get=False):
def __cluster_copy(self, cluster_id, source_path, destination_path, container_name=None, internal=False, get=False, timeout=None):
pool, nodes = self.__get_pool_details(cluster_id)
nodes = [node for node in nodes]
if internal:
Expand All @@ -265,7 +266,8 @@ def __cluster_copy(self, cluster_id, source_path, destination_path, container_na
source_path=source_path,
destination_path=destination_path,
ssh_key=ssh_key.exportKey().decode('utf-8'),
get=get))
get=get,
timeout=timeout))
return output
except (OSError, batch_error.BatchErrorException) as exc:
raise exc
Expand Down
28 changes: 22 additions & 6 deletions aztk/spark/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,39 @@ def get_application_status(self, cluster_id: str, app_name: str):
except batch_error.BatchErrorException as e:
raise error.AztkError(helpers.format_batch_exception(e))

def cluster_run(self, cluster_id: str, command: str, host=False, internal: bool = False):
def cluster_run(self, cluster_id: str, command: str, host=False, internal: bool = False, timeout=None):
try:
return self.__cluster_run(cluster_id, command, internal, container_name='spark' if not host else None)
return self.__cluster_run(cluster_id,
command,
internal,
container_name='spark' if not host else None,
timeout=timeout)
except batch_error.BatchErrorException as e:
raise error.AztkError(helpers.format_batch_exception(e))

def cluster_copy(self, cluster_id: str, source_path: str, destination_path: str, host: bool = False, internal: bool = False):
def cluster_copy(self, cluster_id: str, source_path: str, destination_path: str, host: bool = False, internal: bool = False, timeout=None):
try:
container_name = None if host else 'spark'
return self.__cluster_copy(cluster_id, source_path, destination_path, container_name=container_name, get=False, internal=internal)
return self.__cluster_copy(cluster_id,
source_path,
destination_path,
container_name=container_name,
get=False,
internal=internal,
timeout=timeout)
except batch_error.BatchErrorException as e:
raise error.AztkError(helpers.format_batch_exception(e))

def cluster_download(self, cluster_id: str, source_path: str, destination_path: str, host: bool = False, internal: bool = False):
def cluster_download(self, cluster_id: str, source_path: str, destination_path: str, host: bool = False, internal: bool = False, timeout=None):
try:
container_name = None if host else 'spark'
return self.__cluster_copy(cluster_id, source_path, destination_path, container_name=container_name, get=True, internal=internal)
return self.__cluster_copy(cluster_id,
source_path,
destination_path,
container_name=container_name,
get=True,
internal=internal,
timeout=timeout)
except batch_error.BatchErrorException as e:
raise error.AztkError(helpers.format_batch_exception(e))

Expand Down
49 changes: 30 additions & 19 deletions aztk/utils/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,45 @@
import io
import os
import select
import socket
import socketserver as SocketServer
import sys
from concurrent.futures import ThreadPoolExecutor

from aztk.error import AztkError
from . import helpers


def connect(hostname,
port=22,
username=None,
password=None,
pkey=None):
pkey=None,
timeout=None):
import paramiko

client = paramiko.SSHClient()

client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

if pkey:
ssh_key = paramiko.RSAKey.from_private_key(file_obj=io.StringIO(pkey))
else:
ssh_key = None

client.connect(
hostname,
port=port,
username=username,
password=password,
pkey=ssh_key
)
timeout = timeout or 20
try:
client.connect(hostname, port=port, username=username, password=password, pkey=ssh_key, timeout=timeout)
except socket.timeout:
raise AztkError("Connection timed out to: {}".format(hostname))

return client


def node_exec_command(node_id, command, username, hostname, port, ssh_key=None, password=None, container_name=None):
client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key)
def node_exec_command(node_id, command, username, hostname, port, ssh_key=None, password=None, container_name=None, timeout=None):
try:
client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key, timeout=timeout)
except AztkError as e:
return (node_id, e)
if container_name:
cmd = 'sudo docker exec 2>&1 -t {0} /bin/bash -c \'set -e; set -o pipefail; {1}; wait\''.format(container_name, command)
else:
Expand All @@ -51,7 +54,7 @@ def node_exec_command(node_id, command, username, hostname, port, ssh_key=None,
return (node_id, output)


async def clus_exec_command(command, username, nodes, ports=None, ssh_key=None, password=None, container_name=None):
async def clus_exec_command(command, username, nodes, ports=None, ssh_key=None, password=None, container_name=None, timeout=None):
return await asyncio.gather(
*[asyncio.get_event_loop().run_in_executor(ThreadPoolExecutor(),
node_exec_command,
Expand All @@ -62,12 +65,16 @@ async def clus_exec_command(command, username, nodes, ports=None, ssh_key=None,
node_rls.port,
ssh_key,
password,
container_name) for node, node_rls in nodes]
container_name,
timeout) for node, node_rls in nodes]
)


def copy_from_node(node_id, source_path, destination_path, username, hostname, port, ssh_key=None, password=None, container_name=None):
client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key)
def copy_from_node(node_id, source_path, destination_path, username, hostname, port, ssh_key=None, password=None, container_name=None, timeout=None):
try:
client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key, timeout=timeout)
except AztkError as e:
return (node_id, False, e)
sftp_client = client.open_sftp()
try:
destination_path = os.path.join(os.path.dirname(destination_path), node_id, os.path.basename(destination_path))
Expand All @@ -82,8 +89,11 @@ def copy_from_node(node_id, source_path, destination_path, username, hostname, p
client.close()


def node_copy(node_id, source_path, destination_path, username, hostname, port, ssh_key=None, password=None, container_name=None):
client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key)
def node_copy(node_id, source_path, destination_path, username, hostname, port, ssh_key=None, password=None, container_name=None, timeout=None):
try:
client = connect(hostname=hostname, port=port, username=username, password=password, pkey=ssh_key, timeout=timeout)
except AztkError as e:
return (node_id, False, e)
sftp_client = client.open_sftp()
try:
if container_name:
Expand All @@ -108,7 +118,7 @@ def node_copy(node_id, source_path, destination_path, username, hostname, port,
#TODO: progress bar


async def clus_copy(username, nodes, source_path, destination_path, ssh_key=None, password=None, container_name=None, get=False):
async def clus_copy(username, nodes, source_path, destination_path, ssh_key=None, password=None, container_name=None, get=False, timeout=None):
return await asyncio.gather(
*[asyncio.get_event_loop().run_in_executor(ThreadPoolExecutor(),
copy_from_node if get else node_copy,
Expand All @@ -120,5 +130,6 @@ async def clus_copy(username, nodes, source_path, destination_path, ssh_key=None
node_rls.port,
ssh_key,
password,
container_name) for node, node_rls in nodes]
container_name,
timeout) for node, node_rls in nodes]
)
7 changes: 5 additions & 2 deletions aztk_cli/spark/endpoints/cluster/cluster_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,8 @@ def print_execute_result(node_id, result):
print("-" * (len(node_id) + 6))
print("| ", node_id, " |")
print("-" * (len(node_id) + 6))
for line in result:
print(line)
if isinstance(result, Exception):
print(result + "\n")
else:
for line in result:
print(line)

0 comments on commit 47000a5

Please sign in to comment.