Forráskód Böngészése

implement unix domain (ipc) socket support

server (listening) side:
borg serve --socket  # default location
borg serve --socket=/path/to/socket

client side:
borg -r socket:///path/to/repo create ...
borg --socket=/path/to/socket -r socket:///path/to/repo ...

served connections:
- for ssh: proto: one connection
- for socket: proto: many connections (one after the other)

The socket has user and group permissions (770).

skip socket tests on win32, they hang infinitely, until
github CI terminates them after 60 minutes.

socket tests: use unique socket name

don't use the standard / default socket name, otherwise tests
running in parallel would interfere with each other by using
the same socket / the same borg serve process.

write a .pid file, clean up .pid and .sock file at exit

add stderr print for accepted/finished socket connection
Thomas Waldmann 2 éve
szülő
commit
ffc59dd071

+ 11 - 1
src/borg/archiver/_common.py

@@ -29,7 +29,7 @@ logger = create_logger(__name__)
 
 
 
 
 def get_repository(location, *, create, exclusive, lock_wait, lock, append_only, make_parent_dirs, storage_quota, args):
 def get_repository(location, *, create, exclusive, lock_wait, lock, append_only, make_parent_dirs, storage_quota, args):
-    if location.proto == "ssh":
+    if location.proto in ("ssh", "socket"):
         repository = RemoteRepository(
         repository = RemoteRepository(
             location,
             location,
             create=create,
             create=create,
@@ -573,6 +573,16 @@ def define_common_options(add_common_option):
         action=Highlander,
         action=Highlander,
         help="Use this command to connect to the 'borg serve' process (default: 'ssh')",
         help="Use this command to connect to the 'borg serve' process (default: 'ssh')",
     )
     )
+    add_common_option(
+        "--socket",
+        metavar="PATH",
+        dest="use_socket",
+        default=False,
+        const=True,
+        nargs="?",
+        action=Highlander,
+        help="Use UNIX DOMAIN (IPC) socket at PATH for client/server communication with socket: protocol.",
+    )
     add_common_option(
     add_common_option(
         "-r",
         "-r",
         "--repo",
         "--repo",

+ 12 - 1
src/borg/archiver/serve_cmd.py

@@ -19,6 +19,7 @@ class ServeMixIn:
             restrict_to_repositories=args.restrict_to_repositories,
             restrict_to_repositories=args.restrict_to_repositories,
             append_only=args.append_only,
             append_only=args.append_only,
             storage_quota=args.storage_quota,
             storage_quota=args.storage_quota,
+            use_socket=args.use_socket,
         ).serve()
         ).serve()
         return EXIT_SUCCESS
         return EXIT_SUCCESS
 
 
@@ -27,7 +28,17 @@ class ServeMixIn:
 
 
         serve_epilog = process_epilog(
         serve_epilog = process_epilog(
             """
             """
-        This command starts a repository server process. This command is usually not used manually.
+        This command starts a repository server process.
+
+        borg serve can currently support:
+
+        - Getting automatically started via ssh when the borg client uses a ssh://...
+          remote repository. In this mode, `borg serve` will live until that ssh connection
+          gets terminated.
+
+        - Getting started by some other means (not by the borg client) as a long-running socket
+          server to be used for borg clients using a socket://... repository (see the `--socket`
+          option if you do not want to use the default path for the socket and pid file).
         """
         """
         )
         )
         subparser = subparsers.add_parser(
         subparser = subparsers.add_parser(

+ 1 - 1
src/borg/helpers/__init__.py

@@ -11,7 +11,7 @@ from ..constants import *  # NOQA
 from .checks import check_extension_modules, check_python
 from .checks import check_extension_modules, check_python
 from .datastruct import StableDict, Buffer, EfficientCollectionQueue
 from .datastruct import StableDict, Buffer, EfficientCollectionQueue
 from .errors import Error, ErrorWithTraceback, IntegrityError, DecompressionError
 from .errors import Error, ErrorWithTraceback, IntegrityError, DecompressionError
-from .fs import ensure_dir, join_base_dir
+from .fs import ensure_dir, join_base_dir, get_socket_filename
 from .fs import get_security_dir, get_keys_dir, get_base_dir, get_cache_dir, get_config_dir, get_runtime_dir
 from .fs import get_security_dir, get_keys_dir, get_base_dir, get_cache_dir, get_config_dir, get_runtime_dir
 from .fs import dir_is_tagged, dir_is_cachedir, make_path_safe, scandir_inorder
 from .fs import dir_is_tagged, dir_is_cachedir, make_path_safe, scandir_inorder
 from .fs import secure_erase, safe_unlink, dash_open, os_open, os_stat, umount
 from .fs import secure_erase, safe_unlink, dash_open, os_open, os_stat, umount

+ 4 - 0
src/borg/helpers/fs.py

@@ -118,6 +118,10 @@ def get_runtime_dir(*, legacy=False):
     return runtime_dir
     return runtime_dir
 
 
 
 
+def get_socket_filename():
+    return os.path.join(get_runtime_dir(), "borg.sock")
+
+
 def get_cache_dir(*, legacy=False):
 def get_cache_dir(*, legacy=False):
     """Determine where to repository keys and cache"""
     """Determine where to repository keys and cache"""
 
 

+ 16 - 3
src/borg/helpers/parseformat.py

@@ -390,7 +390,7 @@ class Location:
     # path must not contain :: (it ends at :: or string end), but may contain single colons.
     # path must not contain :: (it ends at :: or string end), but may contain single colons.
     # to avoid ambiguities with other regexes, it must also not start with ":" nor with "//" nor with "ssh://".
     # to avoid ambiguities with other regexes, it must also not start with ":" nor with "//" nor with "ssh://".
     local_path_re = r"""
     local_path_re = r"""
-        (?!(:|//|ssh://))                                   # not starting with ":" or // or ssh://
+        (?!(:|//|ssh://|socket://))                         # not starting with ":" or // or ssh:// or socket://
         (?P<path>([^:]|(:(?!:)))+)                          # any chars, but no "::"
         (?P<path>([^:]|(:(?!:)))+)                          # any chars, but no "::"
         """
         """
 
 
@@ -429,6 +429,14 @@ class Location:
         re.VERBOSE,
         re.VERBOSE,
     )  # path
     )  # path
 
 
+    socket_re = re.compile(
+        r"""
+        (?P<proto>socket)://                                    # socket://
+        """
+        + abs_path_re,
+        re.VERBOSE,
+    )  # path
+
     file_re = re.compile(
     file_re = re.compile(
         r"""
         r"""
         (?P<proto>file)://                                      # file://
         (?P<proto>file)://                                      # file://
@@ -493,6 +501,11 @@ class Location:
             self.path = normpath_special(m.group("path"))
             self.path = normpath_special(m.group("path"))
             return True
             return True
         m = self.file_re.match(text)
         m = self.file_re.match(text)
+        if m:
+            self.proto = m.group("proto")
+            self.path = normpath_special(m.group("path"))
+            return True
+        m = self.socket_re.match(text)
         if m:
         if m:
             self.proto = m.group("proto")
             self.proto = m.group("proto")
             self.path = normpath_special(m.group("path"))
             self.path = normpath_special(m.group("path"))
@@ -516,7 +529,7 @@ class Location:
 
 
     def to_key_filename(self):
     def to_key_filename(self):
         name = re.sub(r"[^\w]", "_", self.path).strip("_")
         name = re.sub(r"[^\w]", "_", self.path).strip("_")
-        if self.proto != "file":
+        if self.proto not in ("file", "socket"):
             name = re.sub(r"[^\w]", "_", self.host) + "__" + name
             name = re.sub(r"[^\w]", "_", self.host) + "__" + name
         if len(name) > 100:
         if len(name) > 100:
             # Limit file names to some reasonable length. Most file systems
             # Limit file names to some reasonable length. Most file systems
@@ -535,7 +548,7 @@ class Location:
             return self._host.lstrip("[").rstrip("]")
             return self._host.lstrip("[").rstrip("]")
 
 
     def canonical_path(self):
     def canonical_path(self):
-        if self.proto == "file":
+        if self.proto in ("file", "socket"):
             return self.path
             return self.path
         else:
         else:
             if self.path and self.path.startswith("~"):
             if self.path and self.path.startswith("~"):

+ 192 - 107
src/borg/remote.py

@@ -1,3 +1,4 @@
+import atexit
 import errno
 import errno
 import functools
 import functools
 import inspect
 import inspect
@@ -7,6 +8,7 @@ import queue
 import select
 import select
 import shlex
 import shlex
 import shutil
 import shutil
+import socket
 import struct
 import struct
 import sys
 import sys
 import tempfile
 import tempfile
@@ -27,6 +29,7 @@ from .helpers import sysinfo
 from .helpers import format_file_size
 from .helpers import format_file_size
 from .helpers import safe_unlink
 from .helpers import safe_unlink
 from .helpers import prepare_subprocess_env, ignore_sigint
 from .helpers import prepare_subprocess_env, ignore_sigint
+from .helpers import get_socket_filename
 from .logger import create_logger, borg_serve_log_queue
 from .logger import create_logger, borg_serve_log_queue
 from .helpers import msgpack
 from .helpers import msgpack
 from .repository import Repository
 from .repository import Repository
@@ -136,7 +139,7 @@ class RepositoryServer:  # pragma: no cover
         "inject_exception",
         "inject_exception",
     )
     )
 
 
-    def __init__(self, restrict_to_paths, restrict_to_repositories, append_only, storage_quota):
+    def __init__(self, restrict_to_paths, restrict_to_repositories, append_only, storage_quota, use_socket):
         self.repository = None
         self.repository = None
         self.restrict_to_paths = restrict_to_paths
         self.restrict_to_paths = restrict_to_paths
         self.restrict_to_repositories = restrict_to_repositories
         self.restrict_to_repositories = restrict_to_repositories
@@ -147,6 +150,12 @@ class RepositoryServer:  # pragma: no cover
         self.append_only = append_only
         self.append_only = append_only
         self.storage_quota = storage_quota
         self.storage_quota = storage_quota
         self.client_version = None  # we update this after client sends version information
         self.client_version = None  # we update this after client sends version information
+        if use_socket is False:
+            self.socket_path = None
+        elif use_socket is True:  # --socket
+            self.socket_path = get_socket_filename()
+        else:  # --socket=/some/path
+            self.socket_path = use_socket
 
 
     def filter_args(self, f, kwargs):
     def filter_args(self, f, kwargs):
         """Remove unknown named parameters from call, because client did (implicitly) say it's ok."""
         """Remove unknown named parameters from call, because client did (implicitly) say it's ok."""
@@ -165,95 +174,133 @@ class RepositoryServer:  # pragma: no cover
                 os_write(self.stdout_fd, msg)
                 os_write(self.stdout_fd, msg)
 
 
     def serve(self):
     def serve(self):
-        self.stdin_fd = sys.stdin.fileno()
-        self.stdout_fd = sys.stdout.fileno()
-        os.set_blocking(self.stdin_fd, False)
-        os.set_blocking(self.stdout_fd, True)
-        unpacker = get_limited_unpacker("server")
-        shutdown_serve = False
-        while True:
-            # before processing any new RPCs, send out all pending log output
-            self.send_queued_log()
-
-            if shutdown_serve:
-                # shutdown wanted! get out of here after sending all log output.
-                if self.repository is not None:
-                    self.repository.close()
-                return
-
-            # process new RPCs
-            r, w, es = select.select([self.stdin_fd], [], [], 10)
-            if r:
-                data = os.read(self.stdin_fd, BUFSIZE)
-                if not data:
-                    shutdown_serve = True
-                    continue
-                unpacker.feed(data)
-                for unpacked in unpacker:
-                    if isinstance(unpacked, dict):
-                        msgid = unpacked[MSGID]
-                        method = unpacked[MSG]
-                        args = unpacked[ARGS]
-                    else:
-                        if self.repository is not None:
-                            self.repository.close()
-                        raise UnexpectedRPCDataFormatFromClient(__version__)
-                    try:
-                        if method not in self.rpc_methods:
-                            raise InvalidRPCMethod(method)
+        def inner_serve():
+            os.set_blocking(self.stdin_fd, False)
+            assert not os.get_blocking(self.stdin_fd)
+            os.set_blocking(self.stdout_fd, True)
+            assert os.get_blocking(self.stdout_fd)
+
+            unpacker = get_limited_unpacker("server")
+            shutdown_serve = False
+            while True:
+                # before processing any new RPCs, send out all pending log output
+                self.send_queued_log()
+
+                if shutdown_serve:
+                    # shutdown wanted! get out of here after sending all log output.
+                    assert self.repository is None
+                    return
+
+                # process new RPCs
+                r, w, es = select.select([self.stdin_fd], [], [], 10)
+                if r:
+                    data = os.read(self.stdin_fd, BUFSIZE)
+                    if not data:
+                        shutdown_serve = True
+                        continue
+                    unpacker.feed(data)
+                    for unpacked in unpacker:
+                        if isinstance(unpacked, dict):
+                            msgid = unpacked[MSGID]
+                            method = unpacked[MSG]
+                            args = unpacked[ARGS]
+                        else:
+                            if self.repository is not None:
+                                self.repository.close()
+                            raise UnexpectedRPCDataFormatFromClient(__version__)
                         try:
                         try:
-                            f = getattr(self, method)
-                        except AttributeError:
-                            f = getattr(self.repository, method)
-                        args = self.filter_args(f, args)
-                        res = f(**args)
-                    except BaseException as e:
-                        ex_short = traceback.format_exception_only(e.__class__, e)
-                        ex_full = traceback.format_exception(*sys.exc_info())
-                        ex_trace = True
-                        if isinstance(e, Error):
-                            ex_short = [e.get_message()]
-                            ex_trace = e.traceback
-                        if isinstance(e, (Repository.DoesNotExist, Repository.AlreadyExists, PathNotAllowed)):
-                            # These exceptions are reconstructed on the client end in RemoteRepository.call_many(),
-                            # and will be handled just like locally raised exceptions. Suppress the remote traceback
-                            # for these, except ErrorWithTraceback, which should always display a traceback.
-                            pass
+                            if method not in self.rpc_methods:
+                                raise InvalidRPCMethod(method)
+                            try:
+                                f = getattr(self, method)
+                            except AttributeError:
+                                f = getattr(self.repository, method)
+                            args = self.filter_args(f, args)
+                            res = f(**args)
+                        except BaseException as e:
+                            ex_short = traceback.format_exception_only(e.__class__, e)
+                            ex_full = traceback.format_exception(*sys.exc_info())
+                            ex_trace = True
+                            if isinstance(e, Error):
+                                ex_short = [e.get_message()]
+                                ex_trace = e.traceback
+                            if isinstance(e, (Repository.DoesNotExist, Repository.AlreadyExists, PathNotAllowed)):
+                                # These exceptions are reconstructed on the client end in RemoteRepository.call_many(),
+                                # and will be handled just like locally raised exceptions. Suppress the remote traceback
+                                # for these, except ErrorWithTraceback, which should always display a traceback.
+                                pass
+                            else:
+                                logging.debug("\n".join(ex_full))
+
+                            sys_info = sysinfo()
+                            try:
+                                msg = msgpack.packb(
+                                    {
+                                        MSGID: msgid,
+                                        "exception_class": e.__class__.__name__,
+                                        "exception_args": e.args,
+                                        "exception_full": ex_full,
+                                        "exception_short": ex_short,
+                                        "exception_trace": ex_trace,
+                                        "sysinfo": sys_info,
+                                    }
+                                )
+                            except TypeError:
+                                msg = msgpack.packb(
+                                    {
+                                        MSGID: msgid,
+                                        "exception_class": e.__class__.__name__,
+                                        "exception_args": [
+                                            x if isinstance(x, (str, bytes, int)) else None for x in e.args
+                                        ],
+                                        "exception_full": ex_full,
+                                        "exception_short": ex_short,
+                                        "exception_trace": ex_trace,
+                                        "sysinfo": sys_info,
+                                    }
+                                )
+                            os_write(self.stdout_fd, msg)
                         else:
                         else:
-                            logging.debug("\n".join(ex_full))
+                            os_write(self.stdout_fd, msgpack.packb({MSGID: msgid, RESULT: res}))
+                if es:
+                    shutdown_serve = True
+                    continue
 
 
-                        sys_info = sysinfo()
-                        try:
-                            msg = msgpack.packb(
-                                {
-                                    MSGID: msgid,
-                                    "exception_class": e.__class__.__name__,
-                                    "exception_args": e.args,
-                                    "exception_full": ex_full,
-                                    "exception_short": ex_short,
-                                    "exception_trace": ex_trace,
-                                    "sysinfo": sys_info,
-                                }
-                            )
-                        except TypeError:
-                            msg = msgpack.packb(
-                                {
-                                    MSGID: msgid,
-                                    "exception_class": e.__class__.__name__,
-                                    "exception_args": [x if isinstance(x, (str, bytes, int)) else None for x in e.args],
-                                    "exception_full": ex_full,
-                                    "exception_short": ex_short,
-                                    "exception_trace": ex_trace,
-                                    "sysinfo": sys_info,
-                                }
-                            )
-
-                        os_write(self.stdout_fd, msg)
-                    else:
-                        os_write(self.stdout_fd, msgpack.packb({MSGID: msgid, RESULT: res}))
-            if es:
-                shutdown_serve = True
-                continue
+        if self.socket_path:  # server for socket:// connections
+            try:
+                # remove any left-over socket file
+                os.unlink(self.socket_path)
+            except OSError:
+                if os.path.exists(self.socket_path):
+                    raise
+            sock_dir = os.path.dirname(self.socket_path)
+            os.makedirs(sock_dir, exist_ok=True)
+            if self.socket_path.endswith(".sock"):
+                pid_file = self.socket_path.replace(".sock", ".pid")
+            else:
+                pid_file = self.socket_path + ".pid"
+            pid = os.getpid()
+            with open(pid_file, "w") as f:
+                f.write(str(pid))
+            atexit.register(functools.partial(os.remove, pid_file))
+            atexit.register(functools.partial(os.remove, self.socket_path))
+            sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
+            sock.bind(self.socket_path)  # this creates the socket file in the fs
+            sock.listen(0)  # no backlog
+            os.chmod(self.socket_path, mode=0o0770)  # group members may use the socket, too.
+            print(f"borg serve: PID {pid}, listening on socket {self.socket_path} ...", file=sys.stderr)
+
+            while True:
+                connection, client_address = sock.accept()
+                print(f"Accepted a connection on socket {self.socket_path} ...", file=sys.stderr)
+                self.stdin_fd = connection.makefile("rb").fileno()
+                self.stdout_fd = connection.makefile("wb").fileno()
+                inner_serve()
+                print(f"Finished with connection on socket {self.socket_path} .", file=sys.stderr)
+        else:  # server for one ssh:// connection
+            self.stdin_fd = sys.stdin.fileno()
+            self.stdout_fd = sys.stdout.fileno()
+            inner_serve()
 
 
     def negotiate(self, client_data):
     def negotiate(self, client_data):
         if isinstance(client_data, dict):
         if isinstance(client_data, dict):
@@ -318,6 +365,7 @@ class RepositoryServer:  # pragma: no cover
     def close(self):
     def close(self):
         if self.repository is not None:
         if self.repository is not None:
             self.repository.__exit__(None, None, None)
             self.repository.__exit__(None, None, None)
+            self.repository = None
         borg.logger.teardown_logging()
         borg.logger.teardown_logging()
         self.send_queued_log()
         self.send_queued_log()
 
 
@@ -489,6 +537,7 @@ class RemoteRepository:
         self.rx_bytes = 0
         self.rx_bytes = 0
         self.tx_bytes = 0
         self.tx_bytes = 0
         self.to_send = EfficientCollectionQueue(1024 * 1024, bytes)
         self.to_send = EfficientCollectionQueue(1024 * 1024, bytes)
+        self.stdin_fd = self.stdout_fd = self.stderr_fd = None
         self.stderr_received = b""  # incomplete stderr line bytes received (no \n yet)
         self.stderr_received = b""  # incomplete stderr line bytes received (no \n yet)
         self.chunkid_to_msgids = {}
         self.chunkid_to_msgids = {}
         self.ignore_responses = set()
         self.ignore_responses = set()
@@ -499,27 +548,54 @@ class RemoteRepository:
         self.upload_buffer_size_limit = args.upload_buffer * 1024 * 1024 if args and args.upload_buffer else 0
         self.upload_buffer_size_limit = args.upload_buffer * 1024 * 1024 if args and args.upload_buffer else 0
         self.unpacker = get_limited_unpacker("client")
         self.unpacker = get_limited_unpacker("client")
         self.server_version = None  # we update this after server sends its version
         self.server_version = None  # we update this after server sends its version
-        self.p = None
+        self.p = self.sock = None
         self._args = args
         self._args = args
-        testing = location.host == "__testsuite__"
-        # when testing, we invoke and talk to a borg process directly (no ssh).
-        # when not testing, we invoke the system-installed ssh binary to talk to a remote borg.
-        env = prepare_subprocess_env(system=not testing)
-        borg_cmd = self.borg_cmd(args, testing)
-        if not testing:
-            borg_cmd = self.ssh_cmd(location) + borg_cmd
-        logger.debug("SSH command line: %s", borg_cmd)
-        # we do not want the ssh getting killed by Ctrl-C/SIGINT because it is needed for clean shutdown of borg.
-        # borg's SIGINT handler tries to write a checkpoint and requires the remote repo connection.
-        self.p = Popen(borg_cmd, bufsize=0, stdin=PIPE, stdout=PIPE, stderr=PIPE, env=env, preexec_fn=ignore_sigint)
-        self.stdin_fd = self.p.stdin.fileno()
-        self.stdout_fd = self.p.stdout.fileno()
-        self.stderr_fd = self.p.stderr.fileno()
+        if self.location.proto == "ssh":
+            testing = location.host == "__testsuite__"
+            # when testing, we invoke and talk to a borg process directly (no ssh).
+            # when not testing, we invoke the system-installed ssh binary to talk to a remote borg.
+            env = prepare_subprocess_env(system=not testing)
+            borg_cmd = self.borg_cmd(args, testing)
+            if not testing:
+                borg_cmd = self.ssh_cmd(location) + borg_cmd
+            logger.debug("SSH command line: %s", borg_cmd)
+            # we do not want the ssh getting killed by Ctrl-C/SIGINT because it is needed for clean shutdown of borg.
+            # borg's SIGINT handler tries to write a checkpoint and requires the remote repo connection.
+            self.p = Popen(borg_cmd, bufsize=0, stdin=PIPE, stdout=PIPE, stderr=PIPE, env=env, preexec_fn=ignore_sigint)
+            self.stdin_fd = self.p.stdin.fileno()
+            self.stdout_fd = self.p.stdout.fileno()
+            self.stderr_fd = self.p.stderr.fileno()
+            self.r_fds = [self.stdout_fd, self.stderr_fd]
+            self.x_fds = [self.stdin_fd, self.stdout_fd, self.stderr_fd]
+        elif self.location.proto == "socket":
+            if args.use_socket is False or args.use_socket is True:  # nothing or --socket
+                socket_path = get_socket_filename()
+            else:  # --socket=/some/path
+                socket_path = args.use_socket
+            self.sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
+            try:
+                self.sock.connect(socket_path)  # note: socket_path length is rather limited.
+            except FileNotFoundError:
+                self.sock = None
+                raise Error(f"The socket file {socket_path} does not exist.")
+            except ConnectionRefusedError:
+                self.sock = None
+                raise Error(f"There is no borg serve running for the socket file {socket_path}.")
+            self.stdin_fd = self.sock.makefile("wb").fileno()
+            self.stdout_fd = self.sock.makefile("rb").fileno()
+            self.stderr_fd = None
+            self.r_fds = [self.stdout_fd]
+            self.x_fds = [self.stdin_fd, self.stdout_fd]
+        else:
+            raise Error(f"Unsupported protocol {location.proto}")
+
         os.set_blocking(self.stdin_fd, False)
         os.set_blocking(self.stdin_fd, False)
+        assert not os.get_blocking(self.stdin_fd)
         os.set_blocking(self.stdout_fd, False)
         os.set_blocking(self.stdout_fd, False)
-        os.set_blocking(self.stderr_fd, False)
-        self.r_fds = [self.stdout_fd, self.stderr_fd]
-        self.x_fds = [self.stdin_fd, self.stdout_fd, self.stderr_fd]
+        assert not os.get_blocking(self.stdout_fd)
+        if self.stderr_fd is not None:
+            os.set_blocking(self.stderr_fd, False)
+            assert not os.get_blocking(self.stderr_fd)
 
 
         try:
         try:
             try:
             try:
@@ -551,7 +627,7 @@ class RemoteRepository:
     def __del__(self):
     def __del__(self):
         if len(self.responses):
         if len(self.responses):
             logging.debug("still %d cached responses left in RemoteRepository" % (len(self.responses),))
             logging.debug("still %d cached responses left in RemoteRepository" % (len(self.responses),))
-        if self.p:
+        if self.p or self.sock:
             self.close()
             self.close()
             assert False, "cleanup happened in Repository.__del__"
             assert False, "cleanup happened in Repository.__del__"
 
 
@@ -906,12 +982,21 @@ class RemoteRepository:
         """actual remoting is done via self.call in the @api decorator"""
         """actual remoting is done via self.call in the @api decorator"""
 
 
     def close(self):
     def close(self):
-        self.call("close", {}, wait=True)
+        if self.p or self.sock:
+            self.call("close", {}, wait=True)
         if self.p:
         if self.p:
             self.p.stdin.close()
             self.p.stdin.close()
             self.p.stdout.close()
             self.p.stdout.close()
             self.p.wait()
             self.p.wait()
             self.p = None
             self.p = None
+        if self.sock:
+            try:
+                self.sock.shutdown(socket.SHUT_RDWR)
+            except OSError as e:
+                if e.errno != errno.ENOTCONN:
+                    raise
+            self.sock.close()
+            self.sock = None
 
 
     def async_response(self, wait=True):
     def async_response(self, wait=True):
         for resp in self.call_many("async_responses", calls=[], wait=True, async_wait=wait):
         for resp in self.call_many("async_responses", calls=[], wait=True, async_wait=wait):

+ 50 - 0
src/borg/testsuite/archiver/serve_cmd.py

@@ -0,0 +1,50 @@
+import os
+import subprocess
+import tempfile
+import time
+
+import pytest
+import platformdirs
+
+from . import exec_cmd
+from ...platformflags import is_win32
+from ...helpers import get_runtime_dir
+
+
+def have_a_short_runtime_dir(mp):
+    # under pytest, we use BORG_BASE_DIR to keep stuff away from the user's normal borg dirs.
+    # this leads to a very long get_runtime_dir() path - too long for a socket file!
+    # thus, we override that again via BORG_RUNTIME_DIR to get a shorter path.
+    mp.setenv("BORG_RUNTIME_DIR", os.path.join(platformdirs.user_runtime_dir(), "pytest"))
+
+
+@pytest.fixture
+def serve_socket(monkeypatch):
+    have_a_short_runtime_dir(monkeypatch)
+    # use a random unique socket filename, so tests can run in parallel.
+    socket_file = tempfile.mktemp(suffix=".sock", prefix="borg-", dir=get_runtime_dir())
+    with subprocess.Popen(["borg", "serve", f"--socket={socket_file}"]) as p:
+        while not os.path.exists(socket_file):
+            time.sleep(0.01)  # wait until socket server has started
+        yield socket_file
+        p.terminate()
+
+
+@pytest.mark.skipif(is_win32, reason="hangs on win32")
+def test_with_socket(serve_socket, tmpdir, monkeypatch):
+    have_a_short_runtime_dir(monkeypatch)
+    repo_path = str(tmpdir.join("repo"))
+    ret, output = exec_cmd(f"--socket={serve_socket}", f"--repo=socket://{repo_path}", "rcreate", "--encryption=none")
+    assert ret == 0
+    ret, output = exec_cmd(f"--socket={serve_socket}", f"--repo=socket://{repo_path}", "rinfo")
+    assert ret == 0
+    assert "Repository ID: " in output
+    monkeypatch.setenv("BORG_DELETE_I_KNOW_WHAT_I_AM_DOING", "YES")
+    ret, output = exec_cmd(f"--socket={serve_socket}", f"--repo=socket://{repo_path}", "rdelete")
+    assert ret == 0
+
+
+@pytest.mark.skipif(is_win32, reason="hangs on win32")
+def test_socket_permissions(serve_socket):
+    st = os.stat(serve_socket)
+    assert st.st_mode & 0o0777 == 0o0770  # user and group are permitted to use the socket

+ 9 - 0
src/borg/testsuite/helpers.py

@@ -184,6 +184,14 @@ class TestLocationWithoutEnv:
             == "Location(proto='ssh', user='user', host='2a02:0001:0002:0003:0004:0005:0006:0007', port=1234, path='/some/path')"
             == "Location(proto='ssh', user='user', host='2a02:0001:0002:0003:0004:0005:0006:0007', port=1234, path='/some/path')"
         )
         )
 
 
+    def test_socket(self, monkeypatch, keys_dir):
+        monkeypatch.delenv("BORG_REPO", raising=False)
+        assert (
+            repr(Location("socket:///repo/path"))
+            == "Location(proto='socket', user=None, host=None, port=None, path='/repo/path')"
+        )
+        assert Location("socket:///some/path").to_key_filename() == keys_dir + "some_path"
+
     def test_file(self, monkeypatch, keys_dir):
     def test_file(self, monkeypatch, keys_dir):
         monkeypatch.delenv("BORG_REPO", raising=False)
         monkeypatch.delenv("BORG_REPO", raising=False)
         assert (
         assert (
@@ -275,6 +283,7 @@ class TestLocationWithoutEnv:
             "file://some/path",
             "file://some/path",
             "host:some/path",
             "host:some/path",
             "host:~user/some/path",
             "host:~user/some/path",
+            "socket:///some/path",
             "ssh://host/some/path",
             "ssh://host/some/path",
             "ssh://user@host:1234/some/path",
             "ssh://user@host:1234/some/path",
         ]
         ]