diff --git a/core/src/datatypes/signal.cpp b/core/src/datatypes/signal.cpp index ed90797..4c79be4 100644 --- a/core/src/datatypes/signal.cpp +++ b/core/src/datatypes/signal.cpp @@ -84,10 +84,50 @@ SignalConnectionRef Signal::Connect(lua_State* state) { return SignalConnectionRef(conn); } +int __waitingThreads = 0; +int Signal::Wait(lua_State* thread) { + // If the table hasn't been constructed yet, make it + if (__waitingThreads == 0) { + lua_newtable(thread); + __waitingThreads = luaL_ref(thread, LUA_REGISTRYINDEX); + } + + // Get waitingThreads table + lua_rawgeti(thread, LUA_REGISTRYINDEX, __waitingThreads); + lua_pushthread(thread); + int threadId = luaL_ref(thread, -2); + lua_pop(thread, -1); // pop __waitingThreads + waitingThreads.push_back(std::make_pair(threadId, thread)); + + // Yield and return results + return lua_yield(thread, 0); +} + void Signal::Fire(std::vector args) { for (std::shared_ptr connection : connections) { connection->Call(args); } + + // Call waiting threads + auto prevThreads = std::move(waitingThreads); + waitingThreads = std::vector>(); + for (auto& [threadId, thread] : prevThreads) { + for (Data::Variant arg : args) { + arg.PushLuaValue(thread); + } + + int status = lua_resume(thread, args.size()); + if (status > LUA_YIELD) { + Logger::error(lua_tostring(thread, -1)); + lua_pop(thread, 1); // Pop return value + } + + // Remove thread from registry + lua_rawgeti(thread, LUA_REGISTRYINDEX, __waitingThreads); + luaL_unref(thread, -1, threadId); + lua_pop(thread, 1); // pop __waitingThreads + } + } void Signal::Fire() { @@ -116,6 +156,7 @@ void SignalConnection::Disconnect() { // static int signal_Connect(lua_State*); +static int signal_Wait(lua_State*); static int signal_gc(lua_State*); static int signal_index(lua_State*); @@ -169,7 +210,6 @@ result Data::SignalRef::FromLuaValue(lua_State* L, } static int signal_gc(lua_State* L) { - printf("Elle!\n"); // Destroy the contained shared_ptr auto userdata = (std::weak_ptr**)luaL_checkudata(L, 1, "__mt_signal"); delete *userdata; @@ -188,6 +228,9 @@ static int signal_index(lua_State* L) { if (key == "Connect") { lua_pushcfunction(L, signal_Connect); return 1; + } else if (key == "Wait") { + lua_pushcfunction(L, signal_Wait); + return 1; } return luaL_error(L, "'%s' is not a valid member of %s", key.c_str(), "Signal"); @@ -210,6 +253,14 @@ static int signal_Connect(lua_State* L) { return 1; } +static int signal_Wait(lua_State* L) { + auto userdata = (std::weak_ptr**)luaL_checkudata(L, 1, "__mt_signal"); + // TODO: Add expiry check here and everywhere else + std::shared_ptr signal = (**userdata).lock(); + + return signal->Wait(L); +} + // static int signalconnection_gc(lua_State*); diff --git a/core/src/datatypes/signal.h b/core/src/datatypes/signal.h index 76c5f79..26702a0 100644 --- a/core/src/datatypes/signal.h +++ b/core/src/datatypes/signal.h @@ -58,6 +58,7 @@ public: class Signal : public std::enable_shared_from_this { std::vector> connections; + std::vector> waitingThreads; friend SignalConnection; public: @@ -71,6 +72,7 @@ public: void Fire(); Data::SignalConnectionRef Connect(std::function)> callback); Data::SignalConnectionRef Connect(lua_State*); + int Wait(lua_State*); }; class SignalSource : public std::shared_ptr {