commit e87bdc3f121ca60d0400b74db0aedf8db20d6cb5
parent 2e521db0447700b8c8461537505b4c2278c37d9e
Author: Natasha Kerensikova <natgh@instinctive.eu>
Date:   Sat, 11 Jan 2025 13:48:26 +0000
Timers are drafted in
Diffstat:
| M | mqttagent.go | | | 142 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------- | 
1 file changed, 128 insertions(+), 14 deletions(-)
diff --git a/mqttagent.go b/mqttagent.go
@@ -57,6 +57,7 @@ func Run(agent MqttAgent, main_script string) {
 	}
 
 	registerMqttClientType(L)
+	registerTimerType(L)
 	registerState(L, fmt.Sprintf("mqttagent-%s-%d", hostname, os.Getpid()), &fromMqtt)
 	defer cleanupClients(L)
 
@@ -64,25 +65,28 @@ func Run(agent MqttAgent, main_script string) {
 		panic(err)
 	}
 
+	timer := time.NewTimer(0)
+	defer timer.Stop()
+
 	for {
-		msg, ok := <-fromMqtt
+		select {
+		case msg, ok := <-fromMqtt:
 
-		if !ok {
-			break
-		}
+			if !ok {
+				break
+			}
 
-		agent.Log(L, &msg)
+			processMsg(L, agent, &msg)
 
-		cnx := L.RawGetInt(stateCnxTable(L), msg.ClientId).(*lua.LTable)
-		subTbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable)
-		L.ForEach(subTbl, func(key, value lua.LValue) { dispatchMsg(L, &msg, cnx, key, value) })
+		case <-timer.C:
+		}
 
-		if key, _ := subTbl.Next(lua.LNil); key == lua.LNil {
-			client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client)
-			if err := client.Disconnect(nil); err != nil {
-				log.Println(err)
-			}
-			L.RawSetInt(stateCnxTable(L), msg.ClientId, lua.LNil)
+		hasTimer, nextTimer := runTimers(L)
+
+		if hasTimer {
+			timer.Reset(time.Until(nextTimer))
+		} else {
+			timer.Stop()
 		}
 
 		if stateCnxTable(L).Len() == 0 {
@@ -152,6 +156,22 @@ func match(actual, filter string) bool {
 	return matchSliced(strings.Split(actual, "/"), strings.Split(filter, "/"))
 }
 
+func processMsg(L *lua.LState, agent MqttAgent, msg *MqttMessage) {
+	agent.Log(L, msg)
+
+	cnx := L.RawGetInt(stateCnxTable(L), msg.ClientId).(*lua.LTable)
+	subTbl := L.RawGetInt(cnx, keySubTable).(*lua.LTable)
+	L.ForEach(subTbl, func(key, value lua.LValue) { dispatchMsg(L, msg, cnx, key, value) })
+
+	if key, _ := subTbl.Next(lua.LNil); key == lua.LNil {
+		client := L.RawGetInt(cnx, keyClient).(*lua.LUserData).Value.(*mqtt.Client)
+		if err := client.Disconnect(nil); err != nil {
+			log.Println(err)
+		}
+		L.RawSetInt(stateCnxTable(L), msg.ClientId, lua.LNil)
+	}
+}
+
 func mqttRead(client *mqtt.Client, toLua chan MqttMessage, id int) {
 	var big *mqtt.BigMessage
 
@@ -188,6 +208,7 @@ const luaStateName = "_mqttagent"
 const keyChanToLua = 1
 const keyClientPrefix = 2
 const keyCnxTable = 3
+const keyTimerTable = 4
 
 func registerState(L *lua.LState, clientPrefix string, toLua *chan MqttMessage) {
 	ud := L.NewUserData()
@@ -197,6 +218,7 @@ func registerState(L *lua.LState, clientPrefix string, toLua *chan MqttMessage) 
 	L.RawSetInt(st, keyChanToLua, ud)
 	L.RawSetInt(st, keyClientPrefix, lua.LString(clientPrefix))
 	L.RawSetInt(st, keyCnxTable, L.NewTable())
+	L.RawSetInt(st, keyTimerTable, L.NewTable())
 	L.SetGlobal(luaStateName, st)
 }
 
@@ -218,6 +240,10 @@ func stateCnxTable(L *lua.LState) *lua.LTable {
 	return stateValue(L, keyCnxTable).(*lua.LTable)
 }
 
+func stateTimerTable(L *lua.LState) *lua.LTable {
+	return stateValue(L, keyTimerTable).(*lua.LTable)
+}
+
 /********** Lua Object for MQTT client **********/
 
 const luaMqttClientTypeName = "mqttclient"
@@ -377,3 +403,91 @@ func luaSubscribe(L *lua.LState) int {
 		return 1
 	}
 }
+
+/********** Lua Object for timers **********/
+
+const luaTimerTypeName = "timer"
+const keyTime = 1
+const keyCallback = 2
+
+func registerTimerType(L *lua.LState) {
+	mt := L.NewTypeMetatable(luaTimerTypeName)
+	L.SetGlobal(luaTimerTypeName, mt)
+	L.SetField(mt, "new", L.NewFunction(newTimer))
+	L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), timerMethods))
+}
+
+func newTimer(L *lua.LState) int {
+	atTime := L.CheckNumber(1)
+	cb := L.CheckFunction(2)
+
+	tbl := L.NewTable()
+	L.RawSetInt(tbl, keyTime, atTime)
+	L.RawSetInt(tbl, keyCallback, cb)
+	L.SetMetatable(tbl, L.GetTypeMetatable(luaTimerTypeName))
+	stateTimerTable(L).Append(tbl)
+	L.Push(tbl)
+	return 1
+}
+
+var timerMethods = map[string]lua.LGFunction{
+	"cancel":   timerCancel,
+	"schedule": timerSchedule,
+}
+
+func timerCancel(L *lua.LState) int {
+	tbl := L.CheckTable(1)
+	L.RawSetInt(tbl, keyTime, lua.LNil)
+	return 0
+}
+
+func timerSchedule(L *lua.LState) int {
+	tbl := L.CheckTable(1)
+	atTime := L.CheckNumber(2)
+	L.RawSetInt(tbl, keyTime, atTime)
+	return 0
+}
+
+func toTime(v lua.LValue, d time.Time) (time.Time, bool) {
+	lsec, ok := v.(lua.LNumber)
+	if !ok {
+		return d, false
+	}
+
+	fsec := float64(lsec)
+	sec := int64(fsec)
+	nsec := int64((fsec - float64(sec)) * 1.0e9)
+
+	return time.Unix(sec, nsec), true
+}
+
+func runTimers(L *lua.LState) (bool, time.Time) {
+	hasNext := false
+	var nextTime time.Time
+
+	now := time.Now()
+	timers := stateTimerTable(L)
+
+	k, v := timers.Next(lua.LNil)
+	for k != lua.LNil {
+		tbl := v.(*lua.LTable)
+		t, ok := toTime(L.RawGetInt(tbl, keyTime), now)
+		if !ok {
+		} else if t.Compare(now) <= 0 {
+			L.RawSetInt(tbl, keyTime, lua.LNil)
+			err := L.CallByParam(lua.P{Fn: L.RawGetInt(tbl, keyCallback), NRet: 0, Protect: true}, v, k)
+			if err != nil {
+				panic(err)
+			}
+			k = lua.LNil
+			hasNext = false
+		} else if !hasNext || t.Compare(nextTime) < 0 {
+			hasNext = true
+			nextTime = t
+		}
+
+		k, v = timers.Next(k)
+	}
+
+	return hasNext, nextTime
+}