neural.lua 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456
  1. --[[
  2. Copyright (c) 2016, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. if confighelp then
  14. return
  15. end
  16. local rspamd_logger = require "rspamd_logger"
  17. local rspamd_util = require "rspamd_util"
  18. local rspamd_kann = require "rspamd_kann"
  19. local lua_redis = require "lua_redis"
  20. local lua_util = require "lua_util"
  21. local fun = require "fun"
  22. local lua_settings = require "lua_settings"
  23. local meta_functions = require "lua_meta"
  24. local ts = require("tableshape").types
  25. local lua_verdict = require "lua_verdict"
  26. local N = "neural"
  27. -- Module vars
  28. local default_options = {
  29. train = {
  30. max_trains = 1000,
  31. max_epoch = 1000,
  32. max_usages = 10,
  33. max_iterations = 25, -- Torch style
  34. mse = 0.001,
  35. autotrain = true,
  36. train_prob = 1.0,
  37. learn_threads = 1,
  38. learning_rate = 0.01,
  39. },
  40. watch_interval = 60.0,
  41. lock_expire = 600,
  42. learning_spawned = false,
  43. ann_expire = 60 * 60 * 24 * 2, -- 2 days
  44. symbol_spam = 'NEURAL_SPAM',
  45. symbol_ham = 'NEURAL_HAM',
  46. }
  47. local redis_profile_schema = ts.shape{
  48. digest = ts.string,
  49. symbols = ts.array_of(ts.string),
  50. version = ts.number,
  51. redis_key = ts.string,
  52. distance = ts.number:is_optional(),
  53. }
  54. -- Rule structure:
  55. -- * static config fields (see `default_options`)
  56. -- * prefix - name or defined prefix
  57. -- * settings - table of settings indexed by settings id, -1 is used when no settings defined
  58. -- Rule settings element defines elements for specific settings id:
  59. -- * symbols - static symbols profile (defined by config or extracted from symcache)
  60. -- * name - name of settings id
  61. -- * digest - digest of all symbols
  62. -- * ann - dynamic ANN configuration loaded from Redis
  63. -- * train - train data for ANN (e.g. the currently trained ANN)
  64. -- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
  65. -- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
  66. -- * version - version of ANN loaded from redis
  67. -- * redis_key - name of ANN key in Redis
  68. -- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
  69. -- * distance - distance between set.symbols and set.ann.symbols
  70. -- * ann - kann object
  71. local settings = {
  72. rules = {},
  73. prefix = 'rn', -- Neural network default prefix
  74. max_profiles = 3, -- Maximum number of NN profiles stored
  75. }
  76. local module_config = rspamd_config:get_all_opt("neural")
  77. if not module_config then
  78. -- Legacy
  79. module_config = rspamd_config:get_all_opt("fann_redis")
  80. end
  81. -- Lua script that checks if we can store a new training vector
  82. -- Uses the following keys:
  83. -- key1 - ann key
  84. -- key2 - spam or ham
  85. -- key3 - maximum trains
  86. -- key4 - sampling coin (as Redis scripts do not allow math.random calls)
  87. -- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
  88. local redis_lua_script_can_store_train_vec = [[
  89. local prefix = KEYS[1]
  90. local locked = redis.call('HGET', prefix, 'lock')
  91. if locked then return {tostring(-1),'locked by another process till: ' .. locked} end
  92. local nspam = 0
  93. local nham = 0
  94. local lim = tonumber(KEYS[3])
  95. local coin = tonumber(KEYS[4])
  96. local ret = redis.call('LLEN', prefix .. '_spam')
  97. if ret then nspam = tonumber(ret) end
  98. ret = redis.call('LLEN', prefix .. '_ham')
  99. if ret then nham = tonumber(ret) end
  100. if KEYS[2] == 'spam' then
  101. if nspam <= lim then
  102. if nspam > nham then
  103. -- Apply sampling
  104. local skip_rate = 1.0 - nham / (nspam + 1)
  105. if coin < skip_rate then
  106. return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)}
  107. end
  108. end
  109. return {tostring(nspam),'can learn'}
  110. else -- Enough learns
  111. return {tostring(-(nspam)),'too many spam samples'}
  112. end
  113. else
  114. if nham <= lim then
  115. if nham > nspam then
  116. -- Apply sampling
  117. local skip_rate = 1.0 - nspam / (nham + 1)
  118. if coin < skip_rate then
  119. return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)}
  120. end
  121. end
  122. return {tostring(nham),'can learn'}
  123. else
  124. return {tostring(-(nham)),'too many ham samples'}
  125. end
  126. end
  127. return {tostring(-1),'bad input'}
  128. ]]
  129. local redis_can_store_train_vec_id = nil
  130. -- Lua script to invalidate ANNs by rank
  131. -- Uses the following keys
  132. -- key1 - prefix for keys
  133. -- key2 - number of elements to leave
  134. local redis_lua_script_maybe_invalidate = [[
  135. local card = redis.call('ZCARD', KEYS[1])
  136. local lim = tonumber(KEYS[2])
  137. if card > lim then
  138. local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
  139. for _,k in ipairs(to_delete) do
  140. local tb = cjson.decode(k)
  141. redis.call('DEL', tb.redis_key)
  142. -- Also train vectors
  143. redis.call('DEL', tb.redis_key .. '_spam')
  144. redis.call('DEL', tb.redis_key .. '_ham')
  145. end
  146. redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
  147. return to_delete
  148. else
  149. return {}
  150. end
  151. ]]
  152. local redis_maybe_invalidate_id = nil
  153. -- Lua script to invalidate ANN from redis
  154. -- Uses the following keys
  155. -- key1 - prefix for keys
  156. -- key2 - current time
  157. -- key3 - key expire
  158. -- key4 - hostname
  159. local redis_lua_script_maybe_lock = [[
  160. local locked = redis.call('HGET', KEYS[1], 'lock')
  161. local now = tonumber(KEYS[2])
  162. if locked then
  163. locked = tonumber(locked)
  164. local expire = tonumber(KEYS[3])
  165. if now > locked and (now - locked) < expire then
  166. return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')}
  167. end
  168. end
  169. redis.call('HSET', KEYS[1], 'lock', tostring(now))
  170. redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
  171. return 1
  172. ]]
  173. local redis_maybe_lock_id = nil
  174. -- Lua script to save and unlock ANN in redis
  175. -- Uses the following keys
  176. -- key1 - prefix for ANN
  177. -- key2 - prefix for profile
  178. -- key3 - compressed ANN
  179. -- key4 - profile as JSON
  180. -- key5 - expire in seconds
  181. -- key6 - current time
  182. -- key7 - old key
  183. local redis_lua_script_save_unlock = [[
  184. local now = tonumber(KEYS[6])
  185. redis.call('ZADD', KEYS[2], now, KEYS[4])
  186. redis.call('HSET', KEYS[1], 'ann', KEYS[3])
  187. redis.call('DEL', KEYS[1] .. '_spam')
  188. redis.call('DEL', KEYS[1] .. '_ham')
  189. redis.call('HDEL', KEYS[1], 'lock')
  190. redis.call('HDEL', KEYS[7], 'lock')
  191. redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
  192. return 1
  193. ]]
  194. local redis_save_unlock_id = nil
  195. local redis_params
  196. local function load_scripts(params)
  197. redis_can_store_train_vec_id = lua_redis.add_redis_script(redis_lua_script_can_store_train_vec,
  198. params)
  199. redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
  200. params)
  201. redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
  202. params)
  203. redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock,
  204. params)
  205. end
  206. local function result_to_vector(task, profile)
  207. if not profile.zeros then
  208. -- Fill zeros vector
  209. local zeros = {}
  210. for i=1,meta_functions.count_metatokens() do
  211. zeros[i] = 0.0
  212. end
  213. for _,_ in ipairs(profile.symbols) do
  214. zeros[#zeros + 1] = 0.0
  215. end
  216. profile.zeros = zeros
  217. end
  218. local vec = lua_util.shallowcopy(profile.zeros)
  219. local mt = meta_functions.rspamd_gen_metatokens(task)
  220. for i,v in ipairs(mt) do
  221. vec[i] = v
  222. end
  223. task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
  224. return vec
  225. end
  226. -- Used to generate new ANN key for specific profile
  227. local function new_ann_key(rule, set, version)
  228. local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
  229. rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
  230. return ann_key
  231. end
  232. -- Extract settings element for a specific settings id
  233. local function get_rule_settings(task, rule)
  234. local sid = task:get_settings_id() or -1
  235. local set = rule.settings[sid]
  236. if not set then return nil end
  237. while type(set) == 'number' do
  238. -- Reference to another settings!
  239. set = rule.settings[set]
  240. end
  241. return set
  242. end
  243. -- Generate redis prefix for specific rule and specific settings
  244. local function redis_ann_prefix(rule, settings_name)
  245. -- We also need to count metatokens:
  246. local n = meta_functions.version
  247. return string.format('%s_%s_%d_%s',
  248. settings.prefix, rule.prefix, n, settings_name)
  249. end
  250. -- Creates and stores ANN profile in Redis
  251. local function new_ann_profile(task, rule, set, version)
  252. local ann_key = new_ann_key(rule, set, version)
  253. local profile = {
  254. symbols = set.symbols,
  255. redis_key = ann_key,
  256. version = version,
  257. digest = set.digest,
  258. distance = 0 -- Since we are using our own profile
  259. }
  260. local ucl = require "ucl"
  261. local profile_serialized = ucl.to_format(profile, 'json-compact', true)
  262. local function add_cb(err, _)
  263. if err then
  264. rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
  265. rule.prefix, set.name, profile.redis_key, err)
  266. else
  267. rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
  268. rule.prefix, set.name, profile.redis_key)
  269. end
  270. end
  271. lua_redis.redis_make_request(task,
  272. rule.redis,
  273. nil,
  274. true, -- is write
  275. add_cb, --callback
  276. 'ZADD', -- command
  277. {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
  278. )
  279. return profile
  280. end
  281. -- ANN filter function, used to insert scores based on the existing symbols
  282. local function ann_scores_filter(task)
  283. for _,rule in pairs(settings.rules) do
  284. local sid = task:get_settings_id() or -1
  285. local ann
  286. local profile
  287. local set = get_rule_settings(task, rule)
  288. if set then
  289. if set.ann then
  290. ann = set.ann.ann
  291. profile = set.ann
  292. else
  293. lua_util.debugm(N, task, 'no ann loaded for %s:%s',
  294. rule.prefix, set.name)
  295. end
  296. else
  297. lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
  298. rule.prefix, sid)
  299. end
  300. if ann then
  301. local vec = result_to_vector(task, profile)
  302. local score
  303. local out = ann:apply1(vec)
  304. score = out[1]
  305. local symscore = string.format('%.3f', score)
  306. lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
  307. rule.prefix, set.name, set.ann.version, symscore)
  308. if score > 0 then
  309. local result = score
  310. task:insert_result(rule.symbol_spam, result, symscore)
  311. else
  312. local result = -(score)
  313. task:insert_result(rule.symbol_ham, result, symscore)
  314. end
  315. end
  316. end
  317. end
  318. local function create_ann(n, nlayers)
  319. -- We ignore number of layers so far when using kann
  320. local nhidden = math.floor((n + 1) / 2)
  321. local t = rspamd_kann.layer.input(n)
  322. t = rspamd_kann.transform.relu(t)
  323. t = rspamd_kann.transform.tanh(rspamd_kann.layer.dense(t, nhidden));
  324. t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.mse)
  325. return rspamd_kann.new.kann(t)
  326. end
  327. local function ann_push_task_result(rule, task, verdict, score, set)
  328. local train_opts = rule.train
  329. local learn_spam, learn_ham
  330. local skip_reason = 'unknown'
  331. if train_opts.autotrain then
  332. if train_opts.spam_score then
  333. learn_spam = score >= train_opts.spam_score
  334. if not learn_spam then
  335. skip_reason = string.format('score < spam_score: %f < %f',
  336. score, train_opts.spam_score)
  337. end
  338. else
  339. learn_spam = verdict == 'spam' or verdict == 'junk'
  340. if not learn_spam then
  341. skip_reason = string.format('verdict: %s',
  342. verdict)
  343. end
  344. end
  345. if train_opts.ham_score then
  346. learn_ham = score <= train_opts.ham_score
  347. if not learn_ham then
  348. skip_reason = string.format('score > ham_score: %f > %f',
  349. score, train_opts.ham_score)
  350. end
  351. else
  352. learn_ham = verdict == 'ham'
  353. if not learn_ham then
  354. skip_reason = string.format('verdict: %s',
  355. verdict)
  356. end
  357. end
  358. else
  359. -- Train by request header
  360. local hdr = task:get_request_header('ANN-Train')
  361. if hdr then
  362. if hdr:lower() == 'spam' then
  363. learn_spam = true
  364. elseif hdr:lower() == 'ham' then
  365. learn_ham = true
  366. else
  367. skip_reason = string.format('no explicit header')
  368. end
  369. end
  370. end
  371. if learn_spam or learn_ham then
  372. local learn_type
  373. if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
  374. local function can_train_cb(err, data)
  375. if not err and type(data) == 'table' then
  376. local nsamples,reason = tonumber(data[1]),data[2]
  377. if nsamples >= 0 then
  378. local coin = math.random()
  379. if coin < 1.0 - train_opts.train_prob then
  380. rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
  381. return
  382. end
  383. local vec = result_to_vector(task, set)
  384. local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
  385. local target_key = set.ann.redis_key .. '_' .. learn_type
  386. local function learn_vec_cb(_err)
  387. if _err then
  388. rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
  389. rule.prefix, set.name, _err)
  390. else
  391. lua_util.debugm(N, task,
  392. "add train data for ANN rule " ..
  393. "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
  394. rule.prefix, set.name, learn_type, #vec, target_key, #str)
  395. end
  396. end
  397. lua_redis.redis_make_request(task,
  398. rule.redis,
  399. nil,
  400. true, -- is write
  401. learn_vec_cb, --callback
  402. 'LPUSH', -- command
  403. { target_key, str } -- arguments
  404. )
  405. else
  406. -- Negative result returned
  407. rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: %s (%s vectors stored)",
  408. learn_type, rule.prefix, set.name, set.ann.redis_key, reason, -tonumber(nsamples))
  409. end
  410. else
  411. if err then
  412. rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
  413. rule.prefix, set.name, err)
  414. else
  415. rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
  416. 'please remove this key from Redis manually if you perform upgrade from the previous version',
  417. rule.prefix, set.name, set.ann.redis_key, type(data))
  418. end
  419. end
  420. end
  421. -- Check if we can learn
  422. if set.can_store_vectors then
  423. if not set.ann then
  424. -- Need to create or load a profile corresponding to the current configuration
  425. set.ann = new_ann_profile(task, rule, set, 0)
  426. lua_util.debugm(N, task,
  427. 'requested new profile for %s, set.ann is missing',
  428. set.name)
  429. end
  430. lua_redis.exec_redis_script(redis_can_store_train_vec_id,
  431. {task = task, is_write = true},
  432. can_train_cb,
  433. {
  434. set.ann.redis_key,
  435. learn_type,
  436. tostring(train_opts.max_trains),
  437. tostring(math.random()),
  438. })
  439. else
  440. lua_util.debugm(N, task,
  441. 'do not push data: train condition not satisfied; reason: not checked existing ANNs')
  442. end
  443. else
  444. lua_util.debugm(N, task,
  445. 'do not push data to key %s: train condition not satisfied; reason: %s',
  446. (set.ann or {}).redis_key,
  447. skip_reason)
  448. end
  449. end
  450. --- Offline training logic
  451. -- Closure generator for unlock function
  452. local function gen_unlock_cb(rule, set, ann_key)
  453. return function (err)
  454. if err then
  455. rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
  456. rule.prefix, set.name, ann_key, err)
  457. else
  458. lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
  459. rule.prefix, set.name, ann_key)
  460. end
  461. end
  462. end
  463. -- This function is intended to extend lock for ANN during training
  464. -- It registers periodic that increases locked key each 30 seconds unless
  465. -- `set.learning_spawned` is set to `true`
  466. local function register_lock_extender(rule, set, ev_base, ann_key)
  467. rspamd_config:add_periodic(ev_base, 30.0,
  468. function()
  469. local function redis_lock_extend_cb(_err, _)
  470. if _err then
  471. rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
  472. ann_key, _err)
  473. else
  474. rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
  475. ann_key)
  476. end
  477. end
  478. if set.learning_spawned then
  479. lua_redis.redis_make_request_taskless(ev_base,
  480. rspamd_config,
  481. rule.redis,
  482. nil,
  483. true, -- is write
  484. redis_lock_extend_cb, --callback
  485. 'HINCRBY', -- command
  486. {ann_key, 'lock', '30'}
  487. )
  488. else
  489. lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
  490. return false -- do not plan any more updates
  491. end
  492. return true
  493. end
  494. )
  495. end
  496. -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
  497. local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
  498. -- Check training data sanity
  499. -- Now we need to join inputs and create the appropriate test vectors
  500. local n = #set.symbols +
  501. meta_functions.rspamd_count_metatokens()
  502. -- Now we can train ann
  503. local train_ann = create_ann(n, 3)
  504. if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
  505. -- Invalidate ANN as it is definitely invalid
  506. -- TODO: add invalidation
  507. assert(false)
  508. else
  509. local inputs, outputs = {}, {}
  510. -- Used to show sparsed vectors in a convenient format (for debugging only)
  511. local function debug_vec(t)
  512. local ret = {}
  513. for i,v in ipairs(t) do
  514. if v ~= 0 then
  515. ret[#ret + 1] = string.format('%d=%.2f', i, v)
  516. end
  517. end
  518. return ret
  519. end
  520. -- Make training set by joining vectors
  521. -- KANN automatically shuffles those samples
  522. -- 1.0 is used for spam and -1.0 is used for ham
  523. -- It implies that output layer can express that (e.g. tanh output)
  524. for _,e in ipairs(spam_vec) do
  525. inputs[#inputs + 1] = e
  526. outputs[#outputs + 1] = {1.0}
  527. --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
  528. end
  529. for _,e in ipairs(ham_vec) do
  530. inputs[#inputs + 1] = e
  531. outputs[#outputs + 1] = {-1.0}
  532. --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
  533. end
  534. -- Called in child process
  535. local function train()
  536. local log_thresh = rule.train.max_iterations / 10
  537. local seen_nan = false
  538. local function train_cb(iter, train_cost, value_cost)
  539. if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
  540. if train_cost ~= train_cost and not seen_nan then
  541. -- We have nan :( try to log lot's of stuff to dig into a problem
  542. seen_nan = true
  543. rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
  544. rule.prefix, set.name,
  545. value_cost)
  546. for i,e in ipairs(inputs) do
  547. lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
  548. debug_vec(e), outputs[i][1])
  549. end
  550. end
  551. rspamd_logger.infox(rspamd_config,
  552. "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
  553. rule.prefix, set.name,
  554. ann_key,
  555. iter,
  556. train_cost,
  557. value_cost)
  558. end
  559. end
  560. train_ann:train1(inputs, outputs, {
  561. lr = rule.train.learning_rate,
  562. max_epoch = rule.train.max_iterations,
  563. cb = train_cb,
  564. })
  565. if not seen_nan then
  566. local out = train_ann:save()
  567. return out
  568. else
  569. return nil
  570. end
  571. end
  572. set.learning_spawned = true
  573. local function redis_save_cb(err)
  574. if err then
  575. rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
  576. rule.prefix, set.name, ann_key, err)
  577. lua_redis.redis_make_request_taskless(ev_base,
  578. rspamd_config,
  579. rule.redis,
  580. nil,
  581. false, -- is write
  582. gen_unlock_cb(rule, set, ann_key), --callback
  583. 'HDEL', -- command
  584. {ann_key, 'lock'}
  585. )
  586. else
  587. rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
  588. rule.prefix, set.name, set.ann.redis_key)
  589. end
  590. end
  591. local function ann_trained(err, data)
  592. set.learning_spawned = false
  593. if err then
  594. rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
  595. rule.prefix, set.name, err)
  596. lua_redis.redis_make_request_taskless(ev_base,
  597. rspamd_config,
  598. rule.redis,
  599. nil,
  600. true, -- is write
  601. gen_unlock_cb(rule, set, ann_key), --callback
  602. 'HDEL', -- command
  603. {ann_key, 'lock'}
  604. )
  605. else
  606. local ann_data = rspamd_util.zstd_compress(data)
  607. if not set.ann then
  608. set.ann = {
  609. symbols = set.symbols,
  610. distance = 0,
  611. digest = set.digest,
  612. redis_key = ann_key,
  613. }
  614. end
  615. -- Deserialise ANN from the child process
  616. ann_trained = rspamd_kann.load(data)
  617. local version = (set.ann.version or 0) + 1
  618. set.ann.version = version
  619. set.ann.ann = ann_trained
  620. set.ann.symbols = set.symbols
  621. set.ann.redis_key = new_ann_key(rule, set, version)
  622. local profile = {
  623. symbols = set.symbols,
  624. digest = set.digest,
  625. redis_key = set.ann.redis_key,
  626. version = version
  627. }
  628. local ucl = require "ucl"
  629. local profile_serialized = ucl.to_format(profile, 'json-compact', true)
  630. rspamd_logger.infox(rspamd_config,
  631. 'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)',
  632. rule.prefix, set.name, #data, set.ann.redis_key, ann_key)
  633. lua_redis.exec_redis_script(redis_save_unlock_id,
  634. {ev_base = ev_base, is_write = true},
  635. redis_save_cb,
  636. {profile.redis_key,
  637. redis_ann_prefix(rule, set.name),
  638. ann_data,
  639. profile_serialized,
  640. tostring(rule.ann_expire),
  641. tostring(os.time()),
  642. ann_key, -- old key to unlock...
  643. })
  644. end
  645. end
  646. worker:spawn_process{
  647. func = train,
  648. on_complete = ann_trained,
  649. proctitle = string.format("ANN train for %s/%s", rule.prefix, set.name),
  650. }
  651. end
  652. -- Spawn learn and register lock extension
  653. set.learning_spawned = true
  654. register_lock_extender(rule, set, ev_base, ann_key)
  655. end
  656. -- Utility to extract and split saved training vectors to a table of tables
  657. local function process_training_vectors(data)
  658. return fun.totable(fun.map(function(tok)
  659. local _,str = rspamd_util.zstd_decompress(tok)
  660. return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';')))
  661. end, data))
  662. end
  663. -- This function does the following:
  664. -- * Tries to lock ANN
  665. -- * Loads spam and ham vectors
  666. -- * Spawn learning process
  667. local function do_train_ann(worker, ev_base, rule, set, ann_key)
  668. local spam_elts = {}
  669. local ham_elts = {}
  670. local function redis_ham_cb(err, data)
  671. if err or type(data) ~= 'table' then
  672. rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
  673. ann_key, err)
  674. -- Unlock on error
  675. lua_redis.redis_make_request_taskless(ev_base,
  676. rspamd_config,
  677. rule.redis,
  678. nil,
  679. true, -- is write
  680. gen_unlock_cb(rule, set, ann_key), --callback
  681. 'HDEL', -- command
  682. {ann_key, 'lock'}
  683. )
  684. else
  685. -- Decompress and convert to numbers each training vector
  686. ham_elts = process_training_vectors(data)
  687. spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts)
  688. end
  689. end
  690. -- Spam vectors received
  691. local function redis_spam_cb(err, data)
  692. if err or type(data) ~= 'table' then
  693. rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
  694. ann_key, err)
  695. -- Unlock ANN on error
  696. lua_redis.redis_make_request_taskless(ev_base,
  697. rspamd_config,
  698. rule.redis,
  699. nil,
  700. true, -- is write
  701. gen_unlock_cb(rule, set, ann_key), --callback
  702. 'HDEL', -- command
  703. {ann_key, 'lock'}
  704. )
  705. else
  706. -- Decompress and convert to numbers each training vector
  707. spam_elts = process_training_vectors(data)
  708. -- Now get ham vectors...
  709. lua_redis.redis_make_request_taskless(ev_base,
  710. rspamd_config,
  711. rule.redis,
  712. nil,
  713. false, -- is write
  714. redis_ham_cb, --callback
  715. 'LRANGE', -- command
  716. {ann_key .. '_ham', '0', '-1'}
  717. )
  718. end
  719. end
  720. local function redis_lock_cb(err, data)
  721. if err then
  722. rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
  723. ann_key, err)
  724. elseif type(data) == 'number' and data == 1 then
  725. -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
  726. lua_redis.redis_make_request_taskless(ev_base,
  727. rspamd_config,
  728. rule.redis,
  729. nil,
  730. false, -- is write
  731. redis_spam_cb, --callback
  732. 'LRANGE', -- command
  733. {ann_key .. '_spam', '0', '-1'}
  734. )
  735. rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
  736. rule.prefix, set.name, ann_key)
  737. else
  738. local lock_tm = tonumber(data[1])
  739. rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
  740. 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
  741. data[2], os.date('%c', lock_tm))
  742. end
  743. end
  744. -- Check if we are already learning this network
  745. if set.learning_spawned then
  746. rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
  747. ann_key)
  748. return
  749. end
  750. -- Call Redis script that tries to acquire a lock
  751. -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
  752. -- ANN is locked by another host (or a process, meh)
  753. lua_redis.exec_redis_script(redis_maybe_lock_id,
  754. {ev_base = ev_base, is_write = true},
  755. redis_lock_cb,
  756. {
  757. ann_key,
  758. tostring(os.time()),
  759. tostring(rule.watch_interval * 2),
  760. rspamd_util.get_hostname()
  761. })
  762. end
  763. -- This function loads new ann from Redis
  764. -- This is based on `profile` attribute.
  765. -- ANN is loaded from `profile.redis_key`
  766. -- Rank of `profile` key is also increased, unfortunately, it means that we need to
  767. -- serialize profile one more time and set its rank to the current time
  768. -- set.ann fields are set according to Redis data received
  769. local function load_new_ann(rule, ev_base, set, profile, min_diff)
  770. local ann_key = profile.redis_key
  771. local function data_cb(err, data)
  772. if err then
  773. rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
  774. ann_key, err)
  775. else
  776. if type(data) == 'string' then
  777. local _err,ann_data = rspamd_util.zstd_decompress(data)
  778. local ann
  779. if _err or not ann_data then
  780. rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
  781. rule.prefix .. ':' .. set.name, ann_key, _err)
  782. return
  783. else
  784. ann = rspamd_kann.load(ann_data)
  785. if ann then
  786. set.ann = {
  787. digest = profile.digest,
  788. version = profile.version,
  789. symbols = profile.symbols,
  790. distance = min_diff,
  791. redis_key = profile.redis_key
  792. }
  793. local ucl = require "ucl"
  794. local profile_serialized = ucl.to_format(profile, 'json-compact', true)
  795. set.ann.ann = ann -- To avoid serialization
  796. local function rank_cb(_, _)
  797. -- TODO: maybe add some logging
  798. end
  799. -- Also update rank for the loaded ANN to avoid removal
  800. lua_redis.redis_make_request_taskless(ev_base,
  801. rspamd_config,
  802. rule.redis,
  803. nil,
  804. true, -- is write
  805. rank_cb, --callback
  806. 'ZADD', -- command
  807. {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
  808. )
  809. rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
  810. rule.prefix, set.name, ann_key, #ann_data, profile.version)
  811. else
  812. rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s',
  813. rule.prefix, set.name, ann_key)
  814. end
  815. end
  816. else
  817. lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s',
  818. rule.prefix, set.name, ann_key)
  819. end
  820. end
  821. end
  822. lua_redis.redis_make_request_taskless(ev_base,
  823. rspamd_config,
  824. rule.redis,
  825. nil,
  826. false, -- is write
  827. data_cb, --callback
  828. 'HGET', -- command
  829. {ann_key, 'ann'} -- arguments
  830. )
  831. end
  832. -- Used to check an element in Redis serialized as JSON
  833. -- for some specific rule + some specific setting
  834. -- This function tries to load more fresh or more specific ANNs in lieu of
  835. -- the existing ones.
  836. -- Use this function to load ANNs as `callback` parameter for `check_anns` function
  837. local function process_existing_ann(_, ev_base, rule, set, profiles)
  838. local my_symbols = set.symbols
  839. local min_diff = math.huge
  840. local sel_elt
  841. for _,elt in fun.iter(profiles) do
  842. if elt and elt.symbols then
  843. local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
  844. -- Check distance
  845. if dist < #my_symbols * .3 then
  846. if dist < min_diff then
  847. min_diff = dist
  848. sel_elt = elt
  849. end
  850. end
  851. end
  852. end
  853. if sel_elt then
  854. -- We can load element from ANN
  855. if set.ann then
  856. -- We have an existing ANN, probably the same...
  857. if set.ann.digest == sel_elt.digest then
  858. -- Same ANN, check version
  859. if set.ann.version < sel_elt.version then
  860. -- Load new ann
  861. rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
  862. 'our version = %s, remote version = %s',
  863. rule.prefix .. ':' .. set.name,
  864. set.ann.version,
  865. sel_elt.version)
  866. load_new_ann(rule, ev_base, set, sel_elt, min_diff)
  867. else
  868. lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
  869. 'our version = %s, remote version = %s',
  870. rule.prefix .. ':' .. set.name,
  871. set.ann.version,
  872. sel_elt.version)
  873. end
  874. else
  875. -- We have some different ANN, so we need to compare distance
  876. if set.ann.distance > min_diff then
  877. -- Load more specific ANN
  878. rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
  879. 'our distance = %s, remote distance = %s',
  880. rule.prefix .. ':' .. set.name,
  881. set.ann.distance,
  882. min_diff)
  883. load_new_ann(rule, ev_base, set, sel_elt, min_diff)
  884. else
  885. lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
  886. 'our distance = %s, remote distance = %s',
  887. rule.prefix .. ':' .. set.name,
  888. set.ann.distance,
  889. min_diff)
  890. end
  891. end
  892. else
  893. -- We have no ANN, load new one
  894. load_new_ann(rule, ev_base, set, sel_elt, min_diff)
  895. end
  896. end
  897. end
  898. -- This function checks all profiles and selects if we can train our
  899. -- ANN. By our we mean that it has exactly the same symbols in profile.
  900. -- Use this function to train ANN as `callback` parameter for `check_anns` function
  901. local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
  902. local my_symbols = set.symbols
  903. local sel_elt
  904. for _,elt in fun.iter(profiles) do
  905. if elt and elt.symbols then
  906. local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
  907. -- Check distance
  908. if dist == 0 then
  909. sel_elt = elt
  910. break
  911. end
  912. end
  913. end
  914. if sel_elt then
  915. -- We have our ANN and that's train vectors, check if we can learn
  916. local ann_key = sel_elt.redis_key
  917. lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
  918. ann_key)
  919. -- Create continuation closure
  920. local redis_len_cb_gen = function(cont_cb, what, is_final)
  921. return function(err, data)
  922. if err then
  923. rspamd_logger.errx(rspamd_config,
  924. 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
  925. elseif data and type(data) == 'number' or type(data) == 'string' then
  926. if tonumber(data) and tonumber(data) >= rule.train.max_trains then
  927. if is_final then
  928. rspamd_logger.debugm(N, rspamd_config,
  929. 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
  930. ann_key, tonumber(data), rule.train.max_trains, what)
  931. else
  932. rspamd_logger.debugm(N, rspamd_config,
  933. 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
  934. what, ann_key, tonumber(data), rule.train.max_trains)
  935. end
  936. cont_cb()
  937. else
  938. rspamd_logger.debugm(N, rspamd_config,
  939. 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
  940. ann_key, what, tonumber(data), rule.train.max_trains)
  941. end
  942. end
  943. end
  944. end
  945. local function initiate_train()
  946. rspamd_logger.infox(rspamd_config,
  947. 'need to learn ANN %s after %s required learn vectors',
  948. ann_key, rule.train.max_trains)
  949. do_train_ann(worker, ev_base, rule, set, ann_key)
  950. end
  951. -- Spam vector is OK, check ham vector length
  952. local function check_ham_len()
  953. lua_redis.redis_make_request_taskless(ev_base,
  954. rspamd_config,
  955. rule.redis,
  956. nil,
  957. false, -- is write
  958. redis_len_cb_gen(initiate_train, 'ham', true), --callback
  959. 'LLEN', -- command
  960. {ann_key .. '_ham'}
  961. )
  962. end
  963. lua_redis.redis_make_request_taskless(ev_base,
  964. rspamd_config,
  965. rule.redis,
  966. nil,
  967. false, -- is write
  968. redis_len_cb_gen(check_ham_len, 'spam', false), --callback
  969. 'LLEN', -- command
  970. {ann_key .. '_spam'}
  971. )
  972. end
  973. end
  974. -- Used to deserialise ANN element from a list
  975. local function load_ann_profile(element)
  976. local ucl = require "ucl"
  977. local parser = ucl.parser()
  978. local res,ucl_err = parser:parse_string(element)
  979. if not res then
  980. rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
  981. ucl_err)
  982. return nil
  983. else
  984. local profile = parser:get_object()
  985. local checked,schema_err = redis_profile_schema:transform(profile)
  986. if not checked then
  987. rspamd_logger.errx(rspamd_config, "cannot parse profile schema: %s", schema_err)
  988. return nil
  989. end
  990. return checked
  991. end
  992. end
  993. -- Function to check or load ANNs from Redis
  994. local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
  995. for _,set in pairs(rule.settings) do
  996. local function members_cb(err, data)
  997. if err then
  998. rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
  999. err)
  1000. set.can_store_vectors = true
  1001. elseif type(data) == 'table' then
  1002. lua_util.debugm(N, cfg, '%s: process element %s:%s',
  1003. what, rule.prefix, set.name)
  1004. process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
  1005. set.can_store_vectors = true
  1006. end
  1007. end
  1008. if type(set) == 'table' then
  1009. -- Extract all profiles for some specific settings id
  1010. -- Get the last `max_profiles` recently used
  1011. -- Select the most appropriate to our profile but it should not differ by more
  1012. -- than 30% of symbols
  1013. lua_redis.redis_make_request_taskless(ev_base,
  1014. cfg,
  1015. rule.redis,
  1016. nil,
  1017. false, -- is write
  1018. members_cb, --callback
  1019. 'ZREVRANGE', -- command
  1020. {set.prefix, '0', tostring(settings.max_profiles)} -- arguments
  1021. )
  1022. end
  1023. end -- Cycle over all settings
  1024. return rule.watch_interval
  1025. end
  1026. -- Function to clean up old ANNs
  1027. local function cleanup_anns(rule, cfg, ev_base)
  1028. for _,set in pairs(rule.settings) do
  1029. local function invalidate_cb(err, data)
  1030. if err then
  1031. rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
  1032. err)
  1033. elseif type(data) == 'table' then
  1034. for _,expired in ipairs(data) do
  1035. local profile = load_ann_profile(expired)
  1036. rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
  1037. rule.prefix .. ':' .. set.name,
  1038. profile.redis_key,
  1039. profile.version)
  1040. end
  1041. end
  1042. end
  1043. if type(set) == 'table' then
  1044. lua_redis.exec_redis_script(redis_maybe_invalidate_id,
  1045. {ev_base = ev_base, is_write = true},
  1046. invalidate_cb,
  1047. {set.prefix, tostring(settings.max_profiles)})
  1048. end
  1049. end
  1050. end
  1051. local function ann_push_vector(task)
  1052. if task:has_flag('skip') then
  1053. lua_util.debugm(N, task, 'do not push data for skipped task')
  1054. return
  1055. end
  1056. if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then
  1057. lua_util.debugm(N, task, 'do not push data for manual scan')
  1058. return
  1059. end
  1060. local verdict,score = lua_verdict.get_specific_verdict(N, task)
  1061. if verdict == 'passthrough' then
  1062. lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
  1063. verdict, score)
  1064. return
  1065. end
  1066. if score ~= score then
  1067. lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
  1068. verdict)
  1069. return
  1070. end
  1071. for _,rule in pairs(settings.rules) do
  1072. local set = get_rule_settings(task, rule)
  1073. if set then
  1074. ann_push_task_result(rule, task, verdict, score, set)
  1075. else
  1076. lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
  1077. end
  1078. end
  1079. end
  1080. -- This function is used to adjust profiles and allowed setting ids for each rule
  1081. -- It must be called when all settings are already registered (e.g. at post-init for config)
  1082. local function process_rules_settings()
  1083. local function process_settings_elt(rule, selt)
  1084. local profile = rule.profile[selt.name]
  1085. if profile then
  1086. -- Use static user defined profile
  1087. -- Ensure that we have an array...
  1088. lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
  1089. rule.prefix, selt.name, profile)
  1090. if not profile[1] then profile = lua_util.keys(profile) end
  1091. selt.symbols = profile
  1092. else
  1093. lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
  1094. rule.prefix, selt.name)
  1095. end
  1096. local function filter_symbols_predicate(sname)
  1097. local fl = rspamd_config:get_symbol_flags(sname)
  1098. if fl then
  1099. fl = lua_util.list_to_hash(fl)
  1100. return not (fl.nostat or fl.idempotent or fl.skip)
  1101. end
  1102. return false
  1103. end
  1104. -- Generic stuff
  1105. table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))
  1106. selt.digest = lua_util.table_digest(selt.symbols)
  1107. selt.prefix = redis_ann_prefix(rule, selt.name)
  1108. lua_redis.register_prefix(selt.prefix, N,
  1109. string.format('NN prefix for rule "%s"; settings id "%s"',
  1110. rule.prefix, selt.name), {
  1111. persistent = true,
  1112. type = 'zlist',
  1113. })
  1114. -- Versions
  1115. lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
  1116. string.format('NN storage for rule "%s"; settings id "%s"',
  1117. rule.prefix, selt.name), {
  1118. persistent = true,
  1119. type = 'hash',
  1120. })
  1121. lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N,
  1122. string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
  1123. rule.prefix, selt.name), {
  1124. persistent = true,
  1125. type = 'list',
  1126. })
  1127. lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N,
  1128. string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
  1129. rule.prefix, selt.name), {
  1130. persistent = true,
  1131. type = 'list',
  1132. })
  1133. end
  1134. for k,rule in pairs(settings.rules) do
  1135. if not rule.allowed_settings then
  1136. rule.allowed_settings = {}
  1137. elseif rule.allowed_settings == 'all' then
  1138. -- Extract all settings ids
  1139. rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
  1140. end
  1141. -- Convert to a map <setting_id> -> true
  1142. rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
  1143. -- Check if we can work without settings
  1144. if k == 'default' or type(rule.default) ~= 'boolean' then
  1145. rule.default = true
  1146. end
  1147. rule.settings = {}
  1148. if rule.default then
  1149. local default_settings = {
  1150. symbols = lua_settings.default_symbols(),
  1151. name = 'default'
  1152. }
  1153. process_settings_elt(rule, default_settings)
  1154. rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
  1155. end
  1156. -- Now, for each allowed settings, we store sorted symbols + digest
  1157. -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
  1158. for s,_ in pairs(rule.allowed_settings) do
  1159. -- Here, we have a name, set of symbols and
  1160. local settings_id = s
  1161. if type(settings_id) ~= 'number' then
  1162. settings_id = lua_settings.numeric_settings_id(s)
  1163. end
  1164. local selt = lua_settings.settings_by_id(settings_id)
  1165. local nelt = {
  1166. symbols = selt.symbols, -- Already sorted
  1167. name = selt.name
  1168. }
  1169. process_settings_elt(rule, nelt)
  1170. for id,ex in pairs(rule.settings) do
  1171. if type(ex) == 'table' then
  1172. if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
  1173. -- Equal symbols, add reference
  1174. lua_util.debugm(N, rspamd_config,
  1175. 'added reference from settings id %s to %s; same symbols',
  1176. nelt.name, ex.name)
  1177. rule.settings[settings_id] = id
  1178. nelt = nil
  1179. end
  1180. end
  1181. end
  1182. if nelt then
  1183. rule.settings[settings_id] = nelt
  1184. lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
  1185. nelt.name, settings_id, rule.prefix)
  1186. end
  1187. end
  1188. end
  1189. end
  1190. redis_params = lua_redis.parse_redis_server('neural')
  1191. if not redis_params then
  1192. redis_params = lua_redis.parse_redis_server('fann_redis')
  1193. end
  1194. -- Initialization part
  1195. if not (module_config and type(module_config) == 'table') or not redis_params then
  1196. rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
  1197. lua_util.disable_module(N, "redis")
  1198. return
  1199. end
  1200. local rules = module_config['rules']
  1201. if not rules then
  1202. -- Use legacy configuration
  1203. rules = {}
  1204. rules['default'] = module_config
  1205. end
  1206. local id = rspamd_config:register_symbol({
  1207. name = 'NEURAL_CHECK',
  1208. type = 'postfilter,nostat',
  1209. priority = 6,
  1210. callback = ann_scores_filter
  1211. })
  1212. settings = lua_util.override_defaults(settings, module_config)
  1213. settings.rules = {} -- Reset unless validated further in the cycle
  1214. -- Check all rules
  1215. for k,r in pairs(rules) do
  1216. local rule_elt = lua_util.override_defaults(default_options, r)
  1217. rule_elt['redis'] = redis_params
  1218. rule_elt['anns'] = {} -- Store ANNs here
  1219. if not rule_elt.prefix then
  1220. rule_elt.prefix = k
  1221. end
  1222. if not rule_elt.name then
  1223. rule_elt.name = k
  1224. end
  1225. if rule_elt.train.max_train then
  1226. rule_elt.train.max_trains = rule_elt.train.max_train
  1227. end
  1228. if not rule_elt.profile then rule_elt.profile = {} end
  1229. rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
  1230. settings.rules[k] = rule_elt
  1231. rspamd_config:set_metric_symbol({
  1232. name = rule_elt.symbol_spam,
  1233. score = 0.0,
  1234. description = 'Neural network SPAM',
  1235. group = 'neural'
  1236. })
  1237. rspamd_config:register_symbol({
  1238. name = rule_elt.symbol_spam,
  1239. type = 'virtual,nostat',
  1240. parent = id
  1241. })
  1242. rspamd_config:set_metric_symbol({
  1243. name = rule_elt.symbol_ham,
  1244. score = -0.0,
  1245. description = 'Neural network HAM',
  1246. group = 'neural'
  1247. })
  1248. rspamd_config:register_symbol({
  1249. name = rule_elt.symbol_ham,
  1250. type = 'virtual,nostat',
  1251. parent = id
  1252. })
  1253. end
  1254. rspamd_config:register_symbol({
  1255. name = 'NEURAL_LEARN',
  1256. type = 'idempotent,nostat,explicit_disable',
  1257. priority = 5,
  1258. callback = ann_push_vector
  1259. })
  1260. -- Add training scripts
  1261. for _,rule in pairs(settings.rules) do
  1262. load_scripts(rule.redis)
  1263. -- We also need to deal with settings
  1264. rspamd_config:add_post_init(process_rules_settings)
  1265. -- This function will check ANNs in Redis when a worker is loaded
  1266. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  1267. if worker:is_scanner() then
  1268. rspamd_config:add_periodic(ev_base, 0.0,
  1269. function(_, _)
  1270. return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
  1271. 'try_load_ann')
  1272. end)
  1273. end
  1274. if worker:is_primary_controller() then
  1275. -- We also want to train neural nets when they have enough data
  1276. rspamd_config:add_periodic(ev_base, 0.0,
  1277. function(_, _)
  1278. -- Clean old ANNs
  1279. cleanup_anns(rule, cfg, ev_base)
  1280. return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
  1281. 'try_train_ann')
  1282. end)
  1283. end
  1284. end)
  1285. end