Browse Source

pathlib refactor cache

Thomas Waldmann 1 week ago
parent
commit
baecf9cccb
2 changed files with 29 additions and 26 deletions
  1. 26 26
      src/borg/cache.py
  2. 3 0
      src/borg/helpers/parseformat.py

+ 26 - 26
src/borg/cache.py

@@ -5,6 +5,7 @@ import shutil
 import stat
 from collections import namedtuple
 from datetime import datetime, timezone, timedelta
+from pathlib import Path
 from time import perf_counter
 
 from borgstore.backends.errors import PermissionDenied
@@ -63,7 +64,7 @@ def discover_files_cache_names(path, files_cache_name="files"):
     :param files_cache_name: base name of the files cache files
     :return: list of files cache file names
     """
-    return [fn for fn in os.listdir(path) if fn.startswith(files_cache_name + ".")]
+    return [p.name for p in Path(path).iterdir() if p.name.startswith(files_cache_name + ".")]
 
 
 # chunks is a list of ChunkListEntry
@@ -92,34 +93,33 @@ class SecurityManager:
 
     def __init__(self, repository):
         self.repository = repository
-        self.dir = get_security_dir(repository.id_str, legacy=(repository.version == 1))
-        self.cache_dir = cache_dir(repository)
-        self.key_type_file = os.path.join(self.dir, "key-type")
-        self.location_file = os.path.join(self.dir, "location")
-        self.manifest_ts_file = os.path.join(self.dir, "manifest-timestamp")
+        self.dir = Path(get_security_dir(repository.id_str, legacy=(repository.version == 1)))
+        self.key_type_file = self.dir / "key-type"
+        self.location_file = self.dir / "location"
+        self.manifest_ts_file = self.dir / "manifest-timestamp"
 
     @staticmethod
     def destroy(repository, path=None):
         """destroy the security dir for ``repository`` or at ``path``"""
         path = path or get_security_dir(repository.id_str, legacy=(repository.version == 1))
-        if os.path.exists(path):
+        if Path(path).exists():
             shutil.rmtree(path)
 
     def known(self):
-        return all(os.path.exists(f) for f in (self.key_type_file, self.location_file, self.manifest_ts_file))
+        return all(f.exists() for f in (self.key_type_file, self.location_file, self.manifest_ts_file))
 
     def key_matches(self, key):
         if not self.known():
             return False
         try:
-            with open(self.key_type_file) as fd:
+            with self.key_type_file.open() as fd:
                 type = fd.read()
                 return type == str(key.TYPE)
         except OSError as exc:
             logger.warning("Could not read/parse key type file: %s", exc)
 
     def save(self, manifest, key):
-        logger.debug("security: saving state for %s to %s", self.repository.id_str, self.dir)
+        logger.debug("security: saving state for %s to %s", self.repository.id_str, str(self.dir))
         current_location = self.repository._location.canonical_path()
         logger.debug("security: current location   %s", current_location)
         logger.debug("security: key type           %s", str(key.TYPE))
@@ -134,7 +134,7 @@ class SecurityManager:
     def assert_location_matches(self):
         # Warn user before sending data to a relocated repository
         try:
-            with open(self.location_file) as fd:
+            with self.location_file.open() as fd:
                 previous_location = fd.read()
             logger.debug("security: read previous location %r", previous_location)
         except FileNotFoundError:
@@ -167,7 +167,7 @@ class SecurityManager:
 
     def assert_no_manifest_replay(self, manifest, key):
         try:
-            with open(self.manifest_ts_file) as fd:
+            with self.manifest_ts_file.open() as fd:
                 timestamp = fd.read()
             logger.debug("security: read manifest timestamp %r", timestamp)
         except FileNotFoundError:
@@ -235,7 +235,7 @@ def assert_secure(repository, manifest):
 
 
 def cache_dir(repository, path=None):
-    return path or os.path.join(get_cache_dir(), repository.id_str)
+    return Path(path) if path else Path(get_cache_dir()) / repository.id_str
 
 
 class CacheConfig:
@@ -243,7 +243,7 @@ class CacheConfig:
         self.repository = repository
         self.path = cache_dir(repository, path)
         logger.debug("Using %s as cache", self.path)
-        self.config_path = os.path.join(self.path, "config")
+        self.config_path = self.path / "config"
 
     def __enter__(self):
         self.open()
@@ -253,7 +253,7 @@ class CacheConfig:
         self.close()
 
     def exists(self):
-        return os.path.exists(self.config_path)
+        return self.config_path.exists()
 
     def create(self):
         assert not self.exists()
@@ -272,7 +272,7 @@ class CacheConfig:
 
     def load(self):
         self._config = configparser.ConfigParser(interpolation=None)
-        with open(self.config_path) as fd:
+        with self.config_path.open() as fd:
             self._config.read_file(fd)
         self._check_upgrade(self.config_path)
         self.id = self._config.get("cache", "repository")
@@ -361,10 +361,10 @@ class Cache:
     @staticmethod
     def destroy(repository, path=None):
         """destroy the cache for ``repository`` or at ``path``"""
-        path = path or os.path.join(get_cache_dir(), repository.id_str)
-        config = os.path.join(path, "config")
-        if os.path.exists(config):
-            os.remove(config)  # kill config first
+        path = cache_dir(repository, path)
+        config = path / "config"
+        if config.exists():
+            config.unlink()  # kill config first
             shutil.rmtree(path)
 
     def __new__(
@@ -540,7 +540,7 @@ class FilesCacheMixin:
         msg = None
         try:
             with IntegrityCheckedFile(
-                path=os.path.join(self.path, self.files_cache_name()),
+                path=str(self.path / self.files_cache_name()),
                 write=False,
                 integrity_data=self.cache_config.integrity.get(self.files_cache_name()),
             ) as fd:
@@ -583,7 +583,7 @@ class FilesCacheMixin:
         ttl = int(os.environ.get("BORG_FILES_CACHE_TTL", 2))
         files_cache_logger.debug("FILES-CACHE-SAVE: starting...")
         # TODO: use something like SaveFile here, but that didn't work due to SyncFile missing .seek().
-        with IntegrityCheckedFile(path=os.path.join(self.path, self.files_cache_name()), write=True) as fd:
+        with IntegrityCheckedFile(path=str(self.path / self.files_cache_name()), write=True) as fd:
             entries = 0
             age_discarded = 0
             race_discarded = 0
@@ -983,7 +983,7 @@ class AdHocWithFilesCache(FilesCacheMixin, ChunksMixin):
         self.cache_config = CacheConfig(self.repository, self.path)
 
         # Warn user before sending data to a never seen before unencrypted repository
-        if not os.path.exists(self.path):
+        if not self.path.exists():
             self.security_manager.assert_access_unknown(warn_if_unencrypted, manifest, self.key)
             self.create()
 
@@ -1009,13 +1009,13 @@ class AdHocWithFilesCache(FilesCacheMixin, ChunksMixin):
 
     def create(self):
         """Create a new empty cache at `self.path`"""
-        os.makedirs(self.path)
-        with open(os.path.join(self.path, "README"), "w") as fd:
+        self.path.mkdir(parents=True, exist_ok=True)
+        with open(self.path / "README", "w") as fd:
             fd.write(CACHE_README)
         self.cache_config.create()
 
     def open(self):
-        if not os.path.isdir(self.path):
+        if not self.path.is_dir():
             raise Exception("%s Does not look like a Borg cache" % self.path)
         self.cache_config.open()
         self.cache_config.load()

+ 3 - 0
src/borg/helpers/parseformat.py

@@ -10,6 +10,7 @@ import re
 import shlex
 import stat
 import uuid
+from pathlib import Path
 from typing import ClassVar, Any, TYPE_CHECKING, Literal
 from collections import OrderedDict
 from datetime import datetime, timezone
@@ -1163,6 +1164,8 @@ class BorgJsonEncoder(json.JSONEncoder):
             return o.info()
         if isinstance(o, (AdHocWithFilesCache,)):
             return {"path": o.path}
+        if isinstance(o, Path):
+            return str(o)
         if callable(getattr(o, "to_json", None)):
             return o.to_json()
         return super().default(o)