mqttagent

MQTT Lua Agent
git clone https://git.instinctive.eu/mqttagent.git
Log | Files | Refs | README | LICENSE

mqttagent.go (13456B)


      1 /*
      2  * Copyright (c) 2025, Natacha Porté
      3  *
      4  * Permission to use, copy, modify, and distribute this software for any
      5  * purpose with or without fee is hereby granted, provided that the above
      6  * copyright notice and this permission notice appear in all copies.
      7  *
      8  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
      9  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
     10  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
     11  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
     12  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
     13  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
     14  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
     15  */
     16 
     17 package mqttagent
     18 
     19 import (
     20 	"errors"
     21 	"fmt"
     22 	"log"
     23 	"os"
     24 	"strings"
     25 	"time"
     26 
     27 	"github.com/go-mqtt/mqtt"
     28 	"github.com/yuin/gluamapper"
     29 	"github.com/yuin/gopher-lua"
     30 )
     31 
     32 type MqttAgent interface {
     33 	Setup(L *lua.LState)
     34 	Log(L *lua.LState, msg *MqttMessage)
     35 	Teardown(L *lua.LState)
     36 }
     37 
     38 type MqttMessage struct {
     39 	Timestamp float64
     40 	ClientId  int
     41 	Topic     []byte
     42 	Message   []byte
     43 }
     44 
     45 func Run(agent MqttAgent, main_script string, capacity int) {
     46 	fromMqtt := make(chan MqttMessage, capacity)
     47 
     48 	L := lua.NewState()
     49 	defer L.Close()
     50 
     51 	agent.Setup(L)
     52 	defer agent.Teardown(L)
     53 
     54 	hostname, err := os.Hostname()
     55 	if err != nil {
     56 		hostname = "<unknown>"
     57 	}
     58 
     59 	idString := fmt.Sprintf("mqttagent-%s-%d", hostname, os.Getpid())
     60 	registerMqttClientType(L)
     61 	registerTimerType(L)
     62 	registerState(L, idString, fromMqtt)
     63 	defer cleanupClients(L)
     64 
     65 	if err := L.DoFile(main_script); err != nil {
     66 		panic(err)
     67 	}
     68 
     69 	timer := time.NewTimer(0)
     70 	defer timer.Stop()
     71 
     72 	log.Println(idString, "started")
     73 
     74 	for {
     75 		select {
     76 		case msg, ok := <-fromMqtt:
     77 
     78 			if !ok {
     79 				log.Println("fromMqtt is closed")
     80 				break
     81 			}
     82 
     83 			processMsg(L, agent, &msg)
     84 
     85 		case <-timer.C:
     86 		}
     87 
     88 		hasTimer, nextTimer := runTimers(L)
     89 
     90 		if hasTimer {
     91 			timer.Reset(time.Until(nextTimer))
     92 		} else {
     93 			timer.Stop()
     94 		}
     95 
     96 		if tableIsEmpty(stateCnxTable(L)) && tableIsEmpty(stateTimerTable(L)) {
     97 			break
     98 		}
     99 	}
    100 
    101 	log.Println(idString, "finished")
    102 }
    103 
    104 func cleanupClients(L *lua.LState) {
    105 	cnxTbl := stateCnxTable(L)
    106 	if cnxTbl == nil {
    107 		return
    108 	}
    109 
    110 	L.ForEach(cnxTbl, func(key, value lua.LValue) {
    111 		cnx := value.(*lua.LTable)
    112 		client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client)
    113 		if err := client.Disconnect(nil); err != nil {
    114 			log.Printf("cleanup client %s: %v", lua.LVAsString(key), err)
    115 		}
    116 	})
    117 }
    118 
    119 func dispatchMsg(L *lua.LState, msg *MqttMessage, cnx, key, value lua.LValue) {
    120 	skey, ok := key.(lua.LString)
    121 	topic := string(msg.Topic)
    122 
    123 	if ok && match(topic, string(skey)) {
    124 		err := L.CallByParam(lua.P{Fn: value, NRet: 0, Protect: true},
    125 			cnx,
    126 			lua.LString(string(msg.Message)),
    127 			lua.LString(topic),
    128 			lua.LNumber(msg.Timestamp))
    129 		if err != nil {
    130 			panic(err)
    131 		}
    132 	}
    133 }
    134 
    135 func matchSliced(actual, filter []string) bool {
    136 	if len(filter) == 0 {
    137 		return len(actual) == 0
    138 	}
    139 
    140 	if filter[0] == "#" {
    141 		if len(filter) == 1 {
    142 			return true
    143 		}
    144 
    145 		for i := range actual {
    146 			if matchSliced(actual[i:], filter[1:]) {
    147 				return true
    148 			}
    149 		}
    150 
    151 		return false
    152 	}
    153 
    154 	if len(actual) > 0 && (filter[0] == "+" || filter[0] == actual[0]) {
    155 		return matchSliced(actual[1:], filter[1:])
    156 	}
    157 
    158 	return false
    159 }
    160 
    161 func match(actual, filter string) bool {
    162 	return matchSliced(strings.Split(actual, "/"), strings.Split(filter, "/"))
    163 }
    164 
    165 func tableIsEmpty(t *lua.LTable) bool {
    166 	key, _ := t.Next(lua.LNil)
    167 	return key == lua.LNil
    168 }
    169 
    170 func processMsg(L *lua.LState, agent MqttAgent, msg *MqttMessage) {
    171 	agent.Log(L, msg)
    172 
    173 	cnx := L.RawGetInt(stateCnxTable(L), msg.ClientId).(*lua.LTable)
    174 	subTbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable)
    175 	L.ForEach(subTbl, func(key, value lua.LValue) { dispatchMsg(L, msg, cnx, key, value) })
    176 
    177 	if tableIsEmpty(subTbl) {
    178 		client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client)
    179 		if err := client.Disconnect(nil); err != nil {
    180 			log.Println("disconnect empty client:", err)
    181 		}
    182 		L.RawSetInt(stateCnxTable(L), msg.ClientId, lua.LNil)
    183 	}
    184 }
    185 
    186 func mqttRead(client *mqtt.Client, toLua chan<- MqttMessage, id int) {
    187 	var big *mqtt.BigMessage
    188 
    189 	for {
    190 		message, topic, err := client.ReadSlices()
    191 		t := float64(time.Now().UnixMicro()) * 1.0e-6
    192 
    193 		switch {
    194 		case err == nil:
    195 			toLua <- MqttMessage{Timestamp: t, ClientId: id, Topic: dup(topic), Message: dup(message)}
    196 
    197 		case errors.As(err, &big):
    198 			data, err := big.ReadAll()
    199 			if err != nil {
    200 				log.Println("mqttRead big message:", err)
    201 			} else {
    202 				toLua <- MqttMessage{Timestamp: t, ClientId: id, Topic: dup(topic), Message: data}
    203 			}
    204 
    205 		case errors.Is(err, mqtt.ErrClosed):
    206 			log.Println("mqttRead finishing:", err)
    207 			return
    208 
    209 		case mqtt.IsConnectionRefused(err):
    210 			log.Println("mqttRead connection refused:", err)
    211 			time.Sleep(15 * time.Minute)
    212 
    213 		default:
    214 			log.Println("mqttRead:", err)
    215 			time.Sleep(2 * time.Second)
    216 		}
    217 	}
    218 }
    219 
    220 func dup(src []byte) []byte {
    221 	res := make([]byte, len(src))
    222 	copy(res, src)
    223 	return res
    224 }
    225 
    226 func newUserData(L *lua.LState, v interface{}) *lua.LUserData {
    227 	res := L.NewUserData()
    228 	res.Value = v
    229 	return res
    230 }
    231 
    232 /********** State Object in the Lua Interpreter **********/
    233 
    234 const luaStateName = "_mqttagent"
    235 const keyChanToLua = 1
    236 const keyClientPrefix = 2
    237 const keyClientNextId = 3
    238 const keyCfgMap = 4
    239 const keyCnxTable = 5
    240 const keyTimerTable = 6
    241 
    242 func registerState(L *lua.LState, clientPrefix string, toLua chan<- MqttMessage) {
    243 	st := L.NewTable()
    244 	L.RawSetInt(st, keyChanToLua, newUserData(L, toLua))
    245 	L.RawSetInt(st, keyClientPrefix, lua.LString(clientPrefix))
    246 	L.RawSetInt(st, keyClientNextId, lua.LNumber(1))
    247 	L.RawSetInt(st, keyCfgMap, newUserData(L, make(mqttConfigMap)))
    248 	L.RawSetInt(st, keyCnxTable, L.NewTable())
    249 	L.RawSetInt(st, keyTimerTable, L.NewTable())
    250 	L.SetGlobal(luaStateName, st)
    251 }
    252 
    253 func stateValue(L *lua.LState, key int) lua.LValue {
    254 	st := L.GetGlobal(luaStateName)
    255 	return L.RawGetInt(st.(*lua.LTable), key)
    256 }
    257 
    258 func stateChanToLua(L *lua.LState) chan<- MqttMessage {
    259 	ud := stateValue(L, keyChanToLua)
    260 	return ud.(*lua.LUserData).Value.(chan<- MqttMessage)
    261 }
    262 
    263 func stateClientNextId(L *lua.LState) (int, string) {
    264 	st := L.GetGlobal(luaStateName).(*lua.LTable)
    265 	result := int(L.RawGetInt(st, keyClientNextId).(lua.LNumber))
    266 	L.RawSetInt(st, keyClientNextId, lua.LNumber(result+1))
    267 	prefix := lua.LVAsString(L.RawGetInt(st, keyClientPrefix))
    268 	return result, fmt.Sprintf("%s-%d", prefix, result)
    269 }
    270 
    271 func stateCfgMap(L *lua.LState) mqttConfigMap {
    272 	return stateValue(L, keyCfgMap).(*lua.LUserData).Value.(mqttConfigMap)
    273 }
    274 
    275 func stateCnxTable(L *lua.LState) *lua.LTable {
    276 	return stateValue(L, keyCnxTable).(*lua.LTable)
    277 }
    278 
    279 func stateTimerTable(L *lua.LState) *lua.LTable {
    280 	return stateValue(L, keyTimerTable).(*lua.LTable)
    281 }
    282 
    283 /********** Lua Object for MQTT client **********/
    284 
    285 const luaMqttClientTypeName = "mqttclient"
    286 const keyClient = 1
    287 const keySubTable = 2
    288 
    289 func registerMqttClientType(L *lua.LState) {
    290 	mt := L.NewTypeMetatable(luaMqttClientTypeName)
    291 	L.SetGlobal(luaMqttClientTypeName, mt)
    292 	L.SetField(mt, "new", L.NewFunction(newMqttClient))
    293 	L.SetField(mt, "__call", L.NewFunction(luaPublish))
    294 	L.SetField(mt, "__index", L.NewFunction(luaQuery))
    295 	L.SetField(mt, "__newindex", L.NewFunction(luaSubscribe))
    296 }
    297 
    298 type mqttConfig struct {
    299 	Connection     string
    300 	PauseTimeout   string
    301 	AtLeastOnceMax int
    302 	ExactlyOnceMax int
    303 	UserName       string
    304 	Password       string
    305 	Will           struct {
    306 		Topic       string
    307 		Message     string
    308 		Retain      bool
    309 		AtLeastOnce bool
    310 		ExactlyOnce bool
    311 	}
    312 	KeepAlive    uint16
    313 	CleanSession bool
    314 }
    315 
    316 type mqttClientEntry struct {
    317 	client *mqtt.Client
    318 	id     int
    319 }
    320 
    321 type mqttConfigMap map[mqttConfig]mqttClientEntry
    322 
    323 func mqttConfigBytes(src string) []byte {
    324 	if src == "" {
    325 		return nil
    326 	} else {
    327 		return []byte(src)
    328 	}
    329 }
    330 
    331 func newClient(config *mqttConfig, id string) (*mqtt.Client, error) {
    332 	pto, err := time.ParseDuration(config.PauseTimeout)
    333 	if err != nil {
    334 		pto = time.Second
    335 	}
    336 
    337 	processed_cfg := mqtt.Config{
    338 		Dialer:         mqtt.NewDialer("tcp", config.Connection),
    339 		PauseTimeout:   pto,
    340 		AtLeastOnceMax: config.AtLeastOnceMax,
    341 		ExactlyOnceMax: config.ExactlyOnceMax,
    342 		UserName:       config.UserName,
    343 		Password:       mqttConfigBytes(config.Password),
    344 		Will: struct {
    345 			Topic       string
    346 			Message     []byte
    347 			Retain      bool
    348 			AtLeastOnce bool
    349 			ExactlyOnce bool
    350 		}{
    351 			Topic:       config.Will.Topic,
    352 			Message:     mqttConfigBytes(config.Will.Message),
    353 			Retain:      config.Will.Retain,
    354 			AtLeastOnce: config.Will.AtLeastOnce,
    355 			ExactlyOnce: config.Will.ExactlyOnce,
    356 		},
    357 		KeepAlive:    config.KeepAlive,
    358 		CleanSession: config.CleanSession,
    359 	}
    360 
    361 	return mqtt.VolatileSession(id, &processed_cfg)
    362 }
    363 
    364 func newMqttClient(L *lua.LState) int {
    365 	var config mqttConfig
    366 	if err := gluamapper.Map(L.CheckTable(1), &config); err != nil {
    367 		log.Println("newMqttClient:", err)
    368 		L.Push(lua.LNil)
    369 		L.Push(lua.LString(err.Error()))
    370 		return 2
    371 	}
    372 
    373 	cfgMap := stateCfgMap(L)
    374 
    375 	if cfg, found := cfgMap[config]; found {
    376 		res := L.RawGetInt(stateCnxTable(L), cfg.id)
    377 		tbl := res.(*lua.LTable)
    378 		if L.RawGetInt(tbl, keyClient).(*lua.LUserData).Value.(*mqtt.Client) != cfg.client {
    379 			panic("Inconsistent configuration table")
    380 		}
    381 
    382 		L.Push(res)
    383 		return 1
    384 	}
    385 
    386 	id, idString := stateClientNextId(L)
    387 	client, err := newClient(&config, idString)
    388 	if err != nil {
    389 		log.Println("newMqttClient:", err)
    390 		L.Push(lua.LNil)
    391 		L.Push(lua.LString(err.Error()))
    392 		return 2
    393 	}
    394 	go mqttRead(client, stateChanToLua(L), id)
    395 
    396 	cfgMap[config] = mqttClientEntry{id: id, client: client}
    397 
    398 	res := L.NewTable()
    399 	L.RawSetInt(res, keyClient, newUserData(L, client))
    400 	L.RawSetInt(res, keySubTable, L.NewTable())
    401 	L.SetMetatable(res, L.GetTypeMetatable(luaMqttClientTypeName))
    402 	L.RawSetInt(stateCnxTable(L), id, res)
    403 	L.Push(res)
    404 	return 1
    405 }
    406 
    407 func luaPublish(L *lua.LState) int {
    408 	cnx := L.CheckTable(1)
    409 	client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client)
    410 
    411 	if L.GetTop() == 1 {
    412 		if err := client.Ping(nil); err != nil {
    413 			log.Println("luaPing:", err)
    414 			L.Push(lua.LNil)
    415 			L.Push(lua.LString(err.Error()))
    416 			return 2
    417 		} else {
    418 			L.Push(lua.LTrue)
    419 			return 1
    420 		}
    421 	}
    422 
    423 	message := L.CheckString(2)
    424 	topic := L.CheckString(3)
    425 
    426 	if err := client.Publish(nil, []byte(message), topic); err != nil {
    427 		L.Push(lua.LNil)
    428 		L.Push(lua.LString(err.Error()))
    429 		return 2
    430 	} else {
    431 		L.Push(lua.LTrue)
    432 		return 1
    433 	}
    434 }
    435 
    436 func luaQuery(L *lua.LState) int {
    437 	cnx := L.CheckTable(1)
    438 	topic := L.CheckString(2)
    439 	subTbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable)
    440 	L.Push(L.GetField(subTbl, topic))
    441 	return 1
    442 }
    443 
    444 func luaSubscribe(L *lua.LState) int {
    445 	var err error
    446 	cnx := L.CheckTable(1)
    447 	topic := L.CheckString(2)
    448 	callback := L.OptFunction(3, nil)
    449 	client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client)
    450 	tbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable)
    451 
    452 	_, is_new := L.GetField(tbl, topic).(*lua.LNilType)
    453 
    454 	if callback == nil {
    455 		err = client.Unsubscribe(nil, topic)
    456 	} else if is_new {
    457 		err = client.Subscribe(nil, topic)
    458 	}
    459 
    460 	if err != nil {
    461 		log.Println("luaSubscribe:", err)
    462 		L.Push(lua.LNil)
    463 		L.Push(lua.LString(err.Error()))
    464 		return 2
    465 	} else {
    466 		if callback == nil {
    467 			if is_new {
    468 				log.Printf("Not subscribed to %q", topic)
    469 			} else {
    470 				log.Printf("Unsubscribed from %q", topic)
    471 			}
    472 			L.SetField(tbl, topic, lua.LNil)
    473 		} else {
    474 			if is_new {
    475 				log.Printf("Subscribed to %q", topic)
    476 			} else {
    477 				log.Printf("Updating subscription to %q", topic)
    478 			}
    479 			L.SetField(tbl, topic, callback)
    480 		}
    481 
    482 		L.Push(lua.LTrue)
    483 		return 1
    484 	}
    485 }
    486 
    487 /********** Lua Object for timers **********/
    488 
    489 const luaTimerTypeName = "timer"
    490 
    491 func registerTimerType(L *lua.LState) {
    492 	mt := L.NewTypeMetatable(luaTimerTypeName)
    493 	L.SetGlobal(luaTimerTypeName, mt)
    494 	L.SetField(mt, "new", L.NewFunction(newTimer))
    495 	L.SetField(mt, "schedule", L.NewFunction(timerSchedule))
    496 	L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), timerMethods))
    497 }
    498 
    499 func newTimer(L *lua.LState) int {
    500 	atTime := L.Get(1)
    501 	cb := L.CheckFunction(2)
    502 	L.Pop(2)
    503 	L.SetMetatable(cb, L.GetTypeMetatable(luaTimerTypeName))
    504 	L.Push(cb)
    505 	L.Push(atTime)
    506 	return timerSchedule(L)
    507 }
    508 
    509 var timerMethods = map[string]lua.LGFunction{
    510 	"cancel":   timerCancel,
    511 	"schedule": timerSchedule,
    512 }
    513 
    514 func timerCancel(L *lua.LState) int {
    515 	timer := L.CheckFunction(1)
    516 	L.RawSet(stateTimerTable(L), timer, lua.LNil)
    517 	return 0
    518 }
    519 
    520 func timerSchedule(L *lua.LState) int {
    521 	timer := L.CheckFunction(1)
    522 	atTime := lua.LNil
    523 	if L.Get(2) != lua.LNil {
    524 		atTime = L.CheckNumber(2)
    525 	}
    526 
    527 	L.RawSet(stateTimerTable(L), timer, atTime)
    528 	return 0
    529 }
    530 
    531 func toTime(lsec lua.LNumber) time.Time {
    532 	fsec := float64(lsec)
    533 	sec := int64(fsec)
    534 	nsec := int64((fsec - float64(sec)) * 1.0e9)
    535 
    536 	return time.Unix(sec, nsec)
    537 }
    538 
    539 func runTimers(L *lua.LState) (bool, time.Time) {
    540 	hasNext := false
    541 	var nextTime time.Time
    542 
    543 	now := time.Now()
    544 	timers := stateTimerTable(L)
    545 
    546 	timer, luaT := timers.Next(lua.LNil)
    547 	for timer != lua.LNil {
    548 		t := toTime(luaT.(lua.LNumber))
    549 		if t.Compare(now) <= 0 {
    550 			L.RawSet(timers, timer, lua.LNil)
    551 			err := L.CallByParam(lua.P{Fn: timer, NRet: 0, Protect: true}, timer, luaT)
    552 			if err != nil {
    553 				panic(err)
    554 			}
    555 			timer = lua.LNil
    556 			hasNext = false
    557 		} else if !hasNext || t.Compare(nextTime) < 0 {
    558 			hasNext = true
    559 			nextTime = t
    560 		}
    561 
    562 		timer, luaT = timers.Next(timer)
    563 	}
    564 
    565 	return hasNext, nextTime
    566 }