feat(signal): wait function

This commit is contained in:
maelstrom 2025-05-11 13:20:36 +02:00
parent d86fd754bd
commit 6e1cfcac80
2 changed files with 54 additions and 1 deletions

View file

@ -84,10 +84,50 @@ SignalConnectionRef Signal::Connect(lua_State* state) {
return SignalConnectionRef(conn); return SignalConnectionRef(conn);
} }
int __waitingThreads = 0;
int Signal::Wait(lua_State* thread) {
// If the table hasn't been constructed yet, make it
if (__waitingThreads == 0) {
lua_newtable(thread);
__waitingThreads = luaL_ref(thread, LUA_REGISTRYINDEX);
}
// Get waitingThreads table
lua_rawgeti(thread, LUA_REGISTRYINDEX, __waitingThreads);
lua_pushthread(thread);
int threadId = luaL_ref(thread, -2);
lua_pop(thread, -1); // pop __waitingThreads
waitingThreads.push_back(std::make_pair(threadId, thread));
// Yield and return results
return lua_yield(thread, 0);
}
void Signal::Fire(std::vector<Data::Variant> args) { void Signal::Fire(std::vector<Data::Variant> args) {
for (std::shared_ptr<SignalConnection> connection : connections) { for (std::shared_ptr<SignalConnection> connection : connections) {
connection->Call(args); connection->Call(args);
} }
// Call waiting threads
auto prevThreads = std::move(waitingThreads);
waitingThreads = std::vector<std::pair<int, lua_State*>>();
for (auto& [threadId, thread] : prevThreads) {
for (Data::Variant arg : args) {
arg.PushLuaValue(thread);
}
int status = lua_resume(thread, args.size());
if (status > LUA_YIELD) {
Logger::error(lua_tostring(thread, -1));
lua_pop(thread, 1); // Pop return value
}
// Remove thread from registry
lua_rawgeti(thread, LUA_REGISTRYINDEX, __waitingThreads);
luaL_unref(thread, -1, threadId);
lua_pop(thread, 1); // pop __waitingThreads
}
} }
void Signal::Fire() { void Signal::Fire() {
@ -116,6 +156,7 @@ void SignalConnection::Disconnect() {
// //
static int signal_Connect(lua_State*); static int signal_Connect(lua_State*);
static int signal_Wait(lua_State*);
static int signal_gc(lua_State*); static int signal_gc(lua_State*);
static int signal_index(lua_State*); static int signal_index(lua_State*);
@ -169,7 +210,6 @@ result<Data::Variant, LuaCastError> Data::SignalRef::FromLuaValue(lua_State* L,
} }
static int signal_gc(lua_State* L) { static int signal_gc(lua_State* L) {
printf("Elle!\n");
// Destroy the contained shared_ptr // Destroy the contained shared_ptr
auto userdata = (std::weak_ptr<Signal>**)luaL_checkudata(L, 1, "__mt_signal"); auto userdata = (std::weak_ptr<Signal>**)luaL_checkudata(L, 1, "__mt_signal");
delete *userdata; delete *userdata;
@ -188,6 +228,9 @@ static int signal_index(lua_State* L) {
if (key == "Connect") { if (key == "Connect") {
lua_pushcfunction(L, signal_Connect); lua_pushcfunction(L, signal_Connect);
return 1; return 1;
} else if (key == "Wait") {
lua_pushcfunction(L, signal_Wait);
return 1;
} }
return luaL_error(L, "'%s' is not a valid member of %s", key.c_str(), "Signal"); return luaL_error(L, "'%s' is not a valid member of %s", key.c_str(), "Signal");
@ -210,6 +253,14 @@ static int signal_Connect(lua_State* L) {
return 1; return 1;
} }
static int signal_Wait(lua_State* L) {
auto userdata = (std::weak_ptr<Signal>**)luaL_checkudata(L, 1, "__mt_signal");
// TODO: Add expiry check here and everywhere else
std::shared_ptr<Signal> signal = (**userdata).lock();
return signal->Wait(L);
}
// //
static int signalconnection_gc(lua_State*); static int signalconnection_gc(lua_State*);

View file

@ -58,6 +58,7 @@ public:
class Signal : public std::enable_shared_from_this<Signal> { class Signal : public std::enable_shared_from_this<Signal> {
std::vector<std::shared_ptr<SignalConnection>> connections; std::vector<std::shared_ptr<SignalConnection>> connections;
std::vector<std::pair<int, lua_State*>> waitingThreads;
friend SignalConnection; friend SignalConnection;
public: public:
@ -71,6 +72,7 @@ public:
void Fire(); void Fire();
Data::SignalConnectionRef Connect(std::function<void(std::vector<Data::Variant>)> callback); Data::SignalConnectionRef Connect(std::function<void(std::vector<Data::Variant>)> callback);
Data::SignalConnectionRef Connect(lua_State*); Data::SignalConnectionRef Connect(lua_State*);
int Wait(lua_State*);
}; };
class SignalSource : public std::shared_ptr<Signal> { class SignalSource : public std::shared_ptr<Signal> {