Просмотр исходного кода

refactor confirmation code, reduce code duplication, add tests

Thomas Waldmann 9 лет назад
Родитель
Сommit
0a6e6cfe2e
4 измененных файлов с 184 добавлено и 36 удалено
  1. 16 16
      borg/archiver.py
  2. 11 18
      borg/cache.py
  3. 81 0
      borg/helpers.py
  4. 76 2
      borg/testsuite/helpers.py

+ 16 - 16
borg/archiver.py

@@ -19,7 +19,7 @@ from .helpers import Error, location_validator, format_time, format_file_size, \
     format_file_mode, ExcludePattern, IncludePattern, exclude_path, adjust_patterns, to_localtime, timestamp, \
     get_cache_dir, get_keys_dir, format_timedelta, prune_within, prune_split, \
     Manifest, remove_surrogates, update_excludes, format_archive, check_extension_modules, Statistics, \
-    is_cachedir, bigint_to_int, ChunkerParams, CompressionSpec, have_cython, is_slow_msgpack, \
+    is_cachedir, bigint_to_int, ChunkerParams, CompressionSpec, have_cython, is_slow_msgpack, yes, \
     EXIT_SUCCESS, EXIT_WARNING, EXIT_ERROR
 from .logger import create_logger, setup_logging
 logger = create_logger()
@@ -88,13 +88,12 @@ class Archiver:
         """Check repository consistency"""
         repository = self.open_repository(args.repository, exclusive=args.repair)
         if args.repair:
-            while not os.environ.get('BORG_CHECK_I_KNOW_WHAT_I_AM_DOING'):
-                self.print_warning("""'check --repair' is an experimental feature that might result
-in data loss.
-
-Type "Yes I am sure" if you understand this and want to continue.\n""")
-                if input('Do you want to continue? ') == 'Yes I am sure':
-                    break
+            msg = ("'check --repair' is an experimental feature that might result in data loss." +
+                   "\n" +
+                   "Type 'YES' if you understand this and want to continue: ")
+            if not yes(msg, false_msg="Aborting.",
+                       env_var_override='BORG_CHECK_I_KNOW_WHAT_I_AM_DOING', truish=('YES', )):
+                return EXIT_ERROR
         if not args.archives_only:
             logger.info('Starting repository check...')
             if repository.check(repair=args.repair):
@@ -330,15 +329,16 @@ Type "Yes I am sure" if you understand this and want to continue.\n""")
                 logger.info(str(cache))
         else:
             if not args.cache_only:
-                print("You requested to completely DELETE the repository *including* all archives it contains:", file=sys.stderr)
+                msg = []
+                msg.append("You requested to completely DELETE the repository *including* all archives it contains:")
                 for archive_info in manifest.list_archive_infos(sort_by='ts'):
-                    print(format_archive(archive_info), file=sys.stderr)
-                if not os.environ.get('BORG_CHECK_I_KNOW_WHAT_I_AM_DOING'):
-                    print("""Type "YES" if you understand this and want to continue.\n""", file=sys.stderr)
-                    # XXX: prompt may end up on stdout, but we'll assume that input() does the right thing
-                    if input('Do you want to continue? ') != 'YES':
-                        self.exit_code = EXIT_ERROR
-                        return self.exit_code
+                    msg.append(format_archive(archive_info))
+                msg.append("Type 'YES' if you understand this and want to continue: ")
+                msg = '\n'.join(msg)
+                if not yes(msg, false_msg="Aborting.",
+                           env_var_override='BORG_CHECK_I_KNOW_WHAT_I_AM_DOING', truish=('YES', )):
+                    self.exit_code = EXIT_ERROR
+                    return self.exit_code
                 repository.destroy()
                 logger.info("Repository deleted.")
             cache.destroy()

+ 11 - 18
borg/cache.py

@@ -14,7 +14,7 @@ from .key import PlaintextKey
 from .logger import create_logger
 logger = create_logger()
 from .helpers import Error, get_cache_dir, decode_dict, st_mtime_ns, unhexlify, int_to_bigint, \
-    bigint_to_int, format_file_size, have_cython
+    bigint_to_int, format_file_size, have_cython, yes
 from .locking import UpgradableLock
 from .hashindex import ChunkIndex
 
@@ -51,15 +51,21 @@ class Cache:
         # Warn user before sending data to a never seen before unencrypted repository
         if not os.path.exists(self.path):
             if warn_if_unencrypted and isinstance(key, PlaintextKey):
-                if not self._confirm('Warning: Attempting to access a previously unknown unencrypted repository',
-                                     'BORG_UNKNOWN_UNENCRYPTED_REPO_ACCESS_IS_OK'):
+                msg = ("Warning: Attempting to access a previously unknown unencrypted repository!" +
+                       "\n" +
+                       "Do you want to continue? [yN] ")
+                if not yes(msg, false_msg="Aborting.",
+                           env_var_override='BORG_UNKNOWN_UNENCRYPTED_REPO_ACCESS_IS_OK'):
                     raise self.CacheInitAbortedError()
             self.create()
         self.open()
         # Warn user before sending data to a relocated repository
         if self.previous_location and self.previous_location != repository._location.canonical_path():
-            msg = 'Warning: The repository at location {} was previously located at {}'.format(repository._location.canonical_path(), self.previous_location)
-            if not self._confirm(msg, 'BORG_RELOCATED_REPO_ACCESS_IS_OK'):
+            msg = ("Warning: The repository at location {} was previously located at {}".format(repository._location.canonical_path(), self.previous_location) +
+                   "\n" +
+                   "Do you want to continue? [yN] ")
+            if not yes(msg, false_msg="Aborting.",
+                       env_var_override='BORG_RELOCATED_REPO_ACCESS_IS_OK'):
                 raise self.RepositoryAccessAborted()
 
         if sync and self.manifest.id != self.manifest_id:
@@ -92,19 +98,6 @@ Chunk index:    {0.total_unique_chunks:20d} {0.total_chunks:20d}"""
             stats[field] = format_file_size(stats[field])
         return Summary(**stats)
 
-    def _confirm(self, message, env_var_override=None):
-        print(message, file=sys.stderr)
-        if env_var_override and os.environ.get(env_var_override):
-            print("Yes (From {})".format(env_var_override), file=sys.stderr)
-            return True
-        if not sys.stdin.isatty():
-            return False
-        try:
-            answer = input('Do you want to continue? [yN] ')
-        except EOFError:
-            return False
-        return answer and answer in 'Yy'
-
     def create(self):
         """Create a new empty cache at `self.path`
         """

+ 81 - 0
borg/helpers.py

@@ -804,3 +804,84 @@ def int_to_bigint(value):
 
 def is_slow_msgpack():
     return msgpack.Packer is msgpack.fallback.Packer
+
+
+def yes(msg=None, retry_msg=None, false_msg=None, true_msg=None,
+        default=False, default_notty=None, default_eof=None,
+        falsish=('No', 'no', 'N', 'n'), truish=('Yes', 'yes', 'Y', 'y'),
+        env_var_override=None, ifile=None, ofile=None, input=input):
+    """
+    Output <msg> (usually a question) and let user input an answer.
+    Qualifies the answer according to falsish and truish as True or False.
+    If it didn't qualify and retry_msg is None (no retries wanted),
+    return the default [which defaults to False]. Otherwise let user retry
+    answering until answer is qualified.
+
+    If env_var_override is given and it is non-empty, counts as truish answer
+    and won't ask user for an answer.
+    If we don't have a tty as input and default_notty is not None, return its value.
+    Otherwise read input from non-tty and proceed as normal.
+    If EOF is received instead an input, return default_eof [or default, if not given].
+
+    :param msg: introducing message to output on ofile, no \n is added [None]
+    :param retry_msg: retry message to output on ofile, no \n is added [None]
+           (also enforces retries instead of returning default)
+    :param false_msg: message to output before returning False [None]
+    :param true_msg: message to output before returning True [None]
+    :param default: default return value (empty answer is given) [False]
+    :param default_notty: if not None, return its value if no tty is connected [None]
+    :param default_eof: return value if EOF was read as answer [same as default]
+    :param falsish: sequence of answers qualifying as False
+    :param truish: sequence of answers qualifying as True
+    :param env_var_override: environment variable name [None]
+    :param ifile: input stream [sys.stdin] (only for testing!)
+    :param ofile: output stream [sys.stderr]
+    :param input: input function [input from builtins]
+    :return: boolean answer value, True or False
+    """
+    # note: we do not assign sys.stdin/stderr as defaults above, so they are
+    # really evaluated NOW,  not at function definition time.
+    if ifile is None:
+        ifile = sys.stdin
+    if ofile is None:
+        ofile = sys.stderr
+    if default not in (True, False):
+        raise ValueError("invalid default value, must be True or False")
+    if default_notty not in (None, True, False):
+        raise ValueError("invalid default_notty value, must be None, True or False")
+    if default_eof not in (None, True, False):
+        raise ValueError("invalid default_eof value, must be None, True or False")
+    if msg:
+        print(msg, file=ofile, end='')
+        ofile.flush()
+    if env_var_override:
+        value = os.environ.get(env_var_override)
+        # currently, any non-empty value counts as truish
+        # TODO: change this so one can give y/n there?
+        if value:
+            value = bool(value)
+            value_str = truish[0] if value else falsish[0]
+            print("{} (from {})".format(value_str, env_var_override), file=ofile)
+            return value
+    if default_notty is not None and not ifile.isatty():
+        # looks like ifile is not a terminal (but e.g. a pipe)
+        return default_notty
+    while True:
+        try:
+            answer = input()  # XXX how can we use ifile?
+        except EOFError:
+            return default_eof if default_eof is not None else default
+        if answer in truish:
+            if true_msg:
+                print(true_msg, file=ofile)
+            return True
+        if answer in falsish:
+            if false_msg:
+                print(false_msg, file=ofile)
+            return False
+        if retry_msg is None:
+            # no retries wanted, we just return the default
+            return default
+        if retry_msg:
+            print(retry_msg, file=ofile, end='')
+            ofile.flush()

+ 76 - 2
borg/testsuite/helpers.py

@@ -10,9 +10,9 @@ import msgpack
 import msgpack.fallback
 
 from ..helpers import adjust_patterns, exclude_path, Location, format_file_size, format_timedelta, IncludePattern, ExcludePattern, make_path_safe, \
-    prune_within, prune_split, get_cache_dir, Statistics, is_slow_msgpack, \
+    prune_within, prune_split, get_cache_dir, Statistics, is_slow_msgpack, yes, \
     StableDict, int_to_bigint, bigint_to_int, parse_timestamp, CompressionSpec, ChunkerParams
-from . import BaseTestCase
+from . import BaseTestCase, environment_variable, FakeInputs
 
 
 class BigIntTestCase(BaseTestCase):
@@ -492,3 +492,77 @@ def test_is_slow_msgpack():
         msgpack.Packer = saved_packer
     # this assumes that we have fast msgpack on test platform:
     assert not is_slow_msgpack()
+
+
+def test_yes_simple():
+    input = FakeInputs(['y', 'Y', 'yes', 'Yes', ])
+    assert yes(input=input)
+    assert yes(input=input)
+    assert yes(input=input)
+    assert yes(input=input)
+    input = FakeInputs(['n', 'N', 'no', 'No', ])
+    assert not yes(input=input)
+    assert not yes(input=input)
+    assert not yes(input=input)
+    assert not yes(input=input)
+
+
+def test_yes_custom():
+    input = FakeInputs(['YES', 'SURE', 'NOPE', ])
+    assert yes(truish=('YES', ), input=input)
+    assert yes(truish=('SURE', ), input=input)
+    assert not yes(falsish=('NOPE', ), input=input)
+
+
+def test_yes_env():
+    input = FakeInputs(['n', 'n'])
+    with environment_variable(OVERRIDE_THIS='nonempty'):
+        assert yes(env_var_override='OVERRIDE_THIS', input=input)
+    with environment_variable(OVERRIDE_THIS=None):  # env not set
+        assert not yes(env_var_override='OVERRIDE_THIS', input=input)
+
+
+def test_yes_defaults():
+    input = FakeInputs(['invalid', '', ' '])
+    assert not yes(input=input)  # default=False
+    assert not yes(input=input)
+    assert not yes(input=input)
+    input = FakeInputs(['invalid', '', ' '])
+    assert yes(default=True, input=input)
+    assert yes(default=True, input=input)
+    assert yes(default=True, input=input)
+    ifile = StringIO()
+    assert yes(default_notty=True, ifile=ifile)
+    assert not yes(default_notty=False, ifile=ifile)
+    input = FakeInputs([])
+    assert yes(default_eof=True, input=input)
+    assert not yes(default_eof=False, input=input)
+    with pytest.raises(ValueError):
+        yes(default=None)
+    with pytest.raises(ValueError):
+        yes(default_notty='invalid')
+    with pytest.raises(ValueError):
+        yes(default_eof='invalid')
+
+
+def test_yes_retry():
+    input = FakeInputs(['foo', 'bar', 'y', ])
+    assert yes(retry_msg='Retry: ', input=input)
+    input = FakeInputs(['foo', 'bar', 'N', ])
+    assert not yes(retry_msg='Retry: ', input=input)
+
+
+def test_yes_output(capfd):
+    input = FakeInputs(['invalid', 'y', 'n'])
+    assert yes(msg='intro-msg', false_msg='false-msg', true_msg='true-msg', retry_msg='retry-msg', input=input)
+    out, err = capfd.readouterr()
+    assert out == ''
+    assert 'intro-msg' in err
+    assert 'retry-msg' in err
+    assert 'true-msg' in err
+    assert not yes(msg='intro-msg', false_msg='false-msg', true_msg='true-msg', retry_msg='retry-msg', input=input)
+    out, err = capfd.readouterr()
+    assert out == ''
+    assert 'intro-msg' in err
+    assert 'retry-msg' not in err
+    assert 'false-msg' in err