Sfoglia il codice sorgente

Merge pull request #7615 from ThomasWaldmann/serve-socket2

implement unix domain socket support
TW 2 anni fa
parent
commit
cad57e70c3

+ 4 - 0
docs/usage/general/environment.rst.inc

@@ -150,6 +150,10 @@ Directories and files:
         `XDG env var`_ ``XDG_DATA_HOME`` is set, then ``$XDG_DATA_HOME/borg`` is being used instead.
         This directory contains all borg data directories, see the FAQ
         for a security advisory about the data in this directory: :ref:`home_data_borg`
+    BORG_RUNTIME_DIR
+        Defaults to ``$BORG_BASE_DIR/.cache/borg``. If ``BORG_BASE_DIR`` is not explicitly set while
+        `XDG env var`_ ``XDG_RUNTIME_DIR`` is set, then ``$XDG_RUNTIME_DIR/borg`` is being used instead.
+        This directory contains borg runtime files, like e.g. the socket file.
     BORG_SECURITY_DIR
         Defaults to ``$BORG_DATA_DIR/security``.
         This directory contains security relevant data.

+ 11 - 1
src/borg/archiver/_common.py

@@ -29,7 +29,7 @@ logger = create_logger(__name__)
 
 
 def get_repository(location, *, create, exclusive, lock_wait, lock, append_only, make_parent_dirs, storage_quota, args):
-    if location.proto == "ssh":
+    if location.proto in ("ssh", "socket"):
         repository = RemoteRepository(
             location,
             create=create,
@@ -573,6 +573,16 @@ def define_common_options(add_common_option):
         action=Highlander,
         help="Use this command to connect to the 'borg serve' process (default: 'ssh')",
     )
+    add_common_option(
+        "--socket",
+        metavar="PATH",
+        dest="use_socket",
+        default=False,
+        const=True,
+        nargs="?",
+        action=Highlander,
+        help="Use UNIX DOMAIN (IPC) socket at PATH for client/server communication with socket: protocol.",
+    )
     add_common_option(
         "-r",
         "--repo",

+ 12 - 1
src/borg/archiver/serve_cmd.py

@@ -19,6 +19,7 @@ class ServeMixIn:
             restrict_to_repositories=args.restrict_to_repositories,
             append_only=args.append_only,
             storage_quota=args.storage_quota,
+            use_socket=args.use_socket,
         ).serve()
         return EXIT_SUCCESS
 
@@ -27,7 +28,17 @@ class ServeMixIn:
 
         serve_epilog = process_epilog(
             """
-        This command starts a repository server process. This command is usually not used manually.
+        This command starts a repository server process.
+
+        borg serve can currently support:
+
+        - Getting automatically started via ssh when the borg client uses a ssh://...
+          remote repository. In this mode, `borg serve` will live until that ssh connection
+          gets terminated.
+
+        - Getting started by some other means (not by the borg client) as a long-running socket
+          server to be used for borg clients using a socket://... repository (see the `--socket`
+          option if you do not want to use the default path for the socket and pid file).
         """
         )
         subparser = subparsers.add_parser(

+ 2 - 1
src/borg/helpers/__init__.py

@@ -11,7 +11,8 @@ from ..constants import *  # NOQA
 from .checks import check_extension_modules, check_python
 from .datastruct import StableDict, Buffer, EfficientCollectionQueue
 from .errors import Error, ErrorWithTraceback, IntegrityError, DecompressionError
-from .fs import ensure_dir, get_security_dir, get_keys_dir, get_base_dir, join_base_dir, get_cache_dir, get_config_dir
+from .fs import ensure_dir, join_base_dir, get_socket_filename
+from .fs import get_security_dir, get_keys_dir, get_base_dir, get_cache_dir, get_config_dir, get_runtime_dir
 from .fs import dir_is_tagged, dir_is_cachedir, make_path_safe, scandir_inorder
 from .fs import secure_erase, safe_unlink, dash_open, os_open, os_stat, umount
 from .fs import O_, flags_root, flags_dir, flags_special_follow, flags_special, flags_base, flags_normal, flags_noatime

+ 16 - 0
src/borg/helpers/fs.py

@@ -106,6 +106,22 @@ def get_data_dir(*, legacy=False):
     return data_dir
 
 
+def get_runtime_dir(*, legacy=False):
+    """Determine where to store runtime files, like sockets, PID files, ..."""
+    assert legacy is False, "there is no legacy variant of the borg runtime dir"
+    runtime_dir = os.environ.get(
+        "BORG_RUNTIME_DIR", join_base_dir(".cache", "borg", legacy=legacy) or platformdirs.user_runtime_dir("borg")
+    )
+
+    # Create path if it doesn't exist yet
+    ensure_dir(runtime_dir)
+    return runtime_dir
+
+
+def get_socket_filename():
+    return os.path.join(get_runtime_dir(), "borg.sock")
+
+
 def get_cache_dir(*, legacy=False):
     """Determine where to repository keys and cache"""
 

+ 16 - 3
src/borg/helpers/parseformat.py

@@ -390,7 +390,7 @@ class Location:
     # path must not contain :: (it ends at :: or string end), but may contain single colons.
     # to avoid ambiguities with other regexes, it must also not start with ":" nor with "//" nor with "ssh://".
     local_path_re = r"""
-        (?!(:|//|ssh://))                                   # not starting with ":" or // or ssh://
+        (?!(:|//|ssh://|socket://))                         # not starting with ":" or // or ssh:// or socket://
         (?P<path>([^:]|(:(?!:)))+)                          # any chars, but no "::"
         """
 
@@ -429,6 +429,14 @@ class Location:
         re.VERBOSE,
     )  # path
 
+    socket_re = re.compile(
+        r"""
+        (?P<proto>socket)://                                    # socket://
+        """
+        + abs_path_re,
+        re.VERBOSE,
+    )  # path
+
     file_re = re.compile(
         r"""
         (?P<proto>file)://                                      # file://
@@ -493,6 +501,11 @@ class Location:
             self.path = normpath_special(m.group("path"))
             return True
         m = self.file_re.match(text)
+        if m:
+            self.proto = m.group("proto")
+            self.path = normpath_special(m.group("path"))
+            return True
+        m = self.socket_re.match(text)
         if m:
             self.proto = m.group("proto")
             self.path = normpath_special(m.group("path"))
@@ -516,7 +529,7 @@ class Location:
 
     def to_key_filename(self):
         name = re.sub(r"[^\w]", "_", self.path).strip("_")
-        if self.proto != "file":
+        if self.proto not in ("file", "socket"):
             name = re.sub(r"[^\w]", "_", self.host) + "__" + name
         if len(name) > 100:
             # Limit file names to some reasonable length. Most file systems
@@ -535,7 +548,7 @@ class Location:
             return self._host.lstrip("[").rstrip("]")
 
     def canonical_path(self):
-        if self.proto == "file":
+        if self.proto in ("file", "socket"):
             return self.path
         else:
             if self.path and self.path.startswith("~"):

+ 26 - 4
src/borg/logger.py

@@ -28,6 +28,24 @@ The way to use this is as follows:
 
 * what is output on INFO level is additionally controlled by commandline
   flags
+
+Logging setup is a bit complicated in borg, as it needs to work under misc. conditions:
+- purely local, not client/server (easy)
+- client/server: RemoteRepository ("borg serve" process) writes log records into a global
+  queue, which is then sent to the client side by the main serve loop (via the RPC protocol,
+  either over ssh stdout, more directly via process stdout without ssh [used in the tests]
+  or via a socket. On the client side, the log records are fed into the clientside logging
+  system. When remote_repo.close() is called, server side must send all queued log records
+  via the RPC channel before returning the close() call's return value (as the client will
+  then shut down the connection).
+- progress output is always given as json to the logger (including the plain text inside
+  the json), but then formatted by the logging system's formatter as either plain text or
+  json depending on the cli args given (--log-json?).
+- tests: potentially running in parallel via pytest-xdist, capturing borg output into a
+  given stream.
+- logging might be short-lived (e.g. when invoking a single borg command via the cli)
+  or long-lived (e.g. borg serve --socket or when running the tests)
+- logging is global and exists only once per process.
 """
 
 import inspect
@@ -115,10 +133,14 @@ def remove_handlers(logger):
         logger.removeHandler(handler)
 
 
-def teardown_logging():
-    global configured
-    logging.shutdown()
-    configured = False
+def flush_logging():
+    # make sure all log output is flushed,
+    # this is especially important for the "borg serve" RemoteRepository logging:
+    # all log output needs to be sent via the ssh / socket connection before closing it.
+    for logger_name in "borg.output.progress", "":
+        logger = logging.getLogger(logger_name)
+        for handler in logger.handlers:
+            handler.flush()
 
 
 def setup_logging(

+ 193 - 108
src/borg/remote.py

@@ -1,3 +1,4 @@
+import atexit
 import errno
 import functools
 import inspect
@@ -7,6 +8,7 @@ import queue
 import select
 import shlex
 import shutil
+import socket
 import struct
 import sys
 import tempfile
@@ -27,6 +29,7 @@ from .helpers import sysinfo
 from .helpers import format_file_size
 from .helpers import safe_unlink
 from .helpers import prepare_subprocess_env, ignore_sigint
+from .helpers import get_socket_filename
 from .logger import create_logger, borg_serve_log_queue
 from .helpers import msgpack
 from .repository import Repository
@@ -136,7 +139,7 @@ class RepositoryServer:  # pragma: no cover
         "inject_exception",
     )
 
-    def __init__(self, restrict_to_paths, restrict_to_repositories, append_only, storage_quota):
+    def __init__(self, restrict_to_paths, restrict_to_repositories, append_only, storage_quota, use_socket):
         self.repository = None
         self.restrict_to_paths = restrict_to_paths
         self.restrict_to_repositories = restrict_to_repositories
@@ -147,6 +150,12 @@ class RepositoryServer:  # pragma: no cover
         self.append_only = append_only
         self.storage_quota = storage_quota
         self.client_version = None  # we update this after client sends version information
+        if use_socket is False:
+            self.socket_path = None
+        elif use_socket is True:  # --socket
+            self.socket_path = get_socket_filename()
+        else:  # --socket=/some/path
+            self.socket_path = use_socket
 
     def filter_args(self, f, kwargs):
         """Remove unknown named parameters from call, because client did (implicitly) say it's ok."""
@@ -165,95 +174,133 @@ class RepositoryServer:  # pragma: no cover
                 os_write(self.stdout_fd, msg)
 
     def serve(self):
-        self.stdin_fd = sys.stdin.fileno()
-        self.stdout_fd = sys.stdout.fileno()
-        os.set_blocking(self.stdin_fd, False)
-        os.set_blocking(self.stdout_fd, True)
-        unpacker = get_limited_unpacker("server")
-        shutdown_serve = False
-        while True:
-            # before processing any new RPCs, send out all pending log output
-            self.send_queued_log()
-
-            if shutdown_serve:
-                # shutdown wanted! get out of here after sending all log output.
-                if self.repository is not None:
-                    self.repository.close()
-                return
-
-            # process new RPCs
-            r, w, es = select.select([self.stdin_fd], [], [], 10)
-            if r:
-                data = os.read(self.stdin_fd, BUFSIZE)
-                if not data:
-                    shutdown_serve = True
-                    continue
-                unpacker.feed(data)
-                for unpacked in unpacker:
-                    if isinstance(unpacked, dict):
-                        msgid = unpacked[MSGID]
-                        method = unpacked[MSG]
-                        args = unpacked[ARGS]
-                    else:
-                        if self.repository is not None:
-                            self.repository.close()
-                        raise UnexpectedRPCDataFormatFromClient(__version__)
-                    try:
-                        if method not in self.rpc_methods:
-                            raise InvalidRPCMethod(method)
+        def inner_serve():
+            os.set_blocking(self.stdin_fd, False)
+            assert not os.get_blocking(self.stdin_fd)
+            os.set_blocking(self.stdout_fd, True)
+            assert os.get_blocking(self.stdout_fd)
+
+            unpacker = get_limited_unpacker("server")
+            shutdown_serve = False
+            while True:
+                # before processing any new RPCs, send out all pending log output
+                self.send_queued_log()
+
+                if shutdown_serve:
+                    # shutdown wanted! get out of here after sending all log output.
+                    assert self.repository is None
+                    return
+
+                # process new RPCs
+                r, w, es = select.select([self.stdin_fd], [], [], 10)
+                if r:
+                    data = os.read(self.stdin_fd, BUFSIZE)
+                    if not data:
+                        shutdown_serve = True
+                        continue
+                    unpacker.feed(data)
+                    for unpacked in unpacker:
+                        if isinstance(unpacked, dict):
+                            msgid = unpacked[MSGID]
+                            method = unpacked[MSG]
+                            args = unpacked[ARGS]
+                        else:
+                            if self.repository is not None:
+                                self.repository.close()
+                            raise UnexpectedRPCDataFormatFromClient(__version__)
                         try:
-                            f = getattr(self, method)
-                        except AttributeError:
-                            f = getattr(self.repository, method)
-                        args = self.filter_args(f, args)
-                        res = f(**args)
-                    except BaseException as e:
-                        ex_short = traceback.format_exception_only(e.__class__, e)
-                        ex_full = traceback.format_exception(*sys.exc_info())
-                        ex_trace = True
-                        if isinstance(e, Error):
-                            ex_short = [e.get_message()]
-                            ex_trace = e.traceback
-                        if isinstance(e, (Repository.DoesNotExist, Repository.AlreadyExists, PathNotAllowed)):
-                            # These exceptions are reconstructed on the client end in RemoteRepository.call_many(),
-                            # and will be handled just like locally raised exceptions. Suppress the remote traceback
-                            # for these, except ErrorWithTraceback, which should always display a traceback.
-                            pass
+                            if method not in self.rpc_methods:
+                                raise InvalidRPCMethod(method)
+                            try:
+                                f = getattr(self, method)
+                            except AttributeError:
+                                f = getattr(self.repository, method)
+                            args = self.filter_args(f, args)
+                            res = f(**args)
+                        except BaseException as e:
+                            ex_short = traceback.format_exception_only(e.__class__, e)
+                            ex_full = traceback.format_exception(*sys.exc_info())
+                            ex_trace = True
+                            if isinstance(e, Error):
+                                ex_short = [e.get_message()]
+                                ex_trace = e.traceback
+                            if isinstance(e, (Repository.DoesNotExist, Repository.AlreadyExists, PathNotAllowed)):
+                                # These exceptions are reconstructed on the client end in RemoteRepository.call_many(),
+                                # and will be handled just like locally raised exceptions. Suppress the remote traceback
+                                # for these, except ErrorWithTraceback, which should always display a traceback.
+                                pass
+                            else:
+                                logging.debug("\n".join(ex_full))
+
+                            sys_info = sysinfo()
+                            try:
+                                msg = msgpack.packb(
+                                    {
+                                        MSGID: msgid,
+                                        "exception_class": e.__class__.__name__,
+                                        "exception_args": e.args,
+                                        "exception_full": ex_full,
+                                        "exception_short": ex_short,
+                                        "exception_trace": ex_trace,
+                                        "sysinfo": sys_info,
+                                    }
+                                )
+                            except TypeError:
+                                msg = msgpack.packb(
+                                    {
+                                        MSGID: msgid,
+                                        "exception_class": e.__class__.__name__,
+                                        "exception_args": [
+                                            x if isinstance(x, (str, bytes, int)) else None for x in e.args
+                                        ],
+                                        "exception_full": ex_full,
+                                        "exception_short": ex_short,
+                                        "exception_trace": ex_trace,
+                                        "sysinfo": sys_info,
+                                    }
+                                )
+                            os_write(self.stdout_fd, msg)
                         else:
-                            logging.debug("\n".join(ex_full))
+                            os_write(self.stdout_fd, msgpack.packb({MSGID: msgid, RESULT: res}))
+                if es:
+                    shutdown_serve = True
+                    continue
 
-                        sys_info = sysinfo()
-                        try:
-                            msg = msgpack.packb(
-                                {
-                                    MSGID: msgid,
-                                    "exception_class": e.__class__.__name__,
-                                    "exception_args": e.args,
-                                    "exception_full": ex_full,
-                                    "exception_short": ex_short,
-                                    "exception_trace": ex_trace,
-                                    "sysinfo": sys_info,
-                                }
-                            )
-                        except TypeError:
-                            msg = msgpack.packb(
-                                {
-                                    MSGID: msgid,
-                                    "exception_class": e.__class__.__name__,
-                                    "exception_args": [x if isinstance(x, (str, bytes, int)) else None for x in e.args],
-                                    "exception_full": ex_full,
-                                    "exception_short": ex_short,
-                                    "exception_trace": ex_trace,
-                                    "sysinfo": sys_info,
-                                }
-                            )
-
-                        os_write(self.stdout_fd, msg)
-                    else:
-                        os_write(self.stdout_fd, msgpack.packb({MSGID: msgid, RESULT: res}))
-            if es:
-                shutdown_serve = True
-                continue
+        if self.socket_path:  # server for socket:// connections
+            try:
+                # remove any left-over socket file
+                os.unlink(self.socket_path)
+            except OSError:
+                if os.path.exists(self.socket_path):
+                    raise
+            sock_dir = os.path.dirname(self.socket_path)
+            os.makedirs(sock_dir, exist_ok=True)
+            if self.socket_path.endswith(".sock"):
+                pid_file = self.socket_path.replace(".sock", ".pid")
+            else:
+                pid_file = self.socket_path + ".pid"
+            pid = os.getpid()
+            with open(pid_file, "w") as f:
+                f.write(str(pid))
+            atexit.register(functools.partial(os.remove, pid_file))
+            atexit.register(functools.partial(os.remove, self.socket_path))
+            sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
+            sock.bind(self.socket_path)  # this creates the socket file in the fs
+            sock.listen(0)  # no backlog
+            os.chmod(self.socket_path, mode=0o0770)  # group members may use the socket, too.
+            print(f"borg serve: PID {pid}, listening on socket {self.socket_path} ...", file=sys.stderr)
+
+            while True:
+                connection, client_address = sock.accept()
+                print(f"Accepted a connection on socket {self.socket_path} ...", file=sys.stderr)
+                self.stdin_fd = connection.makefile("rb").fileno()
+                self.stdout_fd = connection.makefile("wb").fileno()
+                inner_serve()
+                print(f"Finished with connection on socket {self.socket_path} .", file=sys.stderr)
+        else:  # server for one ssh:// connection
+            self.stdin_fd = sys.stdin.fileno()
+            self.stdout_fd = sys.stdout.fileno()
+            inner_serve()
 
     def negotiate(self, client_data):
         if isinstance(client_data, dict):
@@ -318,7 +365,8 @@ class RepositoryServer:  # pragma: no cover
     def close(self):
         if self.repository is not None:
             self.repository.__exit__(None, None, None)
-        borg.logger.teardown_logging()
+            self.repository = None
+        borg.logger.flush_logging()
         self.send_queued_log()
 
     def inject_exception(self, kind):
@@ -489,6 +537,7 @@ class RemoteRepository:
         self.rx_bytes = 0
         self.tx_bytes = 0
         self.to_send = EfficientCollectionQueue(1024 * 1024, bytes)
+        self.stdin_fd = self.stdout_fd = self.stderr_fd = None
         self.stderr_received = b""  # incomplete stderr line bytes received (no \n yet)
         self.chunkid_to_msgids = {}
         self.ignore_responses = set()
@@ -499,27 +548,54 @@ class RemoteRepository:
         self.upload_buffer_size_limit = args.upload_buffer * 1024 * 1024 if args and args.upload_buffer else 0
         self.unpacker = get_limited_unpacker("client")
         self.server_version = None  # we update this after server sends its version
-        self.p = None
+        self.p = self.sock = None
         self._args = args
-        testing = location.host == "__testsuite__"
-        # when testing, we invoke and talk to a borg process directly (no ssh).
-        # when not testing, we invoke the system-installed ssh binary to talk to a remote borg.
-        env = prepare_subprocess_env(system=not testing)
-        borg_cmd = self.borg_cmd(args, testing)
-        if not testing:
-            borg_cmd = self.ssh_cmd(location) + borg_cmd
-        logger.debug("SSH command line: %s", borg_cmd)
-        # we do not want the ssh getting killed by Ctrl-C/SIGINT because it is needed for clean shutdown of borg.
-        # borg's SIGINT handler tries to write a checkpoint and requires the remote repo connection.
-        self.p = Popen(borg_cmd, bufsize=0, stdin=PIPE, stdout=PIPE, stderr=PIPE, env=env, preexec_fn=ignore_sigint)
-        self.stdin_fd = self.p.stdin.fileno()
-        self.stdout_fd = self.p.stdout.fileno()
-        self.stderr_fd = self.p.stderr.fileno()
+        if self.location.proto == "ssh":
+            testing = location.host == "__testsuite__"
+            # when testing, we invoke and talk to a borg process directly (no ssh).
+            # when not testing, we invoke the system-installed ssh binary to talk to a remote borg.
+            env = prepare_subprocess_env(system=not testing)
+            borg_cmd = self.borg_cmd(args, testing)
+            if not testing:
+                borg_cmd = self.ssh_cmd(location) + borg_cmd
+            logger.debug("SSH command line: %s", borg_cmd)
+            # we do not want the ssh getting killed by Ctrl-C/SIGINT because it is needed for clean shutdown of borg.
+            # borg's SIGINT handler tries to write a checkpoint and requires the remote repo connection.
+            self.p = Popen(borg_cmd, bufsize=0, stdin=PIPE, stdout=PIPE, stderr=PIPE, env=env, preexec_fn=ignore_sigint)
+            self.stdin_fd = self.p.stdin.fileno()
+            self.stdout_fd = self.p.stdout.fileno()
+            self.stderr_fd = self.p.stderr.fileno()
+            self.r_fds = [self.stdout_fd, self.stderr_fd]
+            self.x_fds = [self.stdin_fd, self.stdout_fd, self.stderr_fd]
+        elif self.location.proto == "socket":
+            if args.use_socket is False or args.use_socket is True:  # nothing or --socket
+                socket_path = get_socket_filename()
+            else:  # --socket=/some/path
+                socket_path = args.use_socket
+            self.sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
+            try:
+                self.sock.connect(socket_path)  # note: socket_path length is rather limited.
+            except FileNotFoundError:
+                self.sock = None
+                raise Error(f"The socket file {socket_path} does not exist.")
+            except ConnectionRefusedError:
+                self.sock = None
+                raise Error(f"There is no borg serve running for the socket file {socket_path}.")
+            self.stdin_fd = self.sock.makefile("wb").fileno()
+            self.stdout_fd = self.sock.makefile("rb").fileno()
+            self.stderr_fd = None
+            self.r_fds = [self.stdout_fd]
+            self.x_fds = [self.stdin_fd, self.stdout_fd]
+        else:
+            raise Error(f"Unsupported protocol {location.proto}")
+
         os.set_blocking(self.stdin_fd, False)
+        assert not os.get_blocking(self.stdin_fd)
         os.set_blocking(self.stdout_fd, False)
-        os.set_blocking(self.stderr_fd, False)
-        self.r_fds = [self.stdout_fd, self.stderr_fd]
-        self.x_fds = [self.stdin_fd, self.stdout_fd, self.stderr_fd]
+        assert not os.get_blocking(self.stdout_fd)
+        if self.stderr_fd is not None:
+            os.set_blocking(self.stderr_fd, False)
+            assert not os.get_blocking(self.stderr_fd)
 
         try:
             try:
@@ -551,7 +627,7 @@ class RemoteRepository:
     def __del__(self):
         if len(self.responses):
             logging.debug("still %d cached responses left in RemoteRepository" % (len(self.responses),))
-        if self.p:
+        if self.p or self.sock:
             self.close()
             assert False, "cleanup happened in Repository.__del__"
 
@@ -906,12 +982,21 @@ class RemoteRepository:
         """actual remoting is done via self.call in the @api decorator"""
 
     def close(self):
-        self.call("close", {}, wait=True)
+        if self.p or self.sock:
+            self.call("close", {}, wait=True)
         if self.p:
             self.p.stdin.close()
             self.p.stdout.close()
             self.p.wait()
             self.p = None
+        if self.sock:
+            try:
+                self.sock.shutdown(socket.SHUT_RDWR)
+            except OSError as e:
+                if e.errno != errno.ENOTCONN:
+                    raise
+            self.sock.close()
+            self.sock = None
 
     def async_response(self, wait=True):
         for resp in self.call_many("async_responses", calls=[], wait=True, async_wait=wait):

+ 3 - 3
src/borg/testsuite/archiver/__init__.py

@@ -21,7 +21,7 @@ from ...constants import *  # NOQA
 from ...helpers import Location
 from ...helpers import EXIT_SUCCESS
 from ...helpers import bin_to_hex
-from ...logger import teardown_logging
+from ...logger import flush_logging
 from ...manifest import Manifest
 from ...remote import RemoteRepository
 from ...repository import Repository
@@ -83,9 +83,9 @@ def exec_cmd(*args, archiver=None, fork=False, exe=None, input=b"", binary_outpu
                 output_text.flush()
                 return e.code, output.getvalue() if binary_output else output.getvalue().decode()
             try:
-                ret = archiver.run(args)
+                ret = archiver.run(args)  # calls setup_logging internally
             finally:
-                teardown_logging()  # usually done via atexit, but we do not exit here
+                flush_logging()  # usually done via atexit, but we do not exit here
             output_text.flush()
             return ret, output.getvalue() if binary_output else output.getvalue().decode()
         finally:

+ 50 - 0
src/borg/testsuite/archiver/serve_cmd.py

@@ -0,0 +1,50 @@
+import os
+import subprocess
+import tempfile
+import time
+
+import pytest
+import platformdirs
+
+from . import exec_cmd
+from ...platformflags import is_win32
+from ...helpers import get_runtime_dir
+
+
+def have_a_short_runtime_dir(mp):
+    # under pytest, we use BORG_BASE_DIR to keep stuff away from the user's normal borg dirs.
+    # this leads to a very long get_runtime_dir() path - too long for a socket file!
+    # thus, we override that again via BORG_RUNTIME_DIR to get a shorter path.
+    mp.setenv("BORG_RUNTIME_DIR", os.path.join(platformdirs.user_runtime_dir(), "pytest"))
+
+
+@pytest.fixture
+def serve_socket(monkeypatch):
+    have_a_short_runtime_dir(monkeypatch)
+    # use a random unique socket filename, so tests can run in parallel.
+    socket_file = tempfile.mktemp(suffix=".sock", prefix="borg-", dir=get_runtime_dir())
+    with subprocess.Popen(["borg", "serve", f"--socket={socket_file}"]) as p:
+        while not os.path.exists(socket_file):
+            time.sleep(0.01)  # wait until socket server has started
+        yield socket_file
+        p.terminate()
+
+
+@pytest.mark.skipif(is_win32, reason="hangs on win32")
+def test_with_socket(serve_socket, tmpdir, monkeypatch):
+    have_a_short_runtime_dir(monkeypatch)
+    repo_path = str(tmpdir.join("repo"))
+    ret, output = exec_cmd(f"--socket={serve_socket}", f"--repo=socket://{repo_path}", "rcreate", "--encryption=none")
+    assert ret == 0
+    ret, output = exec_cmd(f"--socket={serve_socket}", f"--repo=socket://{repo_path}", "rinfo")
+    assert ret == 0
+    assert "Repository ID: " in output
+    monkeypatch.setenv("BORG_DELETE_I_KNOW_WHAT_I_AM_DOING", "YES")
+    ret, output = exec_cmd(f"--socket={serve_socket}", f"--repo=socket://{repo_path}", "rdelete")
+    assert ret == 0
+
+
+@pytest.mark.skipif(is_win32, reason="hangs on win32")
+def test_socket_permissions(serve_socket):
+    st = os.stat(serve_socket)
+    assert st.st_mode & 0o0777 == 0o0770  # user and group are permitted to use the socket

+ 34 - 1
src/borg/testsuite/helpers.py

@@ -27,7 +27,7 @@ from ..helpers import (
 )
 from ..helpers import make_path_safe, clean_lines
 from ..helpers import interval
-from ..helpers import get_base_dir, get_cache_dir, get_keys_dir, get_security_dir, get_config_dir
+from ..helpers import get_base_dir, get_cache_dir, get_keys_dir, get_security_dir, get_config_dir, get_runtime_dir
 from ..helpers import is_slow_msgpack
 from ..helpers import msgpack
 from ..helpers import yes, TRUISH, FALSISH, DEFAULTISH
@@ -184,6 +184,14 @@ class TestLocationWithoutEnv:
             == "Location(proto='ssh', user='user', host='2a02:0001:0002:0003:0004:0005:0006:0007', port=1234, path='/some/path')"
         )
 
+    def test_socket(self, monkeypatch, keys_dir):
+        monkeypatch.delenv("BORG_REPO", raising=False)
+        assert (
+            repr(Location("socket:///repo/path"))
+            == "Location(proto='socket', user=None, host=None, port=None, path='/repo/path')"
+        )
+        assert Location("socket:///some/path").to_key_filename() == keys_dir + "some_path"
+
     def test_file(self, monkeypatch, keys_dir):
         monkeypatch.delenv("BORG_REPO", raising=False)
         assert (
@@ -275,6 +283,7 @@ class TestLocationWithoutEnv:
             "file://some/path",
             "host:some/path",
             "host:~user/some/path",
+            "socket:///some/path",
             "ssh://host/some/path",
             "ssh://user@host:1234/some/path",
         ]
@@ -752,6 +761,30 @@ def test_get_security_dir(monkeypatch):
         assert get_security_dir() == "/var/tmp"
 
 
+def test_get_runtime_dir(monkeypatch):
+    """test that get_runtime_dir respects environment"""
+    monkeypatch.delenv("BORG_BASE_DIR", raising=False)
+    home_dir = os.path.expanduser("~")
+    if is_win32:
+        monkeypatch.delenv("BORG_RUNTIME_DIR", raising=False)
+        assert get_runtime_dir() == os.path.join(home_dir, "AppData", "Local", "Temp", "borg", "borg")
+        monkeypatch.setenv("BORG_RUNTIME_DIR", home_dir)
+        assert get_runtime_dir() == home_dir
+    elif is_darwin:
+        monkeypatch.delenv("BORG_RUNTIME_DIR", raising=False)
+        assert get_runtime_dir() == os.path.join(home_dir, "Library", "Caches", "TemporaryItems", "borg")
+        monkeypatch.setenv("BORG_RUNTIME_DIR", "/var/tmp")
+        assert get_runtime_dir() == "/var/tmp"
+    else:
+        monkeypatch.delenv("XDG_RUNTIME_DIR", raising=False)
+        monkeypatch.delenv("BORG_RUNTIME_DIR", raising=False)
+        assert get_runtime_dir() == os.path.join("/run/user", str(os.getuid()), "borg")
+        monkeypatch.setenv("XDG_RUNTIME_DIR", "/var/tmp/.cache")
+        assert get_runtime_dir() == os.path.join("/var/tmp/.cache", "borg")
+        monkeypatch.setenv("BORG_RUNTIME_DIR", "/var/tmp")
+        assert get_runtime_dir() == "/var/tmp"
+
+
 def test_file_size():
     """test the size formatting routines"""
     si_size_map = {