Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sapphire] Add type hints #411

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions grizzly/common/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def test_runner_10(mocker, tmp_path):
inc3.write_bytes(b"a")
# build server map
smap = ServerMap()
smap.set_include("/", str(inc_path1))
smap.set_include("/test", str(inc_path2))
smap.set_include("/", inc_path1)
smap.set_include("/test", inc_path2)
with TestCase("a.b", "x") as test:
test.add_from_bytes(b"", test.entry_point)
serv_files = {
Expand Down
1 change: 1 addition & 0 deletions grizzly/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def run(
runner.launch(location, max_retries=launch_attempts, retry_delay=0)
runner.post_launch(delay=post_launch_delay)
# TODO: avoid running test case if runner.startup_failure is True
# especially if it is a hang!

# create and populate a test case
current_test = self.generate_testcase()
Expand Down
7 changes: 4 additions & 3 deletions sapphire/__main__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from logging import DEBUG, INFO, basicConfig
from pathlib import Path
from typing import List, Optional

from .core import Sapphire


def configure_logging(log_level):
def configure_logging(log_level: int) -> None:
if log_level == DEBUG:
date_fmt = None
log_fmt = "%(asctime)s %(levelname).1s %(name)s | %(message)s"
Expand All @@ -18,7 +19,7 @@ def configure_logging(log_level):
basicConfig(format=log_fmt, datefmt=date_fmt, level=log_level)


def parse_args(argv=None):
def parse_args(argv: Optional[List[str]] = None) -> Namespace:
# log levels for console logging
level_map = {"DEBUG": DEBUG, "INFO": INFO}
parser = ArgumentParser()
Expand Down
28 changes: 19 additions & 9 deletions sapphire/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from logging import getLogger
from socket import socket
from time import time
from traceback import format_exception
from typing import Any, Callable, List, Optional, Union

from .job import Job
from .worker import Worker

__author__ = "Tyson Smith"
Expand All @@ -27,24 +30,26 @@ class ConnectionManager:
"_socket",
)

def __init__(self, job, srv_socket, limit=1, poll=0.5):
def __init__(
self, job: Job, srv_socket: socket, limit: int = 1, poll: float = 0.5
) -> None:
assert limit > 0
assert poll > 0
self._deadline = None
self._deadline: Optional[float] = None
self._deadline_exceeded = False
self._job = job
self._limit = limit
self._next_poll = 0
self._next_poll = 0.0
self._poll = poll
self._socket = srv_socket

def __enter__(self):
def __enter__(self) -> "ConnectionManager":
return self

def __exit__(self, *exc):
def __exit__(self, *exc: Any) -> None:
self.close()

def _can_continue(self, continue_cb):
def _can_continue(self, continue_cb: Union[Callable[[], bool], None]) -> bool:
"""Check timeout and callback status.

Args:
Expand All @@ -68,7 +73,7 @@ def _can_continue(self, continue_cb):
return False
return True

def close(self):
def close(self) -> None:
"""Set job state to finished and raise any errors encountered by workers.

Args:
Expand All @@ -88,7 +93,7 @@ def close(self):
raise exc_obj

@staticmethod
def _join_workers(workers, timeout=0):
def _join_workers(workers: List[Worker], timeout: float = 0) -> List[Worker]:
"""Attempt to join workers.

Args:
Expand All @@ -106,7 +111,12 @@ def _join_workers(workers, timeout=0):
alive.append(worker)
return alive

def serve(self, timeout, continue_cb=None, shutdown_delay=SHUTDOWN_DELAY):
def serve(
self,
timeout: int,
continue_cb: Optional[Callable[[], bool]] = None,
shutdown_delay: float = SHUTDOWN_DELAY,
) -> bool:
"""Manage workers and serve job contents.

Args:
Expand Down
80 changes: 45 additions & 35 deletions sapphire/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
"""
Sapphire HTTP server
"""
from argparse import Namespace
from logging import getLogger
from pathlib import Path
from socket import SO_REUSEADDR, SOL_SOCKET, gethostname, socket
from ssl import PROTOCOL_TLS_SERVER, SSLContext
from ssl import PROTOCOL_TLS_SERVER, SSLContext, SSLSocket
from time import sleep, time
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

from .connection_manager import ConnectionManager
from .job import Job, Served
from .server_map import ServerMap

__all__ = (
"BLOCKED_PORTS",
Expand Down Expand Up @@ -47,16 +50,21 @@
LOG = getLogger(__name__)


def create_listening_socket(attempts=10, port=0, remote=False, timeout=None):
def create_listening_socket(
attempts: int = 10,
port: int = 0,
remote: bool = False,
timeout: Optional[float] = None,
) -> socket:
"""Create listening socket. Search for an open socket if needed and
configure the socket. If a specific port is unavailable or no
available ports can be found socket.error will be raised.

Args:
attempts (int): Number of attempts to configure the socket.
port (int): Port to listen on. Use 0 for system assigned port.
remote (bool): Accept all (non-local) incoming connections.
timeout (float): Used to set socket timeout.
attempts: Number of attempts to configure the socket.
port: Port to listen on. Use 0 for system assigned port.
remote: Accept all (non-local) incoming connections.
timeout: Used to set socket timeout.

Returns:
socket: A listening socket.
Expand Down Expand Up @@ -100,13 +108,13 @@ class Sapphire:

def __init__(
self,
allow_remote=False,
auto_close=-1,
allow_remote: bool = False,
auto_close: int = -1,
certs=None,
max_workers=10,
port=0,
timeout=60,
):
max_workers: int = 10,
port: int = 0,
timeout: int = 60,
) -> None:
assert timeout >= 0
self._auto_close = auto_close # call 'window.close()' on 4xx error pages
self._max_workers = max_workers # limit worker threads
Expand All @@ -119,20 +127,22 @@ def __init__(
if certs:
context = SSLContext(PROTOCOL_TLS_SERVER)
context.load_cert_chain(certs.host, certs.key)
self._socket = context.wrap_socket(sock, server_side=True)
self._socket: Union[socket, SSLSocket] = context.wrap_socket(
sock, server_side=True
)
self.scheme = "https"
else:
self._socket = sock
self.scheme = "http"
self.timeout = timeout

def __enter__(self):
def __enter__(self) -> "Sapphire":
return self

def __exit__(self, *exc):
def __exit__(self, *exc: Any) -> None:
self.close()

def clear_backlog(self):
def clear_backlog(self) -> None:
"""Remove all pending connections from backlog. This should only be
called when there isn't anything actively trying to connect.

Expand All @@ -158,7 +168,7 @@ def clear_backlog(self):
assert deadline > time()
self._socket.settimeout(self.LISTEN_TIMEOUT)

def close(self):
def close(self) -> None:
"""Close listening server socket.

Args:
Expand All @@ -170,25 +180,25 @@ def close(self):
self._socket.close()

@property
def port(self):
def port(self) -> int:
"""Port number of listening socket.

Args:
None

Returns:
int: Listening port number.
Listening port number.
"""
return self._socket.getsockname()[1]
return int(self._socket.getsockname()[1])

def serve_path(
self,
path,
continue_cb=None,
forever=False,
required_files=None,
server_map=None,
):
path: Path,
continue_cb: Optional[Callable[[], bool]] = None,
forever: bool = False,
required_files: Optional[Iterable[str]] = None,
server_map: Optional[ServerMap] = None,
) -> Tuple[Served, Dict[str, Path]]:
"""Serve files in path.

The status codes include:
Expand All @@ -197,17 +207,17 @@ def serve_path(
- Served.REQUEST: Some files were requested

Args:
path (Path): Directory to use as wwwroot.
continue_cb (callable): A callback that can be used to exit the serve loop.
This must be a callable that returns a bool.
forever (bool): Continue to handle requests even after all files have
been served. This is meant to be used with continue_cb.
required_files (list(str)): Files that need to be served in order to exit
the serve loop.
path: Directory to use as wwwroot.
continue_cb: A callback that can be used to exit the serve loop.
This must be a callable that returns a bool.
forever: Continue to handle requests even after all files have been served.
This is meant to be used with continue_cb.
required_files: Files that need to be served in order to exit the
serve loop.
server_map (ServerMap):

Returns:
tuple(int, dict[str, Path]): Status code and files served.
Status code and files served.
"""
assert isinstance(path, Path)
assert self.timeout >= 0
Expand All @@ -225,7 +235,7 @@ def serve_path(
return (Served.TIMEOUT if timed_out else job.status, job.served)

@classmethod
def main(cls, args):
def main(cls, args: Namespace) -> None:
try:
with cls(
allow_remote=args.remote, port=args.port, timeout=args.timeout
Expand Down
Loading
Loading