|| --[[Copyright (c) 2011-2017, Vsevolod Stakhov <vsevolod@highsecure.ru>Copyright (c) 2016-2017, Andrew Lewis <nerf@judo.za.org>Licensed under the Apache License, Version 2.0 (the "License");you may not use this file except in compliance with the License.You may obtain a copy of the License at    http://www.apache.org/licenses/LICENSE-2.0Unless required by applicable law or agreed to in writing, softwaredistributed under the License is distributed on an "AS IS" BASIS,WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.See the License for the specific language governing permissions andlimitations under the License.]]--if confighelp then  returnend-- A plugin that implements ratelimits using redislocal E = {}local N = 'ratelimit'local redis_params-- Senders that are considered as bouncelocal settings = {  bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' },-- Do not check ratelimits for these recipients  whitelisted_rcpts = { 'postmaster', 'mailer-daemon' },  prefix = 'RL',  ham_factor_rate = 1.01,  spam_factor_rate = 0.99,  ham_factor_burst = 1.02,  spam_factor_burst = 0.98,  max_rate_mult = 5,  max_bucket_mult = 10,  expire = 60 * 60 * 24 * 2, -- 2 days by default  limits = {},  allow_local = false,}-- Checks bucket, updating it if needed-- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>-- KEYS[2] - current time in milliseconds-- KEYS[3] - bucket leak rate (messages per millisecond)-- KEYS[4] - bucket burst-- KEYS[5] - expire for a bucket-- return 1 if message should be ratelimited and 0 if not-- Redis keys used:--   l - last hit--   b - current burst--   dr - current dynamic rate multiplier (*10000)--   db - current dynamic burst multiplier (*10000)local bucket_check_script = [[  local last = redis.call('HGET', KEYS[1], 'l')  local now = tonumber(KEYS[2])  local dynr, dynb = 0, 0  if not last then    -- New bucket    redis.call('HSET', KEYS[1], 'l', KEYS[2])    redis.call('HSET', KEYS[1], 'b', '0')    redis.call('HSET', KEYS[1], 'dr', '10000')    redis.call('HSET', KEYS[1], 'db', '10000')    redis.call('EXPIRE', KEYS[1], KEYS[5])    return {0, 0, 1, 1}  end  last = tonumber(last)  local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))  -- Perform leak  if burst > 0 then   if last < tonumber(KEYS[2]) then    local rate = tonumber(KEYS[3])    dynr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000.0    rate = rate * dynr    local leaked = ((now - last) * rate)    burst = burst - leaked    redis.call('HINCRBYFLOAT', KEYS[1], 'b', -(leaked))   end  else   burst = 0   redis.call('HSET', KEYS[1], 'b', '0')  end  dynb = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000.0  if (burst + 1) * dynb > tonumber(KEYS[4]) then   return {1, tostring(burst), tostring(dynr), tostring(dynb)}  end  return {0, tostring(burst), tostring(dynr), tostring(dynb)}]]local bucket_check_id-- Updates a bucket-- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>-- KEYS[2] - current time in milliseconds-- KEYS[3] - dynamic rate multiplier-- KEYS[4] - dynamic burst multiplier-- KEYS[5] - max dyn rate (min: 1/x)-- KEYS[6] - max burst rate (min: 1/x)-- KEYS[7] - expire for a bucket-- Redis keys used:--   l - last hit--   b - current burst--   dr - current dynamic rate multiplier--   db - current dynamic burst multiplierlocal bucket_update_script = [[  local last = redis.call('HGET', KEYS[1], 'l')  local now = tonumber(KEYS[2])  if not last then    -- New bucket    redis.call('HSET', KEYS[1], 'l', KEYS[2])    redis.call('HSET', KEYS[1], 'b', '1')    redis.call('HSET', KEYS[1], 'dr', '10000')    redis.call('HSET', KEYS[1], 'db', '10000')    redis.call('EXPIRE', KEYS[1], KEYS[7])    return {1, 1, 1}  end  local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))  local db = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000  local dr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000  if dr < tonumber(KEYS[5]) and dr > 1.0 / tonumber(KEYS[5]) then    dr = dr * tonumber(KEYS[3])    redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))  end  if db < tonumber(KEYS[6]) and db > 1.0 / tonumber(KEYS[6]) then    db = db * tonumber(KEYS[4])    redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))  end  redis.call('HINCRBYFLOAT', KEYS[1], 'b', 1)  redis.call('HSET', KEYS[1], 'l', KEYS[2])  redis.call('EXPIRE', KEYS[1], KEYS[7])  return {tostring(burst), tostring(dr), tostring(db)}]]local bucket_update_id-- message_func(task, limit_type, prefix, bucket)local message_func = function(_, limit_type, _, _)  return string.format('Ratelimit "%s" exceeded', limit_type)endlocal rspamd_logger = require "rspamd_logger"local rspamd_util = require "rspamd_util"local rspamd_lua_utils = require "lua_util"local lua_redis = require "lua_redis"local fun = require "fun"local lua_maps = require "lua_maps"local lua_util = require "lua_util"local rspamd_hash = require "rspamd_cryptobox_hash"local function load_scripts(cfg, ev_base)  bucket_check_id = lua_redis.add_redis_script(bucket_check_script, redis_params)  bucket_update_id = lua_redis.add_redis_script(bucket_update_script, redis_params)endlocal limit_parserlocal function parse_string_limit(lim, no_error)  local function parse_time_suffix(s)    if s == 's' then      return 1    elseif s == 'm' then      return 60    elseif s == 'h' then      return 3600    elseif s == 'd' then      return 86400    end  end  local function parse_num_suffix(s)    if s == '' then      return 1    elseif s == 'k' then      return 1000    elseif s == 'm' then      return 1000000    elseif s == 'g' then      return 1000000000    end  end  local lpeg = require "lpeg"  if not limit_parser then    local digit = lpeg.R("09")    limit_parser = {}    limit_parser.integer =    (lpeg.S("+-") ^ -1) *            (digit   ^  1)    limit_parser.fractional =    (lpeg.P(".")   ) *            (digit ^ 1)    limit_parser.number =    (limit_parser.integer *            (limit_parser.fractional ^ -1)) +            (lpeg.S("+-") * limit_parser.fractional)    limit_parser.time = lpeg.Cf(lpeg.Cc(1) *            (limit_parser.number / tonumber) *            ((lpeg.S("smhd") / parse_time_suffix) ^ -1),      function (acc, val) return acc * val end)    limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *            (limit_parser.number / tonumber) *            ((lpeg.S("kmg") / parse_num_suffix) ^ -1),      function (acc, val) return acc * val end)    limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *            (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *            limit_parser.time)  end  local t = lpeg.match(limit_parser.limit, lim)  if t and t[1] and t[2] and t[2] ~= 0 then    return t[2], t[1]  end  if not no_error then    rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)  end  return nilendlocal function parse_limit(name, data)  local buckets = {}  if type(data) == 'table' then    -- 3 cases here:    --  * old limit in format [burst, rate]    --  * vector of strings in Andrew's string format    --  * proper bucket table    if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then      -- Old style ratelimit      rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name)      if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then        table.insert(buckets, {          burst = data[1],          rate = data[2]        })      elseif data[1] ~= 0 then        rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name)      else        rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name)      end    else      -- Recursively map parse_limit and flatten the list      fun.each(function(l)        -- Flatten list        for _,b in ipairs(l) do table.insert(buckets, b) end      end, fun.map(function(d) return parse_limit(d, name) end, data))    end  elseif type(data) == 'string' then    local rep_rate, burst = parse_string_limit(data)    if rep_rate and burst then      table.insert(buckets, {        burst = burst,        rate = 1.0 / rep_rate -- reciprocal      })    end  end  -- Filter valid  return fun.totable(fun.filter(function(val)    return type(val.burst) == 'number' and type(val.rate) == 'number'  end, buckets))end--- Check whether this addr is bouncelocal function check_bounce(from)  return fun.any(function(b) return b == from end, settings.bounce_senders)endlocal keywords = {  ['ip'] = {    ['get_value'] = function(task)      local ip = task:get_ip()      if ip and ip:is_valid() then return tostring(ip) end      return nil    end,  },  ['rip'] = {    ['get_value'] = function(task)      local ip = task:get_ip()      if ip and ip:is_valid() and not ip:is_local() then return tostring(ip) end      return nil    end,  },  ['from'] = {    ['get_value'] = function(task)      local from = task:get_from(0)      if ((from or E)[1] or E).addr then        return string.lower(from[1]['addr'])      end      return nil    end,  },  ['bounce'] = {    ['get_value'] = function(task)      local from = task:get_from(0)      if not ((from or E)[1] or E).user then        return '_'      end      if check_bounce(from[1]['user']) then return '_' else return nil end    end,  },  ['asn'] = {    ['get_value'] = function(task)      local asn = task:get_mempool():get_variable('asn')      if not asn then        return nil      else        return asn      end    end,  },  ['user'] = {    ['get_value'] = function(task)      local auser = task:get_user()      if not auser then        return nil      else        return auser      end    end,  },  ['to'] = {    ['get_value'] = function(task)      return task:get_principal_recipient()    end,  },}local function gen_rate_key(task, rtype, bucket)  local key_t = {tostring(lua_util.round(100000.0 / bucket.burst))}  local key_keywords = lua_util.str_split(rtype, '_')  local have_user = false  for _, v in ipairs(key_keywords) do    local ret    if keywords[v] and type(keywords[v]['get_value']) == 'function' then      ret = keywords[v]['get_value'](task)    end    if not ret then return nil end    if v == 'user' then have_user = true end    if type(ret) ~= 'string' then ret = tostring(ret) end    table.insert(key_t, ret)  end  if have_user and not task:get_user() then    return nil  end  return table.concat(key_t, ":")endlocal function make_prefix(redis_key, name, bucket)  local hash_len = 24  if hash_len > #redis_key then hash_len = #redis_key end  local hash = settings.prefix ..      string.sub(rspamd_hash.create(redis_key):base32(), 1, hash_len)  -- Fill defaults  if not bucket.spam_factor_rate then    bucket.spam_factor_rate = settings.spam_factor_rate  end  if not bucket.ham_factor_rate then    bucket.ham_factor_rate = settings.ham_factor_rate  end  if not bucket.spam_factor_burst then    bucket.spam_factor_burst = settings.spam_factor_burst  end  if not bucket.ham_factor_burst then    bucket.ham_factor_burst = settings.ham_factor_burst  end  return {    bucket = bucket,    name = name,    hash = hash  }endlocal function limit_to_prefixes(task, k, v, prefixes)  local n = 0  for _,bucket in ipairs(v) do    local prefix = gen_rate_key(task, k, bucket)    if prefix then      prefixes[prefix] = make_prefix(prefix, k, bucket)      n = n + 1    end  end  return nendlocal function ratelimit_cb(task)  if not settings.allow_local and          rspamd_lua_utils.is_rspamc_or_controller(task) then return end  -- Get initial task data  local ip = task:get_from_ip()  if ip and ip:is_valid() and settings.whitelisted_ip then    if settings.whitelisted_ip:get_key(ip) then      -- Do not check whitelisted ip      rspamd_logger.infox(task, 'skip ratelimit for whitelisted IP')      return    end  end  -- Parse all rcpts  local rcpts = task:get_recipients()  local rcpts_user = {}  if rcpts then    fun.each(function(r)      fun.each(function(type) table.insert(rcpts_user, r[type]) end, {'user', 'addr'})    end, rcpts)    if fun.any(function(r) return settings.whitelisted_rcpts:get_key(r) end, rcpts_user) then      rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient')      return    end  end  -- Get user (authuser)  if settings.whitelisted_user then    local auser = task:get_user()    if settings.whitelisted_user:get_key(auser) then      rspamd_logger.infox(task, 'skip ratelimit for whitelisted user')      return    end  end  -- Now create all ratelimit prefixes  local prefixes = {}  local nprefixes = 0  for k,v in pairs(settings.limits) do    nprefixes = nprefixes + limit_to_prefixes(task, k, v, prefixes)  end  for k, hdl in pairs(settings.custom_keywords or E) do    local ret, redis_key, bd = pcall(hdl, task)    if ret then      local bucket = parse_limit(k, bd)      if bucket[1] then        prefixes[redis_key] = make_prefix(redis_key, k, bucket[1])      end      nprefixes = nprefixes + 1    else      rspamd_logger.errx(task, 'cannot call handler for %s: %s',          k, redis_key)    end  end  local function gen_check_cb(prefix, bucket, lim_name)    return function(err, data)      if err then        rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data)      elseif type(data) == 'table' and data[1] and data[1] == 1 then        -- set symbol only and do NOT soft reject        if settings.symbol then          task:insert_result(settings.symbol, 0.0, lim_name .. "(" .. prefix .. ")")          rspamd_logger.infox(task,              'set_symbol_only: ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)',              lim_name, prefix,              bucket.burst, bucket.rate,              data[2], data[3], data[4])          return        -- set INFO symbol and soft reject        elseif settings.info_symbol then          task:insert_result(settings.info_symbol, 1.0,              lim_name .. "(" .. prefix .. ")")        end        rspamd_logger.infox(task,            'ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)',            lim_name, prefix,            bucket.burst, bucket.rate,            data[2], data[3], data[4])        task:set_pre_result('soft reject',                message_func(task, lim_name, prefix, bucket))      end    end  end  -- Don't do anything if pre-result has been already set  if task:has_pre_result() then return end  if nprefixes > 0 then    -- Save prefixes to the cache to allow update    task:cache_set('ratelimit_prefixes', prefixes)    local now = rspamd_util.get_time()    now = lua_util.round(now * 1000.0) -- Get milliseconds    -- Now call check script for all defined prefixes    for pr,value in pairs(prefixes) do      local bucket = value.bucket      local rate = (bucket.rate) / 1000.0 -- Leak rate in messages/ms      rspamd_logger.debugm(N, task, "check limit %s:%s -> %s (%s/%s)",          value.name, pr, value.hash, bucket.burst, bucket.rate)      lua_redis.exec_redis_script(bucket_check_id,              {key = value.hash, task = task, is_write = true},              gen_check_cb(pr, bucket, value.name),              {value.hash, tostring(now), tostring(rate), tostring(bucket.burst),                  tostring(settings.expire)})    end  endendlocal function ratelimit_update_cb(task)  local prefixes = task:cache_get('ratelimit_prefixes')  if prefixes then    if task:has_pre_result() then      -- Already rate limited/greylisted, do nothing      rspamd_logger.debugm(N, task, 'pre-action has been set, do not update')      return    end    local is_spam = not (task:get_metric_action() == 'no action')    -- Update each bucket    for k, v in pairs(prefixes) do      local bucket = v.bucket      local function update_bucket_cb(err, data)        if err then          rspamd_logger.errx(task, 'cannot update rate bucket %s: %s',                  k, err)        else          rspamd_logger.debugm(N, task,              "updated limit %s:%s -> %s (%s/%s), burst: %s, dyn_rate: %s, dyn_burst: %s",              v.name, k, v.hash,              bucket.burst, bucket.rate,              data[1], data[2], data[3])        end      end      local now = rspamd_util.get_time()      now = lua_util.round(now * 1000.0) -- Get milliseconds      local mult_burst = bucket.ham_factor_burst or 1.0      local mult_rate = bucket.ham_factor_burst or 1.0      if is_spam then        mult_burst = bucket.spam_factor_burst or 1.0        mult_rate = bucket.spam_factor_rate or 1.0      end      lua_redis.exec_redis_script(bucket_update_id,              {key = v.hash, task = task, is_write = true},              update_bucket_cb,              {v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst),               tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult),               tostring(settings.expire)})    end  endendlocal opts = rspamd_config:get_all_opt(N)if opts then  settings = lua_util.override_defaults(settings, opts)  if opts['limit'] then    rspamd_logger.errx(rspamd_config, 'Legacy ratelimit config format no longer supported')  end  if opts['rates'] and type(opts['rates']) == 'table' then    -- new way of setting limits    fun.each(function(t, lim)      local buckets = parse_limit(t, lim)      if buckets and #buckets > 0 then        settings.limits[t] = buckets      end    end, opts['rates'])  end  local enabled_limits = fun.totable(fun.map(function(t)    return t  end, settings.limits))  rspamd_logger.infox(rspamd_config,          'enabled rate buckets: [%1]', table.concat(enabled_limits, ','))  -- Ret, ret, ret: stupid legacy stuff:  -- If we have a string with commas then load it as as static map  -- otherwise, apply normal logic of Rspamd maps  local wrcpts = opts['whitelisted_rcpts']  if type(wrcpts) == 'string' then    if string.find(wrcpts, ',') then      settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(        lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts')    else      settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',        'Ratelimit whitelisted rcpts')    end  elseif type(opts['whitelisted_rcpts']) == 'table' then    settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',      'Ratelimit whitelisted rcpts')  else    -- Stupid default...    settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(        settings.whitelisted_rcpts, 'set', 'Ratelimit whitelisted rcpts')  end  if opts['whitelisted_ip'] then    settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix',      'Ratelimit whitelist ip map')  end  if opts['whitelisted_user'] then    settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set',      'Ratelimit whitelist user map')  end  settings.custom_keywords = {}  if opts['custom_keywords'] then    local ret, res_or_err = pcall(loadfile(opts['custom_keywords']))    if ret then      opts['custom_keywords'] = {}      if type(res_or_err) == 'table' then        for k,hdl in pairs(res_or_err) do          settings['custom_keywords'][k] = hdl        end      elseif type(res_or_err) == 'function' then        settings['custom_keywords']['custom'] = res_or_err      end    else      rspamd_logger.errx(rspamd_config, 'cannot execute %s: %s',          opts['custom_keywords'], res_or_err)      settings['custom_keywords'] = {}    end  end  if opts['message_func'] then    message_func = assert(load(opts['message_func']))()  end  redis_params = lua_redis.parse_redis_server('ratelimit')  if not redis_params then    rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module')    lua_util.disable_module(N, "redis")  else    local s = {      type = 'prefilter,nostat',      name = 'RATELIMIT_CHECK',      priority = 7,      callback = ratelimit_cb,      flags = 'empty',    }    if settings.symbol then      s.name = settings.symbol    elseif settings.info_symbol then      s.name = settings.info_symbol    end    rspamd_config:register_symbol(s)    rspamd_config:register_symbol {      type = 'idempotent',      name = 'RATELIMIT_UPDATE',      callback = ratelimit_update_cb,    }  endendrspamd_config:add_on_load(function(cfg, ev_base, worker)  load_scripts(cfg, ev_base)end)
 |