natsbot.go (23454B)
1 /* 2 * Copyright (c) 2025, Natacha Porté 3 * 4 * Permission to use, copy, modify, and distribute this software for any 5 * purpose with or without fee is hereby granted, provided that the above 6 * copyright notice and this permission notice appear in all copies. 7 * 8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 */ 16 17 package natsbot 18 19 import ( 20 "errors" 21 "fmt" 22 "log" 23 "strings" 24 "time" 25 26 "github.com/nats-io/nats.go" 27 "github.com/yuin/gopher-lua" 28 ) 29 30 type NatsBot interface { 31 Setup(L *lua.LState) 32 Teardown(L *lua.LState) 33 } 34 35 type internalEvent struct { 36 name string 37 nc *nats.Conn 38 subs *nats.Subscription 39 err error 40 } 41 42 func Loop(cb NatsBot, mainScript string, capacity int) { 43 evtChan := make(chan *internalEvent, capacity) 44 msgChan := make(chan *nats.Msg, capacity) 45 toClean := make(map[*nats.Subscription]bool) 46 47 L := lua.NewState() 48 defer L.Close() 49 50 cb.Setup(L) 51 defer cb.Teardown(L) 52 53 registerConnType(L) 54 registerTimerType(L) 55 registerState(L, evtChan, msgChan) 56 57 if err := L.DoFile(mainScript); err != nil { 58 panic(err) 59 } 60 61 timer := time.NewTimer(0) 62 defer timer.Stop() 63 64 log.Println("natsbot started") 65 66 for { 67 select { 68 case evt, ok := <-evtChan: 69 if !ok { 70 log.Println("evtChan is closed") 71 break 72 } 73 74 processEvt(L, evt) 75 76 case msg, ok := <-msgChan: 77 78 if !ok { 79 log.Println("msgChan is closed") 80 break 81 } 82 83 processMsg(L, msg) 84 85 if !msg.Sub.IsValid() { 86 toClean[msg.Sub] = true 87 } 88 89 case <-timer.C: 90 } 91 92 runTimers(L, timer) 93 94 if len(msgChan) == 0 { 95 for s := range toClean { 96 if !s.IsValid() { 97 log.Printf("Pruning subscription %q", s.Subject) 98 tbl, idx := stateSubsTable(L) 99 L.RawSetInt(tbl, idx[s], lua.LNil) 100 delete(idx, s) 101 deleteFromSubsCfgMap(L, s) 102 } else { 103 log.Printf("Subscription %q is still valid", s.Subject) 104 } 105 } 106 toClean = make(map[*nats.Subscription]bool) 107 } 108 109 if stateReloadRequested(L) { 110 L = reload(cb, L, mainScript) 111 runTimers(L, timer) 112 stateRequestReload(L, lua.LNil) 113 } 114 115 if tableWithIndexIsEmpty(stateSubsTable(L)) && tableIsEmpty(stateTimerTable(L)) { 116 log.Println("No callback remaining") 117 break 118 } 119 } 120 121 stateClean(L, nil) 122 log.Println("natsbot finished") 123 } 124 125 func processEvt(L *lua.LState, evt *internalEvent) { 126 tbl, idx := stateConnTable(L) 127 connLua := L.RawGetInt(tbl, idx[evt.nc]) 128 fn := L.GetField(L.GetField(L.GetMetatable(connLua), "__index").(*lua.LTable), evt.name) 129 130 if lua.LVIsFalse(fn) { 131 connS := strings.Join(evt.nc.Servers(), "|") 132 subsS := "" 133 errS := "" 134 135 if evt.subs != nil { 136 subsS = fmt.Sprintf(" on %q", evt.subs.Subject) 137 } 138 if evt.err != nil { 139 errS = ": " + evt.err.Error() 140 } 141 142 log.Printf("Event %s on %s%s%s", evt.name, connS, subsS, errS) 143 return 144 } 145 146 var subLua lua.LValue 147 if evt.subs == nil { 148 subLua = lua.LNil 149 } else { 150 tbl, idx := stateSubsTable(L) 151 subLua = L.RawGetInt(tbl, idx[evt.subs]) 152 } 153 154 var errLua lua.LValue 155 if evt.err == nil { 156 errLua = lua.LNil 157 } else { 158 errLua = lua.LString(evt.err.Error()) 159 } 160 161 err := L.CallByParam(lua.P{Fn: fn, NRet: 0, Protect: true}, 162 connLua, 163 subLua, 164 errLua) 165 if err != nil { 166 panic(err) 167 } 168 } 169 170 func processMsg(L *lua.LState, msg *nats.Msg) { 171 tbl, idx := stateSubsTable(L) 172 id, found := idx[msg.Sub] 173 if !found { 174 log.Printf("Got message for stale subscription to %q", msg.Sub.Subject) 175 return 176 } 177 178 subs := L.RawGetInt(tbl, id) 179 fn := L.GetField(L.GetMetatable(subs), "__call") 180 err := L.CallByParam(lua.P{Fn: fn, NRet: 0, Protect: true}, 181 subs, 182 lua.LString(msg.Subject), 183 lua.LString(string(msg.Data)), 184 luaHeaders(L, msg.Reply, msg.Header)) 185 if err != nil { 186 panic(err) 187 } 188 } 189 190 func reload(cb NatsBot, oldL *lua.LState, mainScript string) *lua.LState { 191 log.Println("Reloading", mainScript) 192 193 newL := lua.NewState() 194 cb.Setup(newL) 195 registerConnType(newL) 196 registerTimerType(newL) 197 stateReloadBegin(oldL, newL) 198 199 if err := newL.DoFile(mainScript); err != nil { 200 log.Println("Reload failed:", err) 201 stateReloadAbort(oldL, newL) 202 newL.Close() 203 return oldL 204 } else { 205 stateReloadEnd(oldL, newL) 206 oldL.Close() 207 log.Println("Reload successful") 208 return newL 209 } 210 } 211 212 /********** State Object in the Lua Interpreter **********/ 213 214 const luaStateName = "_natsbot" 215 const ( 216 _ = iota 217 keyEvtChan 218 keyMsgChan 219 keyCfgMap 220 keyConnTable 221 keySubsCfgMap 222 keySubsTable 223 keyTimerTable 224 keyReloadRequest 225 keyOldCfgMap 226 keyOldSubsCfgMap 227 ) 228 229 func registerState(L *lua.LState, evtChan chan *internalEvent, msgChan chan *nats.Msg) { 230 st := L.NewTable() 231 L.RawSetInt(st, keyEvtChan, newUserData(L, evtChan)) 232 L.RawSetInt(st, keyMsgChan, newUserData(L, msgChan)) 233 L.RawSetInt(st, keyCfgMap, newUserData(L, make(natsConfigMap))) 234 L.RawSetInt(st, keyConnTable, newConnTbl(L)) 235 L.RawSetInt(st, keySubsCfgMap, newUserData(L, make(subsCfgMap))) 236 L.RawSetInt(st, keySubsTable, newSubsTbl(L)) 237 L.RawSetInt(st, keyTimerTable, L.NewTable()) 238 stateSet(L, st) 239 240 L.SetGlobal("reload", L.NewFunction(requestReload)) 241 } 242 243 func stateReloadBegin(oldL, newL *lua.LState) { 244 evtChan := stateEvtChan(oldL) 245 msgChan := stateMsgChan(oldL) 246 cfgMap := stateCfgMap(oldL) 247 subsCfgMap := stateSubsCfgMap(oldL) 248 249 registerState(newL, evtChan, msgChan) 250 newL.RawSetInt(stateGet(newL), keyOldCfgMap, newUserData(newL, cfgMap)) 251 newL.RawSetInt(stateGet(newL), keyOldSubsCfgMap, newUserData(newL, subsCfgMap)) 252 } 253 254 func stateReloadAbort(oldL, newL *lua.LState) { 255 stateClean(newL, oldL) 256 } 257 258 func stateReloadEnd(oldL, newL *lua.LState) { 259 stateClean(oldL, newL) 260 newL.RawSetInt(stateGet(newL), keyOldCfgMap, lua.LNil) 261 newL.RawSetInt(stateGet(newL), keyOldSubsCfgMap, lua.LNil) 262 } 263 264 func stateClean(L, keptL *lua.LState) { 265 _, connIdx := stateConnTable(L) 266 _, subsIdx := stateSubsTable(L) 267 268 st := stateGet(L) 269 L.RawSetInt(st, keyConnTable, newConnTbl(L)) 270 L.RawSetInt(st, keySubsTable, newSubsTbl(L)) 271 272 var keptConn connMap 273 var keptSubs subsMap 274 if keptL != nil { 275 _, keptConn = stateConnTable(keptL) 276 _, keptSubs = stateSubsTable(keptL) 277 } 278 279 for nc := range connIdx { 280 if _, found := keptConn[nc]; found { 281 continue 282 } 283 284 nc.SetClosedHandler(nil) 285 nc.SetDisconnectErrHandler(nil) 286 nc.SetDiscoveredServersHandler(nil) 287 nc.SetErrorHandler(nil) 288 nc.SetReconnectHandler(nil) 289 nc.Close() 290 } 291 292 for ns := range subsIdx { 293 if _, found := keptSubs[ns]; found { 294 continue 295 } 296 297 if ns.IsValid() { 298 if err := ns.Unsubscribe(); err != nil { 299 log.Println("Unsubscribe:", err) 300 } 301 } 302 } 303 } 304 305 func stateUncheckedGet(L *lua.LState) *lua.LTable { 306 v := L.GetField(L.Get(lua.RegistryIndex), luaStateName) 307 if result, ok := v.(*lua.LTable); ok { 308 return result 309 } else { 310 return nil 311 } 312 } 313 314 func stateGet(L *lua.LState) *lua.LTable { 315 result := stateUncheckedGet(L) 316 if result == nil { 317 panic("Missing internal state object") 318 } 319 return result 320 } 321 322 func stateSet(L *lua.LState, newState *lua.LTable) { 323 if stateUncheckedGet(L) != nil { 324 panic("Overwriting internal state object") 325 } 326 L.SetField(L.Get(lua.RegistryIndex), luaStateName, newState) 327 } 328 329 func stateValue(L *lua.LState, key int) lua.LValue { 330 return L.RawGetInt(stateGet(L), key) 331 } 332 333 func stateEvtChan(L *lua.LState) chan *internalEvent { 334 ud := stateValue(L, keyEvtChan) 335 return ud.(*lua.LUserData).Value.(chan *internalEvent) 336 } 337 338 func stateMsgChan(L *lua.LState) chan *nats.Msg { 339 ud := stateValue(L, keyMsgChan) 340 return ud.(*lua.LUserData).Value.(chan *nats.Msg) 341 } 342 343 func stateCfgMap(L *lua.LState) natsConfigMap { 344 return stateValue(L, keyCfgMap).(*lua.LUserData).Value.(natsConfigMap) 345 } 346 347 func stateConnTable(L *lua.LState) (*lua.LTable, connMap) { 348 tbl := stateValue(L, keyConnTable).(*lua.LTable) 349 idx := L.RawGetInt(tbl, keyIndex).(*lua.LUserData).Value.(connMap) 350 return tbl, idx 351 } 352 353 func stateSubsCfgMap(L *lua.LState) subsCfgMap { 354 ud := stateValue(L, keySubsCfgMap) 355 return ud.(*lua.LUserData).Value.(subsCfgMap) 356 } 357 358 func stateSubsTable(L *lua.LState) (*lua.LTable, subsMap) { 359 tbl := stateValue(L, keySubsTable).(*lua.LTable) 360 idx := L.RawGetInt(tbl, keyIndex).(*lua.LUserData).Value.(subsMap) 361 return tbl, idx 362 } 363 364 func stateTimerTable(L *lua.LState) *lua.LTable { 365 return stateValue(L, keyTimerTable).(*lua.LTable) 366 } 367 368 func stateReloadRequested(L *lua.LState) bool { 369 return lua.LVAsBool(stateValue(L, keyReloadRequest)) 370 } 371 372 func stateRequestReload(L *lua.LState, v lua.LValue) { 373 L.RawSetInt(stateGet(L), keyReloadRequest, v) 374 } 375 376 func stateOldCfgMap(L *lua.LState) natsConfigMap { 377 v := stateValue(L, keyOldCfgMap) 378 if v == lua.LNil { 379 return nil 380 } else { 381 return v.(*lua.LUserData).Value.(natsConfigMap) 382 } 383 } 384 385 func requestReload(L *lua.LState) int { 386 stateRequestReload(L, lua.LTrue) 387 return 0 388 } 389 390 func addToSubsCfgMap(L *lua.LState, nc *nats.Conn, s *nats.Subscription) { 391 cfgMap := stateSubsCfgMap(L) 392 subsKey := subsCfg{subject: s.Subject, queue: s.Queue} 393 cmap, found := cfgMap[subsKey] 394 if !found { 395 cmap = make(map[*nats.Conn][]*nats.Subscription) 396 cfgMap[subsKey] = cmap 397 } 398 cmap[nc] = append(cmap[nc], s) 399 } 400 401 func deleteFromSubsCfgMap(L *lua.LState, s *nats.Subscription) { 402 wholeMap := stateSubsCfgMap(L) 403 subsKey := subsCfg{subject: s.Subject, queue: s.Queue} 404 cMap := wholeMap[subsKey] 405 for nc, subsArray := range cMap { 406 n := 0 407 for _, ns := range subsArray { 408 if ns != s { 409 subsArray[n] = s 410 n++ 411 } 412 } 413 if n > 0 { 414 cMap[nc] = subsArray[:n] 415 } else { 416 delete(cMap, nc) 417 } 418 } 419 if len(cMap) == 0 { 420 delete(wholeMap, subsKey) 421 } 422 } 423 424 func findInOldSubsCfgMap(L *lua.LState, nc *nats.Conn, key subsCfg) *nats.Subscription { 425 ud := stateValue(L, keyOldSubsCfgMap) 426 if ud == lua.LNil { 427 return nil 428 } 429 430 cMap, found := ud.(*lua.LUserData).Value.(subsCfgMap)[key] 431 if !found { 432 return nil 433 } 434 435 subsArray, found := cMap[nc] 436 if !found { 437 return nil 438 } 439 440 _, knownSubsMap := stateSubsTable(L) 441 for _, s := range subsArray { 442 _, found = knownSubsMap[s] 443 if !found { 444 return s 445 } 446 } 447 448 return nil 449 } 450 451 /********** NATS Connection Configuration **********/ 452 453 type natsConfig struct { 454 url string 455 name string 456 nkey string 457 user string 458 password string 459 token string 460 retry bool 461 } 462 463 type natsConfigMap map[natsConfig]*nats.Conn 464 465 type natsCbMap map[string]*lua.LFunction 466 467 func connConfig(L *lua.LState) (*natsConfig, natsCbMap, error) { 468 arg := L.Get(1) 469 if url, ok := arg.(lua.LString); ok { 470 if cfg, cbmap, err := toConfig(L, L.Get(2)); err != nil { 471 return nil, nil, err 472 } else if cfg.url == "" { 473 cfg.url = string(url) 474 return cfg, cbmap, nil 475 } else if cfg.url != string(url) { 476 return nil, nil, fmt.Errorf("incompatible URLs %q and %q", cfg.url, string(url)) 477 } else { 478 return cfg, cbmap, nil 479 } 480 } else { 481 return toConfig(L, arg) 482 } 483 } 484 485 func toConfig(L *lua.LState, lv lua.LValue) (*natsConfig, natsCbMap, error) { 486 tbl, ok := lv.(*lua.LTable) 487 if !ok { 488 return nil, nil, errors.New("configuration is not a table") 489 } 490 491 var result natsConfig 492 cbmap := make(natsCbMap) 493 var errStr []string 494 495 L.ForEach(tbl, func(key, value lua.LValue) { 496 skey, ok := key.(lua.LString) 497 if !ok { 498 errStr = append(errStr, fmt.Sprintf("bad key: %q", lua.LVAsString(key))) 499 return 500 } 501 502 switch skey { 503 case "closed": 504 fallthrough 505 case "disconnected": 506 fallthrough 507 case "error": 508 fallthrough 509 case "reconnect_error": 510 fallthrough 511 case "reconnected": 512 if s, ok := value.(*lua.LFunction); ok { 513 cbmap[string(skey)] = s 514 } else { 515 errStr = append(errStr, fmt.Sprintf("bad value for key %q: %q", skey, lua.LVAsString(value))) 516 } 517 518 case "url": 519 if s, ok := value.(lua.LString); ok { 520 result.url = string(s) 521 } else { 522 errStr = append(errStr, fmt.Sprintf("bad value for key %q: %q", skey, lua.LVAsString(value))) 523 } 524 525 case "name": 526 if s, ok := value.(lua.LString); ok { 527 result.name = string(s) 528 } else { 529 errStr = append(errStr, fmt.Sprintf("bad value for key %q: %q", skey, lua.LVAsString(value))) 530 } 531 532 case "nkey": 533 if s, ok := value.(lua.LString); ok { 534 result.nkey = string(s) 535 } else { 536 errStr = append(errStr, fmt.Sprintf("bad value for key %q: %q", skey, lua.LVAsString(value))) 537 } 538 539 case "user": 540 if s, ok := value.(lua.LString); ok { 541 result.user = string(s) 542 } else { 543 errStr = append(errStr, fmt.Sprintf("bad value for key %q: %q", skey, lua.LVAsString(value))) 544 } 545 546 case "password": 547 if s, ok := value.(lua.LString); ok { 548 result.password = string(s) 549 } else { 550 errStr = append(errStr, fmt.Sprintf("bad value for key %q: %q", skey, lua.LVAsString(value))) 551 } 552 553 case "token": 554 if s, ok := value.(lua.LString); ok { 555 result.token = string(s) 556 } else { 557 errStr = append(errStr, fmt.Sprintf("bad value for key %q: %q", skey, lua.LVAsString(value))) 558 } 559 560 case "retry": 561 result.retry = lua.LVAsBool(value) 562 563 default: 564 errStr = append(errStr, fmt.Sprintf("Unknown key %q", skey)) 565 } 566 }) 567 568 if len(errStr) > 0 { 569 return nil, nil, errors.New(strings.Join(errStr, ", ")) 570 } else { 571 return &result, cbmap, nil 572 } 573 } 574 575 func newConn(evtChan chan *internalEvent, cfg *natsConfig) (*nats.Conn, error) { 576 opt := []nats.Option{ 577 nats.DisconnectErrHandler(newConnErrHandler(evtChan, "disconnected")), 578 nats.ReconnectHandler(newConnHandler(evtChan, "reconnected")), 579 nats.ReconnectErrHandler(newConnErrHandler(evtChan, "reconnect_error")), 580 nats.ClosedHandler(newConnHandler(evtChan, "closed")), 581 nats.ErrorHandler(newErrHandler(evtChan, "error")), 582 nats.RetryOnFailedConnect(cfg.retry), 583 } 584 585 if cfg.name != "" { 586 opt = append(opt, nats.Name(cfg.name)) 587 } 588 if cfg.nkey != "" { 589 if o, err := nats.NkeyOptionFromSeed(cfg.nkey); err != nil { 590 return nil, err 591 } else { 592 opt = append(opt, o) 593 } 594 } 595 if cfg.user != "" || cfg.password != "" { 596 opt = append(opt, nats.UserInfo(cfg.user, cfg.password)) 597 } 598 if cfg.token != "" { 599 opt = append(opt, nats.Token(cfg.token)) 600 } 601 602 return nats.Connect(cfg.url, opt...) 603 } 604 605 /********** Lua Object for NATS connection **********/ 606 607 const luaNatsConnTypeName = "natsconn" 608 const keyIndex = 1 609 610 type connMap map[*nats.Conn]int 611 612 func registerConnType(L *lua.LState) { 613 L.SetGlobal(luaNatsConnTypeName, L.NewFunction(natsConnect)) 614 } 615 616 func newConnTbl(L *lua.LState) lua.LValue { 617 connTbl := L.NewTable() 618 L.RawSetInt(connTbl, keyIndex, newUserData(L, make(connMap))) 619 return connTbl 620 } 621 622 func natsConnect(L *lua.LState) int { 623 pcfg, cbmap, err := connConfig(L) 624 if err != nil { 625 log.Println(err) 626 L.Push(lua.LNil) 627 L.Push(lua.LString(err.Error())) 628 return 2 629 } 630 cfg := *pcfg 631 632 cfgMap := stateCfgMap(L) 633 if nc, found := cfgMap[cfg]; found { 634 tbl, idx := stateConnTable(L) 635 if id, ok := idx[nc]; ok { 636 res := L.RawGetInt(tbl, id) 637 if lua.LVIsFalse(res) { 638 panic("Inconsistent connection table") 639 } 640 L.Push(res) 641 return 1 642 } else { 643 panic("Inconsistent connection table") 644 } 645 } 646 647 if oldCfgMap := stateOldCfgMap(L); oldCfgMap != nil { 648 if nc, found := oldCfgMap[cfg]; found { 649 cfgMap[cfg] = nc 650 L.Push(wrapConn(L, nc, cbmap)) 651 return 1 652 } 653 } 654 655 nc, err := newConn(stateEvtChan(L), pcfg) 656 if err != nil { 657 log.Println("newConn", err) 658 L.Push(lua.LNil) 659 L.Push(lua.LString(err.Error())) 660 return 2 661 } 662 663 cfgMap[cfg] = nc 664 L.Push(wrapConn(L, nc, cbmap)) 665 return 1 666 } 667 668 func wrapConn(L *lua.LState, nc *nats.Conn, cbmap natsCbMap) lua.LValue { 669 luaConn := newUserData(L, nc) 670 671 index := L.NewTable() 672 L.SetField(index, "publish", L.NewFunction(natsPublish)) 673 L.SetField(index, "subscribe", L.NewFunction(natsSubscribe)) 674 675 for key, fn := range cbmap { 676 L.SetField(index, key, fn) 677 } 678 679 mt := L.NewTable() 680 L.SetField(mt, "__index", index) 681 L.SetField(mt, "__newindex", L.NewFunction(connNewIndex)) 682 L.SetMetatable(luaConn, mt) 683 684 tbl, idx := stateConnTable(L) 685 id := tbl.Len() + 1 686 L.RawSetInt(tbl, id, luaConn) 687 688 if _, found := idx[nc]; found { 689 panic("id already in connection table") 690 } 691 idx[nc] = id 692 693 return luaConn 694 } 695 696 func connNewIndex(L *lua.LState) int { 697 ud := L.CheckUserData(1) 698 699 if _, ok := ud.Value.(*nats.Conn); !ok { 700 L.ArgError(1, "connection expected") 701 return 0 702 } 703 704 event := L.CheckString(2) 705 if event != "disconnected" && 706 event != "reconnected" && 707 event != "reconnect_error" && 708 event != "closed" && 709 event != "error" { 710 L.ArgError(2, "unsupported callback name") 711 } 712 713 fn := L.Get(3) 714 if _, ok := fn.(*lua.LFunction); ok && fn != lua.LNil { 715 L.ArgError(3, "function expected") 716 } 717 718 mt := L.GetMetatable(ud) 719 idx := L.GetField(mt, "__index").(*lua.LTable) 720 L.SetField(idx, event, fn) 721 return 0 722 } 723 724 func newConnHandler(evtChan chan *internalEvent, name string) nats.ConnHandler { 725 return func(nc *nats.Conn) { 726 evtChan <- &internalEvent{name: name, nc: nc} 727 } 728 } 729 730 func newConnErrHandler(evtChan chan *internalEvent, name string) nats.ConnErrHandler { 731 return func(nc *nats.Conn, err error) { 732 evtChan <- &internalEvent{name: name, nc: nc, err: err} 733 } 734 } 735 736 func newErrHandler(evtChan chan *internalEvent, name string) nats.ErrHandler { 737 return func(nc *nats.Conn, s *nats.Subscription, err error) { 738 evtChan <- &internalEvent{name: name, nc: nc, subs: s, err: err} 739 } 740 } 741 742 func checkConn(L *lua.LState, index int) *nats.Conn { 743 ud := L.CheckUserData(index) 744 745 if v, ok := ud.Value.(*natsSubs); ok { 746 return v.nc 747 } 748 749 if v, ok := ud.Value.(*nats.Conn); ok { 750 return v 751 } 752 753 L.ArgError(index, "connection expected") 754 return nil 755 } 756 757 func natsPublish(L *lua.LState) int { 758 nc := checkConn(L, 1) 759 subject := L.CheckString(2) 760 data := L.OptString(3, "") 761 762 if err := nc.Publish(subject, []byte(data)); err != nil { 763 log.Println("Publish:", err) 764 L.Push(lua.LNil) 765 L.Push(lua.LString(err.Error())) 766 return 2 767 } else { 768 L.Push(lua.LTrue) 769 return 1 770 } 771 } 772 773 func natsSubscribe(L *lua.LState) int { 774 nc := checkConn(L, 1) 775 subject := L.CheckString(2) 776 fn := L.CheckFunction(3) 777 778 if s := findInOldSubsCfgMap(L, nc, subsCfg{subject: subject}); s != nil { 779 L.Push(wrapSubs(L, fn, s, nc)) 780 return 1 781 } else if s, err := nc.ChanSubscribe(subject, stateMsgChan(L)); err != nil { 782 log.Println("Subscribe:", err) 783 L.Push(lua.LNil) 784 L.Push(lua.LString(err.Error())) 785 return 2 786 } else { 787 L.Push(wrapSubs(L, fn, s, nc)) 788 return 1 789 } 790 } 791 792 /********** Lua Object for NATS subscription **********/ 793 794 type natsSubs struct { 795 id int 796 nc *nats.Conn 797 subs *nats.Subscription 798 } 799 800 type subsMap map[*nats.Subscription]int 801 802 type subsCfg struct { 803 subject string 804 queue string 805 } 806 807 type subsCfgMap map[subsCfg]map[*nats.Conn][]*nats.Subscription 808 809 func newSubsTbl(L *lua.LState) lua.LValue { 810 subsTbl := L.NewTable() 811 L.RawSetInt(subsTbl, keyIndex, newUserData(L, make(subsMap))) 812 return subsTbl 813 } 814 815 func wrapSubs(L *lua.LState, fn lua.LValue, ns *nats.Subscription, nc *nats.Conn) lua.LValue { 816 tbl, subsIdx := stateSubsTable(L) 817 id := tbl.Len() + 1 818 luaSub := newUserData(L, &natsSubs{id: id, nc: nc, subs: ns}) 819 820 subsIdx[ns] = id 821 L.RawSetInt(tbl, id, luaSub) 822 823 addToSubsCfgMap(L, nc, ns) 824 825 index := L.NewTable() 826 L.SetField(index, "callback", fn) 827 L.SetField(index, "id", lua.LNumber(id)) 828 L.SetField(index, "subject", lua.LString(string(ns.Subject))) 829 830 L.SetField(index, "publish", L.NewFunction(natsPublish)) 831 L.SetField(index, "subscribe", L.NewFunction(natsSubscribe)) 832 L.SetField(index, "unsubscribe", L.NewFunction(natsUnsubscribe)) 833 834 mt := L.NewTable() 835 L.SetField(mt, "__call", fn) 836 L.SetField(mt, "__index", index) 837 L.SetField(mt, "__newindex", L.NewFunction(subsNewIndex)) 838 839 L.SetMetatable(luaSub, mt) 840 return luaSub 841 } 842 843 func subsNewIndex(L *lua.LState) int { 844 _ = checkSubs(L, 1) 845 mt := L.GetMetatable(L.Get(1)) 846 key := L.Get(2) 847 848 if s, ok := key.(lua.LString); !ok || string(s) != "callback" { 849 L.RaiseError("attempt to change bad subscription field") 850 return 0 851 } 852 853 fn := L.CheckFunction(3) 854 L.SetField(mt, "__call", fn) 855 index := L.GetField(mt, "__index").(*lua.LTable) 856 L.SetField(index, "callback", fn) 857 return 0 858 } 859 860 func checkSubs(L *lua.LState, index int) *natsSubs { 861 ud := L.CheckUserData(index) 862 863 if v, ok := ud.Value.(*natsSubs); ok { 864 return v 865 } 866 867 L.ArgError(index, "subscription expected") 868 return nil 869 } 870 871 func natsUnsubscribe(L *lua.LState) int { 872 s := checkSubs(L, 1) 873 count := L.OptInt(2, 0) 874 875 if err := s.subs.AutoUnsubscribe(count); err != nil { 876 log.Println("Unsubscribe:", err) 877 L.Push(lua.LNil) 878 L.Push(lua.LString(err.Error())) 879 return 2 880 } else { 881 L.Push(lua.LTrue) 882 return 1 883 } 884 } 885 886 /********** Lua Object for timers **********/ 887 888 const luaTimerTypeName = "timer" 889 890 func registerTimerType(L *lua.LState) { 891 mt := L.NewTypeMetatable(luaTimerTypeName) 892 L.SetGlobal(luaTimerTypeName, mt) 893 L.SetField(mt, "new", L.NewFunction(newTimer)) 894 L.SetField(mt, "schedule", L.NewFunction(timerSchedule)) 895 L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), timerMethods)) 896 } 897 898 func newTimer(L *lua.LState) int { 899 atTime := L.Get(1) 900 cb := L.CheckFunction(2) 901 L.Pop(2) 902 L.SetMetatable(cb, L.GetTypeMetatable(luaTimerTypeName)) 903 L.Push(cb) 904 L.Push(atTime) 905 return timerSchedule(L) 906 } 907 908 var timerMethods = map[string]lua.LGFunction{ 909 "cancel": timerCancel, 910 "schedule": timerSchedule, 911 } 912 913 func timerCancel(L *lua.LState) int { 914 timer := L.CheckFunction(1) 915 L.RawSet(stateTimerTable(L), timer, lua.LNil) 916 return 0 917 } 918 919 func timerSchedule(L *lua.LState) int { 920 timer := L.CheckFunction(1) 921 atTime := lua.LNil 922 if L.Get(2) != lua.LNil { 923 atTime = L.CheckNumber(2) 924 } 925 926 L.RawSet(stateTimerTable(L), timer, atTime) 927 return 0 928 } 929 930 func toTime(lsec lua.LNumber) time.Time { 931 fsec := float64(lsec) 932 sec := int64(fsec) 933 nsec := int64((fsec - float64(sec)) * 1.0e9) 934 935 return time.Unix(sec, nsec) 936 } 937 938 func runTimers(L *lua.LState, parentTimer *time.Timer) { 939 hasNext := false 940 var nextTime time.Time 941 942 now := time.Now() 943 timers := stateTimerTable(L) 944 945 timer, luaT := timers.Next(lua.LNil) 946 for timer != lua.LNil { 947 t := toTime(luaT.(lua.LNumber)) 948 if t.Compare(now) <= 0 { 949 L.RawSet(timers, timer, lua.LNil) 950 err := L.CallByParam(lua.P{Fn: timer, NRet: 0, Protect: true}, timer, luaT) 951 if err != nil { 952 panic(err) 953 } 954 timer = lua.LNil 955 hasNext = false 956 } else if !hasNext || t.Compare(nextTime) < 0 { 957 hasNext = true 958 nextTime = t 959 } 960 961 timer, luaT = timers.Next(timer) 962 } 963 964 if hasNext { 965 parentTimer.Reset(time.Until(nextTime)) 966 } else { 967 parentTimer.Stop() 968 } 969 } 970 971 /********** Tools **********/ 972 973 func luaHeader(L *lua.LState, header []string) lua.LValue { 974 switch len(header) { 975 case 0: 976 return lua.LNil 977 case 1: 978 return lua.LString(header[0]) 979 default: 980 result := L.CreateTable(len(header), 0) 981 for i, v := range header { 982 L.RawSetInt(result, i+1, lua.LString(v)) 983 } 984 return result 985 } 986 } 987 988 func luaHeaders(L *lua.LState, reply string, headers map[string][]string) lua.LValue { 989 result := L.NewTable() 990 991 if reply != "" { 992 L.RawSetInt(result, 1, lua.LString(reply)) 993 } 994 995 for key, values := range headers { 996 L.SetField(result, key, luaHeader(L, values)) 997 } 998 999 return result 1000 } 1001 1002 func newUserData(L *lua.LState, v interface{}) *lua.LUserData { 1003 res := L.NewUserData() 1004 res.Value = v 1005 return res 1006 } 1007 1008 func tableIsEmpty(t *lua.LTable) bool { 1009 key, _ := t.Next(lua.LNil) 1010 return key == lua.LNil 1011 } 1012 1013 func tableWithIndexIsEmpty(t *lua.LTable, idx interface{}) bool { 1014 key, _ := t.Next(lua.LNil) 1015 1016 if n, ok := key.(lua.LNumber); ok && n == lua.LNumber(keyIndex) { 1017 key, _ = t.Next(key) 1018 } 1019 1020 return key == lua.LNil 1021 }