Ver código fonte

Optimize python bootstrapper

FreddleSpl0it 3 meses atrás
pai
commit
eb7d2628ac

+ 16 - 23
data/Dockerfiles/bootstrap/main.py

@@ -10,42 +10,35 @@ def main():
   signal.signal(signal.SIGTERM, handle_sigterm)
 
   container_name = os.getenv("CONTAINER_NAME")
+  service_name = container_name.replace("-mailcow", "").replace("-", "")
+  module_name = f"Bootstrap{service_name.capitalize()}"
 
-  if container_name == "sogo-mailcow":
-    from modules.BootstrapSogo import Bootstrap
-  elif container_name == "nginx-mailcow":
-    from modules.BootstrapNginx import Bootstrap
-  elif container_name == "postfix-mailcow":
-    from modules.BootstrapPostfix import Bootstrap
-  elif container_name == "dovecot-mailcow":
-    from modules.BootstrapDovecot import Bootstrap
-  elif container_name == "rspamd-mailcow":
-    from modules.BootstrapRspamd import Bootstrap
-  elif container_name == "clamd-mailcow":
-    from modules.BootstrapClamd import Bootstrap
-  elif container_name == "mysql-mailcow":
-    from modules.BootstrapMysql import Bootstrap
-  elif container_name == "php-fpm-mailcow":
-    from modules.BootstrapPhpfpm import Bootstrap
-  else:
-    print(f"No bootstrap handler for container: {container_name}", file=sys.stderr)
+  try:
+    mod = __import__(f"modules.{module_name}", fromlist=[module_name])
+    Bootstrap = getattr(mod, module_name)
+  except (ImportError, AttributeError) as e:
+    print(f"Failed to load bootstrap module for: {container_name} → {module_name}")
+    print(str(e))
     sys.exit(1)
 
   b = Bootstrap(
     container=container_name,
+    service=service_name,
     db_config={
       "host": "localhost",
       "user": os.getenv("DBUSER") or os.getenv("MYSQL_USER"),
       "password": os.getenv("DBPASS") or os.getenv("MYSQL_PASSWORD"),
       "database": os.getenv("DBNAME") or os.getenv("MYSQL_DATABASE"),
       "unix_socket": "/var/run/mysqld/mysqld.sock",
-      'connection_timeout': 2
+      'connection_timeout': 2,
+      'service_table': "service_settings",
+      'service_types': [service_name]
     },
-    db_table="service_settings",
-    db_settings=['sogo'],
     redis_config={
-      "host": os.getenv("REDIS_SLAVEOF_IP") or "redis-mailcow",
-      "port": int(os.getenv("REDIS_SLAVEOF_PORT") or 6379),
+      "read_host": "redis-mailcow",
+      "read_port": 6379,
+      "write_host": os.getenv("REDIS_SLAVEOF_IP") or "redis-mailcow",
+      "write_port": int(os.getenv("REDIS_SLAVEOF_PORT") or 6379),
       "password": os.getenv("REDISPASS"),
       "db": 0
     }

+ 124 - 52
data/Dockerfiles/bootstrap/modules/BootstrapBase.py

@@ -7,28 +7,28 @@ import string
 import subprocess
 import time
 import socket
-import signal
 import re
 import redis
 import hashlib
 import json
+import psutil
+import signal
 from pathlib import Path
 import dns.resolver
 import mysql.connector
-from jinja2 import Environment, FileSystemLoader
 
 class BootstrapBase:
-  def __init__(self, container, db_config, db_table, db_settings, redis_config):
+  def __init__(self, container, service, db_config, redis_config):
     self.container = container
+    self.service = service
     self.db_config = db_config
-    self.db_table = db_table
-    self.db_settings = db_settings
     self.redis_config = redis_config
 
     self.env = None
     self.env_vars = None
     self.mysql_conn = None
-    self.redis_conn = None
+    self.redis_connr = None
+    self.redis_connw = None
 
   def render_config(self, config_dir):
     """
@@ -43,9 +43,6 @@ class BootstrapBase:
       - Also copies the rendered file to: <config_dir>/rendered_configs/<relative_output_path>
     """
 
-    import json
-    from pathlib import Path
-
     config_dir = Path(config_dir)
     config_path = config_dir / "config.json"
 
@@ -67,7 +64,13 @@ class BootstrapBase:
         continue
 
       output_path.parent.mkdir(parents=True, exist_ok=True)
-      template = self.env.get_template(template_name)
+
+      try:
+        template = self.env.get_template(template_name)
+      except Exception as e:
+        print(f"Template not found: {template_name} ({e})")
+        continue
+
       rendered = template.render(self.env_vars)
 
       if clean_blank_lines:
@@ -112,12 +115,12 @@ class BootstrapBase:
     try:
       cursor = self.mysql_conn.cursor()
 
-      if self.db_settings:
-        placeholders = ','.join(['%s'] * len(self.db_settings))
-        sql = f"SELECT `key`, `value` FROM {self.db_table} WHERE `type` IN ({placeholders})"
-        cursor.execute(sql, self.db_settings)
+      if self.db_config['service_types']:
+        placeholders = ','.join(['%s'] * len(self.db_config['service_types']))
+        sql = f"SELECT `key`, `value` FROM {self.db_config['service_table']} WHERE `type` IN ({placeholders})"
+        cursor.execute(sql, self.db_config['service_types'])
       else:
-        cursor.execute(f"SELECT `key`, `value` FROM {self.db_table}")
+        cursor.execute(f"SELECT `key`, `value` FROM {self.db_config['service_table']}")
 
       for key, value in cursor.fetchall():
         env_vars[key] = value
@@ -247,6 +250,23 @@ class BootstrapBase:
         os.chown(sub_path, uid, gid)
     os.chown(p, uid, gid)
 
+  def fix_permissions(self, path, user=None, group=None, mode=None, recursive=False):
+    """
+    Sets owner and/or permissions on a file or directory.
+
+    Args:
+      path (str or Path): Target path.
+      user (str|int, optional): Username or UID.
+      group (str|int, optional): Group name or GID.
+      mode (int, optional): File mode (e.g. 0o644).
+      recursive (bool): Apply recursively if path is a directory.
+    """
+
+    if user or group:
+      self.set_owner(path, user, group, recursive)
+    if mode:
+      self.set_permissions(path, mode)
+
   def move_file(self, src, dst, overwrite=True):
     """
     Moves a file from src to dst, optionally overwriting existing files.
@@ -458,25 +478,28 @@ class BootstrapBase:
     except Exception as e:
       raise Exception(f"Failed to resolve {record_type} record for {hostname}: {e}")
 
-  def kill_proc(self, process):
+  def kill_proc(self, process_name):
     """
-    Sends a SIGTERM signal to all processes matching the given name using `killall`.
+    Sends SIGTERM to all running processes matching the given name.
 
     Args:
-        process (str): The name of the process to terminate.
+      process_name (str): Name of the process to terminate.
 
     Returns:
-        True if the signal was sent successfully, or the subprocess error if it failed.
+      int: Number of processes successfully signaled.
     """
 
-    try:
-      subprocess.run(["killall", "-TERM", process], check=True)
-    except subprocess.CalledProcessError as e:
-      return e
-
-    return True
+    killed = 0
+    for proc in psutil.process_iter(['name']):
+      try:
+        if proc.info['name'] == process_name:
+          proc.send_signal(signal.SIGTERM)
+          killed += 1
+      except (psutil.NoSuchProcess, psutil.AccessDenied):
+        continue
+    return killed
 
-  def connect_mysql(self):
+  def connect_mysql(self, socket=None):
     """
     Establishes a connection to the MySQL database using the provided configuration.
 
@@ -485,13 +508,24 @@ class BootstrapBase:
 
     Logs:
         Connection status and retry errors to stdout.
+
+    Args:
+      socket (str, optional): Custom UNIX socket path to override the default.
     """
 
     print("Connecting to MySQL...")
+    config = {
+      "host": self.db_config['host'],
+      "user": self.db_config['user'],
+      "password": self.db_config['password'],
+      "database": self.db_config['database'],
+      "unix_socket": socket or self.db_config['unix_socket'],
+      'connection_timeout': self.db_config['connection_timeout']
+    }
 
     while True:
       try:
-        self.mysql_conn = mysql.connector.connect(**self.db_config)
+        self.mysql_conn = mysql.connector.connect(**config)
         if self.mysql_conn.is_connected():
           print("MySQL is up and ready!")
           break
@@ -509,48 +543,86 @@ class BootstrapBase:
     if self.mysql_conn and self.mysql_conn.is_connected():
       self.mysql_conn.close()
 
-  def connect_redis(self, retries=10, delay=2):
+  def connect_redis(self, max_retries=10, delay=2):
     """
-    Establishes a Redis connection and stores it in `self.redis_conn`.
+    Connects to both read and write Redis servers and stores the connections.
 
-    Args:
-      retries (int): Number of ping retries before giving up.
-      delay (int): Seconds between retries.
+    Read server: tries indefinitely until successful.
+    Write server: tries up to `max_retries` before giving up.
+
+    Sets:
+      self.redis_connr: Redis client for read
+      self.redis_connw: Redis client for write
     """
 
-    client = redis.Redis(
-      host=self.redis_config['host'],
-      port=self.redis_config['port'],
-      password=self.redis_config['password'],
-      db=self.redis_config['db'],
-      decode_responses=True
-    )
+    use_rw = self.redis_config['read_host'] == self.redis_config['write_host'] and self.redis_config['read_port'] == self.redis_config['write_port']
+
+    if use_rw:
+      print("Connecting to Redis read server...")
+    else:
+      print("Connecting to Redis server...")
 
-    for _ in range(retries):
+    while True:
       try:
-        if client.ping():
-          self.redis_conn = client
-          return
+        clientr = redis.Redis(
+          host=self.redis_config['read_host'],
+          port=self.redis_config['read_port'],
+          password=self.redis_config['password'],
+          db=self.redis_config['db'],
+          decode_responses=True
+        )
+        if clientr.ping():
+          self.redis_connr = clientr
+          print("Redis read server is up and ready!")
+          if use_rw:
+            break
+          else:
+            self.redis_connw = clientr
+            return
       except redis.RedisError as e:
-        print(f"Waiting for Redis... ({e})")
+        print(f"Waiting for Redis read... ({e})")
         time.sleep(delay)
 
-    raise ConnectionError("Redis is not available after multiple attempts.")
+
+    print("Connecting to Redis write server...")
+    for attempt in range(max_retries):
+      try:
+        clientw = redis.Redis(
+          host=self.redis_config['write_host'],
+          port=self.redis_config['write_port'],
+          password=self.redis_config['password'],
+          db=self.redis_config['db'],
+          decode_responses=True
+        )
+        if clientw.ping():
+          self.redis_connw = clientw
+          print("Redis write server is up and ready!")
+          return
+      except redis.RedisError as e:
+        print(f"Waiting for Redis write... (attempt {attempt + 1}/{max_retries}) ({e})")
+        time.sleep(delay)
+    print("Redis write server is unreachable.")
 
   def close_redis(self):
     """
-    Closes the Redis connection if it's open.
-
-    Safe to call even if Redis was never connected or already closed.
+    Closes the Redis read/write connections if open.
     """
 
-    if self.redis_conn:
+    if self.redis_connr:
+      try:
+        self.redis_connr.close()
+      except Exception as e:
+        print(f"Error while closing Redis read connection: {e}")
+      finally:
+        self.redis_connr = None
+
+    if self.redis_connw:
       try:
-        self.redis_conn.close()
+        self.redis_connw.close()
       except Exception as e:
-        print(f"Error while closing Redis connection: {e}")
+        print(f"Error while closing Redis write connection: {e}")
       finally:
-        self.redis_conn = None
+        self.redis_connw = None
 
   def wait_for_schema_update(self, init_file_path="init_db.inc.php", check_interval=5):
     """

+ 1 - 2
data/Dockerfiles/bootstrap/modules/BootstrapClamd.py

@@ -4,9 +4,8 @@ from pathlib import Path
 import os
 import sys
 import time
-import platform
 
-class Bootstrap(BootstrapBase):
+class BootstrapClamd(BootstrapBase):
   def bootstrap(self):
     # Skip Clamd if set
     if self.isYes(os.getenv("SKIP_CLAMD", "")):

+ 3 - 4
data/Dockerfiles/bootstrap/modules/BootstrapDovecot.py

@@ -2,12 +2,10 @@ from jinja2 import Environment, FileSystemLoader
 from modules.BootstrapBase import BootstrapBase
 from pathlib import Path
 import os
-import sys
-import time
 import pwd
 import hashlib
 
-class Bootstrap(BootstrapBase):
+class BootstrapDovecot(BootstrapBase):
   def bootstrap(self):
     # Connect to MySQL
     self.connect_mysql()
@@ -15,7 +13,8 @@ class Bootstrap(BootstrapBase):
 
     # Connect to Redis
     self.connect_redis()
-    self.redis_conn.set("DOVECOT_REPL_HEALTH", 1)
+    if self.redis_connw:
+      self.redis_connw.set("DOVECOT_REPL_HEALTH", 1)
 
     # Wait for DNS
     self.wait_for_dns("mailcow.email")

+ 4 - 6
data/Dockerfiles/bootstrap/modules/BootstrapMysql.py

@@ -1,27 +1,25 @@
 from jinja2 import Environment, FileSystemLoader
 from modules.BootstrapBase import BootstrapBase
-from pathlib import Path
 import os
-import sys
 import time
-import platform
 import subprocess
 
-class Bootstrap(BootstrapBase):
+class BootstrapMysql(BootstrapBase):
   def bootstrap(self):
     dbuser = "root"
     dbpass = os.getenv("MYSQL_ROOT_PASSWORD", "")
-    socket = "/var/run/mysqld/mysqld.sock"
+    socket = "/tmp/mysql-temp.sock"
 
     print("Starting temporary mysqld for upgrade...")
     self.start_temporary(socket)
 
-    self.connect_mysql()
+    self.connect_mysql(socket)
 
     print("Running mysql_upgrade...")
     self.upgrade_mysql(dbuser, dbpass, socket)
     print("Checking timezone support with CONVERT_TZ...")
     self.check_and_import_timezone_support(dbuser, dbpass, socket)
+    time.sleep(15)
 
     print("Shutting down temporary mysqld...")
     self.close_mysql()

+ 1 - 4
data/Dockerfiles/bootstrap/modules/BootstrapNginx.py

@@ -1,11 +1,8 @@
 from jinja2 import Environment, FileSystemLoader
 from modules.BootstrapBase import BootstrapBase
-from pathlib import Path
 import os
-import sys
-import time
 
-class Bootstrap(BootstrapBase):
+class BootstrapNginx(BootstrapBase):
   def bootstrap(self):
     # Connect to MySQL
     self.connect_mysql()

+ 11 - 14
data/Dockerfiles/bootstrap/modules/BootstrapPhpfpm.py

@@ -1,14 +1,9 @@
 from jinja2 import Environment, FileSystemLoader
 from modules.BootstrapBase import BootstrapBase
-from pathlib import Path
 import os
 import ipaddress
-import sys
-import time
-import platform
-import subprocess
 
-class Bootstrap(BootstrapBase):
+class BootstrapPhpfpm(BootstrapBase):
   def bootstrap(self):
     self.connect_mysql()
     self.connect_redis()
@@ -63,16 +58,16 @@ class Bootstrap(BootstrapBase):
     print("Setting default Redis keys if missing...")
 
     # Q_RELEASE_FORMAT
-    if self.redis_conn.get("Q_RELEASE_FORMAT") is None:
-      self.redis_conn.set("Q_RELEASE_FORMAT", "raw")
+    if self.redis_connw and self.redis_connr.get("Q_RELEASE_FORMAT") is None:
+        self.redis_connw.set("Q_RELEASE_FORMAT", "raw")
 
     # Q_MAX_AGE
-    if self.redis_conn.get("Q_MAX_AGE") is None:
-      self.redis_conn.set("Q_MAX_AGE", 365)
+    if self.redis_connw and self.redis_connr.get("Q_MAX_AGE") is None:
+      self.redis_connw.set("Q_MAX_AGE", 365)
 
     # PASSWD_POLICY hash defaults
-    if self.redis_conn.hget("PASSWD_POLICY", "length") is None:
-      self.redis_conn.hset("PASSWD_POLICY", mapping={
+    if self.redis_connw and self.redis_connr.hget("PASSWD_POLICY", "length") is None:
+      self.redis_connw.hset("PASSWD_POLICY", mapping={
         "length": 6,
         "chars": 0,
         "special_chars": 0,
@@ -82,7 +77,8 @@ class Bootstrap(BootstrapBase):
 
     # DOMAIN_MAP
     print("Rebuilding DOMAIN_MAP from MySQL...")
-    self.redis_conn.delete("DOMAIN_MAP")
+    if self.redis_connw:
+      self.redis_connw.delete("DOMAIN_MAP")
     domains = set()
     try:
       cursor = self.mysql_conn.cursor()
@@ -96,7 +92,8 @@ class Bootstrap(BootstrapBase):
 
       if domains:
         for domain in domains:
-          self.redis_conn.hset("DOMAIN_MAP", domain, 1)
+          if self.redis_connw:
+            self.redis_conn.hset("DOMAIN_MAP", domain, 1)
         print(f"{len(domains)} domains added to DOMAIN_MAP.")
       else:
         print("No domains found to insert into DOMAIN_MAP.")

+ 1 - 4
data/Dockerfiles/bootstrap/modules/BootstrapPostfix.py

@@ -1,11 +1,8 @@
 from jinja2 import Environment, FileSystemLoader
 from modules.BootstrapBase import BootstrapBase
 from pathlib import Path
-import os
-import sys
-import time
 
-class Bootstrap(BootstrapBase):
+class BootstrapPostfix(BootstrapBase):
   def bootstrap(self):
     # Connect to MySQL
     self.connect_mysql()

+ 1 - 3
data/Dockerfiles/bootstrap/modules/BootstrapRspamd.py

@@ -1,12 +1,10 @@
 from jinja2 import Environment, FileSystemLoader
 from modules.BootstrapBase import BootstrapBase
 from pathlib import Path
-import os
-import sys
 import time
 import platform
 
-class Bootstrap(BootstrapBase):
+class BootstrapRspamd(BootstrapBase):
   def bootstrap(self):
     # Connect to MySQL
     self.connect_mysql()

+ 2 - 2
data/Dockerfiles/bootstrap/modules/BootstrapSogo.py

@@ -5,7 +5,7 @@ import os
 import sys
 import time
 
-class Bootstrap(BootstrapBase):
+class BootstrapSogo(BootstrapBase):
   def bootstrap(self):
     # Skip SOGo if set
     if self.isYes(os.getenv("SKIP_SOGO", "")):
@@ -135,4 +135,4 @@ class Bootstrap(BootstrapBase):
       return iam_settings
     except Exception as e:
       print(f"Error fetching identity provider settings: {e}")
-      return {}
+      return {}