mqttagent.go (11949B)
1 /* 2 * Copyright (c) 2025, Natacha Porté 3 * 4 * Permission to use, copy, modify, and distribute this software for any 5 * purpose with or without fee is hereby granted, provided that the above 6 * copyright notice and this permission notice appear in all copies. 7 * 8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 */ 16 17 package mqttagent 18 19 import ( 20 "errors" 21 "fmt" 22 "log" 23 "os" 24 "strings" 25 "time" 26 27 "github.com/go-mqtt/mqtt" 28 "github.com/yuin/gluamapper" 29 "github.com/yuin/gopher-lua" 30 ) 31 32 type MqttAgent interface { 33 Setup(L *lua.LState) 34 Log(L *lua.LState, msg *MqttMessage) 35 Teardown(L *lua.LState) 36 } 37 38 type MqttMessage struct { 39 Timestamp float64 40 ClientId int 41 Topic []byte 42 Message []byte 43 } 44 45 func Run(agent MqttAgent, main_script string) { 46 fromMqtt := make(chan MqttMessage) 47 48 L := lua.NewState() 49 defer L.Close() 50 51 agent.Setup(L) 52 defer agent.Teardown(L) 53 54 hostname, err := os.Hostname() 55 if err != nil { 56 hostname = "<unknown>" 57 } 58 59 idString := fmt.Sprintf("mqttagent-%s-%d", hostname, os.Getpid()) 60 registerMqttClientType(L) 61 registerTimerType(L) 62 registerState(L, idString, fromMqtt) 63 defer cleanupClients(L) 64 65 if err := L.DoFile(main_script); err != nil { 66 panic(err) 67 } 68 69 timer := time.NewTimer(0) 70 defer timer.Stop() 71 72 log.Println(idString, "started") 73 74 for { 75 select { 76 case msg, ok := <-fromMqtt: 77 78 if !ok { 79 log.Println("fromMqtt is closed") 80 break 81 } 82 83 processMsg(L, agent, &msg) 84 85 case <-timer.C: 86 } 87 88 hasTimer, nextTimer := runTimers(L) 89 90 if hasTimer { 91 timer.Reset(time.Until(nextTimer)) 92 } else { 93 timer.Stop() 94 } 95 96 if tableIsEmpty(stateCnxTable(L)) && tableIsEmpty(stateTimerTable(L)) { 97 break 98 } 99 } 100 101 log.Println(idString, "finished") 102 } 103 104 func cleanupClients(L *lua.LState) { 105 cnxTbl := stateCnxTable(L) 106 if cnxTbl == nil { 107 return 108 } 109 110 L.ForEach(cnxTbl, func(key, value lua.LValue) { 111 cnx := value.(*lua.LTable) 112 client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client) 113 if err := client.Disconnect(nil); err != nil { 114 log.Printf("cleanup client %s: %v", lua.LVAsString(key), err) 115 } 116 }) 117 } 118 119 func dispatchMsg(L *lua.LState, msg *MqttMessage, cnx, key, value lua.LValue) { 120 skey, ok := key.(lua.LString) 121 topic := string(msg.Topic) 122 123 if ok && match(topic, string(skey)) { 124 err := L.CallByParam(lua.P{Fn: value, NRet: 0, Protect: true}, 125 cnx, 126 lua.LString(string(msg.Message)), 127 lua.LString(topic), 128 lua.LNumber(msg.Timestamp)) 129 if err != nil { 130 panic(err) 131 } 132 } 133 } 134 135 func matchSliced(actual, filter []string) bool { 136 if len(filter) == 0 { 137 return len(actual) == 0 138 } 139 140 if filter[0] == "#" { 141 if len(filter) == 1 { 142 return true 143 } 144 145 for i := range actual { 146 if matchSliced(actual[i:], filter[1:]) { 147 return true 148 } 149 } 150 151 return false 152 } 153 154 if len(actual) > 0 && (filter[0] == "+" || filter[0] == actual[0]) { 155 return matchSliced(actual[1:], filter[1:]) 156 } 157 158 return false 159 } 160 161 func match(actual, filter string) bool { 162 return matchSliced(strings.Split(actual, "/"), strings.Split(filter, "/")) 163 } 164 165 func tableIsEmpty(t *lua.LTable) bool { 166 key, _ := t.Next(lua.LNil) 167 return key == lua.LNil 168 } 169 170 func processMsg(L *lua.LState, agent MqttAgent, msg *MqttMessage) { 171 agent.Log(L, msg) 172 173 cnx := L.RawGetInt(stateCnxTable(L), msg.ClientId).(*lua.LTable) 174 subTbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable) 175 L.ForEach(subTbl, func(key, value lua.LValue) { dispatchMsg(L, msg, cnx, key, value) }) 176 177 if tableIsEmpty(subTbl) { 178 client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client) 179 if err := client.Disconnect(nil); err != nil { 180 log.Println("disconnect empty client:", err) 181 } 182 L.RawSetInt(stateCnxTable(L), msg.ClientId, lua.LNil) 183 } 184 } 185 186 func mqttRead(client *mqtt.Client, toLua chan<- MqttMessage, id int) { 187 var big *mqtt.BigMessage 188 189 for { 190 message, topic, err := client.ReadSlices() 191 t := float64(time.Now().UnixMicro()) * 1.0e-6 192 193 switch { 194 case err == nil: 195 toLua <- MqttMessage{Timestamp: t, ClientId: id, Topic: dup(topic), Message: dup(message)} 196 197 case errors.As(err, &big): 198 data, err := big.ReadAll() 199 if err != nil { 200 log.Println("mqttRead big message:", err) 201 } else { 202 toLua <- MqttMessage{Timestamp: t, ClientId: id, Topic: dup(topic), Message: data} 203 } 204 205 case errors.Is(err, mqtt.ErrClosed): 206 log.Println("mqttRead finishing:", err) 207 return 208 209 case mqtt.IsConnectionRefused(err): 210 log.Println("mqttRead connection refused:", err) 211 time.Sleep(15 * time.Minute) 212 213 default: 214 log.Println("mqttRead:", err) 215 time.Sleep(2 * time.Second) 216 } 217 } 218 } 219 220 func dup(src []byte) []byte { 221 res := make([]byte, len(src)) 222 copy(res, src) 223 return res 224 } 225 226 /********** State Object in the Lua Interpreter **********/ 227 228 const luaStateName = "_mqttagent" 229 const keyChanToLua = 1 230 const keyClientPrefix = 2 231 const keyCnxTable = 3 232 const keyTimerTable = 4 233 234 func registerState(L *lua.LState, clientPrefix string, toLua chan<- MqttMessage) { 235 ud := L.NewUserData() 236 ud.Value = toLua 237 238 st := L.NewTable() 239 L.RawSetInt(st, keyChanToLua, ud) 240 L.RawSetInt(st, keyClientPrefix, lua.LString(clientPrefix)) 241 L.RawSetInt(st, keyCnxTable, L.NewTable()) 242 L.RawSetInt(st, keyTimerTable, L.NewTable()) 243 L.SetGlobal(luaStateName, st) 244 } 245 246 func stateValue(L *lua.LState, key int) lua.LValue { 247 st := L.GetGlobal(luaStateName) 248 return L.RawGetInt(st.(*lua.LTable), key) 249 } 250 251 func stateChanToLua(L *lua.LState) chan<- MqttMessage { 252 ud := stateValue(L, keyChanToLua) 253 return ud.(*lua.LUserData).Value.(chan<- MqttMessage) 254 } 255 256 func stateClientPrefix(L *lua.LState) string { 257 return lua.LVAsString(stateValue(L, keyClientPrefix)) 258 } 259 260 func stateCnxTable(L *lua.LState) *lua.LTable { 261 return stateValue(L, keyCnxTable).(*lua.LTable) 262 } 263 264 func stateTimerTable(L *lua.LState) *lua.LTable { 265 return stateValue(L, keyTimerTable).(*lua.LTable) 266 } 267 268 /********** Lua Object for MQTT client **********/ 269 270 const luaMqttClientTypeName = "mqttclient" 271 const keyClient = 1 272 const keySubTable = 2 273 274 func registerMqttClientType(L *lua.LState) { 275 mt := L.NewTypeMetatable(luaMqttClientTypeName) 276 L.SetGlobal(luaMqttClientTypeName, mt) 277 L.SetField(mt, "new", L.NewFunction(newMqttClient)) 278 L.SetField(mt, "__gc", L.NewFunction(deleteMqttClient)) 279 L.SetField(mt, "__call", L.NewFunction(luaPublish)) 280 L.SetField(mt, "__index", L.NewFunction(luaQuery)) 281 L.SetField(mt, "__newindex", L.NewFunction(luaSubscribe)) 282 } 283 284 type mqttConfig struct { 285 Connection string 286 PauseTimeout string 287 AtLeastOnceMax int 288 ExactlyOnceMax int 289 UserName string 290 Password []byte 291 Will struct { 292 Topic string 293 Message []byte 294 Retain bool 295 AtLeastOnce bool 296 ExactlyOnce bool 297 } 298 KeepAlive uint16 299 CleanSession bool 300 } 301 302 func newClient(L *lua.LState, id string) (*mqtt.Client, error) { 303 var config mqttConfig 304 if err := gluamapper.Map(L.CheckTable(1), &config); err != nil { 305 return nil, err 306 } 307 308 pto, err := time.ParseDuration(config.PauseTimeout) 309 if err != nil { 310 pto = time.Second 311 } 312 313 processed_cfg := mqtt.Config{ 314 Dialer: mqtt.NewDialer("tcp", config.Connection), 315 PauseTimeout: pto, 316 AtLeastOnceMax: config.AtLeastOnceMax, 317 ExactlyOnceMax: config.ExactlyOnceMax, 318 UserName: config.UserName, 319 Password: config.Password, 320 Will: struct { 321 Topic string 322 Message []byte 323 Retain bool 324 AtLeastOnce bool 325 ExactlyOnce bool 326 }{ 327 Topic: config.Will.Topic, 328 Message: config.Will.Message, 329 Retain: config.Will.Retain, 330 AtLeastOnce: config.Will.AtLeastOnce, 331 ExactlyOnce: config.Will.ExactlyOnce, 332 }, 333 KeepAlive: config.KeepAlive, 334 CleanSession: config.CleanSession, 335 } 336 337 return mqtt.VolatileSession(id, &processed_cfg) 338 } 339 340 func newMqttClient(L *lua.LState) int { 341 id := stateCnxTable(L).Len() + 1 342 idString := fmt.Sprintf("%s-%d", stateClientPrefix(L), id) 343 344 client, err := newClient(L, idString) 345 if err != nil { 346 log.Println("newMqttClient:", err) 347 L.Push(lua.LNil) 348 L.Push(lua.LString(err.Error())) 349 return 2 350 } 351 go mqttRead(client, stateChanToLua(L), id) 352 353 ud := L.NewUserData() 354 ud.Value = client 355 356 res := L.NewTable() 357 L.RawSetInt(res, keyClient, ud) 358 L.RawSetInt(res, keySubTable, L.NewTable()) 359 L.SetMetatable(res, L.GetTypeMetatable(luaMqttClientTypeName)) 360 L.RawSetInt(stateCnxTable(L), id, res) 361 L.Push(res) 362 return 1 363 } 364 365 func deleteMqttClient(L *lua.LState) int { 366 log.Println("deleteMqttClient: TODO") 367 return 0 368 } 369 370 func luaPublish(L *lua.LState) int { 371 cnx := L.CheckTable(1) 372 message := L.CheckString(2) 373 topic := L.CheckString(3) 374 client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client) 375 376 err := client.Publish(nil, []byte(message), topic) 377 378 if err != nil { 379 L.Push(lua.LNil) 380 L.Push(lua.LString(err.Error())) 381 return 2 382 } else { 383 L.Push(lua.LTrue) 384 return 1 385 } 386 } 387 388 func luaQuery(L *lua.LState) int { 389 cnx := L.CheckTable(1) 390 topic := L.CheckString(2) 391 subTbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable) 392 L.Push(L.GetField(subTbl, topic)) 393 return 1 394 } 395 396 func luaSubscribe(L *lua.LState) int { 397 var err error 398 cnx := L.CheckTable(1) 399 topic := L.CheckString(2) 400 callback := L.OptFunction(3, nil) 401 client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client) 402 403 if callback == nil { 404 err = client.Unsubscribe(nil, topic) 405 } else { 406 err = client.Subscribe(nil, topic) 407 } 408 409 if err != nil { 410 log.Println("luaSubscribe:", err) 411 L.Push(lua.LNil) 412 L.Push(lua.LString(err.Error())) 413 return 2 414 } else { 415 tbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable) 416 417 if callback == nil { 418 L.SetField(tbl, topic, lua.LNil) 419 } else { 420 L.SetField(tbl, topic, callback) 421 } 422 423 L.Push(lua.LTrue) 424 return 1 425 } 426 } 427 428 /********** Lua Object for timers **********/ 429 430 const luaTimerTypeName = "timer" 431 const keyTime = 1 432 const keyCallback = 2 433 434 func registerTimerType(L *lua.LState) { 435 mt := L.NewTypeMetatable(luaTimerTypeName) 436 L.SetGlobal(luaTimerTypeName, mt) 437 L.SetField(mt, "new", L.NewFunction(newTimer)) 438 L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), timerMethods)) 439 } 440 441 func newTimer(L *lua.LState) int { 442 atTime := L.CheckNumber(1) 443 cb := L.CheckFunction(2) 444 445 tbl := L.NewTable() 446 L.RawSetInt(tbl, keyTime, atTime) 447 L.RawSetInt(tbl, keyCallback, cb) 448 L.SetMetatable(tbl, L.GetTypeMetatable(luaTimerTypeName)) 449 stateTimerTable(L).Append(tbl) 450 L.Push(tbl) 451 return 1 452 } 453 454 var timerMethods = map[string]lua.LGFunction{ 455 "cancel": timerCancel, 456 "schedule": timerSchedule, 457 } 458 459 func timerCancel(L *lua.LState) int { 460 tbl := L.CheckTable(1) 461 L.RawSetInt(tbl, keyTime, lua.LNil) 462 return 0 463 } 464 465 func timerSchedule(L *lua.LState) int { 466 tbl := L.CheckTable(1) 467 atTime := L.CheckNumber(2) 468 L.RawSetInt(tbl, keyTime, atTime) 469 return 0 470 } 471 472 func toTime(v lua.LValue, d time.Time) (time.Time, bool) { 473 lsec, ok := v.(lua.LNumber) 474 if !ok { 475 return d, false 476 } 477 478 fsec := float64(lsec) 479 sec := int64(fsec) 480 nsec := int64((fsec - float64(sec)) * 1.0e9) 481 482 return time.Unix(sec, nsec), true 483 } 484 485 func runTimers(L *lua.LState) (bool, time.Time) { 486 hasNext := false 487 var nextTime time.Time 488 489 now := time.Now() 490 timers := stateTimerTable(L) 491 492 k, v := timers.Next(lua.LNil) 493 for k != lua.LNil { 494 tbl := v.(*lua.LTable) 495 luaT := L.RawGetInt(tbl, keyTime) 496 t, ok := toTime(luaT, now) 497 if !ok { 498 } else if t.Compare(now) <= 0 { 499 L.RawSetInt(tbl, keyTime, lua.LNil) 500 err := L.CallByParam(lua.P{Fn: L.RawGetInt(tbl, keyCallback), NRet: 0, Protect: true}, v, luaT) 501 if err != nil { 502 panic(err) 503 } 504 k = lua.LNil 505 hasNext = false 506 } else if !hasNext || t.Compare(nextTime) < 0 { 507 hasNext = true 508 nextTime = t 509 } 510 511 k, v = timers.Next(k) 512 } 513 514 return hasNext, nextTime 515 }