mqttagent

MQTT Lua Agent
git clone https://git.instinctive.eu/mqttagent.git
Log | Files | Refs | README | LICENSE

main.go (9879B)


      1 /*
      2  * Copyright (c) 2025, Natacha Porté
      3  *
      4  * Permission to use, copy, modify, and distribute this software for any
      5  * purpose with or without fee is hereby granted, provided that the above
      6  * copyright notice and this permission notice appear in all copies.
      7  *
      8  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
      9  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
     10  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
     11  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
     12  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
     13  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
     14  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
     15  */
     16 
     17 package main
     18 
     19 import (
     20 	"database/sql"
     21 	"log"
     22 	"net/http"
     23 	"net/url"
     24 	"os"
     25 	"runtime/debug"
     26 	"strings"
     27 	"time"
     28 
     29 	"github.com/cjoudrey/gluahttp"
     30 	_ "github.com/glebarez/go-sqlite"
     31 	luajson "github.com/layeh/gopher-json"
     32 	"github.com/yuin/gopher-lua"
     33 	"instinctive.eu/go/mqttagent"
     34 )
     35 
     36 type fullMqttAgent struct {
     37 	loggers    map[string]*sqlogger
     38 	oldLoggers map[string]*sqlogger
     39 }
     40 
     41 func (agent *fullMqttAgent) Setup(L *lua.LState) {
     42 	luajson.Preload(L)
     43 	L.PreloadModule("http", gluahttp.NewHttpModule(&http.Client{}).Loader)
     44 	L.SetGlobal("urlencode", L.NewFunction(luaUrlEncode))
     45 	setBuildInfo(L, "buildinfo")
     46 	setVersion(L, "version")
     47 	agent.loggers = make(map[string]*sqlogger)
     48 
     49 	mt := L.NewTypeMetatable("sqlogger")
     50 	L.SetGlobal("sqlogger", mt)
     51 	L.SetField(mt, "new", L.NewFunction(func(L *lua.LState) int {
     52 		return luaSqloggerNew(L, agent)
     53 	}))
     54 	L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), luaSqloggerMethods))
     55 }
     56 
     57 func (agent *fullMqttAgent) Teardown(L *lua.LState) {
     58 	if agent.oldLoggers != nil {
     59 		panic("Unexpected state")
     60 	}
     61 	for _, logger := range agent.loggers {
     62 		logger.Close()
     63 	}
     64 	agent.loggers = nil
     65 }
     66 
     67 func (agent *fullMqttAgent) ReloadBegin(oldL, newL *lua.LState) {
     68 	if agent.oldLoggers != nil {
     69 		panic("Unexpected state")
     70 	}
     71 	agent.oldLoggers = agent.loggers
     72 	agent.Setup(newL)
     73 }
     74 
     75 func (agent *fullMqttAgent) ReloadAbort(oldL, newL *lua.LState) {
     76 	for key, logger := range agent.loggers {
     77 		if _, found := agent.oldLoggers[key]; !found {
     78 			logger.Close()
     79 		}
     80 	}
     81 	agent.loggers = agent.oldLoggers
     82 	agent.oldLoggers = nil
     83 }
     84 
     85 func (agent *fullMqttAgent) ReloadEnd(oldL, newL *lua.LState) {
     86 	for key, logger := range agent.oldLoggers {
     87 		if _, found := agent.loggers[key]; !found {
     88 			logger.Close()
     89 		}
     90 	}
     91 	agent.oldLoggers = nil
     92 }
     93 
     94 func luaUrlEncode(L *lua.LState) int {
     95 	tbl := L.CheckTable(1)
     96 	result := make([]string, 0, tbl.Len())
     97 
     98 	L.ForEach(tbl, func(key, value lua.LValue) {
     99 		skey := url.QueryEscape(lua.LVAsString(key))
    100 		sval := url.QueryEscape(lua.LVAsString(value))
    101 		result = append(result, skey+"="+sval)
    102 	})
    103 
    104 	L.Push(lua.LString(strings.Join(result, "&")))
    105 	return 1
    106 }
    107 
    108 type sqlogger struct {
    109 	db             *sql.DB
    110 	insertReceived *sql.Stmt
    111 	insertSent     *sql.Stmt
    112 }
    113 
    114 func logErr(context string, err error) {
    115 	if err != nil {
    116 		log.Println(context, err)
    117 	}
    118 }
    119 
    120 func (logger *sqlogger) Close() {
    121 	if logger == nil {
    122 		return
    123 	}
    124 
    125 	if logger.insertReceived != nil {
    126 		logErr("Close insertReceived:", logger.insertReceived.Close())
    127 	}
    128 
    129 	if logger.insertSent != nil {
    130 		logErr("Close insertSend:", logger.insertSent.Close())
    131 	}
    132 
    133 	if logger.db != nil {
    134 		logErr("Close DB:", logger.db.Close())
    135 	}
    136 }
    137 
    138 func connect(connectionString string) (*sqlogger, error) {
    139 	db, err := sql.Open("sqlite", connectionString)
    140 	if err != nil {
    141 		return nil, err
    142 	}
    143 
    144 	for _, cmd := range []string{
    145 		"CREATE TABLE IF NOT EXISTS topics" +
    146 			"(id INTEGER PRIMARY KEY AUTOINCREMENT," +
    147 			" name TEXT NOT NULL);",
    148 		"CREATE UNIQUE INDEX IF NOT EXISTS i_topics ON topics(name);",
    149 		"CREATE TABLE IF NOT EXISTS received" +
    150 			"(timestamp REAL NOT NULL," +
    151 			" topic_id INTEGER NOT NULL," +
    152 			" message TEXT NOT NULL," +
    153 			" FOREIGN KEY (topic_id) REFERENCES topics (id));",
    154 		"CREATE TABLE IF NOT EXISTS sent" +
    155 			"(timestamp REAL NOT NULL," +
    156 			" topic_id INTEGER NOT NULL," +
    157 			" message TEXT NOT NULL," +
    158 			" FOREIGN KEY (topic_id) REFERENCES topics (id));",
    159 		"CREATE INDEX IF NOT EXISTS i_rtime ON received(timestamp);",
    160 		"CREATE INDEX IF NOT EXISTS i_rtopicid ON received(topic_id);",
    161 		"CREATE INDEX IF NOT EXISTS i_stime ON received(timestamp);",
    162 		"CREATE INDEX IF NOT EXISTS i_stopicid ON received(topic_id);",
    163 		"CREATE VIEW IF NOT EXISTS receivedf" +
    164 			"(timestamp,topic,message)" +
    165 			" AS SELECT datetime(timestamp),topics.name,message" +
    166 			" FROM received LEFT OUTER JOIN topics" +
    167 			" ON topics.id = topic_id;",
    168 		"CREATE VIEW IF NOT EXISTS sentf" +
    169 			"(timestamp,topic,message)" +
    170 			" AS SELECT datetime(timestamp),topics.name,message" +
    171 			" FROM sent LEFT OUTER JOIN topics" +
    172 			" ON topics.id = topic_id;",
    173 		"CREATE TRIGGER IF NOT EXISTS insert_received" +
    174 			" INSTEAD OF INSERT ON receivedf BEGIN" +
    175 			" INSERT INTO topics(name)" +
    176 			" SELECT NEW.topic WHERE NOT EXISTS" +
    177 			" (SELECT 1 FROM topics WHERE name = NEW.topic);" +
    178 			" INSERT INTO received(timestamp,topic_id,message)" +
    179 			" VALUES (NEW.timestamp," +
    180 			" (SELECT id FROM topics WHERE name = NEW.topic)," +
    181 			" NEW.message); END;",
    182 		"CREATE TRIGGER IF NOT EXISTS insert_sent" +
    183 			" INSTEAD OF INSERT ON sentf BEGIN" +
    184 			" INSERT INTO topics(name)" +
    185 			" SELECT NEW.topic WHERE NOT EXISTS" +
    186 			" (SELECT 1 FROM topics WHERE name = NEW.topic);" +
    187 			" INSERT INTO sent(timestamp,topic_id,message)" +
    188 			" VALUES (NEW.timestamp," +
    189 			" (SELECT id FROM topics WHERE name = NEW.topic)," +
    190 			" NEW.message); END;",
    191 	} {
    192 		if _, err = db.Exec(cmd); err != nil {
    193 			logErr("Close DB:", db.Close())
    194 			return nil, err
    195 		}
    196 	}
    197 
    198 	s1, err := db.Prepare("INSERT INTO receivedf(timestamp,topic,message)" +
    199 		" VALUES (?,?,?);")
    200 	if err != nil {
    201 		logErr("Close DB:", db.Close())
    202 		return nil, err
    203 	}
    204 
    205 	s2, err := db.Prepare("INSERT INTO sentf(timestamp,topic,message)" +
    206 		" VALUES (?,?,?);")
    207 	if err != nil {
    208 		logErr("Close insertReceived:", s1.Close())
    209 		logErr("Close DB:", db.Close())
    210 		return nil, err
    211 	}
    212 
    213 	return &sqlogger{db: db, insertReceived: s1, insertSent: s2}, nil
    214 }
    215 
    216 func checkSqlogger(L *lua.LState, index int) *sqlogger {
    217 	ud := L.CheckUserData(index)
    218 	if v, ok := ud.Value.(*sqlogger); ok {
    219 		return v
    220 	}
    221 	L.ArgError(index, "sqlogger expected")
    222 	return nil
    223 }
    224 
    225 func luaSqloggerNew(L *lua.LState, agent *fullMqttAgent) int {
    226 	arg := L.CheckString(1)
    227 	if logger, found := agent.loggers[arg]; found {
    228 		ud := L.NewUserData()
    229 		ud.Value = logger
    230 		L.SetMetatable(ud, L.GetTypeMetatable("sqlogger"))
    231 		L.Push(ud)
    232 		return 1
    233 	} else if logger, found := agent.oldLoggers[arg]; found {
    234 		agent.loggers[arg] = logger
    235 		ud := L.NewUserData()
    236 		ud.Value = logger
    237 		L.SetMetatable(ud, L.GetTypeMetatable("sqlogger"))
    238 		L.Push(ud)
    239 		return 1
    240 	} else if logger, err := connect(arg); err != nil {
    241 		log.Println(err)
    242 		L.Push(lua.LNil)
    243 		L.Push(lua.LString(err.Error()))
    244 		return 2
    245 	} else {
    246 		agent.loggers[arg] = logger
    247 		ud := L.NewUserData()
    248 		ud.Value = logger
    249 		L.SetMetatable(ud, L.GetTypeMetatable("sqlogger"))
    250 		L.Push(ud)
    251 		return 1
    252 	}
    253 }
    254 
    255 func luaSqloggerInsert(L *lua.LState, stmt *sql.Stmt) int {
    256 	message := L.CheckString(2)
    257 	topic := L.CheckString(3)
    258 	timestamp := L.OptNumber(4, lua.LNumber(time.Now().UnixMicro())*1.0e-6)
    259 	julian := (float64(timestamp) / 86400.0) + 2440587.5
    260 
    261 	if _, err := stmt.Exec(julian, topic, message); err != nil {
    262 		log.Println(err)
    263 	}
    264 
    265 	return 0
    266 }
    267 
    268 func luaSqloggerReceived(L *lua.LState) int {
    269 	logger := checkSqlogger(L, 1)
    270 	return luaSqloggerInsert(L, logger.insertReceived)
    271 }
    272 
    273 func luaSqloggerSent(L *lua.LState) int {
    274 	logger := checkSqlogger(L, 1)
    275 	return luaSqloggerInsert(L, logger.insertSent)
    276 }
    277 
    278 var luaSqloggerMethods = map[string]lua.LGFunction{
    279 	"received": luaSqloggerReceived,
    280 	"sent":     luaSqloggerSent,
    281 }
    282 
    283 func setBuildInfo(L *lua.LState, name string) {
    284 	info, ok := debug.ReadBuildInfo()
    285 	if !ok {
    286 		return
    287 	}
    288 
    289 	L.SetGlobal(name, luaBuildInfo(L, info))
    290 }
    291 
    292 func setVersion(L *lua.LState, name string) {
    293 	info, ok := debug.ReadBuildInfo()
    294 	if !ok {
    295 		return
    296 	}
    297 
    298 	version := info.Main.Version
    299 
    300 	if version == "(devel)" {
    301 		vcs := ""
    302 		rev := ""
    303 		dirty := ""
    304 		for _, setting := range info.Settings {
    305 			switch setting.Key {
    306 			case "vcs":
    307 				vcs = setting.Value + "-"
    308 			case "vcs.revision":
    309 				rev = setting.Value[0:8]
    310 			case "vcs.modified":
    311 				if setting.Value == "true" {
    312 					dirty = "*"
    313 				}
    314 			}
    315 		}
    316 
    317 		if rev != "" {
    318 			version = vcs + rev + dirty
    319 		}
    320 	}
    321 
    322 	L.SetGlobal(name, lua.LString(version))
    323 }
    324 
    325 func luaBuildInfo(L *lua.LState, info *debug.BuildInfo) lua.LValue {
    326 	tbl := L.NewTable()
    327 	tbl.RawSetString("go_version", lua.LString(info.GoVersion))
    328 	tbl.RawSetString("path", lua.LString(info.Path))
    329 	tbl.RawSetString("main", luaModule(L, &info.Main))
    330 	tbl.RawSetString("deps", luaModules(L, info.Deps))
    331 	tbl.RawSetString("settings", luaSettings(L, info.Settings))
    332 	return tbl
    333 }
    334 
    335 func luaModule(L *lua.LState, module *debug.Module) lua.LValue {
    336 	tbl := L.NewTable()
    337 	tbl.RawSetString("path", lua.LString(module.Path))
    338 	tbl.RawSetString("version", lua.LString(module.Version))
    339 	tbl.RawSetString("sum", lua.LString(module.Sum))
    340 
    341 	if module.Replace != nil {
    342 		tbl.RawSetString("replace", luaModule(L, module.Replace))
    343 	}
    344 	return tbl
    345 }
    346 
    347 func luaModules(L *lua.LState, modules []*debug.Module) lua.LValue {
    348 	tbl := L.NewTable()
    349 
    350 	for index, module := range modules {
    351 		tbl.RawSetInt(index+1, luaModule(L, module))
    352 	}
    353 
    354 	return tbl
    355 }
    356 
    357 func luaSettings(L *lua.LState, settings []debug.BuildSetting) lua.LValue {
    358 	tbl := L.NewTable()
    359 
    360 	for _, setting := range settings {
    361 		tbl.RawSetString(setting.Key, lua.LString(setting.Value))
    362 	}
    363 
    364 	return tbl
    365 }
    366 
    367 func main() {
    368 	var agent fullMqttAgent
    369 
    370 	main_script := "mqttagent.lua"
    371 	if len(os.Args) > 1 {
    372 		main_script = os.Args[1]
    373 	}
    374 
    375 	mqttagent.Run(&agent, main_script, 10)
    376 
    377 	os.Exit(0)
    378 }