mqttagent.go (13456B)
1 /* 2 * Copyright (c) 2025, Natacha Porté 3 * 4 * Permission to use, copy, modify, and distribute this software for any 5 * purpose with or without fee is hereby granted, provided that the above 6 * copyright notice and this permission notice appear in all copies. 7 * 8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 */ 16 17 package mqttagent 18 19 import ( 20 "errors" 21 "fmt" 22 "log" 23 "os" 24 "strings" 25 "time" 26 27 "github.com/go-mqtt/mqtt" 28 "github.com/yuin/gluamapper" 29 "github.com/yuin/gopher-lua" 30 ) 31 32 type MqttAgent interface { 33 Setup(L *lua.LState) 34 Log(L *lua.LState, msg *MqttMessage) 35 Teardown(L *lua.LState) 36 } 37 38 type MqttMessage struct { 39 Timestamp float64 40 ClientId int 41 Topic []byte 42 Message []byte 43 } 44 45 func Run(agent MqttAgent, main_script string, capacity int) { 46 fromMqtt := make(chan MqttMessage, capacity) 47 48 L := lua.NewState() 49 defer L.Close() 50 51 agent.Setup(L) 52 defer agent.Teardown(L) 53 54 hostname, err := os.Hostname() 55 if err != nil { 56 hostname = "<unknown>" 57 } 58 59 idString := fmt.Sprintf("mqttagent-%s-%d", hostname, os.Getpid()) 60 registerMqttClientType(L) 61 registerTimerType(L) 62 registerState(L, idString, fromMqtt) 63 defer cleanupClients(L) 64 65 if err := L.DoFile(main_script); err != nil { 66 panic(err) 67 } 68 69 timer := time.NewTimer(0) 70 defer timer.Stop() 71 72 log.Println(idString, "started") 73 74 for { 75 select { 76 case msg, ok := <-fromMqtt: 77 78 if !ok { 79 log.Println("fromMqtt is closed") 80 break 81 } 82 83 processMsg(L, agent, &msg) 84 85 case <-timer.C: 86 } 87 88 hasTimer, nextTimer := runTimers(L) 89 90 if hasTimer { 91 timer.Reset(time.Until(nextTimer)) 92 } else { 93 timer.Stop() 94 } 95 96 if tableIsEmpty(stateCnxTable(L)) && tableIsEmpty(stateTimerTable(L)) { 97 break 98 } 99 } 100 101 log.Println(idString, "finished") 102 } 103 104 func cleanupClients(L *lua.LState) { 105 cnxTbl := stateCnxTable(L) 106 if cnxTbl == nil { 107 return 108 } 109 110 L.ForEach(cnxTbl, func(key, value lua.LValue) { 111 cnx := value.(*lua.LTable) 112 client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client) 113 if err := client.Disconnect(nil); err != nil { 114 log.Printf("cleanup client %s: %v", lua.LVAsString(key), err) 115 } 116 }) 117 } 118 119 func dispatchMsg(L *lua.LState, msg *MqttMessage, cnx, key, value lua.LValue) { 120 skey, ok := key.(lua.LString) 121 topic := string(msg.Topic) 122 123 if ok && match(topic, string(skey)) { 124 err := L.CallByParam(lua.P{Fn: value, NRet: 0, Protect: true}, 125 cnx, 126 lua.LString(string(msg.Message)), 127 lua.LString(topic), 128 lua.LNumber(msg.Timestamp)) 129 if err != nil { 130 panic(err) 131 } 132 } 133 } 134 135 func matchSliced(actual, filter []string) bool { 136 if len(filter) == 0 { 137 return len(actual) == 0 138 } 139 140 if filter[0] == "#" { 141 if len(filter) == 1 { 142 return true 143 } 144 145 for i := range actual { 146 if matchSliced(actual[i:], filter[1:]) { 147 return true 148 } 149 } 150 151 return false 152 } 153 154 if len(actual) > 0 && (filter[0] == "+" || filter[0] == actual[0]) { 155 return matchSliced(actual[1:], filter[1:]) 156 } 157 158 return false 159 } 160 161 func match(actual, filter string) bool { 162 return matchSliced(strings.Split(actual, "/"), strings.Split(filter, "/")) 163 } 164 165 func tableIsEmpty(t *lua.LTable) bool { 166 key, _ := t.Next(lua.LNil) 167 return key == lua.LNil 168 } 169 170 func processMsg(L *lua.LState, agent MqttAgent, msg *MqttMessage) { 171 agent.Log(L, msg) 172 173 cnx := L.RawGetInt(stateCnxTable(L), msg.ClientId).(*lua.LTable) 174 subTbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable) 175 L.ForEach(subTbl, func(key, value lua.LValue) { dispatchMsg(L, msg, cnx, key, value) }) 176 177 if tableIsEmpty(subTbl) { 178 client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client) 179 if err := client.Disconnect(nil); err != nil { 180 log.Println("disconnect empty client:", err) 181 } 182 L.RawSetInt(stateCnxTable(L), msg.ClientId, lua.LNil) 183 } 184 } 185 186 func mqttRead(client *mqtt.Client, toLua chan<- MqttMessage, id int) { 187 var big *mqtt.BigMessage 188 189 for { 190 message, topic, err := client.ReadSlices() 191 t := float64(time.Now().UnixMicro()) * 1.0e-6 192 193 switch { 194 case err == nil: 195 toLua <- MqttMessage{Timestamp: t, ClientId: id, Topic: dup(topic), Message: dup(message)} 196 197 case errors.As(err, &big): 198 data, err := big.ReadAll() 199 if err != nil { 200 log.Println("mqttRead big message:", err) 201 } else { 202 toLua <- MqttMessage{Timestamp: t, ClientId: id, Topic: dup(topic), Message: data} 203 } 204 205 case errors.Is(err, mqtt.ErrClosed): 206 log.Println("mqttRead finishing:", err) 207 return 208 209 case mqtt.IsConnectionRefused(err): 210 log.Println("mqttRead connection refused:", err) 211 time.Sleep(15 * time.Minute) 212 213 default: 214 log.Println("mqttRead:", err) 215 time.Sleep(2 * time.Second) 216 } 217 } 218 } 219 220 func dup(src []byte) []byte { 221 res := make([]byte, len(src)) 222 copy(res, src) 223 return res 224 } 225 226 func newUserData(L *lua.LState, v interface{}) *lua.LUserData { 227 res := L.NewUserData() 228 res.Value = v 229 return res 230 } 231 232 /********** State Object in the Lua Interpreter **********/ 233 234 const luaStateName = "_mqttagent" 235 const keyChanToLua = 1 236 const keyClientPrefix = 2 237 const keyClientNextId = 3 238 const keyCfgMap = 4 239 const keyCnxTable = 5 240 const keyTimerTable = 6 241 242 func registerState(L *lua.LState, clientPrefix string, toLua chan<- MqttMessage) { 243 st := L.NewTable() 244 L.RawSetInt(st, keyChanToLua, newUserData(L, toLua)) 245 L.RawSetInt(st, keyClientPrefix, lua.LString(clientPrefix)) 246 L.RawSetInt(st, keyClientNextId, lua.LNumber(1)) 247 L.RawSetInt(st, keyCfgMap, newUserData(L, make(mqttConfigMap))) 248 L.RawSetInt(st, keyCnxTable, L.NewTable()) 249 L.RawSetInt(st, keyTimerTable, L.NewTable()) 250 L.SetGlobal(luaStateName, st) 251 } 252 253 func stateValue(L *lua.LState, key int) lua.LValue { 254 st := L.GetGlobal(luaStateName) 255 return L.RawGetInt(st.(*lua.LTable), key) 256 } 257 258 func stateChanToLua(L *lua.LState) chan<- MqttMessage { 259 ud := stateValue(L, keyChanToLua) 260 return ud.(*lua.LUserData).Value.(chan<- MqttMessage) 261 } 262 263 func stateClientNextId(L *lua.LState) (int, string) { 264 st := L.GetGlobal(luaStateName).(*lua.LTable) 265 result := int(L.RawGetInt(st, keyClientNextId).(lua.LNumber)) 266 L.RawSetInt(st, keyClientNextId, lua.LNumber(result+1)) 267 prefix := lua.LVAsString(L.RawGetInt(st, keyClientPrefix)) 268 return result, fmt.Sprintf("%s-%d", prefix, result) 269 } 270 271 func stateCfgMap(L *lua.LState) mqttConfigMap { 272 return stateValue(L, keyCfgMap).(*lua.LUserData).Value.(mqttConfigMap) 273 } 274 275 func stateCnxTable(L *lua.LState) *lua.LTable { 276 return stateValue(L, keyCnxTable).(*lua.LTable) 277 } 278 279 func stateTimerTable(L *lua.LState) *lua.LTable { 280 return stateValue(L, keyTimerTable).(*lua.LTable) 281 } 282 283 /********** Lua Object for MQTT client **********/ 284 285 const luaMqttClientTypeName = "mqttclient" 286 const keyClient = 1 287 const keySubTable = 2 288 289 func registerMqttClientType(L *lua.LState) { 290 mt := L.NewTypeMetatable(luaMqttClientTypeName) 291 L.SetGlobal(luaMqttClientTypeName, mt) 292 L.SetField(mt, "new", L.NewFunction(newMqttClient)) 293 L.SetField(mt, "__call", L.NewFunction(luaPublish)) 294 L.SetField(mt, "__index", L.NewFunction(luaQuery)) 295 L.SetField(mt, "__newindex", L.NewFunction(luaSubscribe)) 296 } 297 298 type mqttConfig struct { 299 Connection string 300 PauseTimeout string 301 AtLeastOnceMax int 302 ExactlyOnceMax int 303 UserName string 304 Password string 305 Will struct { 306 Topic string 307 Message string 308 Retain bool 309 AtLeastOnce bool 310 ExactlyOnce bool 311 } 312 KeepAlive uint16 313 CleanSession bool 314 } 315 316 type mqttClientEntry struct { 317 client *mqtt.Client 318 id int 319 } 320 321 type mqttConfigMap map[mqttConfig]mqttClientEntry 322 323 func mqttConfigBytes(src string) []byte { 324 if src == "" { 325 return nil 326 } else { 327 return []byte(src) 328 } 329 } 330 331 func newClient(config *mqttConfig, id string) (*mqtt.Client, error) { 332 pto, err := time.ParseDuration(config.PauseTimeout) 333 if err != nil { 334 pto = time.Second 335 } 336 337 processed_cfg := mqtt.Config{ 338 Dialer: mqtt.NewDialer("tcp", config.Connection), 339 PauseTimeout: pto, 340 AtLeastOnceMax: config.AtLeastOnceMax, 341 ExactlyOnceMax: config.ExactlyOnceMax, 342 UserName: config.UserName, 343 Password: mqttConfigBytes(config.Password), 344 Will: struct { 345 Topic string 346 Message []byte 347 Retain bool 348 AtLeastOnce bool 349 ExactlyOnce bool 350 }{ 351 Topic: config.Will.Topic, 352 Message: mqttConfigBytes(config.Will.Message), 353 Retain: config.Will.Retain, 354 AtLeastOnce: config.Will.AtLeastOnce, 355 ExactlyOnce: config.Will.ExactlyOnce, 356 }, 357 KeepAlive: config.KeepAlive, 358 CleanSession: config.CleanSession, 359 } 360 361 return mqtt.VolatileSession(id, &processed_cfg) 362 } 363 364 func newMqttClient(L *lua.LState) int { 365 var config mqttConfig 366 if err := gluamapper.Map(L.CheckTable(1), &config); err != nil { 367 log.Println("newMqttClient:", err) 368 L.Push(lua.LNil) 369 L.Push(lua.LString(err.Error())) 370 return 2 371 } 372 373 cfgMap := stateCfgMap(L) 374 375 if cfg, found := cfgMap[config]; found { 376 res := L.RawGetInt(stateCnxTable(L), cfg.id) 377 tbl := res.(*lua.LTable) 378 if L.RawGetInt(tbl, keyClient).(*lua.LUserData).Value.(*mqtt.Client) != cfg.client { 379 panic("Inconsistent configuration table") 380 } 381 382 L.Push(res) 383 return 1 384 } 385 386 id, idString := stateClientNextId(L) 387 client, err := newClient(&config, idString) 388 if err != nil { 389 log.Println("newMqttClient:", err) 390 L.Push(lua.LNil) 391 L.Push(lua.LString(err.Error())) 392 return 2 393 } 394 go mqttRead(client, stateChanToLua(L), id) 395 396 cfgMap[config] = mqttClientEntry{id: id, client: client} 397 398 res := L.NewTable() 399 L.RawSetInt(res, keyClient, newUserData(L, client)) 400 L.RawSetInt(res, keySubTable, L.NewTable()) 401 L.SetMetatable(res, L.GetTypeMetatable(luaMqttClientTypeName)) 402 L.RawSetInt(stateCnxTable(L), id, res) 403 L.Push(res) 404 return 1 405 } 406 407 func luaPublish(L *lua.LState) int { 408 cnx := L.CheckTable(1) 409 client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client) 410 411 if L.GetTop() == 1 { 412 if err := client.Ping(nil); err != nil { 413 log.Println("luaPing:", err) 414 L.Push(lua.LNil) 415 L.Push(lua.LString(err.Error())) 416 return 2 417 } else { 418 L.Push(lua.LTrue) 419 return 1 420 } 421 } 422 423 message := L.CheckString(2) 424 topic := L.CheckString(3) 425 426 if err := client.Publish(nil, []byte(message), topic); err != nil { 427 L.Push(lua.LNil) 428 L.Push(lua.LString(err.Error())) 429 return 2 430 } else { 431 L.Push(lua.LTrue) 432 return 1 433 } 434 } 435 436 func luaQuery(L *lua.LState) int { 437 cnx := L.CheckTable(1) 438 topic := L.CheckString(2) 439 subTbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable) 440 L.Push(L.GetField(subTbl, topic)) 441 return 1 442 } 443 444 func luaSubscribe(L *lua.LState) int { 445 var err error 446 cnx := L.CheckTable(1) 447 topic := L.CheckString(2) 448 callback := L.OptFunction(3, nil) 449 client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client) 450 tbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable) 451 452 _, is_new := L.GetField(tbl, topic).(*lua.LNilType) 453 454 if callback == nil { 455 err = client.Unsubscribe(nil, topic) 456 } else if is_new { 457 err = client.Subscribe(nil, topic) 458 } 459 460 if err != nil { 461 log.Println("luaSubscribe:", err) 462 L.Push(lua.LNil) 463 L.Push(lua.LString(err.Error())) 464 return 2 465 } else { 466 if callback == nil { 467 if is_new { 468 log.Printf("Not subscribed to %q", topic) 469 } else { 470 log.Printf("Unsubscribed from %q", topic) 471 } 472 L.SetField(tbl, topic, lua.LNil) 473 } else { 474 if is_new { 475 log.Printf("Subscribed to %q", topic) 476 } else { 477 log.Printf("Updating subscription to %q", topic) 478 } 479 L.SetField(tbl, topic, callback) 480 } 481 482 L.Push(lua.LTrue) 483 return 1 484 } 485 } 486 487 /********** Lua Object for timers **********/ 488 489 const luaTimerTypeName = "timer" 490 491 func registerTimerType(L *lua.LState) { 492 mt := L.NewTypeMetatable(luaTimerTypeName) 493 L.SetGlobal(luaTimerTypeName, mt) 494 L.SetField(mt, "new", L.NewFunction(newTimer)) 495 L.SetField(mt, "schedule", L.NewFunction(timerSchedule)) 496 L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), timerMethods)) 497 } 498 499 func newTimer(L *lua.LState) int { 500 atTime := L.Get(1) 501 cb := L.CheckFunction(2) 502 L.Pop(2) 503 L.SetMetatable(cb, L.GetTypeMetatable(luaTimerTypeName)) 504 L.Push(cb) 505 L.Push(atTime) 506 return timerSchedule(L) 507 } 508 509 var timerMethods = map[string]lua.LGFunction{ 510 "cancel": timerCancel, 511 "schedule": timerSchedule, 512 } 513 514 func timerCancel(L *lua.LState) int { 515 timer := L.CheckFunction(1) 516 L.RawSet(stateTimerTable(L), timer, lua.LNil) 517 return 0 518 } 519 520 func timerSchedule(L *lua.LState) int { 521 timer := L.CheckFunction(1) 522 atTime := lua.LNil 523 if L.Get(2) != lua.LNil { 524 atTime = L.CheckNumber(2) 525 } 526 527 L.RawSet(stateTimerTable(L), timer, atTime) 528 return 0 529 } 530 531 func toTime(lsec lua.LNumber) time.Time { 532 fsec := float64(lsec) 533 sec := int64(fsec) 534 nsec := int64((fsec - float64(sec)) * 1.0e9) 535 536 return time.Unix(sec, nsec) 537 } 538 539 func runTimers(L *lua.LState) (bool, time.Time) { 540 hasNext := false 541 var nextTime time.Time 542 543 now := time.Now() 544 timers := stateTimerTable(L) 545 546 timer, luaT := timers.Next(lua.LNil) 547 for timer != lua.LNil { 548 t := toTime(luaT.(lua.LNumber)) 549 if t.Compare(now) <= 0 { 550 L.RawSet(timers, timer, lua.LNil) 551 err := L.CallByParam(lua.P{Fn: timer, NRet: 0, Protect: true}, timer, luaT) 552 if err != nil { 553 panic(err) 554 } 555 timer = lua.LNil 556 hasNext = false 557 } else if !hasNext || t.Compare(nextTime) < 0 { 558 hasNext = true 559 nextTime = t 560 } 561 562 timer, luaT = timers.Next(timer) 563 } 564 565 return hasNext, nextTime 566 }