Przeglądaj źródła

Add python bootstrapper for containers

FreddleSpl0it 5 miesięcy temu
rodzic
commit
d3185c3c68

+ 30 - 0
data/Dockerfiles/bootstrap/main.py

@@ -0,0 +1,30 @@
+import os
+import sys
+
+def main():
+  container_name = os.getenv("CONTAINER_NAME")
+
+  if container_name == "sogo-mailcow":
+    from modules.BootstrapSogo import Bootstrap
+  else:
+    print(f"No bootstrap handler for container: {container_name}", file=sys.stderr)
+    sys.exit(1)
+
+  b = Bootstrap(
+    container=container_name,
+    db_config = {
+      "host": "localhost",
+      "user": os.getenv("DBUSER"),
+      "password": os.getenv("DBPASS"),
+      "database": os.getenv("DBNAME"),
+      "unix_socket": "/var/run/mysqld/mysqld.sock",
+      'connection_timeout': 2
+    },
+    db_table="service_settings",
+    db_settings=['sogo']
+  )
+
+  b.bootstrap()
+
+if __name__ == "__main__":
+  main()

+ 456 - 0
data/Dockerfiles/bootstrap/modules/BootstrapBase.py

@@ -0,0 +1,456 @@
+import os
+import pwd
+import grp
+import shutil
+import secrets
+import string
+import subprocess
+import time
+import socket
+import signal
+import re
+import json
+from pathlib import Path
+import mysql.connector
+from jinja2 import Environment, FileSystemLoader
+
+class BootstrapBase:
+  def __init__(self, container, db_config, db_table, db_settings):
+    self.container = container
+    self.db_config = db_config
+    self.db_table = db_table
+    self.db_settings = db_settings
+
+    self.env = None
+    self.env_vars = None
+    self.mysql_conn = None
+
+  def render_config(self, template_name, output_path):
+    """
+    Renders a Jinja2 template and writes it to the specified output path.
+
+    The method uses the class's `self.env` Jinja2 environment and `self.env_vars`
+    for rendering template variables.
+
+    Args:
+        template_name (str): Name of the template file.
+        output_path (str or Path): Path to write the rendered output file.
+    """
+
+    output_path = Path(output_path)
+    output_path.parent.mkdir(parents=True, exist_ok=True)
+
+    template = self.env.get_template(template_name)
+    rendered = template.render(self.env_vars)
+
+    with open(output_path, "w") as f:
+      f.write(rendered)
+
+  def prepare_template_vars(self, overwrite_path, extra_vars = None):
+    """
+    Loads and merges environment variables for Jinja2 templates from multiple sources.
+
+    This method combines:
+      1. System environment variables
+      2. Key/value pairs from the MySQL `service_settings` table
+      3. An optional dictionary of extra_vars
+      4. A JSON file with overrides (if the file exists)
+
+    Args:
+        overwrite_path (str or Path): Path to a JSON file containing key-value overrides.
+        extra_vars (dict, optional): A dictionary of additional variables to include.
+
+    Returns:
+        dict: A dictionary containing all resolved template variables.
+
+    Raises:
+        Prints errors if database fetch or JSON parsing fails, but does not raise exceptions.
+    """
+
+    # 1. Load env vars
+    env_vars = dict(os.environ)
+
+    # 2. Load from MySQL
+    try:
+      cursor = self.mysql_conn.cursor()
+
+      if self.db_settings:
+        placeholders = ','.join(['%s'] * len(self.db_settings))
+        sql = f"SELECT `key`, `value` FROM {self.db_table} WHERE `type` IN ({placeholders})"
+        cursor.execute(sql, self.db_settings)
+      else:
+        cursor.execute(f"SELECT `key`, `value` FROM {self.db_table}")
+
+      for key, value in cursor.fetchall():
+        env_vars[key] = value
+
+      cursor.close()
+    except Exception as e:
+      print(f"Failed to fetch DB service settings: {e}")
+
+    # 3. Load extra vars
+    if extra_vars:
+      env_vars.update(extra_vars)
+
+    # 4. Load overwrites
+    overwrite_path = Path(overwrite_path)
+    if overwrite_path.exists():
+      try:
+        with overwrite_path.open("r") as f:
+          overwrite_data = json.load(f)
+          env_vars.update(overwrite_data)
+      except Exception as e:
+        print(f"Failed to parse overwrites: {e}")
+
+    return env_vars
+
+  def set_timezone(self):
+    """
+    Sets the system timezone based on the TZ environment variable.
+
+    If the TZ variable is set, writes its value to /etc/timezone.
+    """
+
+    timezone = os.getenv("TZ")
+    if timezone:
+      with open("/etc/timezone", "w") as f:
+        f.write(timezone + "\n")
+
+  def set_syslog_redis(self):
+    """
+    Reconfigures syslog-ng to use a Redis slave configuration.
+
+    If the REDIS_SLAVEOF_IP environment variable is set, replaces the syslog-ng config
+    with the Redis slave-specific config.
+    """
+
+    redis_slave_ip = os.getenv("REDIS_SLAVEOF_IP")
+    if redis_slave_ip:
+      shutil.copy("/etc/syslog-ng/syslog-ng-redis_slave.conf", "/etc/syslog-ng/syslog-ng.conf")
+
+  def rsync_file(self, src, dst, recursive=False, owner=None, mode=None):
+    """
+    Copies files or directories using rsync, with optional ownership and permissions.
+
+    Args:
+        src (str or Path): Source file or directory.
+        dst (str or Path): Destination directory.
+        recursive (bool): If True, copies contents recursively.
+        owner (tuple): Tuple of (user, group) to set ownership.
+        mode (int): File mode (e.g., 0o644) to set permissions after sync.
+    """
+
+    src_path = Path(src)
+    dst_path = Path(dst)
+    dst_path.mkdir(parents=True, exist_ok=True)
+
+    rsync_cmd = ["rsync", "-a"]
+    if recursive:
+      rsync_cmd.append(str(src_path) + "/")
+    else:
+      rsync_cmd.append(str(src_path))
+    rsync_cmd.append(str(dst_path))
+
+    try:
+      subprocess.run(rsync_cmd, check=True)
+    except Exception as e:
+      print(f"Rsync failed: {e}")
+
+    if owner:
+      self.set_owner(dst_path, *owner, recursive=True)
+    if mode:
+      self.set_permissions(dst_path, mode)
+
+  def set_permissions(self, path, mode):
+    """
+    Sets file or directory permissions.
+
+    Args:
+        path (str or Path): Path to the file or directory.
+        mode (int): File mode to apply, e.g., 0o644.
+
+    Raises:
+        FileNotFoundError: If the path does not exist.
+    """
+
+    file_path = Path(path)
+    if not file_path.exists():
+      raise FileNotFoundError(f"Cannot chmod: {file_path} does not exist")
+    os.chmod(file_path, mode)
+
+  def set_owner(self, path, user, group=None, recursive=False):
+    """
+    Changes ownership of a file or directory.
+
+    Args:
+        path (str or Path): Path to the file or directory.
+        user (str): Username for new owner.
+        group (str, optional): Group name; defaults to user's group if not provided.
+        recursive (bool): If True and path is a directory, ownership is applied recursively.
+
+    Raises:
+        FileNotFoundError: If the path does not exist.
+    """
+
+    uid = pwd.getpwnam(user).pw_uid
+    gid = grp.getgrnam(group or user).gr_gid
+
+    p = Path(path)
+    if not p.exists():
+      raise FileNotFoundError(f"{path} does not exist")
+
+    if recursive and p.is_dir():
+      for sub_path in p.rglob("*"):
+        os.chown(sub_path, uid, gid)
+    os.chown(p, uid, gid)
+
+  def move_file(self, src, dst, overwrite=True):
+    """
+    Moves a file from src to dst, optionally overwriting existing files.
+
+    Args:
+        src (str or Path): Source file path.
+        dst (str or Path): Destination path.
+        overwrite (bool): If False, raises error if dst exists.
+
+    Raises:
+        FileNotFoundError: If the source file does not exist.
+        FileExistsError: If the destination file exists and overwrite is False.
+    """
+
+    src_path = Path(src)
+    dst_path = Path(dst)
+
+    if not src_path.exists():
+      raise FileNotFoundError(f"Source file does not exist: {src}")
+
+    dst_path.parent.mkdir(parents=True, exist_ok=True)
+
+    if dst_path.exists() and not overwrite:
+      raise FileExistsError(f"Destination already exists: {dst} (set overwrite=True to overwrite)")
+
+    shutil.move(str(src_path), str(dst_path))
+
+  def patch_exists(self, target_file, patch_file, reverse=False):
+    """
+    Checks whether a patch can be applied (or reversed) to a target file.
+
+    Args:
+        target_file (str): File to test the patch against.
+        patch_file (str): Patch file to apply.
+        reverse (bool): If True, checks whether the patch can be reversed.
+
+    Returns:
+        bool: True if patch is applicable, False otherwise.
+    """
+
+    cmd = ["patch", "-sfN", "--dry-run", target_file, "<", patch_file]
+    if reverse:
+      cmd.insert(1, "-R")
+    try:
+      result = subprocess.run(
+        " ".join(cmd),
+        shell=True,
+        stdout=subprocess.DEVNULL,
+        stderr=subprocess.DEVNULL
+      )
+      return result.returncode == 0
+    except Exception as e:
+      print(f"Patch dry-run failed: {e}")
+      return False
+
+  def apply_patch(self, target_file, patch_file, reverse=False):
+    """
+    Applies a patch file to a target file.
+
+    Args:
+        target_file (str): File to be patched.
+        patch_file (str): Patch file containing the diff.
+        reverse (bool): If True, applies the patch in reverse (rollback).
+
+    Logs:
+        Success or failure of the patching operation.
+    """
+
+    cmd = ["patch", target_file, "<", patch_file]
+    if reverse:
+      cmd.insert(0, "-R")
+    try:
+      subprocess.run(" ".join(cmd), shell=True, check=True)
+      print(f"Applied patch {'(reverse)' if reverse else ''} to {target_file}")
+    except subprocess.CalledProcessError as e:
+      print(f"Patch failed: {e}")
+
+  def isYes(self, value):
+    """
+    Determines whether a given string represents a "yes"-like value.
+
+    Args:
+        value (str): Input string to evaluate.
+
+    Returns:
+        bool: True if value is "yes" or "y" (case-insensitive), otherwise False.
+    """
+    return value.lower() in ["yes", "y"]
+
+  def is_port_open(self, host, port):
+    """
+    Checks whether a TCP port is open on a given host.
+
+    Args:
+        host (str): The hostname or IP address to check.
+        port (int): The TCP port number to test.
+
+    Returns:
+        bool: True if the port is open and accepting connections, False otherwise.
+    """
+
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+      sock.settimeout(1)
+      result = sock.connect_ex((host, port))
+      return result == 0
+
+  def kill_proc(self, process):
+    """
+    Sends a SIGTERM signal to all processes matching the given name using `killall`.
+
+    Args:
+        process (str): The name of the process to terminate.
+
+    Returns:
+        True if the signal was sent successfully, or the subprocess error if it failed.
+    """
+
+    try:
+      subprocess.run(["killall", "-TERM", process], check=True)
+    except subprocess.CalledProcessError as e:
+      return e
+
+    return True
+
+  def connect_mysql(self):
+    """
+    Establishes a connection to the MySQL database using the provided configuration.
+
+    Continuously retries the connection until the database is reachable. Stores
+    the connection in `self.mysql_conn` once successful.
+
+    Logs:
+        Connection status and retry errors to stdout.
+    """
+
+    print("Connecting to MySQL...")
+
+    while True:
+      try:
+        self.mysql_conn = mysql.connector.connect(**self.db_config)
+        if self.mysql_conn.is_connected():
+          print("MySQL is up and ready!")
+          break
+      except Error as e:
+        print(f"Waiting for MySQL... ({e})")
+        time.sleep(2)
+
+  def close_mysql(self):
+    """
+    Closes the MySQL connection if it's currently open and connected.
+
+    Safe to call even if the connection has already been closed.
+    """
+
+    if self.mysql_conn and self.mysql_conn.is_connected():
+      self.mysql_conn.close()
+
+  def wait_for_schema_update(self, init_file_path="init_db.inc.php", check_interval=5):
+    """
+    Waits until the current database schema version matches the expected version
+    defined in a PHP initialization file.
+
+    Compares the `version` value in the `versions` table for `application = 'db_schema'`
+    with the `$db_version` value extracted from the specified PHP file.
+
+    Args:
+        init_file_path (str): Path to the PHP file containing the expected version string.
+        check_interval (int): Time in seconds to wait between version checks.
+
+    Logs:
+        Current vs. expected schema versions until they match.
+    """
+
+    print("Checking database schema version...")
+
+    while True:
+      current_version = self._get_current_db_version()
+      expected_version = self._get_expected_schema_version(init_file_path)
+
+      if current_version == expected_version:
+        print(f"DB schema is up to date: {current_version}")
+        break
+
+      print(f"Waiting for schema update... (DB: {current_version}, Expected: {expected_version})")
+      time.sleep(check_interval)
+
+  def _get_current_db_version(self):
+    """
+    Fetches the current schema version from the database.
+
+    Executes a SELECT query on the `versions` table where `application = 'db_schema'`.
+
+    Returns:
+        str or None: The current schema version as a string, or None if not found or on error.
+
+    Logs:
+        Error message if the query fails.
+    """
+
+    try:
+      cursor = self.mysql_conn.cursor()
+      cursor.execute("SELECT version FROM versions WHERE application = 'db_schema'")
+      result = cursor.fetchone()
+      cursor.close()
+      return result[0] if result else None
+    except Exception as e:
+      print(f"Error fetching current DB schema version: {e}")
+      return None
+
+  def _get_expected_schema_version(self, filepath):
+    """
+    Extracts the expected database schema version from a PHP initialization file.
+
+    Looks for a line in the form of: `$db_version = "..."` and extracts the version string.
+
+    Args:
+        filepath (str): Path to the PHP file containing the `$db_version` definition.
+
+    Returns:
+        str or None: The extracted version string, or None if not found or on error.
+
+    Logs:
+        Error message if the file cannot be read or parsed.
+    """
+
+    try:
+      with open(filepath, "r") as f:
+        content = f.read()
+        match = re.search(r'\$db_version\s*=\s*"([^"]+)"', content)
+        if match:
+          return match.group(1)
+    except Exception as e:
+      print(f"Error reading expected schema version from {filepath}: {e}")
+    return None
+
+  def rand_pass(self, length=22):
+    """
+    Generates a secure random password using allowed characters.
+
+    Allowed characters include upper/lowercase letters, digits, underscores, and hyphens.
+
+    Args:
+        length (int): Length of the password to generate. Default is 22.
+
+    Returns:
+        str: A securely generated random password string.
+    """
+
+    allowed_chars = string.ascii_letters + string.digits + "_-"
+    return ''.join(secrets.choice(allowed_chars) for _ in range(length))

+ 136 - 0
data/Dockerfiles/bootstrap/modules/BootstrapSogo.py

@@ -0,0 +1,136 @@
+from jinja2 import Environment, FileSystemLoader
+from modules.BootstrapBase import BootstrapBase
+from pathlib import Path
+import os
+import sys
+import time
+
+class Bootstrap(BootstrapBase):
+  def bootstrap(self):
+    # Skip SOGo if set
+    if self.isYes(os.getenv("SKIP_SOGO", "")):
+      print("SKIP_SOGO is set, skipping SOGo startup...")
+      time.sleep(365 * 24 * 60 * 60)
+      sys.exit(1)
+
+    # Connect to MySQL
+    self.connect_mysql()
+
+    # Wait until port is free
+    while self.is_port_open("sogo-mailcow", 20000):
+      print("Port 20000 still in use — terminating sogod...")
+      self.kill_proc("sogod")
+      time.sleep(3)
+
+    # Wait for schema to update to expected version
+    self.wait_for_schema_update(init_file_path="init_db.inc.php")
+
+    # Setup Jinja2 Environment and load vars
+    self.env = Environment(
+      loader=FileSystemLoader("./etc/sogo/config_templates"),
+      keep_trailing_newline=True,
+      lstrip_blocks=False,
+      trim_blocks=False
+    )
+    extra_vars = {
+      "SQL_DOMAINS": self.get_domains(),
+      "IAM_SETTINGS": self.get_identity_provider_settings()
+    }
+    self.env_vars = self.prepare_template_vars('/overwrites.json', extra_vars)
+
+    print("Set Timezone")
+    self.set_timezone()
+
+    print("Set Syslog redis")
+    self.set_syslog_redis()
+
+    print("Render config")
+    self.render_config("sogod.plist.j2", "/var/lib/sogo/GNUstep/Defaults/sogod.plist")
+    self.render_config("UIxTopnavToolbar.wox.j2", "/usr/lib/GNUstep/SOGo/Templates/UIxTopnavToolbar.wox")
+
+    print("Fix permissions")
+    self.set_owner("/var/lib/sogo", "sogo", "sogo", recursive=True)
+    self.set_permissions("/var/lib/sogo/GNUstep/Defaults/sogod.plist", 0o600)
+
+    # Rename custom logo
+    logo_src = Path("/etc/sogo/sogo-full.svg")
+    if logo_src.exists():
+      print("Set Logo")
+      self.move_file(logo_src, "/etc/sogo/custom-fulllogo.svg")
+
+    # Rsync web content
+    print("Syncing web content")
+    self.rsync_file("/usr/lib/GNUstep/SOGo/", "/sogo_web/", recursive=True)
+
+    # Chown backup path
+    self.set_owner("/sogo_backup", "sogo", "sogo", recursive=True)
+
+  def get_domains(self):
+    """
+    Retrieves a list of domains and their GAL (Global Address List) status.
+
+    Executes a SQL query to select:
+      - `domain`
+      - a human-readable GAL status ("YES" or "NO")
+      - `ldap_gal` as a boolean (True/False)
+
+    Returns:
+      list[dict]: A list of dicts with keys: domain, gal_status, ldap_gal.
+                  Example: [{"domain": "example.com", "gal_status": "YES", "ldap_gal": True}]
+
+    Logs:
+      Error messages if the query fails.
+    """
+
+    query = """
+      SELECT domain,
+             CASE gal WHEN '1' THEN 'YES' ELSE 'NO' END AS gal_status,
+             ldap_gal = 1 AS ldap_gal
+      FROM domain;
+    """
+    try:
+      cursor = self.mysql_conn.cursor()
+      cursor.execute(query)
+      result = cursor.fetchall()
+      cursor.close()
+
+      return [
+        {
+          "domain": row[0],
+          "gal_status": row[1],
+          "ldap_gal": bool(row[2])
+        }
+        for row in result
+      ]
+    except Exception as e:
+      print(f"Error fetching domains: {e}")
+      return []
+
+  def get_identity_provider_settings(self):
+    """
+    Retrieves all key-value identity provider settings.
+
+    Returns:
+      dict: Settings in the format { key: value }
+
+    Logs:
+      Error messages if the query fails.
+    """
+    query = "SELECT `key`, `value` FROM identity_provider;"
+    try:
+      cursor = self.mysql_conn.cursor()
+      cursor.execute(query)
+      result = cursor.fetchall()
+      cursor.close()
+
+      iam_settings = {row[0]: row[1] for row in result}
+
+      if iam_settings['authsource'] == "ldap":
+        protocol = "ldaps" if iam_settings.get("use_ssl") else "ldap"
+        starttls = "/????!StartTLS" if iam_settings.get("use_tls") else ""
+        iam_settings['ldap_url'] = f"{protocol}://{iam_settings['host']}:{iam_settings['port']}{starttls}"
+
+      return iam_settings
+    except Exception as e:
+      print(f"Error fetching identity provider settings: {e}")
+      return {}

+ 0 - 0
data/Dockerfiles/bootstrap/modules/__init__.py