ratelimit.lua 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864
  1. --[[
  2. Copyright (c) 2011-2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Copyright (c) 2016-2017, Andrew Lewis <nerf@judo.za.org>
  4. Licensed under the Apache License, Version 2.0 (the "License");
  5. you may not use this file except in compliance with the License.
  6. You may obtain a copy of the License at
  7. http://www.apache.org/licenses/LICENSE-2.0
  8. Unless required by applicable law or agreed to in writing, software
  9. distributed under the License is distributed on an "AS IS" BASIS,
  10. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. See the License for the specific language governing permissions and
  12. limitations under the License.
  13. ]]--
  14. if confighelp then
  15. return
  16. end
  17. local rspamd_logger = require "rspamd_logger"
  18. local rspamd_util = require "rspamd_util"
  19. local rspamd_lua_utils = require "lua_util"
  20. local lua_redis = require "lua_redis"
  21. local fun = require "fun"
  22. local lua_maps = require "lua_maps"
  23. local lua_util = require "lua_util"
  24. local rspamd_hash = require "rspamd_cryptobox_hash"
  25. local lua_selectors = require "lua_selectors"
  26. local ts = require("tableshape").types
  27. -- A plugin that implements ratelimits using redis
  28. local E = {}
  29. local N = 'ratelimit'
  30. local redis_params
  31. -- Senders that are considered as bounce
  32. local settings = {
  33. bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' },
  34. -- Do not check ratelimits for these recipients
  35. whitelisted_rcpts = { 'postmaster', 'mailer-daemon' },
  36. prefix = 'RL',
  37. ham_factor_rate = 1.01,
  38. spam_factor_rate = 0.99,
  39. ham_factor_burst = 1.02,
  40. spam_factor_burst = 0.98,
  41. max_rate_mult = 5,
  42. max_bucket_mult = 10,
  43. expire = 60 * 60 * 24 * 2, -- 2 days by default
  44. limits = {},
  45. allow_local = false,
  46. }
  47. -- Checks bucket, updating it if needed
  48. -- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
  49. -- KEYS[2] - current time in milliseconds
  50. -- KEYS[3] - bucket leak rate (messages per millisecond)
  51. -- KEYS[4] - bucket burst
  52. -- KEYS[5] - expire for a bucket
  53. -- return 1 if message should be ratelimited and 0 if not
  54. -- Redis keys used:
  55. -- l - last hit
  56. -- b - current burst
  57. -- dr - current dynamic rate multiplier (*10000)
  58. -- db - current dynamic burst multiplier (*10000)
  59. local bucket_check_script = [[
  60. local last = redis.call('HGET', KEYS[1], 'l')
  61. local now = tonumber(KEYS[2])
  62. local dynr, dynb, leaked = 0, 0, 0
  63. if not last then
  64. -- New bucket
  65. redis.call('HSET', KEYS[1], 'l', KEYS[2])
  66. redis.call('HSET', KEYS[1], 'b', '0')
  67. redis.call('HSET', KEYS[1], 'dr', '10000')
  68. redis.call('HSET', KEYS[1], 'db', '10000')
  69. redis.call('EXPIRE', KEYS[1], KEYS[5])
  70. return {0, '0', '1', '1', '0'}
  71. end
  72. last = tonumber(last)
  73. local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))
  74. -- Perform leak
  75. if burst > 0 then
  76. if last < tonumber(KEYS[2]) then
  77. local rate = tonumber(KEYS[3])
  78. dynr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000.0
  79. if dynr == 0 then dynr = 0.0001 end
  80. rate = rate * dynr
  81. leaked = ((now - last) * rate)
  82. if leaked > burst then leaked = burst end
  83. burst = burst - leaked
  84. redis.call('HINCRBYFLOAT', KEYS[1], 'b', -(leaked))
  85. redis.call('HSET', KEYS[1], 'l', KEYS[2])
  86. end
  87. dynb = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000.0
  88. if dynb == 0 then dynb = 0.0001 end
  89. if burst > 0 and (burst + 1) > tonumber(KEYS[4]) * dynb then
  90. return {1, tostring(burst), tostring(dynr), tostring(dynb), tostring(leaked)}
  91. end
  92. else
  93. burst = 0
  94. redis.call('HSET', KEYS[1], 'b', '0')
  95. end
  96. return {0, tostring(burst), tostring(dynr), tostring(dynb), tostring(leaked)}
  97. ]]
  98. local bucket_check_id
  99. -- Updates a bucket
  100. -- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
  101. -- KEYS[2] - current time in milliseconds
  102. -- KEYS[3] - dynamic rate multiplier
  103. -- KEYS[4] - dynamic burst multiplier
  104. -- KEYS[5] - max dyn rate (min: 1/x)
  105. -- KEYS[6] - max burst rate (min: 1/x)
  106. -- KEYS[7] - expire for a bucket
  107. -- Redis keys used:
  108. -- l - last hit
  109. -- b - current burst
  110. -- dr - current dynamic rate multiplier
  111. -- db - current dynamic burst multiplier
  112. local bucket_update_script = [[
  113. local last = redis.call('HGET', KEYS[1], 'l')
  114. local now = tonumber(KEYS[2])
  115. if not last then
  116. -- New bucket
  117. redis.call('HSET', KEYS[1], 'l', KEYS[2])
  118. redis.call('HSET', KEYS[1], 'b', '1')
  119. redis.call('HSET', KEYS[1], 'dr', '10000')
  120. redis.call('HSET', KEYS[1], 'db', '10000')
  121. redis.call('EXPIRE', KEYS[1], KEYS[7])
  122. return {1, 1, 1}
  123. end
  124. local dr, db = 1.0, 1.0
  125. if tonumber(KEYS[5]) > 1 then
  126. local rate_mult = tonumber(KEYS[3])
  127. local rate_limit = tonumber(KEYS[5])
  128. dr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000
  129. if rate_mult > 1.0 and dr < rate_limit then
  130. dr = dr * rate_mult
  131. if dr > 0.0001 then
  132. redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))
  133. else
  134. redis.call('HSET', KEYS[1], 'dr', '1')
  135. end
  136. elseif rate_mult < 1.0 and dr > (1.0 / rate_limit) then
  137. dr = dr * rate_mult
  138. if dr > 0.0001 then
  139. redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))
  140. else
  141. redis.call('HSET', KEYS[1], 'dr', '1')
  142. end
  143. end
  144. end
  145. if tonumber(KEYS[6]) > 1 then
  146. local rate_mult = tonumber(KEYS[4])
  147. local rate_limit = tonumber(KEYS[6])
  148. db = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000
  149. if rate_mult > 1.0 and db < rate_limit then
  150. db = db * rate_mult
  151. if db > 0.0001 then
  152. redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))
  153. else
  154. redis.call('HSET', KEYS[1], 'db', '1')
  155. end
  156. elseif rate_mult < 1.0 and db > (1.0 / rate_limit) then
  157. db = db * rate_mult
  158. if db > 0.0001 then
  159. redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))
  160. else
  161. redis.call('HSET', KEYS[1], 'db', '1')
  162. end
  163. end
  164. end
  165. local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))
  166. if burst < 0 then burst = 0 end
  167. redis.call('HINCRBYFLOAT', KEYS[1], 'b', 1)
  168. redis.call('HSET', KEYS[1], 'l', KEYS[2])
  169. redis.call('EXPIRE', KEYS[1], KEYS[7])
  170. return {tostring(burst), tostring(dr), tostring(db)}
  171. ]]
  172. local bucket_update_id
  173. -- message_func(task, limit_type, prefix, bucket, limit_key)
  174. local message_func = function(_, limit_type, _, _, _)
  175. return string.format('Ratelimit "%s" exceeded', limit_type)
  176. end
  177. local function load_scripts(cfg, ev_base)
  178. bucket_check_id = lua_redis.add_redis_script(bucket_check_script, redis_params)
  179. bucket_update_id = lua_redis.add_redis_script(bucket_update_script, redis_params)
  180. end
  181. local limit_parser
  182. local function parse_string_limit(lim, no_error)
  183. local function parse_time_suffix(s)
  184. if s == 's' then
  185. return 1
  186. elseif s == 'm' then
  187. return 60
  188. elseif s == 'h' then
  189. return 3600
  190. elseif s == 'd' then
  191. return 86400
  192. end
  193. end
  194. local function parse_num_suffix(s)
  195. if s == '' then
  196. return 1
  197. elseif s == 'k' then
  198. return 1000
  199. elseif s == 'm' then
  200. return 1000000
  201. elseif s == 'g' then
  202. return 1000000000
  203. end
  204. end
  205. local lpeg = require "lpeg"
  206. if not limit_parser then
  207. local digit = lpeg.R("09")
  208. limit_parser = {}
  209. limit_parser.integer =
  210. (lpeg.S("+-") ^ -1) *
  211. (digit ^ 1)
  212. limit_parser.fractional =
  213. (lpeg.P(".") ) *
  214. (digit ^ 1)
  215. limit_parser.number =
  216. (limit_parser.integer *
  217. (limit_parser.fractional ^ -1)) +
  218. (lpeg.S("+-") * limit_parser.fractional)
  219. limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
  220. (limit_parser.number / tonumber) *
  221. ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
  222. function (acc, val) return acc * val end)
  223. limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
  224. (limit_parser.number / tonumber) *
  225. ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
  226. function (acc, val) return acc * val end)
  227. limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
  228. (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
  229. limit_parser.time)
  230. end
  231. local t = lpeg.match(limit_parser.limit, lim)
  232. if t and t[1] and t[2] and t[2] ~= 0 then
  233. return t[2], t[1]
  234. end
  235. if not no_error then
  236. rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)
  237. end
  238. return nil
  239. end
  240. local function str_to_rate(str)
  241. local divider,divisor = parse_string_limit(str, false)
  242. if not divisor then
  243. rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str)
  244. return nil
  245. end
  246. return divisor / divider
  247. end
  248. local bucket_schema = ts.shape{
  249. burst = ts.number + ts.string / lua_util.dehumanize_number,
  250. rate = ts.number + ts.string / str_to_rate
  251. }
  252. local function parse_limit(name, data)
  253. if type(data) == 'table' then
  254. -- 2 cases here:
  255. -- * old limit in format [burst, rate]
  256. -- * vector of strings in Andrew's string format (removed from 1.8.2)
  257. -- * proper bucket table
  258. if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then
  259. -- Old style ratelimit
  260. rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name)
  261. if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then
  262. return {
  263. burst = data[1],
  264. rate = data[2]
  265. }
  266. elseif data[1] ~= 0 then
  267. rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name)
  268. else
  269. rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name)
  270. end
  271. return nil
  272. else
  273. local parsed_bucket,err = bucket_schema:transform(data)
  274. if not parsed_bucket or err then
  275. rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s',
  276. name, err, data)
  277. else
  278. return parsed_bucket
  279. end
  280. end
  281. elseif type(data) == 'string' then
  282. local rep_rate, burst = parse_string_limit(data)
  283. rspamd_logger.warnx(rspamd_config, 'old style rate bucket config detected for %s: %s',
  284. name, data)
  285. if rep_rate and burst then
  286. return {
  287. burst = burst,
  288. rate = burst / rep_rate -- reciprocal
  289. }
  290. end
  291. end
  292. return nil
  293. end
  294. --- Check whether this addr is bounce
  295. local function check_bounce(from)
  296. return fun.any(function(b) return b == from end, settings.bounce_senders)
  297. end
  298. local keywords = {
  299. ['ip'] = {
  300. ['get_value'] = function(task)
  301. local ip = task:get_ip()
  302. if ip and ip:is_valid() then return tostring(ip) end
  303. return nil
  304. end,
  305. },
  306. ['rip'] = {
  307. ['get_value'] = function(task)
  308. local ip = task:get_ip()
  309. if ip and ip:is_valid() and not ip:is_local() then return tostring(ip) end
  310. return nil
  311. end,
  312. },
  313. ['from'] = {
  314. ['get_value'] = function(task)
  315. local from = task:get_from(0)
  316. if ((from or E)[1] or E).addr then
  317. return string.lower(from[1]['addr'])
  318. end
  319. return nil
  320. end,
  321. },
  322. ['bounce'] = {
  323. ['get_value'] = function(task)
  324. local from = task:get_from(0)
  325. if not ((from or E)[1] or E).user then
  326. return '_'
  327. end
  328. if check_bounce(from[1]['user']) then return '_' else return nil end
  329. end,
  330. },
  331. ['asn'] = {
  332. ['get_value'] = function(task)
  333. local asn = task:get_mempool():get_variable('asn')
  334. if not asn then
  335. return nil
  336. else
  337. return asn
  338. end
  339. end,
  340. },
  341. ['user'] = {
  342. ['get_value'] = function(task)
  343. local auser = task:get_user()
  344. if not auser then
  345. return nil
  346. else
  347. return auser
  348. end
  349. end,
  350. },
  351. ['to'] = {
  352. ['get_value'] = function(task)
  353. return task:get_principal_recipient()
  354. end,
  355. },
  356. ['digest'] = {
  357. ['get_value'] = function(task)
  358. return task:get_digest()
  359. end,
  360. },
  361. ['attachments'] = {
  362. ['get_value'] = function(task)
  363. local parts = task:get_parts() or E
  364. local digests = {}
  365. for _,p in ipairs(parts) do
  366. if p:get_filename() then
  367. table.insert(digests, p:get_digest())
  368. end
  369. end
  370. if #digests > 0 then
  371. return table.concat(digests, '')
  372. end
  373. return nil
  374. end,
  375. },
  376. ['files'] = {
  377. ['get_value'] = function(task)
  378. local parts = task:get_parts() or E
  379. local files = {}
  380. for _,p in ipairs(parts) do
  381. local fname = p:get_filename()
  382. if fname then
  383. table.insert(files, fname)
  384. end
  385. end
  386. if #files > 0 then
  387. return table.concat(files, ':')
  388. end
  389. return nil
  390. end,
  391. },
  392. }
  393. local function gen_rate_key(task, rtype, bucket)
  394. local key_t = {tostring(lua_util.round(100000.0 / bucket.burst))}
  395. local key_keywords = lua_util.str_split(rtype, '_')
  396. local have_user = false
  397. for _, v in ipairs(key_keywords) do
  398. local ret
  399. if keywords[v] and type(keywords[v]['get_value']) == 'function' then
  400. ret = keywords[v]['get_value'](task)
  401. end
  402. if not ret then return nil end
  403. if v == 'user' then have_user = true end
  404. if type(ret) ~= 'string' then ret = tostring(ret) end
  405. table.insert(key_t, ret)
  406. end
  407. if have_user and not task:get_user() then
  408. return nil
  409. end
  410. return table.concat(key_t, ":")
  411. end
  412. local function make_prefix(redis_key, name, bucket)
  413. local hash_len = 24
  414. if hash_len > #redis_key then hash_len = #redis_key end
  415. local hash = settings.prefix ..
  416. string.sub(rspamd_hash.create(redis_key):base32(), 1, hash_len)
  417. -- Fill defaults
  418. if not bucket.spam_factor_rate then
  419. bucket.spam_factor_rate = settings.spam_factor_rate
  420. end
  421. if not bucket.ham_factor_rate then
  422. bucket.ham_factor_rate = settings.ham_factor_rate
  423. end
  424. if not bucket.spam_factor_burst then
  425. bucket.spam_factor_burst = settings.spam_factor_burst
  426. end
  427. if not bucket.ham_factor_burst then
  428. bucket.ham_factor_burst = settings.ham_factor_burst
  429. end
  430. return {
  431. bucket = bucket,
  432. name = name,
  433. hash = hash
  434. }
  435. end
  436. local function limit_to_prefixes(task, k, v, prefixes)
  437. local n = 0
  438. for _,bucket in ipairs(v.buckets) do
  439. if v.selector then
  440. local selectors = lua_selectors.process_selectors(task, v.selector)
  441. if selectors then
  442. local combined = lua_selectors.combine_selectors(task, selectors, ':')
  443. if type(combined) == 'string' then
  444. prefixes[combined] = make_prefix(combined, k, bucket)
  445. n = n + 1
  446. else
  447. fun.each(function(p)
  448. prefixes[p] = make_prefix(p, k, bucket)
  449. n = n + 1
  450. end, combined)
  451. end
  452. end
  453. else
  454. local prefix = gen_rate_key(task, k, bucket)
  455. if prefix then
  456. if type(prefix) == 'string' then
  457. prefixes[prefix] = make_prefix(prefix, k, bucket)
  458. n = n + 1
  459. else
  460. fun.each(function(p)
  461. prefixes[p] = make_prefix(p, k, bucket)
  462. n = n + 1
  463. end, prefix)
  464. end
  465. end
  466. end
  467. end
  468. return n
  469. end
  470. local function ratelimit_cb(task)
  471. if not settings.allow_local and
  472. rspamd_lua_utils.is_rspamc_or_controller(task) then return end
  473. -- Get initial task data
  474. local ip = task:get_from_ip()
  475. if ip and ip:is_valid() and settings.whitelisted_ip then
  476. if settings.whitelisted_ip:get_key(ip) then
  477. -- Do not check whitelisted ip
  478. rspamd_logger.infox(task, 'skip ratelimit for whitelisted IP')
  479. return
  480. end
  481. end
  482. -- Parse all rcpts
  483. local rcpts = task:get_recipients()
  484. local rcpts_user = {}
  485. if rcpts then
  486. fun.each(function(r)
  487. fun.each(function(type) table.insert(rcpts_user, r[type]) end, {'user', 'addr'})
  488. end, rcpts)
  489. if fun.any(function(r) return settings.whitelisted_rcpts:get_key(r) end, rcpts_user) then
  490. rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient')
  491. return
  492. end
  493. end
  494. -- Get user (authuser)
  495. if settings.whitelisted_user then
  496. local auser = task:get_user()
  497. if settings.whitelisted_user:get_key(auser) then
  498. rspamd_logger.infox(task, 'skip ratelimit for whitelisted user')
  499. return
  500. end
  501. end
  502. -- Now create all ratelimit prefixes
  503. local prefixes = {}
  504. local nprefixes = 0
  505. for k,v in pairs(settings.limits) do
  506. nprefixes = nprefixes + limit_to_prefixes(task, k, v, prefixes)
  507. end
  508. for k, hdl in pairs(settings.custom_keywords or E) do
  509. local ret, redis_key, bd = pcall(hdl, task)
  510. if ret then
  511. local bucket = parse_limit(k, bd)
  512. if bucket then
  513. prefixes[redis_key] = make_prefix(redis_key, k, bucket)
  514. end
  515. nprefixes = nprefixes + 1
  516. else
  517. rspamd_logger.errx(task, 'cannot call handler for %s: %s',
  518. k, redis_key)
  519. end
  520. end
  521. local function gen_check_cb(prefix, bucket, lim_name, lim_key)
  522. return function(err, data)
  523. if err then
  524. rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data)
  525. elseif type(data) == 'table' and data[1] then
  526. lua_util.debugm(N, task,
  527. "got reply for limit %s (%s / %s); %s burst, %s:%s dyn, %s leaked",
  528. prefix, bucket.burst, bucket.rate,
  529. data[2], data[3], data[4], data[5])
  530. if data[1] == 1 then
  531. -- set symbol only and do NOT soft reject
  532. if settings.symbol then
  533. task:insert_result(settings.symbol, 0.0,
  534. string.format('%s(%s)', lim_name, lim_key))
  535. rspamd_logger.infox(task,
  536. 'set_symbol_only: ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn); redis key: %s',
  537. lim_name, prefix,
  538. bucket.burst, bucket.rate,
  539. data[2], data[3], data[4], lim_key)
  540. return
  541. -- set INFO symbol and soft reject
  542. elseif settings.info_symbol then
  543. task:insert_result(settings.info_symbol, 1.0,
  544. string.format('%s(%s)', lim_name, lim_key))
  545. end
  546. rspamd_logger.infox(task,
  547. 'ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn); redis key: %s',
  548. lim_name, prefix,
  549. bucket.burst, bucket.rate,
  550. data[2], data[3], data[4], lim_key)
  551. task:set_pre_result('soft reject',
  552. message_func(task, lim_name, prefix, bucket, lim_key), N)
  553. end
  554. end
  555. end
  556. end
  557. -- Don't do anything if pre-result has been already set
  558. if task:has_pre_result() then return end
  559. if nprefixes > 0 then
  560. -- Save prefixes to the cache to allow update
  561. task:cache_set('ratelimit_prefixes', prefixes)
  562. local now = rspamd_util.get_time()
  563. now = lua_util.round(now * 1000.0) -- Get milliseconds
  564. -- Now call check script for all defined prefixes
  565. for pr,value in pairs(prefixes) do
  566. local bucket = value.bucket
  567. local rate = (bucket.rate) / 1000.0 -- Leak rate in messages/ms
  568. lua_util.debugm(N, task, "check limit %s:%s -> %s (%s/%s)",
  569. value.name, pr, value.hash, bucket.burst, bucket.rate)
  570. lua_redis.exec_redis_script(bucket_check_id,
  571. {key = value.hash, task = task, is_write = true},
  572. gen_check_cb(pr, bucket, value.name, value.hash),
  573. {value.hash, tostring(now), tostring(rate), tostring(bucket.burst),
  574. tostring(settings.expire)})
  575. end
  576. end
  577. end
  578. local function ratelimit_update_cb(task)
  579. if task:has_flag('skip') then return end
  580. if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
  581. local prefixes = task:cache_get('ratelimit_prefixes')
  582. if prefixes then
  583. if task:has_pre_result() then
  584. -- Already rate limited/greylisted, do nothing
  585. lua_util.debugm(N, task, 'pre-action has been set, do not update')
  586. return
  587. end
  588. local verdict = lua_util.get_task_verdict(task)
  589. -- Update each bucket
  590. for k, v in pairs(prefixes) do
  591. local bucket = v.bucket
  592. local function update_bucket_cb(err, data)
  593. if err then
  594. rspamd_logger.errx(task, 'cannot update rate bucket %s: %s',
  595. k, err)
  596. else
  597. lua_util.debugm(N, task,
  598. "updated limit %s:%s -> %s (%s/%s), burst: %s, dyn_rate: %s, dyn_burst: %s",
  599. v.name, k, v.hash,
  600. bucket.burst, bucket.rate,
  601. data[1], data[2], data[3])
  602. end
  603. end
  604. local now = rspamd_util.get_time()
  605. now = lua_util.round(now * 1000.0) -- Get milliseconds
  606. local mult_burst = 1.0
  607. local mult_rate = 1.0
  608. if verdict == 'spam' or verdict == 'junk' then
  609. mult_burst = bucket.spam_factor_burst or 1.0
  610. mult_rate = bucket.spam_factor_rate or 1.0
  611. elseif verdict == 'ham' then
  612. mult_burst = bucket.ham_factor_burst or 1.0
  613. mult_rate = bucket.ham_factor_rate or 1.0
  614. end
  615. lua_redis.exec_redis_script(bucket_update_id,
  616. {key = v.hash, task = task, is_write = true},
  617. update_bucket_cb,
  618. {v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst),
  619. tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult),
  620. tostring(settings.expire)})
  621. end
  622. end
  623. end
  624. local opts = rspamd_config:get_all_opt(N)
  625. if opts then
  626. settings = lua_util.override_defaults(settings, opts)
  627. if opts['limit'] then
  628. rspamd_logger.errx(rspamd_config, 'Legacy ratelimit config format no longer supported')
  629. end
  630. if opts['rates'] and type(opts['rates']) == 'table' then
  631. -- new way of setting limits
  632. fun.each(function(t, lim)
  633. local buckets = {}
  634. if type(lim) == 'table' and lim.bucket then
  635. if lim.bucket[1] then
  636. for _,bucket in ipairs(lim.bucket) do
  637. local b = parse_limit(t, bucket)
  638. if not b then
  639. rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"',
  640. t, b)
  641. return
  642. end
  643. table.insert(buckets, b)
  644. end
  645. else
  646. local bucket = parse_limit(t, lim.bucket)
  647. if not bucket then
  648. rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"',
  649. t, lim.bucket)
  650. return
  651. end
  652. buckets = {bucket}
  653. end
  654. settings.limits[t] = {
  655. buckets = buckets
  656. }
  657. if lim.selector then
  658. local selector = lua_selectors.parse_selector(rspamd_config, lim.selector)
  659. if not selector then
  660. rspamd_logger.errx(rspamd_config, 'bad ratelimit selector for %s: "%s"',
  661. t, lim.selector)
  662. settings.limits[t] = nil
  663. return
  664. end
  665. settings.limits[t].selector = selector
  666. end
  667. else
  668. rspamd_logger.warnx(rspamd_config, 'old syntax for ratelimits: %s', lim)
  669. buckets = parse_limit(t, lim)
  670. if buckets then
  671. settings.limits[t] = {
  672. buckets = {buckets}
  673. }
  674. end
  675. end
  676. end, opts['rates'])
  677. end
  678. -- Display what's enabled
  679. fun.each(function(s)
  680. rspamd_logger.infox(rspamd_config, 'enabled ratelimit: %s', s)
  681. end, fun.map(function(n,d)
  682. return string.format('%s [%s]', n,
  683. table.concat(fun.totable(fun.map(function(v)
  684. return string.format('%s msgs burst, %s msgs/sec rate',
  685. v.burst, v.rate)
  686. end, d.buckets)), '; ')
  687. )
  688. end, settings.limits))
  689. -- Ret, ret, ret: stupid legacy stuff:
  690. -- If we have a string with commas then load it as as static map
  691. -- otherwise, apply normal logic of Rspamd maps
  692. local wrcpts = opts['whitelisted_rcpts']
  693. if type(wrcpts) == 'string' then
  694. if string.find(wrcpts, ',') then
  695. settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
  696. lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts')
  697. else
  698. settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
  699. 'Ratelimit whitelisted rcpts')
  700. end
  701. elseif type(opts['whitelisted_rcpts']) == 'table' then
  702. settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
  703. 'Ratelimit whitelisted rcpts')
  704. else
  705. -- Stupid default...
  706. settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
  707. settings.whitelisted_rcpts, 'set', 'Ratelimit whitelisted rcpts')
  708. end
  709. if opts['whitelisted_ip'] then
  710. settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix',
  711. 'Ratelimit whitelist ip map')
  712. end
  713. if opts['whitelisted_user'] then
  714. settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set',
  715. 'Ratelimit whitelist user map')
  716. end
  717. settings.custom_keywords = {}
  718. if opts['custom_keywords'] then
  719. local ret, res_or_err = pcall(loadfile(opts['custom_keywords']))
  720. if ret then
  721. opts['custom_keywords'] = {}
  722. if type(res_or_err) == 'table' then
  723. for k,hdl in pairs(res_or_err) do
  724. settings['custom_keywords'][k] = hdl
  725. end
  726. elseif type(res_or_err) == 'function' then
  727. settings['custom_keywords']['custom'] = res_or_err
  728. end
  729. else
  730. rspamd_logger.errx(rspamd_config, 'cannot execute %s: %s',
  731. opts['custom_keywords'], res_or_err)
  732. settings['custom_keywords'] = {}
  733. end
  734. end
  735. if opts['message_func'] then
  736. message_func = assert(load(opts['message_func']))()
  737. end
  738. redis_params = lua_redis.parse_redis_server('ratelimit')
  739. if not redis_params then
  740. rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module')
  741. lua_util.disable_module(N, "redis")
  742. else
  743. local s = {
  744. type = 'prefilter,nostat',
  745. name = 'RATELIMIT_CHECK',
  746. priority = 7,
  747. callback = ratelimit_cb,
  748. flags = 'empty',
  749. }
  750. if settings.symbol then
  751. s.name = settings.symbol
  752. elseif settings.info_symbol then
  753. s.name = settings.info_symbol
  754. end
  755. rspamd_config:register_symbol(s)
  756. rspamd_config:register_symbol {
  757. type = 'idempotent',
  758. name = 'RATELIMIT_UPDATE',
  759. callback = ratelimit_update_cb,
  760. }
  761. end
  762. end
  763. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  764. load_scripts(cfg, ev_base)
  765. end)