Skip to content

Commit

Permalink
Merge pull request #2 from zubairahm3d/feature/shahzaib
Browse files Browse the repository at this point in the history
Exit cleanup, get_host_port logics modified for efficiency & spawn ma…
  • Loading branch information
zubairahm3d authored Nov 27, 2024
2 parents 53d3a0b + e666456 commit d44581b
Showing 1 changed file with 96 additions and 70 deletions.
166 changes: 96 additions & 70 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,10 @@

logging.basicConfig(level=logging.DEBUG) # gives us access to 'app.logger'

## SSH KEY SETUP ##
ssh_key_dir = os.path.expanduser("~/.ssh")
ssh_key_path = os.path.join(ssh_key_dir, "docker_container_key")
public_key_path = ssh_key_path + ".pub"

## CLEANUP ON EXIT SETUP ##
def cleanup_on_exit():
"""calls the cleanup route at app shutdown to remove any spawned containers."""
with app.test_client() as flask_client:
response = flask_client.post('/remove_all_containers') # calls remove_all_containers route
app.logger.info(f"Cleanup result: {response.get_data(as_text=True)}")

atexit.register(cleanup_on_exit) # register 'cleanup' to run when the app quits or receives SIGINT

def handle_sigint(signum, frame):
"""Custom signal handler to catch Ctrl+C (SIGINT) interrupts."""
app.logger.info("Received SIGINT (Ctrl+C), cleaning up...")
cleanup_on_exit() # run the cleanup before exiting
exit(0)

signal.signal(signal.SIGINT, handle_sigint) # register the SIGINT handler


## HELPER FUNCTIONS ##
def ensure_ssh_key():
Expand All @@ -55,6 +37,7 @@ def ensure_ssh_key():
with open(public_key_path, "r") as key_file:
return key_file.read().strip()


def spawn_container(public_key):
try:
container = client.containers.run(
Expand Down Expand Up @@ -85,46 +68,55 @@ def spawn_container(public_key):
app.logger.error(f"Error spawning container: {str(e)}")
raise

def wait_for_ssh(host, port, retries=5, delay=5):

def get_port_mapping(container):
"""
Waits for SSH to be available on the specified host and port, with exponential backoff.
Returns the host port mapping for a given container.
"""
for attempt in range(retries):
try:
with paramiko.SSHClient() as client:
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(hostname=host, port=port, username='root', key_filename=ssh_key_path, timeout=5)
container.reload() # refresh container status
ports = container.attrs['NetworkSettings']['Ports']
return None if '22/tcp' not in ports else ports['22/tcp'][0]['HostPort']

app.logger.info(f"SSH is ready on {host}:{port}.")
return True

except Exception as e:
app.logger.debug(f"SSH attempt {attempt + 1}/{retries} failed: {e}")
time.sleep(delay * (2 ** attempt)) # exponential backoff
## [NOTE] SSH readiness check is not reliable:
# def wait_for_ssh(host, port, retries=5, delay=5):
# """
# Waits for SSH to be available on the specified host and port, with exponential backoff.
# """
# for attempt in range(retries):
# try:
# with paramiko.SSHClient() as ssh_client:
# ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# ssh_client.connect(hostname=host, port=port, username='root', key_filename=ssh_key_path, timeout=5)

app.logger.error(f"SSH not ready after {retries} attempts.")
return False
# app.logger.info(f"SSH is ready on {host}:{port}.")
# return True

def wait_for_container(container, timeout=60, interval=5):
"""
Waits for a container to initialize and report its SSH readiness.
"""
elapsed = 0
while elapsed < timeout:
container.reload()
ports = container.attrs['NetworkSettings']['Ports']
# except Exception as e:
# app.logger.debug(f"SSH attempt {attempt + 1}/{retries} failed: {e}")
# time.sleep(delay * (2 ** attempt)) # exponential backoff

# app.logger.error(f"SSH not ready after {retries} attempts.")
# return False

if ports and ports.get('22/tcp') and ports['22/tcp'][0]['HostPort']:
host_port = ports['22/tcp'][0]['HostPort']

if wait_for_ssh('localhost', host_port):
return host_port
# def wait_for_container(container, timeout=60, interval=5):
# """
# Waits for a container to initialize and report its SSH readiness.
# """
# elapsed = 0
# while elapsed < timeout:
# host_port = get_port_mapping(container)

app.logger.info(f"Waiting for container {container.id[:12]} to initialize...")
time.sleep(interval)
elapsed += interval
# if host_port and wait_for_ssh('localhost', host_port):
# return host_port

# app.logger.info(f"Waiting for container {container.id[:12]} to initialize...")
# time.sleep(interval)
# elapsed += interval

# raise TimeoutError(f"Container {container.id[:12]} failed to initialize within {timeout} seconds.")

raise TimeoutError(f"Container {container.id[:12]} failed to initialize within {timeout} seconds.")

def create_temp_file(content, suffix, writer=lambda content, file: file.write(content)):
"""
Expand All @@ -135,6 +127,7 @@ def create_temp_file(content, suffix, writer=lambda content, file: file.write(co
writer(content, temp_file) # write file contents (using provided function)
return temp_file.name


def parse_ansible_results(result):
"""
Parses Ansible results from a given AnsibleRunner object.
Expand All @@ -151,6 +144,7 @@ def parse_ansible_results(result):

return results


def cleanup_files(file_paths):
"""
Removes temporary files created during the application lifecycle.
Expand All @@ -159,6 +153,7 @@ def cleanup_files(file_paths):
if os.path.exists(file_path):
os.unlink(file_path)


def run_ansible(hosts, command):
"""
Runs an Ansible playbook on specified hosts with a given command.
Expand Down Expand Up @@ -207,7 +202,8 @@ def run_ansible(hosts, command):
finally:
cleanup_files([inventory_file_path, playbook_file_path])

def remove_or_stop_containers(containers, action):

def remove_or_stop_containers(containers, action, type='spawned'):
"""
Stops or removes the specified containers based on the action parameter.
"""
Expand All @@ -220,31 +216,43 @@ def remove_or_stop_containers(containers, action):
app.logger.info(f"spawned container {container.id[:12]} {'removed' if action == 'remove' else 'stopped'}")

except Exception as e:
app.logger.warning(f"Failed to {action} spawned container {container.id[:12]}: {str(e)}")
app.logger.warning(f"Failed to {action} {type} container {container.id[:12]}: {str(e)}")


## ROUTES FOR FLASK APP ##
# HTTP status codes: 200 - OK (default), 400 - Bad Request, 404 - Not Found, 500 - Internal Server Error
def cleanup(action='remove'):
"""Cleans up all spawned containers and orphaned containers (by stopping or removing them)."""
containers = client.containers.list(all=True, filters={"label": "flask_app=spawned_container"})
remove_or_stop_containers(containers, action)

# [OPTIONAL] remove all orphaned containers (if any, due to previous errors)
containers = client.containers.list(all=True, filters={"ancestor": "debian:bullseye-slim"})
remove_or_stop_containers(containers, action, 'orphaned')

app.logger.info("Cleanup complete.")


## ROUTES (FOR FLASK APP) ##
# HTTP status codes: 200 - OK (default), 400 - Bad Request, 404 - Not Found, 500 - Internal Server Error
@app.route('/')
def MAIN_INDEX_ROUTE():
return render_template('index.html')


@app.route('/spawn', methods=['POST'])
def SPAWN_MACHINES_ROUTE():
try:
num_machines = request.form.get('num_machines', type=int)
if not isinstance(num_machines, int) or num_machines <= 0:
if not isinstance(num_machines, int) or num_machines <= 0: # [OPTIONAL] validation for preventing injection attacks
return jsonify({"error": "Invalid 'num_machines' value. Must be a positive integer."}), 400

public_key = ensure_ssh_key()
containers = [] # Track spawned containers
machine_info = [] # Track machine details for display
machine_info = session.get('machine_info', []) # Track machine details for display

for _ in range(num_machines):
try:
container = spawn_container(public_key)
host_port = wait_for_container(container)
host_port = get_port_mapping(container)
containers.append(container)
machine_info.append({
'container_id': container.id[:12],
Expand Down Expand Up @@ -274,7 +282,7 @@ def SPAWN_MACHINES_ROUTE():
def RUN_COMMAND_ROUTE():
try:
command = request.form['command']
if not command or ';' in command or '&&' in command: # prevent command injection
if not command: # [OPTIONAL] validation: `or ';' in command or '&&' in command or '|' in command:`, [IMPROVEMENT] use regex patterns
return jsonify({"error": "Invalid command."}), 400

machine_info = session.get('machine_info', [])
Expand All @@ -289,7 +297,7 @@ def RUN_COMMAND_ROUTE():
"ansible_ssh_private_key_file": ssh_key_path,
"ansible_ssh_extra_args": "-o StrictHostKeyChecking=no"
}
for machine in machine_info if machine['ssh_status'] == 'Ready'
for machine in machine_info if machine['ssh_status'] == 'Ready' # [NOTE] currently, only ready machines are considered, ssh_status is not reliable
}

if not ansible_hosts:
Expand All @@ -308,48 +316,66 @@ def RUN_COMMAND_ROUTE():
app.logger.error(f"Error running command: {str(e)}")
return jsonify({"error": str(e)}), 500


@app.route('/<action>_all_containers', methods=['POST'])
def STOP_OR_REMOVE_ALL_CONTAINERS_ROUTE(action):
try:
if not session.get('machine_info'): # safeguard against empty session data
return jsonify({"status": "No machines spawned. Nothing to stop or remove."})
return jsonify({"error": "No machines spawned. Nothing to stop or remove."}), 400

if action not in ['stop', 'remove']: # validate the action parameter
if action not in ['stop', 'remove']: # [OPTIONAL] validate the action parameter
return jsonify({"error": f"Invalid action: {action}. Use 'stop' or 'remove'."}), 400

containers = client.containers.list(all=True, filters={"label": "flask_app=spawned_container"})
remove_or_stop_containers(containers, action)

session.pop('machine_info', None) # clear session data
return jsonify({"status": f"All containers {'stopped and' if action == 'stop' else ''} removed successfully."})
cleanup(action)
session['machine_info'] = [] # clear session data for spawned containers
return jsonify({"status": f"All containers stopped {'and removed' if action == 'remove' else ''} successfully."})

except Exception as e:
app.logger.error(f"Error {'stopping' if action == 'stop' else 'removing' } all containers: {str(e)}")
return jsonify({"error": str(e)}), 500


@app.route('/<action>_container', methods=['POST'])
def STOP_OR_REMOVE_CONTAINER_ROUTE():
def STOP_OR_REMOVE_CONTAINER_ROUTE(action):
try:
container_id = request.form.get('container_id')
if not container_id:
return jsonify({"error": "No container ID provided."}), 400

container = client.containers.get(container_id)
remove_or_stop_containers([container], 'stop')
remove_or_stop_containers([container], action)

# Remove the stopped container from the session data
# remove the stopped/removed container from the session data
session['machine_info'] = [machine for machine in session.get('machine_info', []) if machine['container_id'] != container_id]

return jsonify({"status": f"Container {container_id[:12]} stopped and removed successfully."})
return jsonify({"status": f"Container {container_id[:12]} stopped {'and removed' if action == 'remove' else ''} successfully."})

except docker.errors.NotFound:
return jsonify({"error": f"Container {container_id[:12]} not found."}), 404

except Exception as e:
app.logger.error(f"Error stopping container {container_id[:12]}: {str(e)}")
app.logger.error(f"Error {'stopping' if action == 'stop' else 'removing' } container {container_id[:12]}: {str(e)}")
return jsonify({"error": str(e)}), 500


## MAIN APP ENTRY POINT ##
## EXIT POINT SETUP (FOR CLEANUP) ##
def handle_sigint(signum, frame):
"""Custom signal handler to catch Ctrl+C (SIGINT) interrupts."""
app.logger.info("Received SIGINT (Ctrl+C), cleaning up...")
cleanup() # run the cleanup before exiting
exit(0)

def handle_sigstp(signum, frame):
"""Custom signal handler to catch SIGTSTP (Ctrl+Z) interrupts."""
app.logger.info("Received SIGTSTP (Ctrl+Z), cleaning up...")
cleanup()
exit(0)

atexit.register(cleanup) # [OPTIONAL] register 'cleanup' to run when the app quits normally
signal.signal(signal.SIGINT, handle_sigint) # register SIGINT handler to run 'cleanup' when Ctrl+C is pressed
signal.signal(signal.SIGTSTP, handle_sigstp) # register SIGTSTP handler to run 'cleanup' when Ctrl+Z is pressed


## MAIN FLASK APP ENTRY POINT ##
if __name__ == '__main__':
app.run(debug=True)

0 comments on commit d44581b

Please sign in to comment.