mqttagent

MQTT Lua Agent
git clone https://git.instinctive.eu/mqttagent.git
Log | Files | Refs | README | LICENSE

mqttagent.go (11949B)


      1 /*
      2  * Copyright (c) 2025, Natacha Porté
      3  *
      4  * Permission to use, copy, modify, and distribute this software for any
      5  * purpose with or without fee is hereby granted, provided that the above
      6  * copyright notice and this permission notice appear in all copies.
      7  *
      8  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
      9  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
     10  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
     11  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
     12  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
     13  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
     14  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
     15  */
     16 
     17 package mqttagent
     18 
     19 import (
     20 	"errors"
     21 	"fmt"
     22 	"log"
     23 	"os"
     24 	"strings"
     25 	"time"
     26 
     27 	"github.com/go-mqtt/mqtt"
     28 	"github.com/yuin/gluamapper"
     29 	"github.com/yuin/gopher-lua"
     30 )
     31 
     32 type MqttAgent interface {
     33 	Setup(L *lua.LState)
     34 	Log(L *lua.LState, msg *MqttMessage)
     35 	Teardown(L *lua.LState)
     36 }
     37 
     38 type MqttMessage struct {
     39 	Timestamp float64
     40 	ClientId  int
     41 	Topic     []byte
     42 	Message   []byte
     43 }
     44 
     45 func Run(agent MqttAgent, main_script string) {
     46 	fromMqtt := make(chan MqttMessage)
     47 
     48 	L := lua.NewState()
     49 	defer L.Close()
     50 
     51 	agent.Setup(L)
     52 	defer agent.Teardown(L)
     53 
     54 	hostname, err := os.Hostname()
     55 	if err != nil {
     56 		hostname = "<unknown>"
     57 	}
     58 
     59 	idString := fmt.Sprintf("mqttagent-%s-%d", hostname, os.Getpid())
     60 	registerMqttClientType(L)
     61 	registerTimerType(L)
     62 	registerState(L, idString, fromMqtt)
     63 	defer cleanupClients(L)
     64 
     65 	if err := L.DoFile(main_script); err != nil {
     66 		panic(err)
     67 	}
     68 
     69 	timer := time.NewTimer(0)
     70 	defer timer.Stop()
     71 
     72 	log.Println(idString, "started")
     73 
     74 	for {
     75 		select {
     76 		case msg, ok := <-fromMqtt:
     77 
     78 			if !ok {
     79 				log.Println("fromMqtt is closed")
     80 				break
     81 			}
     82 
     83 			processMsg(L, agent, &msg)
     84 
     85 		case <-timer.C:
     86 		}
     87 
     88 		hasTimer, nextTimer := runTimers(L)
     89 
     90 		if hasTimer {
     91 			timer.Reset(time.Until(nextTimer))
     92 		} else {
     93 			timer.Stop()
     94 		}
     95 
     96 		if tableIsEmpty(stateCnxTable(L)) && tableIsEmpty(stateTimerTable(L)) {
     97 			break
     98 		}
     99 	}
    100 
    101 	log.Println(idString, "finished")
    102 }
    103 
    104 func cleanupClients(L *lua.LState) {
    105 	cnxTbl := stateCnxTable(L)
    106 	if cnxTbl == nil {
    107 		return
    108 	}
    109 
    110 	L.ForEach(cnxTbl, func(key, value lua.LValue) {
    111 		cnx := value.(*lua.LTable)
    112 		client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client)
    113 		if err := client.Disconnect(nil); err != nil {
    114 			log.Printf("cleanup client %s: %v", lua.LVAsString(key), err)
    115 		}
    116 	})
    117 }
    118 
    119 func dispatchMsg(L *lua.LState, msg *MqttMessage, cnx, key, value lua.LValue) {
    120 	skey, ok := key.(lua.LString)
    121 	topic := string(msg.Topic)
    122 
    123 	if ok && match(topic, string(skey)) {
    124 		err := L.CallByParam(lua.P{Fn: value, NRet: 0, Protect: true},
    125 			cnx,
    126 			lua.LString(string(msg.Message)),
    127 			lua.LString(topic),
    128 			lua.LNumber(msg.Timestamp))
    129 		if err != nil {
    130 			panic(err)
    131 		}
    132 	}
    133 }
    134 
    135 func matchSliced(actual, filter []string) bool {
    136 	if len(filter) == 0 {
    137 		return len(actual) == 0
    138 	}
    139 
    140 	if filter[0] == "#" {
    141 		if len(filter) == 1 {
    142 			return true
    143 		}
    144 
    145 		for i := range actual {
    146 			if matchSliced(actual[i:], filter[1:]) {
    147 				return true
    148 			}
    149 		}
    150 
    151 		return false
    152 	}
    153 
    154 	if len(actual) > 0 && (filter[0] == "+" || filter[0] == actual[0]) {
    155 		return matchSliced(actual[1:], filter[1:])
    156 	}
    157 
    158 	return false
    159 }
    160 
    161 func match(actual, filter string) bool {
    162 	return matchSliced(strings.Split(actual, "/"), strings.Split(filter, "/"))
    163 }
    164 
    165 func tableIsEmpty(t *lua.LTable) bool {
    166 	key, _ := t.Next(lua.LNil)
    167 	return key == lua.LNil
    168 }
    169 
    170 func processMsg(L *lua.LState, agent MqttAgent, msg *MqttMessage) {
    171 	agent.Log(L, msg)
    172 
    173 	cnx := L.RawGetInt(stateCnxTable(L), msg.ClientId).(*lua.LTable)
    174 	subTbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable)
    175 	L.ForEach(subTbl, func(key, value lua.LValue) { dispatchMsg(L, msg, cnx, key, value) })
    176 
    177 	if tableIsEmpty(subTbl) {
    178 		client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client)
    179 		if err := client.Disconnect(nil); err != nil {
    180 			log.Println("disconnect empty client:", err)
    181 		}
    182 		L.RawSetInt(stateCnxTable(L), msg.ClientId, lua.LNil)
    183 	}
    184 }
    185 
    186 func mqttRead(client *mqtt.Client, toLua chan<- MqttMessage, id int) {
    187 	var big *mqtt.BigMessage
    188 
    189 	for {
    190 		message, topic, err := client.ReadSlices()
    191 		t := float64(time.Now().UnixMicro()) * 1.0e-6
    192 
    193 		switch {
    194 		case err == nil:
    195 			toLua <- MqttMessage{Timestamp: t, ClientId: id, Topic: dup(topic), Message: dup(message)}
    196 
    197 		case errors.As(err, &big):
    198 			data, err := big.ReadAll()
    199 			if err != nil {
    200 				log.Println("mqttRead big message:", err)
    201 			} else {
    202 				toLua <- MqttMessage{Timestamp: t, ClientId: id, Topic: dup(topic), Message: data}
    203 			}
    204 
    205 		case errors.Is(err, mqtt.ErrClosed):
    206 			log.Println("mqttRead finishing:", err)
    207 			return
    208 
    209 		case mqtt.IsConnectionRefused(err):
    210 			log.Println("mqttRead connection refused:", err)
    211 			time.Sleep(15 * time.Minute)
    212 
    213 		default:
    214 			log.Println("mqttRead:", err)
    215 			time.Sleep(2 * time.Second)
    216 		}
    217 	}
    218 }
    219 
    220 func dup(src []byte) []byte {
    221 	res := make([]byte, len(src))
    222 	copy(res, src)
    223 	return res
    224 }
    225 
    226 /********** State Object in the Lua Interpreter **********/
    227 
    228 const luaStateName = "_mqttagent"
    229 const keyChanToLua = 1
    230 const keyClientPrefix = 2
    231 const keyCnxTable = 3
    232 const keyTimerTable = 4
    233 
    234 func registerState(L *lua.LState, clientPrefix string, toLua chan<- MqttMessage) {
    235 	ud := L.NewUserData()
    236 	ud.Value = toLua
    237 
    238 	st := L.NewTable()
    239 	L.RawSetInt(st, keyChanToLua, ud)
    240 	L.RawSetInt(st, keyClientPrefix, lua.LString(clientPrefix))
    241 	L.RawSetInt(st, keyCnxTable, L.NewTable())
    242 	L.RawSetInt(st, keyTimerTable, L.NewTable())
    243 	L.SetGlobal(luaStateName, st)
    244 }
    245 
    246 func stateValue(L *lua.LState, key int) lua.LValue {
    247 	st := L.GetGlobal(luaStateName)
    248 	return L.RawGetInt(st.(*lua.LTable), key)
    249 }
    250 
    251 func stateChanToLua(L *lua.LState) chan<- MqttMessage {
    252 	ud := stateValue(L, keyChanToLua)
    253 	return ud.(*lua.LUserData).Value.(chan<- MqttMessage)
    254 }
    255 
    256 func stateClientPrefix(L *lua.LState) string {
    257 	return lua.LVAsString(stateValue(L, keyClientPrefix))
    258 }
    259 
    260 func stateCnxTable(L *lua.LState) *lua.LTable {
    261 	return stateValue(L, keyCnxTable).(*lua.LTable)
    262 }
    263 
    264 func stateTimerTable(L *lua.LState) *lua.LTable {
    265 	return stateValue(L, keyTimerTable).(*lua.LTable)
    266 }
    267 
    268 /********** Lua Object for MQTT client **********/
    269 
    270 const luaMqttClientTypeName = "mqttclient"
    271 const keyClient = 1
    272 const keySubTable = 2
    273 
    274 func registerMqttClientType(L *lua.LState) {
    275 	mt := L.NewTypeMetatable(luaMqttClientTypeName)
    276 	L.SetGlobal(luaMqttClientTypeName, mt)
    277 	L.SetField(mt, "new", L.NewFunction(newMqttClient))
    278 	L.SetField(mt, "__gc", L.NewFunction(deleteMqttClient))
    279 	L.SetField(mt, "__call", L.NewFunction(luaPublish))
    280 	L.SetField(mt, "__index", L.NewFunction(luaQuery))
    281 	L.SetField(mt, "__newindex", L.NewFunction(luaSubscribe))
    282 }
    283 
    284 type mqttConfig struct {
    285 	Connection     string
    286 	PauseTimeout   string
    287 	AtLeastOnceMax int
    288 	ExactlyOnceMax int
    289 	UserName       string
    290 	Password       []byte
    291 	Will           struct {
    292 		Topic       string
    293 		Message     []byte
    294 		Retain      bool
    295 		AtLeastOnce bool
    296 		ExactlyOnce bool
    297 	}
    298 	KeepAlive    uint16
    299 	CleanSession bool
    300 }
    301 
    302 func newClient(L *lua.LState, id string) (*mqtt.Client, error) {
    303 	var config mqttConfig
    304 	if err := gluamapper.Map(L.CheckTable(1), &config); err != nil {
    305 		return nil, err
    306 	}
    307 
    308 	pto, err := time.ParseDuration(config.PauseTimeout)
    309 	if err != nil {
    310 		pto = time.Second
    311 	}
    312 
    313 	processed_cfg := mqtt.Config{
    314 		Dialer:         mqtt.NewDialer("tcp", config.Connection),
    315 		PauseTimeout:   pto,
    316 		AtLeastOnceMax: config.AtLeastOnceMax,
    317 		ExactlyOnceMax: config.ExactlyOnceMax,
    318 		UserName:       config.UserName,
    319 		Password:       config.Password,
    320 		Will: struct {
    321 			Topic       string
    322 			Message     []byte
    323 			Retain      bool
    324 			AtLeastOnce bool
    325 			ExactlyOnce bool
    326 		}{
    327 			Topic:       config.Will.Topic,
    328 			Message:     config.Will.Message,
    329 			Retain:      config.Will.Retain,
    330 			AtLeastOnce: config.Will.AtLeastOnce,
    331 			ExactlyOnce: config.Will.ExactlyOnce,
    332 		},
    333 		KeepAlive:    config.KeepAlive,
    334 		CleanSession: config.CleanSession,
    335 	}
    336 
    337 	return mqtt.VolatileSession(id, &processed_cfg)
    338 }
    339 
    340 func newMqttClient(L *lua.LState) int {
    341 	id := stateCnxTable(L).Len() + 1
    342 	idString := fmt.Sprintf("%s-%d", stateClientPrefix(L), id)
    343 
    344 	client, err := newClient(L, idString)
    345 	if err != nil {
    346 		log.Println("newMqttClient:", err)
    347 		L.Push(lua.LNil)
    348 		L.Push(lua.LString(err.Error()))
    349 		return 2
    350 	}
    351 	go mqttRead(client, stateChanToLua(L), id)
    352 
    353 	ud := L.NewUserData()
    354 	ud.Value = client
    355 
    356 	res := L.NewTable()
    357 	L.RawSetInt(res, keyClient, ud)
    358 	L.RawSetInt(res, keySubTable, L.NewTable())
    359 	L.SetMetatable(res, L.GetTypeMetatable(luaMqttClientTypeName))
    360 	L.RawSetInt(stateCnxTable(L), id, res)
    361 	L.Push(res)
    362 	return 1
    363 }
    364 
    365 func deleteMqttClient(L *lua.LState) int {
    366 	log.Println("deleteMqttClient: TODO")
    367 	return 0
    368 }
    369 
    370 func luaPublish(L *lua.LState) int {
    371 	cnx := L.CheckTable(1)
    372 	message := L.CheckString(2)
    373 	topic := L.CheckString(3)
    374 	client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client)
    375 
    376 	err := client.Publish(nil, []byte(message), topic)
    377 
    378 	if err != nil {
    379 		L.Push(lua.LNil)
    380 		L.Push(lua.LString(err.Error()))
    381 		return 2
    382 	} else {
    383 		L.Push(lua.LTrue)
    384 		return 1
    385 	}
    386 }
    387 
    388 func luaQuery(L *lua.LState) int {
    389 	cnx := L.CheckTable(1)
    390 	topic := L.CheckString(2)
    391 	subTbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable)
    392 	L.Push(L.GetField(subTbl, topic))
    393 	return 1
    394 }
    395 
    396 func luaSubscribe(L *lua.LState) int {
    397 	var err error
    398 	cnx := L.CheckTable(1)
    399 	topic := L.CheckString(2)
    400 	callback := L.OptFunction(3, nil)
    401 	client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client)
    402 
    403 	if callback == nil {
    404 		err = client.Unsubscribe(nil, topic)
    405 	} else {
    406 		err = client.Subscribe(nil, topic)
    407 	}
    408 
    409 	if err != nil {
    410 		log.Println("luaSubscribe:", err)
    411 		L.Push(lua.LNil)
    412 		L.Push(lua.LString(err.Error()))
    413 		return 2
    414 	} else {
    415 		tbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable)
    416 
    417 		if callback == nil {
    418 			L.SetField(tbl, topic, lua.LNil)
    419 		} else {
    420 			L.SetField(tbl, topic, callback)
    421 		}
    422 
    423 		L.Push(lua.LTrue)
    424 		return 1
    425 	}
    426 }
    427 
    428 /********** Lua Object for timers **********/
    429 
    430 const luaTimerTypeName = "timer"
    431 const keyTime = 1
    432 const keyCallback = 2
    433 
    434 func registerTimerType(L *lua.LState) {
    435 	mt := L.NewTypeMetatable(luaTimerTypeName)
    436 	L.SetGlobal(luaTimerTypeName, mt)
    437 	L.SetField(mt, "new", L.NewFunction(newTimer))
    438 	L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), timerMethods))
    439 }
    440 
    441 func newTimer(L *lua.LState) int {
    442 	atTime := L.CheckNumber(1)
    443 	cb := L.CheckFunction(2)
    444 
    445 	tbl := L.NewTable()
    446 	L.RawSetInt(tbl, keyTime, atTime)
    447 	L.RawSetInt(tbl, keyCallback, cb)
    448 	L.SetMetatable(tbl, L.GetTypeMetatable(luaTimerTypeName))
    449 	stateTimerTable(L).Append(tbl)
    450 	L.Push(tbl)
    451 	return 1
    452 }
    453 
    454 var timerMethods = map[string]lua.LGFunction{
    455 	"cancel":   timerCancel,
    456 	"schedule": timerSchedule,
    457 }
    458 
    459 func timerCancel(L *lua.LState) int {
    460 	tbl := L.CheckTable(1)
    461 	L.RawSetInt(tbl, keyTime, lua.LNil)
    462 	return 0
    463 }
    464 
    465 func timerSchedule(L *lua.LState) int {
    466 	tbl := L.CheckTable(1)
    467 	atTime := L.CheckNumber(2)
    468 	L.RawSetInt(tbl, keyTime, atTime)
    469 	return 0
    470 }
    471 
    472 func toTime(v lua.LValue, d time.Time) (time.Time, bool) {
    473 	lsec, ok := v.(lua.LNumber)
    474 	if !ok {
    475 		return d, false
    476 	}
    477 
    478 	fsec := float64(lsec)
    479 	sec := int64(fsec)
    480 	nsec := int64((fsec - float64(sec)) * 1.0e9)
    481 
    482 	return time.Unix(sec, nsec), true
    483 }
    484 
    485 func runTimers(L *lua.LState) (bool, time.Time) {
    486 	hasNext := false
    487 	var nextTime time.Time
    488 
    489 	now := time.Now()
    490 	timers := stateTimerTable(L)
    491 
    492 	k, v := timers.Next(lua.LNil)
    493 	for k != lua.LNil {
    494 		tbl := v.(*lua.LTable)
    495 		luaT := L.RawGetInt(tbl, keyTime)
    496 		t, ok := toTime(luaT, now)
    497 		if !ok {
    498 		} else if t.Compare(now) <= 0 {
    499 			L.RawSetInt(tbl, keyTime, lua.LNil)
    500 			err := L.CallByParam(lua.P{Fn: L.RawGetInt(tbl, keyCallback), NRet: 0, Protect: true}, v, luaT)
    501 			if err != nil {
    502 				panic(err)
    503 			}
    504 			k = lua.LNil
    505 			hasNext = false
    506 		} else if !hasNext || t.Compare(nextTime) < 0 {
    507 			hasNext = true
    508 			nextTime = t
    509 		}
    510 
    511 		k, v = timers.Next(k)
    512 	}
    513 
    514 	return hasNext, nextTime
    515 }