Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exit cleanup, get_host_port logics modified for efficiency & spawn ma… #2

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 96 additions & 70 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,10 @@

logging.basicConfig(level=logging.DEBUG) # gives us access to 'app.logger'

## SSH KEY SETUP ##
ssh_key_dir = os.path.expanduser("~/.ssh")
ssh_key_path = os.path.join(ssh_key_dir, "docker_container_key")
public_key_path = ssh_key_path + ".pub"

## CLEANUP ON EXIT SETUP ##
def cleanup_on_exit():
"""calls the cleanup route at app shutdown to remove any spawned containers."""
with app.test_client() as flask_client:
response = flask_client.post('/remove_all_containers') # calls remove_all_containers route
app.logger.info(f"Cleanup result: {response.get_data(as_text=True)}")

atexit.register(cleanup_on_exit) # register 'cleanup' to run when the app quits or receives SIGINT

def handle_sigint(signum, frame):
"""Custom signal handler to catch Ctrl+C (SIGINT) interrupts."""
app.logger.info("Received SIGINT (Ctrl+C), cleaning up...")
cleanup_on_exit() # run the cleanup before exiting
exit(0)

signal.signal(signal.SIGINT, handle_sigint) # register the SIGINT handler


## HELPER FUNCTIONS ##
def ensure_ssh_key():
Expand All @@ -55,6 +37,7 @@ def ensure_ssh_key():
with open(public_key_path, "r") as key_file:
return key_file.read().strip()


def spawn_container(public_key):
try:
container = client.containers.run(
Expand Down Expand Up @@ -85,46 +68,55 @@ def spawn_container(public_key):
app.logger.error(f"Error spawning container: {str(e)}")
raise

def wait_for_ssh(host, port, retries=5, delay=5):

def get_port_mapping(container):
"""
Waits for SSH to be available on the specified host and port, with exponential backoff.
Returns the host port mapping for a given container.
"""
for attempt in range(retries):
try:
with paramiko.SSHClient() as client:
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(hostname=host, port=port, username='root', key_filename=ssh_key_path, timeout=5)
container.reload() # refresh container status
ports = container.attrs['NetworkSettings']['Ports']
return None if '22/tcp' not in ports else ports['22/tcp'][0]['HostPort']

app.logger.info(f"SSH is ready on {host}:{port}.")
return True

except Exception as e:
app.logger.debug(f"SSH attempt {attempt + 1}/{retries} failed: {e}")
time.sleep(delay * (2 ** attempt)) # exponential backoff
## [NOTE] SSH readiness check is not reliable:
# def wait_for_ssh(host, port, retries=5, delay=5):
# """
# Waits for SSH to be available on the specified host and port, with exponential backoff.
# """
# for attempt in range(retries):
# try:
# with paramiko.SSHClient() as ssh_client:
# ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# ssh_client.connect(hostname=host, port=port, username='root', key_filename=ssh_key_path, timeout=5)

app.logger.error(f"SSH not ready after {retries} attempts.")
return False
# app.logger.info(f"SSH is ready on {host}:{port}.")
# return True

def wait_for_container(container, timeout=60, interval=5):
"""
Waits for a container to initialize and report its SSH readiness.
"""
elapsed = 0
while elapsed < timeout:
container.reload()
ports = container.attrs['NetworkSettings']['Ports']
# except Exception as e:
# app.logger.debug(f"SSH attempt {attempt + 1}/{retries} failed: {e}")
# time.sleep(delay * (2 ** attempt)) # exponential backoff

# app.logger.error(f"SSH not ready after {retries} attempts.")
# return False

if ports and ports.get('22/tcp') and ports['22/tcp'][0]['HostPort']:
host_port = ports['22/tcp'][0]['HostPort']

if wait_for_ssh('localhost', host_port):
return host_port
# def wait_for_container(container, timeout=60, interval=5):
# """
# Waits for a container to initialize and report its SSH readiness.
# """
# elapsed = 0
# while elapsed < timeout:
# host_port = get_port_mapping(container)

app.logger.info(f"Waiting for container {container.id[:12]} to initialize...")
time.sleep(interval)
elapsed += interval
# if host_port and wait_for_ssh('localhost', host_port):
# return host_port

# app.logger.info(f"Waiting for container {container.id[:12]} to initialize...")
# time.sleep(interval)
# elapsed += interval

# raise TimeoutError(f"Container {container.id[:12]} failed to initialize within {timeout} seconds.")

raise TimeoutError(f"Container {container.id[:12]} failed to initialize within {timeout} seconds.")

def create_temp_file(content, suffix, writer=lambda content, file: file.write(content)):
"""
Expand All @@ -135,6 +127,7 @@ def create_temp_file(content, suffix, writer=lambda content, file: file.write(co
writer(content, temp_file) # write file contents (using provided function)
return temp_file.name


def parse_ansible_results(result):
"""
Parses Ansible results from a given AnsibleRunner object.
Expand All @@ -151,6 +144,7 @@ def parse_ansible_results(result):

return results


def cleanup_files(file_paths):
"""
Removes temporary files created during the application lifecycle.
Expand All @@ -159,6 +153,7 @@ def cleanup_files(file_paths):
if os.path.exists(file_path):
os.unlink(file_path)


def run_ansible(hosts, command):
"""
Runs an Ansible playbook on specified hosts with a given command.
Expand Down Expand Up @@ -207,7 +202,8 @@ def run_ansible(hosts, command):
finally:
cleanup_files([inventory_file_path, playbook_file_path])

def remove_or_stop_containers(containers, action):

def remove_or_stop_containers(containers, action, type='spawned'):
"""
Stops or removes the specified containers based on the action parameter.
"""
Expand All @@ -220,31 +216,43 @@ def remove_or_stop_containers(containers, action):
app.logger.info(f"spawned container {container.id[:12]} {'removed' if action == 'remove' else 'stopped'}")

except Exception as e:
app.logger.warning(f"Failed to {action} spawned container {container.id[:12]}: {str(e)}")
app.logger.warning(f"Failed to {action} {type} container {container.id[:12]}: {str(e)}")


## ROUTES FOR FLASK APP ##
# HTTP status codes: 200 - OK (default), 400 - Bad Request, 404 - Not Found, 500 - Internal Server Error
def cleanup(action='remove'):
"""Cleans up all spawned containers and orphaned containers (by stopping or removing them)."""
containers = client.containers.list(all=True, filters={"label": "flask_app=spawned_container"})
remove_or_stop_containers(containers, action)

# [OPTIONAL] remove all orphaned containers (if any, due to previous errors)
containers = client.containers.list(all=True, filters={"ancestor": "debian:bullseye-slim"})
remove_or_stop_containers(containers, action, 'orphaned')

app.logger.info("Cleanup complete.")


## ROUTES (FOR FLASK APP) ##
# HTTP status codes: 200 - OK (default), 400 - Bad Request, 404 - Not Found, 500 - Internal Server Error
@app.route('/')
def MAIN_INDEX_ROUTE():
return render_template('index.html')


@app.route('/spawn', methods=['POST'])
def SPAWN_MACHINES_ROUTE():
try:
num_machines = request.form.get('num_machines', type=int)
if not isinstance(num_machines, int) or num_machines <= 0:
if not isinstance(num_machines, int) or num_machines <= 0: # [OPTIONAL] validation for preventing injection attacks
return jsonify({"error": "Invalid 'num_machines' value. Must be a positive integer."}), 400

public_key = ensure_ssh_key()
containers = [] # Track spawned containers
machine_info = [] # Track machine details for display
machine_info = session.get('machine_info', []) # Track machine details for display

for _ in range(num_machines):
try:
container = spawn_container(public_key)
host_port = wait_for_container(container)
host_port = get_port_mapping(container)
containers.append(container)
machine_info.append({
'container_id': container.id[:12],
Expand Down Expand Up @@ -274,7 +282,7 @@ def SPAWN_MACHINES_ROUTE():
def RUN_COMMAND_ROUTE():
try:
command = request.form['command']
if not command or ';' in command or '&&' in command: # prevent command injection
if not command: # [OPTIONAL] validation: `or ';' in command or '&&' in command or '|' in command:`, [IMPROVEMENT] use regex patterns
return jsonify({"error": "Invalid command."}), 400

machine_info = session.get('machine_info', [])
Expand All @@ -289,7 +297,7 @@ def RUN_COMMAND_ROUTE():
"ansible_ssh_private_key_file": ssh_key_path,
"ansible_ssh_extra_args": "-o StrictHostKeyChecking=no"
}
for machine in machine_info if machine['ssh_status'] == 'Ready'
for machine in machine_info if machine['ssh_status'] == 'Ready' # [NOTE] currently, only ready machines are considered, ssh_status is not reliable
}

if not ansible_hosts:
Expand All @@ -308,48 +316,66 @@ def RUN_COMMAND_ROUTE():
app.logger.error(f"Error running command: {str(e)}")
return jsonify({"error": str(e)}), 500


@app.route('/<action>_all_containers', methods=['POST'])
def STOP_OR_REMOVE_ALL_CONTAINERS_ROUTE(action):
try:
if not session.get('machine_info'): # safeguard against empty session data
return jsonify({"status": "No machines spawned. Nothing to stop or remove."})
return jsonify({"error": "No machines spawned. Nothing to stop or remove."}), 400

if action not in ['stop', 'remove']: # validate the action parameter
if action not in ['stop', 'remove']: # [OPTIONAL] validate the action parameter
return jsonify({"error": f"Invalid action: {action}. Use 'stop' or 'remove'."}), 400

containers = client.containers.list(all=True, filters={"label": "flask_app=spawned_container"})
remove_or_stop_containers(containers, action)

session.pop('machine_info', None) # clear session data
return jsonify({"status": f"All containers {'stopped and' if action == 'stop' else ''} removed successfully."})
cleanup(action)
session['machine_info'] = [] # clear session data for spawned containers
return jsonify({"status": f"All containers stopped {'and removed' if action == 'remove' else ''} successfully."})

except Exception as e:
app.logger.error(f"Error {'stopping' if action == 'stop' else 'removing' } all containers: {str(e)}")
return jsonify({"error": str(e)}), 500


@app.route('/<action>_container', methods=['POST'])
def STOP_OR_REMOVE_CONTAINER_ROUTE():
def STOP_OR_REMOVE_CONTAINER_ROUTE(action):
try:
container_id = request.form.get('container_id')
if not container_id:
return jsonify({"error": "No container ID provided."}), 400

container = client.containers.get(container_id)
remove_or_stop_containers([container], 'stop')
remove_or_stop_containers([container], action)

# Remove the stopped container from the session data
# remove the stopped/removed container from the session data
session['machine_info'] = [machine for machine in session.get('machine_info', []) if machine['container_id'] != container_id]

return jsonify({"status": f"Container {container_id[:12]} stopped and removed successfully."})
return jsonify({"status": f"Container {container_id[:12]} stopped {'and removed' if action == 'remove' else ''} successfully."})

except docker.errors.NotFound:
return jsonify({"error": f"Container {container_id[:12]} not found."}), 404

except Exception as e:
app.logger.error(f"Error stopping container {container_id[:12]}: {str(e)}")
app.logger.error(f"Error {'stopping' if action == 'stop' else 'removing' } container {container_id[:12]}: {str(e)}")
return jsonify({"error": str(e)}), 500


## MAIN APP ENTRY POINT ##
## EXIT POINT SETUP (FOR CLEANUP) ##
def handle_sigint(signum, frame):
"""Custom signal handler to catch Ctrl+C (SIGINT) interrupts."""
app.logger.info("Received SIGINT (Ctrl+C), cleaning up...")
cleanup() # run the cleanup before exiting
exit(0)

def handle_sigstp(signum, frame):
"""Custom signal handler to catch SIGTSTP (Ctrl+Z) interrupts."""
app.logger.info("Received SIGTSTP (Ctrl+Z), cleaning up...")
cleanup()
exit(0)

atexit.register(cleanup) # [OPTIONAL] register 'cleanup' to run when the app quits normally
signal.signal(signal.SIGINT, handle_sigint) # register SIGINT handler to run 'cleanup' when Ctrl+C is pressed
signal.signal(signal.SIGTSTP, handle_sigstp) # register SIGTSTP handler to run 'cleanup' when Ctrl+Z is pressed


## MAIN FLASK APP ENTRY POINT ##
if __name__ == '__main__':
app.run(debug=True)