Explorar o código

Merge pull request #1631 from ThomasWaldmann/improve-signal-handling

Improve signal handling
TW %!s(int64=8) %!d(string=hai) anos
pai
achega
6642dadfc6
Modificáronse 2 ficheiros con 128 adicións e 100 borrados
  1. 80 100
      borg/archiver.py
  2. 48 0
      borg/helpers.py

+ 80 - 100
borg/archiver.py

@@ -23,6 +23,7 @@ from .helpers import Error, location_validator, archivename_validator, format_li
     Manifest, NoManifestError, remove_surrogates, update_excludes, format_archive, check_extension_modules, Statistics, \
     dir_is_tagged, bigint_to_int, ChunkerParams, CompressionSpec, PrefixSpec, is_slow_msgpack, yes, sysinfo, \
     EXIT_SUCCESS, EXIT_WARNING, EXIT_ERROR, log_multi, PatternMatcher, ErrorIgnoringTextIOWrapper
+from .helpers import signal_handler, raising_signal_handler, SigHup, SigTerm
 from .logger import create_logger, setup_logging
 logger = create_logger()
 from .compress import Compressor
@@ -1697,59 +1698,28 @@ class Archiver:
         return args.func(args)
 
 
-def sig_info_handler(signum, stack):  # pragma: no cover
+def sig_info_handler(sig_no, stack):  # pragma: no cover
     """search the stack for infos about the currently processed file and print them"""
-    for frame in inspect.getouterframes(stack):
-        func, loc = frame[3], frame[0].f_locals
-        if func in ('process_file', '_process', ):  # create op
-            path = loc['path']
-            try:
-                pos = loc['fd'].tell()
-                total = loc['st'].st_size
-            except Exception:
-                pos, total = 0, 0
-            logger.info("{0} {1}/{2}".format(path, format_file_size(pos), format_file_size(total)))
-            break
-        if func in ('extract_item', ):  # extract op
-            path = loc['item'][b'path']
-            try:
-                pos = loc['fd'].tell()
-            except Exception:
-                pos = 0
-            logger.info("{0} {1}/???".format(path, format_file_size(pos)))
-            break
-
-
-class SIGTERMReceived(BaseException):
-    pass
-
-
-def sig_term_handler(signum, stack):
-    raise SIGTERMReceived
-
-
-class SIGHUPReceived(BaseException):
-    pass
-
-
-def sig_hup_handler(signum, stack):
-    raise SIGHUPReceived
-
-
-def setup_signal_handlers():  # pragma: no cover
-    sigs = []
-    if hasattr(signal, 'SIGUSR1'):
-        sigs.append(signal.SIGUSR1)  # kill -USR1 pid
-    if hasattr(signal, 'SIGINFO'):
-        sigs.append(signal.SIGINFO)  # kill -INFO pid (or ctrl-t)
-    for sig in sigs:
-        signal.signal(sig, sig_info_handler)
-    # If we received SIGTERM or SIGHUP, catch them and raise a proper exception
-    # that can be handled for an orderly exit. SIGHUP is important especially
-    # for systemd systems, where logind sends it when a session exits, in
-    # addition to any traditional use.
-    signal.signal(signal.SIGTERM, sig_term_handler)
-    signal.signal(signal.SIGHUP, sig_hup_handler)
+    with signal_handler(sig_no, signal.SIG_IGN):
+        for frame in inspect.getouterframes(stack):
+            func, loc = frame[3], frame[0].f_locals
+            if func in ('process_file', '_process', ):  # create op
+                path = loc['path']
+                try:
+                    pos = loc['fd'].tell()
+                    total = loc['st'].st_size
+                except Exception:
+                    pos, total = 0, 0
+                logger.info("{0} {1}/{2}".format(path, format_file_size(pos), format_file_size(total)))
+                break
+            if func in ('extract_item', ):  # extract op
+                path = loc['item'][b'path']
+                try:
+                    pos = loc['fd'].tell()
+                except Exception:
+                    pos = 0
+                logger.info("{0} {1}/???".format(path, format_file_size(pos)))
+                break
 
 
 def main():  # pragma: no cover
@@ -1757,54 +1727,64 @@ def main():  # pragma: no cover
     # issues when print()-ing unicode file names
     sys.stdout = ErrorIgnoringTextIOWrapper(sys.stdout.buffer, sys.stdout.encoding, 'replace', line_buffering=True)
     sys.stderr = ErrorIgnoringTextIOWrapper(sys.stderr.buffer, sys.stderr.encoding, 'replace', line_buffering=True)
-    setup_signal_handlers()
-    archiver = Archiver()
-    msg = None
-    try:
-        args = archiver.get_args(sys.argv, os.environ.get('SSH_ORIGINAL_COMMAND'))
-    except Error as e:
-        msg = e.get_message()
-        if e.traceback:
-            msg += "\n%s\n%s" % (traceback.format_exc(), sysinfo())
-        # we might not have logging setup yet, so get out quickly
-        print(msg, file=sys.stderr)
-        sys.exit(e.exit_code)
-    try:
-        exit_code = archiver.run(args)
-    except Error as e:
-        msg = e.get_message()
-        if e.traceback:
-            msg += "\n%s\n%s" % (traceback.format_exc(), sysinfo())
-        exit_code = e.exit_code
-    except RemoteRepository.RPCError as e:
-        msg = '%s\n%s' % (str(e), sysinfo())
-        exit_code = EXIT_ERROR
-    except Exception:
-        msg = 'Local Exception.\n%s\n%s' % (traceback.format_exc(), sysinfo())
-        exit_code = EXIT_ERROR
-    except KeyboardInterrupt:
-        msg = 'Keyboard interrupt.\n%s\n%s' % (traceback.format_exc(), sysinfo())
-        exit_code = EXIT_ERROR
-    except SIGTERMReceived:
-        msg = 'Received SIGTERM.'
-        exit_code = EXIT_ERROR
-    except SIGHUPReceived:
-        msg = 'Received SIGHUP.'
-        exit_code = EXIT_ERROR
-    if msg:
-        logger.error(msg)
-    if args.show_rc:
-        exit_msg = 'terminating with %s status, rc %d'
-        if exit_code == EXIT_SUCCESS:
-            logger.info(exit_msg % ('success', exit_code))
-        elif exit_code == EXIT_WARNING:
-            logger.warning(exit_msg % ('warning', exit_code))
-        elif exit_code == EXIT_ERROR:
-            logger.error(exit_msg % ('error', exit_code))
-        else:
-            # if you see 666 in output, it usually means exit_code was None
-            logger.error(exit_msg % ('abnormal', exit_code or 666))
-    sys.exit(exit_code)
+    # If we receive SIGINT (ctrl-c), SIGTERM (kill) or SIGHUP (kill -HUP),
+    # catch them and raise a proper exception that can be handled for an
+    # orderly exit.
+    # SIGHUP is important especially for systemd systems, where logind
+    # sends it when a session exits, in addition to any traditional use.
+    # Output some info if we receive SIGUSR1 or SIGINFO (ctrl-t).
+    with signal_handler('SIGINT', raising_signal_handler(KeyboardInterrupt)), \
+         signal_handler('SIGHUP', raising_signal_handler(SigHup)), \
+         signal_handler('SIGTERM', raising_signal_handler(SigTerm)), \
+         signal_handler('SIGUSR1', sig_info_handler), \
+         signal_handler('SIGINFO', sig_info_handler):
+        archiver = Archiver()
+        msg = None
+        try:
+            args = archiver.get_args(sys.argv, os.environ.get('SSH_ORIGINAL_COMMAND'))
+        except Error as e:
+            msg = e.get_message()
+            if e.traceback:
+                msg += "\n%s\n%s" % (traceback.format_exc(), sysinfo())
+            # we might not have logging setup yet, so get out quickly
+            print(msg, file=sys.stderr)
+            sys.exit(e.exit_code)
+        try:
+            exit_code = archiver.run(args)
+        except Error as e:
+            msg = e.get_message()
+            if e.traceback:
+                msg += "\n%s\n%s" % (traceback.format_exc(), sysinfo())
+            exit_code = e.exit_code
+        except RemoteRepository.RPCError as e:
+            msg = '%s\n%s' % (str(e), sysinfo())
+            exit_code = EXIT_ERROR
+        except Exception:
+            msg = 'Local Exception.\n%s\n%s' % (traceback.format_exc(), sysinfo())
+            exit_code = EXIT_ERROR
+        except KeyboardInterrupt:
+            msg = 'Keyboard interrupt.\n%s\n%s' % (traceback.format_exc(), sysinfo())
+            exit_code = EXIT_ERROR
+        except SigTerm:
+            msg = 'Received SIGTERM.'
+            exit_code = EXIT_ERROR
+        except SigHup:
+            msg = 'Received SIGHUP.'
+            exit_code = EXIT_ERROR
+        if msg:
+            logger.error(msg)
+        if args.show_rc:
+            exit_msg = 'terminating with %s status, rc %d'
+            if exit_code == EXIT_SUCCESS:
+                logger.info(exit_msg % ('success', exit_code))
+            elif exit_code == EXIT_WARNING:
+                logger.warning(exit_msg % ('warning', exit_code))
+            elif exit_code == EXIT_ERROR:
+                logger.error(exit_msg % ('error', exit_code))
+            else:
+                # if you see 666 in output, it usually means exit_code was None
+                logger.error(exit_msg % ('abnormal', exit_code or 666))
+        sys.exit(exit_code)
 
 
 if __name__ == '__main__':

+ 48 - 0
borg/helpers.py

@@ -1,5 +1,6 @@
 import argparse
 from collections import namedtuple
+import contextlib
 from functools import wraps
 import grp
 import os
@@ -10,6 +11,7 @@ import re
 from shutil import get_terminal_size
 import sys
 import platform
+import signal
 import threading
 import time
 import unicodedata
@@ -1160,3 +1162,49 @@ class ErrorIgnoringTextIOWrapper(io.TextIOWrapper):
                 except OSError:
                     pass
         return len(s)
+
+
+class SignalException(BaseException):
+    """base class for all signal-based exceptions"""
+
+
+class SigHup(SignalException):
+    """raised on SIGHUP signal"""
+
+
+class SigTerm(SignalException):
+    """raised on SIGTERM signal"""
+
+
+@contextlib.contextmanager
+def signal_handler(sig, handler):
+    """
+    when entering context, set up signal handler <handler> for signal <sig>.
+    when leaving context, restore original signal handler.
+
+    <sig> can bei either a str when giving a signal.SIGXXX attribute name (it
+    won't crash if the attribute name does not exist as some names are platform
+    specific) or a int, when giving a signal number.
+
+    <handler> is any handler value as accepted by the signal.signal(sig, handler).
+    """
+    if isinstance(sig, str):
+        sig = getattr(signal, sig, None)
+    if sig is not None:
+        orig_handler = signal.signal(sig, handler)
+    try:
+        yield
+    finally:
+        if sig is not None:
+            signal.signal(sig, orig_handler)
+
+
+def raising_signal_handler(exc_cls):
+    def handler(sig_no, frame):
+        # setting SIG_IGN avoids that an incoming second signal of this
+        # kind would raise a 2nd exception while we still process the
+        # exception handler for exc_cls for the 1st signal.
+        signal.signal(sig_no, signal.SIG_IGN)
+        raise exc_cls
+
+    return handler