diff --git a/core/src/datatypes/signal.cpp b/core/src/datatypes/signal.cpp index 4c79be4..4389b86 100644 --- a/core/src/datatypes/signal.cpp +++ b/core/src/datatypes/signal.cpp @@ -84,6 +84,18 @@ SignalConnectionRef Signal::Connect(lua_State* state) { return SignalConnectionRef(conn); } +SignalConnectionRef Signal::Once(std::function)> callback) { + auto conn = std::dynamic_pointer_cast(std::make_shared(callback, weak_from_this())); + onceConnections.push_back(conn); + return SignalConnectionRef(conn); +} + +SignalConnectionRef Signal::Once(lua_State* state) { + auto conn = std::dynamic_pointer_cast(std::make_shared(state, weak_from_this())); + onceConnections.push_back(conn); + return SignalConnectionRef(conn); +} + int __waitingThreads = 0; int Signal::Wait(lua_State* thread) { // If the table hasn't been constructed yet, make it @@ -108,6 +120,13 @@ void Signal::Fire(std::vector args) { connection->Call(args); } + // Call once connections + auto prevOnceConns = std::move(onceConnections); + onceConnections = std::vector>(); + for (std::shared_ptr connection : prevOnceConns) { + connection->Call(args); + } + // Call waiting threads auto prevThreads = std::move(waitingThreads); waitingThreads = std::vector>(); @@ -139,23 +158,45 @@ void Signal::DisconnectAll() { connection->parentSignal = {}; } connections.clear(); + + for (std::shared_ptr connection : onceConnections) { + connection->parentSignal = {}; + } + onceConnections.clear(); + + for (auto& [threadId, thread] : waitingThreads) { + lua_rawgeti(thread, LUA_REGISTRYINDEX, __waitingThreads); + luaL_unref(thread, -1, threadId); + lua_pop(thread, 1); + } + waitingThreads.clear(); } void SignalConnection::Disconnect() { if (!Connected()) return; auto signal = parentSignal.lock(); + for(auto it = signal->connections.begin(); it != signal->connections.end();) { if (*it == shared_from_this()) it = signal->connections.erase(it); else it++; } + + for(auto it = signal->onceConnections.begin(); it != signal->onceConnections.end();) { + if (*it == shared_from_this()) + it = signal->onceConnections.erase(it); + else + it++; + } + parentSignal = {}; } // static int signal_Connect(lua_State*); +static int signal_Once(lua_State*); static int signal_Wait(lua_State*); static int signal_gc(lua_State*); @@ -228,6 +269,9 @@ static int signal_index(lua_State* L) { if (key == "Connect") { lua_pushcfunction(L, signal_Connect); return 1; + } else if (key == "Once") { + lua_pushcfunction(L, signal_Once); + return 1; } else if (key == "Wait") { lua_pushcfunction(L, signal_Wait); return 1; @@ -253,6 +297,17 @@ static int signal_Connect(lua_State* L) { return 1; } +static int signal_Once(lua_State* L) { + auto userdata = (std::weak_ptr**)luaL_checkudata(L, 1, "__mt_signal"); + std::shared_ptr signal = (**userdata).lock(); + luaL_checktype(L, 2, LUA_TFUNCTION); + + SignalConnectionRef ref = signal->Once(L); + ref.PushLuaValue(L); + + return 1; +} + static int signal_Wait(lua_State* L) { auto userdata = (std::weak_ptr**)luaL_checkudata(L, 1, "__mt_signal"); // TODO: Add expiry check here and everywhere else @@ -263,6 +318,8 @@ static int signal_Wait(lua_State* L) { // +static int signalconnection_Disconnect(lua_State*); + static int signalconnection_gc(lua_State*); static int signalconnection_index(lua_State*); static int signalconnection_tostring(lua_State*); @@ -336,5 +393,19 @@ static int signalconnection_index(lua_State* L) { std::string key(lua_tostring(L, 2)); lua_pop(L, 2); + if (key == "Disconnect") { + lua_pushcfunction(L, signalconnection_Disconnect); + return 1; + } + return luaL_error(L, "'%s' is not a valid member of %s", key.c_str(), "SignalConnection"); +} + +static int signalconnection_Disconnect(lua_State* L) { + auto userdata = (std::weak_ptr**)luaL_checkudata(L, 1, "__mt_signalconnection"); + std::shared_ptr signal = (**userdata).lock(); + + signal->Disconnect(); + + return 0; } \ No newline at end of file diff --git a/core/src/datatypes/signal.h b/core/src/datatypes/signal.h index 26702a0..831d39b 100644 --- a/core/src/datatypes/signal.h +++ b/core/src/datatypes/signal.h @@ -58,6 +58,7 @@ public: class Signal : public std::enable_shared_from_this { std::vector> connections; + std::vector> onceConnections; std::vector> waitingThreads; friend SignalConnection; @@ -72,6 +73,8 @@ public: void Fire(); Data::SignalConnectionRef Connect(std::function)> callback); Data::SignalConnectionRef Connect(lua_State*); + Data::SignalConnectionRef Once(std::function)> callback); + Data::SignalConnectionRef Once(lua_State*); int Wait(lua_State*); };