diff --git a/src/borg/archiver/_common.py b/src/borg/archiver/_common.py index 00729e2217..1e5d5bd2c5 100644 --- a/src/borg/archiver/_common.py +++ b/src/borg/archiver/_common.py @@ -29,7 +29,7 @@ def get_repository(location, *, create, exclusive, lock_wait, lock, append_only, make_parent_dirs, storage_quota, args): - if location.proto == "ssh": + if location.proto in ("ssh", "socket"): repository = RemoteRepository( location, create=create, diff --git a/src/borg/archiver/serve_cmd.py b/src/borg/archiver/serve_cmd.py index b69d60329c..f26ca4ce3c 100644 --- a/src/borg/archiver/serve_cmd.py +++ b/src/borg/archiver/serve_cmd.py @@ -19,6 +19,7 @@ def do_serve(self, args): restrict_to_repositories=args.restrict_to_repositories, append_only=args.append_only, storage_quota=args.storage_quota, + socket_path=args.socket_path, ).serve() return EXIT_SUCCESS @@ -82,3 +83,10 @@ def build_parser_serve(self, subparsers, common_parser, mid_common_parser): "When a new repository is initialized, sets the storage quota on the new " "repository as well. Default: no quota.", ) + subparser.add_argument( + "--socket", + metavar="PATH", + dest="socket_path", + action=Highlander, + help="create a UNIX DOMAIN (IPC) socket at PATH and listen on it.", + ) diff --git a/src/borg/helpers/parseformat.py b/src/borg/helpers/parseformat.py index 85dc017895..2b48fcd66a 100644 --- a/src/borg/helpers/parseformat.py +++ b/src/borg/helpers/parseformat.py @@ -382,7 +382,7 @@ class Location: # path must not contain :: (it ends at :: or string end), but may contain single colons. # to avoid ambiguities with other regexes, it must also not start with ":" nor with "//" nor with "ssh://". local_path_re = r""" - (?!(:|//|ssh://)) # not starting with ":" or // or ssh:// + (?!(:|//|ssh://|socket://)) # not starting with ":" or // or ssh:// or socket:// (?P([^:]|(:(?!:)))+) # any chars, but no "::" """ @@ -421,6 +421,14 @@ class Location: re.VERBOSE, ) # path + socket_re = re.compile( + r""" + (?Psocket):// # socket:// + """ + + abs_path_re, + re.VERBOSE, + ) # path + file_re = re.compile( r""" (?Pfile):// # file:// @@ -485,6 +493,11 @@ def normpath_special(p): self.path = normpath_special(m.group("path")) return True m = self.file_re.match(text) + if m: + self.proto = m.group("proto") + self.path = normpath_special(m.group("path")) + return True + m = self.socket_re.match(text) if m: self.proto = m.group("proto") self.path = normpath_special(m.group("path")) @@ -508,7 +521,7 @@ def __str__(self): def to_key_filename(self): name = re.sub(r"[^\w]", "_", self.path).strip("_") - if self.proto != "file": + if self.proto not in ("file", "socket"): name = re.sub(r"[^\w]", "_", self.host) + "__" + name if len(name) > 100: # Limit file names to some reasonable length. Most file systems @@ -527,7 +540,7 @@ def host(self): return self._host.lstrip("[").rstrip("]") def canonical_path(self): - if self.proto == "file": + if self.proto in ("file", "socket"): return self.path else: if self.path and self.path.startswith("~"): diff --git a/src/borg/remote.py b/src/borg/remote.py index 9349d9bc6b..a32b2fcfc1 100644 --- a/src/borg/remote.py +++ b/src/borg/remote.py @@ -7,6 +7,7 @@ import select import shlex import shutil +import socket import struct import sys import tempfile @@ -164,7 +165,7 @@ class RepositoryServer: # pragma: no cover "inject_exception", ) - def __init__(self, restrict_to_paths, restrict_to_repositories, append_only, storage_quota): + def __init__(self, restrict_to_paths, restrict_to_repositories, append_only, storage_quota, socket_path): self.repository = None self.restrict_to_paths = restrict_to_paths self.restrict_to_repositories = restrict_to_repositories @@ -174,6 +175,7 @@ def __init__(self, restrict_to_paths, restrict_to_repositories, append_only, sto # (see RepositoryServer.open below). self.append_only = append_only self.storage_quota = storage_quota + self.socket_path = socket_path self.client_version = parse_version( "1.0.8" ) # fallback version if client is too old to send version information @@ -196,123 +198,151 @@ def filter_args(self, f, kwargs): return {name: kwargs[name] for name in kwargs if name in known} def serve(self): - stdin_fd = sys.stdin.fileno() - stdout_fd = sys.stdout.fileno() - stderr_fd = sys.stdout.fileno() - os.set_blocking(stdin_fd, False) - os.set_blocking(stdout_fd, True) - os.set_blocking(stderr_fd, True) - unpacker = get_limited_unpacker("server") - while True: - r, w, es = select.select([stdin_fd], [], [], 10) - if r: - data = os.read(stdin_fd, BUFSIZE) - if not data: - if self.repository is not None: - self.repository.close() - else: - os_write( - stderr_fd, - "Borg {}: Got connection close before repository was opened.\n".format( - __version__ - ).encode(), - ) - return - unpacker.feed(data) - for unpacked in unpacker: - if isinstance(unpacked, dict): - dictFormat = True - msgid = unpacked[MSGID] - method = unpacked[MSG] - args = unpacked[ARGS] - elif isinstance(unpacked, tuple) and len(unpacked) == 4: - dictFormat = False - # The first field 'type' was always 1 and has always been ignored - _, msgid, method, args = unpacked - args = self.positional_to_named(method, args) - else: + def setup_blocking(stdin_fd, stdout_fd, stderr_fd): + os.set_blocking(stdin_fd, False) + assert not os.get_blocking(stdin_fd) + os.set_blocking(stdout_fd, True) + assert os.get_blocking(stdout_fd) + if stderr_fd != stdout_fd: + os.set_blocking(stderr_fd, True) + assert os.get_blocking(stderr_fd) + + def inner_serve(): + unpacker = get_limited_unpacker("server") + while True: + r, w, es = select.select([stdin_fd], [], [], 10) + if r: + data = os.read(stdin_fd, BUFSIZE) + if not data: if self.repository is not None: self.repository.close() - raise UnexpectedRPCDataFormatFromClient(__version__) - try: - if method not in self.rpc_methods: - raise InvalidRPCMethod(method) + else: + os_write( + stderr_fd, + "Borg {}: Got connection close before repository was opened.\n".format( + __version__ + ).encode(), + ) + return + unpacker.feed(data) + for unpacked in unpacker: + if isinstance(unpacked, dict): + dictFormat = True + msgid = unpacked[MSGID] + method = unpacked[MSG] + args = unpacked[ARGS] + elif isinstance(unpacked, tuple) and len(unpacked) == 4: + dictFormat = False + # The first field 'type' was always 1 and has always been ignored + _, msgid, method, args = unpacked + args = self.positional_to_named(method, args) + else: + if self.repository is not None: + self.repository.close() + raise UnexpectedRPCDataFormatFromClient(__version__) try: - f = getattr(self, method) - except AttributeError: - f = getattr(self.repository, method) - args = self.filter_args(f, args) - res = f(**args) - except BaseException as e: - if dictFormat: - ex_short = traceback.format_exception_only(e.__class__, e) - ex_full = traceback.format_exception(*sys.exc_info()) - ex_trace = True - if isinstance(e, Error): - ex_short = [e.get_message()] - ex_trace = e.traceback - if isinstance(e, (Repository.DoesNotExist, Repository.AlreadyExists, PathNotAllowed)): - # These exceptions are reconstructed on the client end in RemoteRepository.call_many(), - # and will be handled just like locally raised exceptions. Suppress the remote traceback - # for these, except ErrorWithTraceback, which should always display a traceback. - pass - else: - logging.debug("\n".join(ex_full)) - + if method not in self.rpc_methods: + raise InvalidRPCMethod(method) try: - msg = msgpack.packb( - { - MSGID: msgid, - "exception_class": e.__class__.__name__, - "exception_args": e.args, - "exception_full": ex_full, - "exception_short": ex_short, - "exception_trace": ex_trace, - "sysinfo": sysinfo(), - } - ) - except TypeError: - msg = msgpack.packb( - { - MSGID: msgid, - "exception_class": e.__class__.__name__, - "exception_args": [ - x if isinstance(x, (str, bytes, int)) else None for x in e.args - ], - "exception_full": ex_full, - "exception_short": ex_short, - "exception_trace": ex_trace, - "sysinfo": sysinfo(), - } - ) - - os_write(stdout_fd, msg) - else: - if isinstance(e, (Repository.DoesNotExist, Repository.AlreadyExists, PathNotAllowed)): - # These exceptions are reconstructed on the client end in RemoteRepository.call_many(), - # and will be handled just like locally raised exceptions. Suppress the remote traceback - # for these, except ErrorWithTraceback, which should always display a traceback. - pass - else: + f = getattr(self, method) + except AttributeError: + f = getattr(self.repository, method) + args = self.filter_args(f, args) + res = f(**args) + except BaseException as e: + if dictFormat: + ex_short = traceback.format_exception_only(e.__class__, e) + ex_full = traceback.format_exception(*sys.exc_info()) + ex_trace = True if isinstance(e, Error): - tb_log_level = logging.ERROR if e.traceback else logging.DEBUG - msg = e.get_message() + ex_short = [e.get_message()] + ex_trace = e.traceback + if isinstance(e, (Repository.DoesNotExist, Repository.AlreadyExists, PathNotAllowed)): + # These exceptions are reconstructed on the client end in RemoteRepository.call_many(), + # and will be handled just like locally raised exceptions. Suppress the remote traceback + # for these, except ErrorWithTraceback, which should always display a traceback. + pass else: - tb_log_level = logging.ERROR - msg = "%s Exception in RPC call" % e.__class__.__name__ - tb = f"{traceback.format_exc()}\n{sysinfo()}" - logging.error(msg) - logging.log(tb_log_level, tb) - exc = "Remote Exception (see remote log for the traceback)" - os_write(stdout_fd, msgpack.packb((1, msgid, e.__class__.__name__, exc))) - else: - if dictFormat: - os_write(stdout_fd, msgpack.packb({MSGID: msgid, RESULT: res})) + logging.debug("\n".join(ex_full)) + + try: + msg = msgpack.packb( + { + MSGID: msgid, + "exception_class": e.__class__.__name__, + "exception_args": e.args, + "exception_full": ex_full, + "exception_short": ex_short, + "exception_trace": ex_trace, + "sysinfo": sysinfo(), + } + ) + except TypeError: + msg = msgpack.packb( + { + MSGID: msgid, + "exception_class": e.__class__.__name__, + "exception_args": [ + x if isinstance(x, (str, bytes, int)) else None for x in e.args + ], + "exception_full": ex_full, + "exception_short": ex_short, + "exception_trace": ex_trace, + "sysinfo": sysinfo(), + } + ) + + os_write(stdout_fd, msg) + else: + if isinstance(e, (Repository.DoesNotExist, Repository.AlreadyExists, PathNotAllowed)): + # These exceptions are reconstructed on the client end in RemoteRepository.call_many(), + # and will be handled just like locally raised exceptions. Suppress the remote traceback + # for these, except ErrorWithTraceback, which should always display a traceback. + pass + else: + if isinstance(e, Error): + tb_log_level = logging.ERROR if e.traceback else logging.DEBUG + msg = e.get_message() + else: + tb_log_level = logging.ERROR + msg = "%s Exception in RPC call" % e.__class__.__name__ + tb = f"{traceback.format_exc()}\n{sysinfo()}" + logging.error(msg) + logging.log(tb_log_level, tb) + exc = "Remote Exception (see remote log for the traceback)" + os_write(stdout_fd, msgpack.packb((1, msgid, e.__class__.__name__, exc))) else: - os_write(stdout_fd, msgpack.packb((1, msgid, None, res))) - if es: - self.repository.close() - return + if dictFormat: + os_write(stdout_fd, msgpack.packb({MSGID: msgid, RESULT: res})) + else: + os_write(stdout_fd, msgpack.packb((1, msgid, None, res))) + if es: + self.repository.close() + return + + if self.socket_path: # server for socket:// connections + try: + # remove any left-over socket file + os.unlink(self.socket_path) + except OSError: + if os.path.exists(self.socket_path): + raise + sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM) + sock.bind(self.socket_path) # this creates the socket file in the fs + sock.listen(0) # no backlog + while True: + connection, client_address = sock.accept() + stdin_fd = connection.makefile("rb").fileno() + stdout_fd = connection.makefile("wb").fileno() + stderr_fd = stdout_fd # TODO log output on sys.stderr is not going over the socket + setup_blocking(stdin_fd, stdout_fd, stderr_fd) + inner_serve() + else: # server for one ssh:// connection + stdin_fd = sys.stdin.fileno() + stdout_fd = sys.stdout.fileno() + stderr_fd = stdout_fd + setup_blocking(stdin_fd, stdout_fd, stderr_fd) + inner_serve() def negotiate(self, client_data): # old format used in 1.0.x @@ -574,27 +604,49 @@ def __init__( self.server_version = parse_version( "1.0.8" ) # fallback version if server is too old to send version information - self.p = None + self.p = self.sock = None self._args = args - testing = location.host == "__testsuite__" - # when testing, we invoke and talk to a borg process directly (no ssh). - # when not testing, we invoke the system-installed ssh binary to talk to a remote borg. - env = prepare_subprocess_env(system=not testing) - borg_cmd = self.borg_cmd(args, testing) - if not testing: - borg_cmd = self.ssh_cmd(location) + borg_cmd - logger.debug("SSH command line: %s", borg_cmd) - # we do not want the ssh getting killed by Ctrl-C/SIGINT because it is needed for clean shutdown of borg. - # borg's SIGINT handler tries to write a checkpoint and requires the remote repo connection. - self.p = Popen(borg_cmd, bufsize=0, stdin=PIPE, stdout=PIPE, stderr=PIPE, env=env, preexec_fn=ignore_sigint) - self.stdin_fd = self.p.stdin.fileno() - self.stdout_fd = self.p.stdout.fileno() - self.stderr_fd = self.p.stderr.fileno() + if self.location.proto == "ssh": + testing = location.host == "__testsuite__" + # when testing, we invoke and talk to a borg process directly (no ssh). + # when not testing, we invoke the system-installed ssh binary to talk to a remote borg. + env = prepare_subprocess_env(system=not testing) + borg_cmd = self.borg_cmd(args, testing) + if not testing: + borg_cmd = self.ssh_cmd(location) + borg_cmd + logger.debug("SSH command line: %s", borg_cmd) + # we do not want the ssh getting killed by Ctrl-C/SIGINT because it is needed for clean shutdown of borg. + # borg's SIGINT handler tries to write a checkpoint and requires the remote repo connection. + self.p = Popen(borg_cmd, bufsize=0, stdin=PIPE, stdout=PIPE, stderr=PIPE, env=env, preexec_fn=ignore_sigint) + self.stdin_fd = self.p.stdin.fileno() + self.stdout_fd = self.p.stdout.fileno() + self.stderr_fd = self.p.stderr.fileno() + self.r_fds = [self.stdout_fd, self.stderr_fd] + self.x_fds = [self.stdin_fd, self.stdout_fd, self.stderr_fd] + elif self.location.proto == "socket": + socket_path = os.path.join(location.path, "socket") + self.sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM) + try: + self.sock.connect(socket_path) + except FileNotFoundError: + raise Error(f"The socket file {socket_path} does not exist.") + except ConnectionRefusedError: + raise Error(f"There is no borg serve running for the socket file {socket_path}.") + self.stdin_fd = self.sock.makefile("wb").fileno() + self.stdout_fd = self.sock.makefile("rb").fileno() + self.stderr_fd = None + self.r_fds = [self.stdout_fd] + self.x_fds = [self.stdin_fd, self.stdout_fd] + else: + raise Error(f"Unsupported protocol {location.proto}") + os.set_blocking(self.stdin_fd, False) + assert not os.get_blocking(self.stdin_fd) os.set_blocking(self.stdout_fd, False) - os.set_blocking(self.stderr_fd, False) - self.r_fds = [self.stdout_fd, self.stderr_fd] - self.x_fds = [self.stdin_fd, self.stdout_fd, self.stderr_fd] + assert not os.get_blocking(self.stdout_fd) + if self.stderr_fd is not None: + os.set_blocking(self.stderr_fd, False) + assert not os.get_blocking(self.stderr_fd) try: try: @@ -653,7 +705,7 @@ def do_open(): def __del__(self): if len(self.responses): logging.debug("still %d cached responses left in RemoteRepository" % (len(self.responses),)) - if self.p: + if self.p or self.sock: self.close() assert False, "cleanup happened in Repository.__del__" @@ -1042,6 +1094,14 @@ def close(self): self.p.stdout.close() self.p.wait() self.p = None + if self.sock: + try: + self.sock.shutdown(socket.SHUT_RDWR) + except OSError as e: + if e.errno != errno.ENOTCONN: + raise + self.sock.close() + self.sock = None def async_response(self, wait=True): for resp in self.call_many("async_responses", calls=[], wait=True, async_wait=wait): diff --git a/src/borg/testsuite/helpers.py b/src/borg/testsuite/helpers.py index df9b2ea2db..3735482151 100644 --- a/src/borg/testsuite/helpers.py +++ b/src/borg/testsuite/helpers.py @@ -184,6 +184,14 @@ def test_ssh(self, monkeypatch, keys_dir): == "Location(proto='ssh', user='user', host='2a02:0001:0002:0003:0004:0005:0006:0007', port=1234, path='/some/path')" ) + def test_socket(self, monkeypatch, keys_dir): + monkeypatch.delenv("BORG_REPO", raising=False) + assert ( + repr(Location("socket:///repo/path")) + == "Location(proto='socket', user=None, host=None, port=None, path='/repo/path')" + ) + assert Location("socket:///some/path").to_key_filename() == keys_dir + "some_path" + def test_file(self, monkeypatch, keys_dir): monkeypatch.delenv("BORG_REPO", raising=False) assert ( @@ -275,6 +283,7 @@ def test_canonical_path(self, monkeypatch): "file://some/path", "host:some/path", "host:~user/some/path", + "socket:///some/path", "ssh://host/some/path", "ssh://user@host:1234/some/path", ]