| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456 | --[[Copyright (c) 2016, Vsevolod Stakhov <vsevolod@highsecure.ru>Licensed under the Apache License, Version 2.0 (the "License");you may not use this file except in compliance with the License.You may obtain a copy of the License at    http://www.apache.org/licenses/LICENSE-2.0Unless required by applicable law or agreed to in writing, softwaredistributed under the License is distributed on an "AS IS" BASIS,WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.See the License for the specific language governing permissions andlimitations under the License.]]--if confighelp then  returnendlocal rspamd_logger = require "rspamd_logger"local rspamd_util = require "rspamd_util"local rspamd_kann = require "rspamd_kann"local lua_redis = require "lua_redis"local lua_util = require "lua_util"local fun = require "fun"local lua_settings = require "lua_settings"local meta_functions = require "lua_meta"local ts = require("tableshape").typeslocal lua_verdict = require "lua_verdict"local N = "neural"-- Module varslocal default_options = {  train = {    max_trains = 1000,    max_epoch = 1000,    max_usages = 10,    max_iterations = 25, -- Torch style    mse = 0.001,    autotrain = true,    train_prob = 1.0,    learn_threads = 1,    learning_rate = 0.01,  },  watch_interval = 60.0,  lock_expire = 600,  learning_spawned = false,  ann_expire = 60 * 60 * 24 * 2, -- 2 days  symbol_spam = 'NEURAL_SPAM',  symbol_ham = 'NEURAL_HAM',}local redis_profile_schema = ts.shape{  digest = ts.string,  symbols = ts.array_of(ts.string),  version = ts.number,  redis_key = ts.string,  distance = ts.number:is_optional(),}-- Rule structure:-- * static config fields (see `default_options`)-- * prefix - name or defined prefix-- * settings - table of settings indexed by settings id, -1 is used when no settings defined-- Rule settings element defines elements for specific settings id:-- * symbols - static symbols profile (defined by config or extracted from symcache)-- * name - name of settings id-- * digest - digest of all symbols-- * ann - dynamic ANN configuration loaded from Redis-- * train - train data for ANN (e.g. the currently trained ANN)-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically-- * version - version of ANN loaded from redis-- * redis_key - name of ANN key in Redis-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)-- * distance - distance between set.symbols and set.ann.symbols-- * ann - kann objectlocal settings = {  rules = {},  prefix = 'rn', -- Neural network default prefix  max_profiles = 3, -- Maximum number of NN profiles stored}local module_config = rspamd_config:get_all_opt("neural")if not module_config then  -- Legacy  module_config = rspamd_config:get_all_opt("fann_redis")end-- Lua script that checks if we can store a new training vector-- Uses the following keys:-- key1 - ann key-- key2 - spam or ham-- key3 - maximum trains-- key4 - sampling coin (as Redis scripts do not allow math.random calls)-- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learnlocal redis_lua_script_can_store_train_vec = [[  local prefix = KEYS[1]  local locked = redis.call('HGET', prefix, 'lock')  if locked then return {tostring(-1),'locked by another process till: ' .. locked} end  local nspam = 0  local nham = 0  local lim = tonumber(KEYS[3])  local coin = tonumber(KEYS[4])  local ret = redis.call('LLEN', prefix .. '_spam')  if ret then nspam = tonumber(ret) end  ret = redis.call('LLEN', prefix .. '_ham')  if ret then nham = tonumber(ret) end  if KEYS[2] == 'spam' then    if nspam <= lim then      if nspam > nham then        -- Apply sampling        local skip_rate = 1.0 - nham / (nspam + 1)        if coin < skip_rate then          return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)}        end      end      return {tostring(nspam),'can learn'}    else -- Enough learns      return {tostring(-(nspam)),'too many spam samples'}    end  else    if nham <= lim then      if nham > nspam then        -- Apply sampling        local skip_rate = 1.0 - nspam / (nham + 1)        if coin < skip_rate then          return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)}        end      end      return {tostring(nham),'can learn'}    else      return {tostring(-(nham)),'too many ham samples'}    end  end  return {tostring(-1),'bad input'}]]local redis_can_store_train_vec_id = nil-- Lua script to invalidate ANNs by rank-- Uses the following keys-- key1 - prefix for keys-- key2 - number of elements to leavelocal redis_lua_script_maybe_invalidate = [[  local card = redis.call('ZCARD', KEYS[1])  local lim = tonumber(KEYS[2])  if card > lim then    local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)    for _,k in ipairs(to_delete) do      local tb = cjson.decode(k)      redis.call('DEL', tb.redis_key)      -- Also train vectors      redis.call('DEL', tb.redis_key .. '_spam')      redis.call('DEL', tb.redis_key .. '_ham')    end    redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)    return to_delete  else    return {}  end]]local redis_maybe_invalidate_id = nil-- Lua script to invalidate ANN from redis-- Uses the following keys-- key1 - prefix for keys-- key2 - current time-- key3 - key expire-- key4 - hostnamelocal redis_lua_script_maybe_lock = [[  local locked = redis.call('HGET', KEYS[1], 'lock')  local now = tonumber(KEYS[2])  if locked then    locked = tonumber(locked)    local expire = tonumber(KEYS[3])    if now > locked and (now - locked) < expire then      return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')}    end  end  redis.call('HSET', KEYS[1], 'lock', tostring(now))  redis.call('HSET', KEYS[1], 'hostname', KEYS[4])  return 1]]local redis_maybe_lock_id = nil-- Lua script to save and unlock ANN in redis-- Uses the following keys-- key1 - prefix for ANN-- key2 - prefix for profile-- key3 - compressed ANN-- key4 - profile as JSON-- key5 - expire in seconds-- key6 - current time-- key7 - old keylocal redis_lua_script_save_unlock = [[  local now = tonumber(KEYS[6])  redis.call('ZADD', KEYS[2], now, KEYS[4])  redis.call('HSET', KEYS[1], 'ann', KEYS[3])  redis.call('DEL', KEYS[1] .. '_spam')  redis.call('DEL', KEYS[1] .. '_ham')  redis.call('HDEL', KEYS[1], 'lock')  redis.call('HDEL', KEYS[7], 'lock')  redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))  return 1]]local redis_save_unlock_id = nillocal redis_paramslocal function load_scripts(params)  redis_can_store_train_vec_id = lua_redis.add_redis_script(redis_lua_script_can_store_train_vec,    params)  redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,    params)  redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock,    params)  redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock,    params)endlocal function result_to_vector(task, profile)  if not profile.zeros then    -- Fill zeros vector    local zeros = {}    for i=1,meta_functions.count_metatokens() do      zeros[i] = 0.0    end    for _,_ in ipairs(profile.symbols) do      zeros[#zeros + 1] = 0.0    end    profile.zeros = zeros  end  local vec = lua_util.shallowcopy(profile.zeros)  local mt = meta_functions.rspamd_gen_metatokens(task)  for i,v in ipairs(mt) do    vec[i] = v  end  task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)  return vecend-- Used to generate new ANN key for specific profilelocal function new_ann_key(rule, set, version)  local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,      rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))  return ann_keyend-- Extract settings element for a specific settings idlocal function get_rule_settings(task, rule)  local sid = task:get_settings_id() or -1  local set = rule.settings[sid]  if not set then return nil end  while type(set) == 'number' do    -- Reference to another settings!    set = rule.settings[set]  end  return setend-- Generate redis prefix for specific rule and specific settingslocal function redis_ann_prefix(rule, settings_name)  -- We also need to count metatokens:  local n = meta_functions.version  return string.format('%s_%s_%d_%s',      settings.prefix, rule.prefix, n, settings_name)end-- Creates and stores ANN profile in Redislocal function new_ann_profile(task, rule, set, version)  local ann_key = new_ann_key(rule, set, version)  local profile = {    symbols = set.symbols,    redis_key = ann_key,    version = version,    digest = set.digest,    distance = 0 -- Since we are using our own profile  }  local ucl = require "ucl"  local profile_serialized = ucl.to_format(profile, 'json-compact', true)  local function add_cb(err, _)    if err then      rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',          rule.prefix, set.name, profile.redis_key, err)    else      rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',          rule.prefix, set.name, profile.redis_key)    end  end  lua_redis.redis_make_request(task,      rule.redis,      nil,      true, -- is write      add_cb, --callback      'ZADD', -- command      {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}  )  return profileend-- ANN filter function, used to insert scores based on the existing symbolslocal function ann_scores_filter(task)  for _,rule in pairs(settings.rules) do    local sid = task:get_settings_id() or -1    local ann    local profile    local set = get_rule_settings(task, rule)    if set then      if set.ann then        ann = set.ann.ann        profile = set.ann      else        lua_util.debugm(N, task, 'no ann loaded for %s:%s',            rule.prefix, set.name)      end    else      lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',          rule.prefix, sid)    end    if ann then      local vec = result_to_vector(task, profile)      local score      local out = ann:apply1(vec)      score = out[1]      local symscore = string.format('%.3f', score)      lua_util.debugm(N, task, '%s:%s:%s ann score: %s',          rule.prefix, set.name, set.ann.version, symscore)      if score > 0 then        local result = score        task:insert_result(rule.symbol_spam, result, symscore)      else        local result = -(score)        task:insert_result(rule.symbol_ham, result, symscore)      end    end  endendlocal function create_ann(n, nlayers)    -- We ignore number of layers so far when using kann  local nhidden = math.floor((n + 1) / 2)  local t = rspamd_kann.layer.input(n)  t = rspamd_kann.transform.relu(t)  t = rspamd_kann.transform.tanh(rspamd_kann.layer.dense(t, nhidden));  t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.mse)  return rspamd_kann.new.kann(t)endlocal function ann_push_task_result(rule, task, verdict, score, set)  local train_opts = rule.train  local learn_spam, learn_ham  local skip_reason = 'unknown'  if train_opts.autotrain then    if train_opts.spam_score then      learn_spam = score >= train_opts.spam_score      if not learn_spam then        skip_reason = string.format('score < spam_score: %f < %f',            score, train_opts.spam_score)      end    else      learn_spam = verdict == 'spam' or verdict == 'junk'      if not learn_spam then        skip_reason = string.format('verdict: %s',            verdict)      end    end    if train_opts.ham_score then      learn_ham = score <= train_opts.ham_score      if not learn_ham then        skip_reason = string.format('score > ham_score: %f > %f',            score, train_opts.ham_score)      end    else      learn_ham = verdict == 'ham'      if not learn_ham then        skip_reason = string.format('verdict: %s',            verdict)      end    end  else    -- Train by request header    local hdr = task:get_request_header('ANN-Train')    if hdr then      if hdr:lower() == 'spam' then        learn_spam = true      elseif hdr:lower() == 'ham' then        learn_ham = true      else        skip_reason = string.format('no explicit header')      end    end  end  if learn_spam or learn_ham then    local learn_type    if learn_spam then learn_type = 'spam' else learn_type = 'ham' end    local function can_train_cb(err, data)      if not err and type(data) == 'table' then        local nsamples,reason = tonumber(data[1]),data[2]        if nsamples >= 0 then          local coin = math.random()          if coin < 1.0 - train_opts.train_prob then            rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)            return          end          local vec = result_to_vector(task, set)          local str = rspamd_util.zstd_compress(table.concat(vec, ';'))          local target_key = set.ann.redis_key .. '_' .. learn_type          local function learn_vec_cb(_err)            if _err then              rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',                  rule.prefix, set.name, _err)            else              lua_util.debugm(N, task,                  "add train data for ANN rule " ..                      "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",                  rule.prefix, set.name, learn_type, #vec, target_key, #str)            end          end          lua_redis.redis_make_request(task,              rule.redis,              nil,              true, -- is write              learn_vec_cb, --callback              'LPUSH', -- command              { target_key, str } -- arguments          )        else          -- Negative result returned          rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: %s (%s vectors stored)",              learn_type, rule.prefix, set.name, set.ann.redis_key, reason, -tonumber(nsamples))        end      else        if err then          rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',              rule.prefix, set.name, err)        else          rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..              'please remove this key from Redis manually if you perform upgrade from the previous version',              rule.prefix, set.name, set.ann.redis_key, type(data))        end      end    end    -- Check if we can learn    if set.can_store_vectors then      if not set.ann then        -- Need to create or load a profile corresponding to the current configuration        set.ann = new_ann_profile(task, rule, set, 0)        lua_util.debugm(N, task,            'requested new profile for %s, set.ann is missing',            set.name)      end      lua_redis.exec_redis_script(redis_can_store_train_vec_id,          {task = task, is_write = true},          can_train_cb,          {            set.ann.redis_key,            learn_type,            tostring(train_opts.max_trains),            tostring(math.random()),          })    else      lua_util.debugm(N, task,          'do not push data: train condition not satisfied; reason: not checked existing ANNs')    end  else    lua_util.debugm(N, task,        'do not push data to key %s: train condition not satisfied; reason: %s',        (set.ann or {}).redis_key,        skip_reason)  endend--- Offline training logic-- Closure generator for unlock functionlocal function gen_unlock_cb(rule, set, ann_key)  return function (err)    if err then      rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',          rule.prefix, set.name, ann_key, err)    else      lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',          rule.prefix, set.name, ann_key)    end  endend-- This function is intended to extend lock for ANN during training-- It registers periodic that increases locked key each 30 seconds unless-- `set.learning_spawned` is set to `true`local function register_lock_extender(rule, set, ev_base, ann_key)  rspamd_config:add_periodic(ev_base, 30.0,      function()        local function redis_lock_extend_cb(_err, _)          if _err then            rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',                ann_key, _err)          else            rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',                ann_key)          end        end        if set.learning_spawned then          lua_redis.redis_make_request_taskless(ev_base,              rspamd_config,              rule.redis,              nil,              true, -- is write              redis_lock_extend_cb, --callback              'HINCRBY', -- command              {ann_key, 'lock', '30'}          )        else          lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")          return false -- do not plan any more updates        end        return true      end  )end-- This function receives training vectors, checks them, spawn learning and saves ANN in Redislocal function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)  -- Check training data sanity  -- Now we need to join inputs and create the appropriate test vectors  local n = #set.symbols +      meta_functions.rspamd_count_metatokens()  -- Now we can train ann  local train_ann = create_ann(n, 3)  if #ham_vec + #spam_vec < rule.train.max_trains / 2 then    -- Invalidate ANN as it is definitely invalid    -- TODO: add invalidation    assert(false)  else    local inputs, outputs = {}, {}    -- Used to show sparsed vectors in a convenient format (for debugging only)    local function debug_vec(t)      local ret = {}      for i,v in ipairs(t) do        if v ~= 0 then          ret[#ret + 1] = string.format('%d=%.2f', i, v)        end      end      return ret    end    -- Make training set by joining vectors    -- KANN automatically shuffles those samples    -- 1.0 is used for spam and -1.0 is used for ham    -- It implies that output layer can express that (e.g. tanh output)    for _,e in ipairs(spam_vec) do      inputs[#inputs + 1] = e      outputs[#outputs + 1] = {1.0}      --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))    end    for _,e in ipairs(ham_vec) do      inputs[#inputs + 1] = e      outputs[#outputs + 1] = {-1.0}      --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))    end    -- Called in child process    local function train()      local log_thresh = rule.train.max_iterations / 10      local seen_nan = false      local function train_cb(iter, train_cost, value_cost)        if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then          if train_cost ~= train_cost and not seen_nan then            -- We have nan :( try to log lot's of stuff to dig into a problem            seen_nan = true            rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',                rule.prefix, set.name,                value_cost)            for i,e in ipairs(inputs) do              lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',                  debug_vec(e), outputs[i][1])            end          end          rspamd_logger.infox(rspamd_config,              "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",              rule.prefix, set.name,              ann_key,              iter,              train_cost,              value_cost)        end      end      train_ann:train1(inputs, outputs, {        lr = rule.train.learning_rate,        max_epoch = rule.train.max_iterations,        cb = train_cb,      })      if not seen_nan then        local out = train_ann:save()        return out      else        return nil      end    end    set.learning_spawned = true    local function redis_save_cb(err)      if err then        rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',            rule.prefix, set.name, ann_key, err)        lua_redis.redis_make_request_taskless(ev_base,            rspamd_config,            rule.redis,            nil,            false, -- is write            gen_unlock_cb(rule, set, ann_key), --callback            'HDEL', -- command            {ann_key, 'lock'}        )      else        rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',            rule.prefix, set.name, set.ann.redis_key)      end    end    local function ann_trained(err, data)      set.learning_spawned = false      if err then        rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',            rule.prefix, set.name, err)        lua_redis.redis_make_request_taskless(ev_base,            rspamd_config,            rule.redis,            nil,            true, -- is write            gen_unlock_cb(rule, set, ann_key), --callback            'HDEL', -- command            {ann_key, 'lock'}        )      else        local ann_data = rspamd_util.zstd_compress(data)        if not set.ann then          set.ann = {            symbols = set.symbols,            distance = 0,            digest = set.digest,            redis_key = ann_key,          }        end        -- Deserialise ANN from the child process        ann_trained = rspamd_kann.load(data)        local version = (set.ann.version or 0) + 1        set.ann.version = version        set.ann.ann = ann_trained        set.ann.symbols = set.symbols        set.ann.redis_key = new_ann_key(rule, set, version)        local profile = {          symbols = set.symbols,          digest = set.digest,          redis_key = set.ann.redis_key,          version = version        }        local ucl = require "ucl"        local profile_serialized = ucl.to_format(profile, 'json-compact', true)        rspamd_logger.infox(rspamd_config,            'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)',            rule.prefix, set.name, #data, set.ann.redis_key, ann_key)        lua_redis.exec_redis_script(redis_save_unlock_id,            {ev_base = ev_base, is_write = true},            redis_save_cb,            {profile.redis_key,             redis_ann_prefix(rule, set.name),             ann_data,             profile_serialized,             tostring(rule.ann_expire),             tostring(os.time()),             ann_key, -- old key to unlock...            })      end    end    worker:spawn_process{      func = train,      on_complete = ann_trained,      proctitle = string.format("ANN train for %s/%s", rule.prefix, set.name),    }  end  -- Spawn learn and register lock extension  set.learning_spawned = true  register_lock_extender(rule, set, ev_base, ann_key)end-- Utility to extract and split saved training vectors to a table of tableslocal function process_training_vectors(data)  return fun.totable(fun.map(function(tok)    local _,str = rspamd_util.zstd_decompress(tok)    return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';')))  end, data))end-- This function does the following:-- * Tries to lock ANN-- * Loads spam and ham vectors-- * Spawn learning processlocal function do_train_ann(worker, ev_base, rule, set, ann_key)  local spam_elts = {}  local ham_elts = {}  local function redis_ham_cb(err, data)    if err or type(data) ~= 'table' then      rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',        ann_key, err)      -- Unlock on error      lua_redis.redis_make_request_taskless(ev_base,        rspamd_config,        rule.redis,        nil,        true, -- is write          gen_unlock_cb(rule, set, ann_key), --callback        'HDEL', -- command        {ann_key, 'lock'}      )    else      -- Decompress and convert to numbers each training vector      ham_elts = process_training_vectors(data)      spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts)    end  end  -- Spam vectors received  local function redis_spam_cb(err, data)    if err or type(data) ~= 'table' then      rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',        ann_key, err)      -- Unlock ANN on error      lua_redis.redis_make_request_taskless(ev_base,        rspamd_config,        rule.redis,        nil,        true, -- is write          gen_unlock_cb(rule, set, ann_key), --callback        'HDEL', -- command        {ann_key, 'lock'}      )    else      -- Decompress and convert to numbers each training vector      spam_elts = process_training_vectors(data)      -- Now get ham vectors...      lua_redis.redis_make_request_taskless(ev_base,        rspamd_config,        rule.redis,        nil,        false, -- is write        redis_ham_cb, --callback        'LRANGE', -- command        {ann_key .. '_ham', '0', '-1'}      )    end  end  local function redis_lock_cb(err, data)    if err then      rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',        ann_key, err)    elseif type(data) == 'number' and data == 1 then      -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning      lua_redis.redis_make_request_taskless(ev_base,        rspamd_config,        rule.redis,        nil,        false, -- is write        redis_spam_cb, --callback        'LRANGE', -- command        {ann_key .. '_spam', '0', '-1'}      )      rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',        rule.prefix, set.name, ann_key)    else      local lock_tm = tonumber(data[1])      rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..          'locked by another host %s at %s', rule.prefix, set.name, ann_key,          data[2], os.date('%c', lock_tm))    end  end  -- Check if we are already learning this network  if set.learning_spawned then    rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',        ann_key)    return  end  -- Call Redis script that tries to acquire a lock  -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when  -- ANN is locked by another host (or a process, meh)  lua_redis.exec_redis_script(redis_maybe_lock_id,    {ev_base = ev_base, is_write = true},    redis_lock_cb,      {        ann_key,        tostring(os.time()),        tostring(rule.watch_interval * 2),        rspamd_util.get_hostname()    })end-- This function loads new ann from Redis-- This is based on `profile` attribute.-- ANN is loaded from `profile.redis_key`-- Rank of `profile` key is also increased, unfortunately, it means that we need to-- serialize profile one more time and set its rank to the current time-- set.ann fields are set according to Redis data receivedlocal function load_new_ann(rule, ev_base, set, profile, min_diff)  local ann_key = profile.redis_key  local function data_cb(err, data)    if err then      rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',          ann_key, err)    else      if type(data) == 'string' then        local _err,ann_data = rspamd_util.zstd_decompress(data)        local ann        if _err or not ann_data then          rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',              rule.prefix .. ':' .. set.name, ann_key, _err)          return        else          ann = rspamd_kann.load(ann_data)          if ann then            set.ann = {              digest = profile.digest,              version = profile.version,              symbols = profile.symbols,              distance = min_diff,              redis_key = profile.redis_key            }            local ucl = require "ucl"            local profile_serialized = ucl.to_format(profile, 'json-compact', true)            set.ann.ann = ann -- To avoid serialization            local function rank_cb(_, _)              -- TODO: maybe add some logging            end            -- Also update rank for the loaded ANN to avoid removal            lua_redis.redis_make_request_taskless(ev_base,                rspamd_config,                rule.redis,                nil,                true, -- is write                rank_cb, --callback                'ZADD', -- command                {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}            )            rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',                rule.prefix, set.name, ann_key, #ann_data, profile.version)          else            rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s',                rule.prefix, set.name, ann_key)          end        end      else        lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s',            rule.prefix, set.name, ann_key)      end    end  end  lua_redis.redis_make_request_taskless(ev_base,      rspamd_config,      rule.redis,      nil,      false, -- is write      data_cb, --callback      'HGET', -- command      {ann_key, 'ann'} -- arguments  )end-- Used to check an element in Redis serialized as JSON-- for some specific rule + some specific setting-- This function tries to load more fresh or more specific ANNs in lieu of-- the existing ones.-- Use this function to load ANNs as `callback` parameter for `check_anns` functionlocal function process_existing_ann(_, ev_base, rule, set, profiles)  local my_symbols = set.symbols  local min_diff = math.huge  local sel_elt  for _,elt in fun.iter(profiles) do    if elt and elt.symbols then      local dist = lua_util.distance_sorted(elt.symbols, my_symbols)      -- Check distance      if dist < #my_symbols * .3 then        if dist < min_diff then          min_diff = dist          sel_elt = elt        end      end    end  end  if sel_elt then    -- We can load element from ANN    if set.ann then      -- We have an existing ANN, probably the same...      if set.ann.digest == sel_elt.digest then        -- Same ANN, check version        if set.ann.version < sel_elt.version then          -- Load new ann          rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..              'our version = %s, remote version = %s',              rule.prefix .. ':' .. set.name,              set.ann.version,              sel_elt.version)          load_new_ann(rule, ev_base, set, sel_elt, min_diff)        else          lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..              'our version = %s, remote version = %s',              rule.prefix .. ':' .. set.name,              set.ann.version,              sel_elt.version)        end      else        -- We have some different ANN, so we need to compare distance        if set.ann.distance > min_diff then          -- Load more specific ANN          rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..              'our distance = %s, remote distance = %s',              rule.prefix .. ':' .. set.name,              set.ann.distance,              min_diff)          load_new_ann(rule, ev_base, set, sel_elt, min_diff)        else          lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..              'our distance = %s, remote distance = %s',              rule.prefix .. ':' .. set.name,              set.ann.distance,              min_diff)        end      end    else      -- We have no ANN, load new one      load_new_ann(rule, ev_base, set, sel_elt, min_diff)    end  endend-- This function checks all profiles and selects if we can train our-- ANN. By our we mean that it has exactly the same symbols in profile.-- Use this function to train ANN as `callback` parameter for `check_anns` functionlocal function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)  local my_symbols = set.symbols  local sel_elt  for _,elt in fun.iter(profiles) do    if elt and elt.symbols then      local dist = lua_util.distance_sorted(elt.symbols, my_symbols)      -- Check distance      if dist == 0 then        sel_elt = elt        break      end    end  end  if sel_elt then    -- We have our ANN and that's train vectors, check if we can learn    local ann_key = sel_elt.redis_key    lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",        ann_key)    -- Create continuation closure    local redis_len_cb_gen = function(cont_cb, what, is_final)      return function(err, data)        if err then          rspamd_logger.errx(rspamd_config,              'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)        elseif data and type(data) == 'number' or type(data) == 'string' then          if tonumber(data) and tonumber(data) >= rule.train.max_trains then            if is_final then              rspamd_logger.debugm(N, rspamd_config,                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',                  ann_key, tonumber(data), rule.train.max_trains, what)            else              rspamd_logger.debugm(N, rspamd_config,                  'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',                  what, ann_key, tonumber(data), rule.train.max_trains)            end            cont_cb()          else            rspamd_logger.debugm(N, rspamd_config,                'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',                ann_key, what, tonumber(data), rule.train.max_trains)          end        end      end    end    local function initiate_train()      rspamd_logger.infox(rspamd_config,          'need to learn ANN %s after %s required learn vectors',          ann_key, rule.train.max_trains)      do_train_ann(worker, ev_base, rule, set, ann_key)    end    -- Spam vector is OK, check ham vector length    local function check_ham_len()      lua_redis.redis_make_request_taskless(ev_base,          rspamd_config,          rule.redis,          nil,          false, -- is write          redis_len_cb_gen(initiate_train, 'ham', true), --callback          'LLEN', -- command          {ann_key .. '_ham'}      )    end    lua_redis.redis_make_request_taskless(ev_base,        rspamd_config,        rule.redis,        nil,        false, -- is write        redis_len_cb_gen(check_ham_len, 'spam', false), --callback        'LLEN', -- command        {ann_key .. '_spam'}    )  endend-- Used to deserialise ANN element from a listlocal function load_ann_profile(element)  local ucl = require "ucl"  local parser = ucl.parser()  local res,ucl_err = parser:parse_string(element)  if not res then    rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',        ucl_err)    return nil  else    local profile = parser:get_object()    local checked,schema_err = redis_profile_schema:transform(profile)    if not checked then      rspamd_logger.errx(rspamd_config, "cannot parse profile schema: %s", schema_err)      return nil    end    return checked  endend-- Function to check or load ANNs from Redislocal function check_anns(worker, cfg, ev_base, rule, process_callback, what)  for _,set in pairs(rule.settings) do    local function members_cb(err, data)      if err then        rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',            err)        set.can_store_vectors = true      elseif type(data) == 'table' then        lua_util.debugm(N, cfg, '%s: process element %s:%s',            what, rule.prefix, set.name)        process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))        set.can_store_vectors = true      end    end    if type(set) == 'table' then      -- Extract all profiles for some specific settings id      -- Get the last `max_profiles` recently used      -- Select the most appropriate to our profile but it should not differ by more      -- than 30% of symbols      lua_redis.redis_make_request_taskless(ev_base,          cfg,          rule.redis,          nil,          false, -- is write          members_cb, --callback          'ZREVRANGE', -- command          {set.prefix, '0', tostring(settings.max_profiles)} -- arguments      )    end  end -- Cycle over all settings  return rule.watch_intervalend-- Function to clean up old ANNslocal function cleanup_anns(rule, cfg, ev_base)  for _,set in pairs(rule.settings) do    local function invalidate_cb(err, data)      if err then        rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',            err)      elseif type(data) == 'table' then        for _,expired in ipairs(data) do          local profile = load_ann_profile(expired)          rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',              rule.prefix .. ':' .. set.name,              profile.redis_key,              profile.version)        end      end    end    if type(set) == 'table' then      lua_redis.exec_redis_script(redis_maybe_invalidate_id,          {ev_base = ev_base, is_write = true},          invalidate_cb,          {set.prefix, tostring(settings.max_profiles)})    end  endendlocal function ann_push_vector(task)  if task:has_flag('skip') then    lua_util.debugm(N, task, 'do not push data for skipped task')    return  end  if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then    lua_util.debugm(N, task, 'do not push data for manual scan')    return  end  local verdict,score = lua_verdict.get_specific_verdict(N, task)  if verdict == 'passthrough' then    lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',        verdict, score)    return  end  if score ~= score then    lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',        verdict)    return  end  for _,rule in pairs(settings.rules) do    local set = get_rule_settings(task, rule)    if set then      ann_push_task_result(rule, task, verdict, score, set)    else      lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)    end  endend-- This function is used to adjust profiles and allowed setting ids for each rule-- It must be called when all settings are already registered (e.g. at post-init for config)local function process_rules_settings()  local function process_settings_elt(rule, selt)    local profile = rule.profile[selt.name]    if profile then      -- Use static user defined profile      -- Ensure that we have an array...      lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",          rule.prefix, selt.name, profile)      if not profile[1] then profile = lua_util.keys(profile) end      selt.symbols = profile    else      lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",          rule.prefix, selt.name)    end    local function filter_symbols_predicate(sname)      local fl = rspamd_config:get_symbol_flags(sname)      if fl then        fl = lua_util.list_to_hash(fl)        return not (fl.nostat or fl.idempotent or fl.skip)      end      return false    end    -- Generic stuff    table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))    selt.digest = lua_util.table_digest(selt.symbols)    selt.prefix = redis_ann_prefix(rule, selt.name)    lua_redis.register_prefix(selt.prefix, N,        string.format('NN prefix for rule "%s"; settings id "%s"',            rule.prefix, selt.name), {          persistent = true,          type = 'zlist',        })    -- Versions    lua_redis.register_prefix(selt.prefix .. '_\\d+', N,        string.format('NN storage for rule "%s"; settings id "%s"',            rule.prefix, selt.name), {          persistent = true,          type = 'hash',        })    lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N,        string.format('NN learning set (spam) for rule "%s"; settings id "%s"',            rule.prefix, selt.name), {          persistent = true,          type = 'list',        })    lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N,        string.format('NN learning set (spam) for rule "%s"; settings id "%s"',            rule.prefix, selt.name), {          persistent = true,          type = 'list',        })  end  for k,rule in pairs(settings.rules) do    if not rule.allowed_settings then      rule.allowed_settings = {}    elseif rule.allowed_settings == 'all' then      -- Extract all settings ids      rule.allowed_settings = lua_util.keys(lua_settings.all_settings())    end    -- Convert to a map <setting_id> -> true    rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)    -- Check if we can work without settings    if k == 'default' or type(rule.default) ~= 'boolean' then      rule.default = true    end    rule.settings = {}    if rule.default then      local default_settings = {        symbols = lua_settings.default_symbols(),        name = 'default'      }      process_settings_elt(rule, default_settings)      rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32    end    -- Now, for each allowed settings, we store sorted symbols + digest    -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }    for s,_ in pairs(rule.allowed_settings) do      -- Here, we have a name, set of symbols and      local settings_id = s      if type(settings_id) ~= 'number' then        settings_id = lua_settings.numeric_settings_id(s)      end      local selt = lua_settings.settings_by_id(settings_id)      local nelt = {        symbols = selt.symbols, -- Already sorted        name = selt.name      }      process_settings_elt(rule, nelt)      for id,ex in pairs(rule.settings) do        if type(ex) == 'table' then          if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then            -- Equal symbols, add reference            lua_util.debugm(N, rspamd_config,                'added reference from settings id %s to %s; same symbols',                nelt.name, ex.name)            rule.settings[settings_id] = id            nelt = nil          end        end      end      if nelt then        rule.settings[settings_id] = nelt        lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',            nelt.name, settings_id, rule.prefix)      end    end  endendredis_params = lua_redis.parse_redis_server('neural')if not redis_params then  redis_params = lua_redis.parse_redis_server('fann_redis')end-- Initialization partif not (module_config and type(module_config) == 'table') or not redis_params then  rspamd_logger.infox(rspamd_config, 'Module is unconfigured')  lua_util.disable_module(N, "redis")  returnendlocal rules = module_config['rules']if not rules then  -- Use legacy configuration  rules = {}  rules['default'] = module_configendlocal id = rspamd_config:register_symbol({  name = 'NEURAL_CHECK',  type = 'postfilter,nostat',  priority = 6,  callback = ann_scores_filter})settings = lua_util.override_defaults(settings, module_config)settings.rules = {} -- Reset unless validated further in the cycle-- Check all rulesfor k,r in pairs(rules) do  local rule_elt = lua_util.override_defaults(default_options, r)  rule_elt['redis'] = redis_params  rule_elt['anns'] = {} -- Store ANNs here  if not rule_elt.prefix then    rule_elt.prefix = k  end  if not rule_elt.name then    rule_elt.name = k  end  if rule_elt.train.max_train then    rule_elt.train.max_trains = rule_elt.train.max_train  end  if not rule_elt.profile then rule_elt.profile = {} end  rspamd_logger.infox(rspamd_config, "register ann rule %s", k)  settings.rules[k] = rule_elt  rspamd_config:set_metric_symbol({    name = rule_elt.symbol_spam,    score = 0.0,    description = 'Neural network SPAM',    group = 'neural'  })  rspamd_config:register_symbol({    name = rule_elt.symbol_spam,    type = 'virtual,nostat',    parent = id  })  rspamd_config:set_metric_symbol({    name = rule_elt.symbol_ham,    score = -0.0,    description = 'Neural network HAM',    group = 'neural'  })  rspamd_config:register_symbol({    name = rule_elt.symbol_ham,    type = 'virtual,nostat',    parent = id  })endrspamd_config:register_symbol({  name = 'NEURAL_LEARN',  type = 'idempotent,nostat,explicit_disable',  priority = 5,  callback = ann_push_vector})-- Add training scriptsfor _,rule in pairs(settings.rules) do  load_scripts(rule.redis)  -- We also need to deal with settings  rspamd_config:add_post_init(process_rules_settings)  -- This function will check ANNs in Redis when a worker is loaded  rspamd_config:add_on_load(function(cfg, ev_base, worker)    if worker:is_scanner() then      rspamd_config:add_periodic(ev_base, 0.0,          function(_, _)            return check_anns(worker, cfg, ev_base, rule, process_existing_ann,                'try_load_ann')          end)    end    if worker:is_primary_controller() then      -- We also want to train neural nets when they have enough data      rspamd_config:add_periodic(ev_base, 0.0,          function(_, _)            -- Clean old ANNs            cleanup_anns(rule, cfg, ev_base)            return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,                'try_train_ann')          end)    end  end)end
 |