fix(signal): double-free

This commit is contained in:
maelstrom 2025-05-11 12:43:54 +02:00
parent b9dc280311
commit d86fd754bd
3 changed files with 80 additions and 67 deletions

View file

@ -4,6 +4,7 @@
#include "lua.h"
#include <pugixml.hpp>
#include <memory>
#include <vector>
SignalSource::SignalSource() : std::shared_ptr<Signal>(std::make_shared<Signal>()) {}
SignalSource::~SignalSource() = default;
@ -11,44 +12,43 @@ SignalSource::~SignalSource() = default;
Signal::Signal() {}
Signal::~Signal() = default;
SignalConnection::SignalConnection(std::weak_ptr<Signal> parent) : parentSignal(parent) {}
SignalConnection::~SignalConnection() = default;
// Only used for its address
int __savedThreads = 0;
LuaSignalConnection::LuaSignalConnection(lua_State* L) {
// Create thread from function at top of stack
thread = lua_newthread(L);
lua_xmove(L, thread, 1);
int __savedCallbacks = 0;
LuaSignalConnection::LuaSignalConnection(lua_State* L, std::weak_ptr<Signal> parent) : SignalConnection(parent) {
state = L;
// https://stackoverflow.com/a/31952046/16255372
// Create the table
if (__savedThreads == 0) {
lua_newtable(thread);
__savedThreads = luaL_ref(thread, LUA_REGISTRYINDEX);
if (__savedCallbacks == 0) {
lua_newtable(L);
__savedCallbacks = luaL_ref(L, LUA_REGISTRYINDEX);
}
// Save thread so it doesn't get GC'd
lua_rawgeti(thread, LUA_REGISTRYINDEX, __savedThreads);
lua_pushthread(thread); // key
lua_pushboolean(thread, true); // value
lua_rawset(thread, -3); // set
lua_pop(thread, 1); // Pop __savedThreads
// Save function so it doesn't get GC'd
lua_rawgeti(L, LUA_REGISTRYINDEX, __savedCallbacks);
lua_pushvalue(L, -2);
function = luaL_ref(L, -2);
lua_pop(L, 2);
}
LuaSignalConnection::~LuaSignalConnection() {
// Remove thread so that it can get properly GC'd
lua_rawgeti(thread, LUA_REGISTRYINDEX, __savedThreads);
lua_pushthread(thread); // key
lua_pushnil(thread); // value
lua_rawset(thread, -3); // set
lua_pop(thread, 1); // Pop __savedThreads
// Remove LuaSignalConnectionthread so that it can get properly GC'd
lua_rawgeti(state, LUA_REGISTRYINDEX, __savedCallbacks);
luaL_unref(state, -1, function);
lua_pop(state, 1); // Pop __savedCallbacks
}
void LuaSignalConnection::Call(std::vector<Data::Variant> args) {
lua_State* thread = lua_newthread(state);
// Push function
lua_rawgeti(thread, LUA_REGISTRYINDEX, __savedCallbacks);
lua_rawgeti(thread, -1, function);
lua_remove(thread, -2);
for (Data::Variant arg : args) {
arg.PushLuaValue(thread);
}
@ -62,7 +62,7 @@ void LuaSignalConnection::Call(std::vector<Data::Variant> args) {
//
CSignalConnection::CSignalConnection(std::function<void(std::vector<Data::Variant>)> func) {
CSignalConnection::CSignalConnection(std::function<void(std::vector<Data::Variant>)> func, std::weak_ptr<Signal> parent) : SignalConnection(parent) {
this->function = func;
}
@ -73,13 +73,13 @@ void CSignalConnection::Call(std::vector<Data::Variant> args) {
//
SignalConnectionRef Signal::Connect(std::function<void(std::vector<Data::Variant>)> callback) {
auto conn = std::dynamic_pointer_cast<SignalConnection>(std::make_shared<CSignalConnection>(CSignalConnection(callback)));
auto conn = std::dynamic_pointer_cast<SignalConnection>(std::make_shared<CSignalConnection>(callback, weak_from_this()));
connections.push_back(conn);
return SignalConnectionRef(conn);
}
SignalConnectionRef Signal::Connect(lua_State* state) {
auto conn = std::dynamic_pointer_cast<SignalConnection>(std::make_shared<LuaSignalConnection>(LuaSignalConnection(state)));
auto conn = std::dynamic_pointer_cast<SignalConnection>(std::make_shared<LuaSignalConnection>(state, weak_from_this()));
connections.push_back(conn);
return SignalConnectionRef(conn);
}
@ -90,6 +90,10 @@ void Signal::Fire(std::vector<Data::Variant> args) {
}
}
void Signal::Fire() {
return Fire(std::vector<Data::Variant> {});
}
void Signal::DisconnectAll() {
for (std::shared_ptr<SignalConnection> connection : connections) {
connection->parentSignal = {};
@ -148,8 +152,8 @@ void Data::SignalRef::Serialize(pugi::xml_node node) const {
void Data::SignalRef::PushLuaValue(lua_State* L) const {
int n = lua_gettop(L);
auto userdata = (std::weak_ptr<Signal>*)lua_newuserdata(L, sizeof(std::weak_ptr<Signal>));
new(userdata) std::weak_ptr<Signal>(signal);
auto userdata = (std::weak_ptr<Signal>**)lua_newuserdata(L, sizeof(std::weak_ptr<Signal>));
*userdata = new std::weak_ptr<Signal>(signal);
// Create the instance's metatable
luaL_newmetatable(L, "__mt_signal");
@ -159,15 +163,16 @@ void Data::SignalRef::PushLuaValue(lua_State* L) const {
}
result<Data::Variant, LuaCastError> Data::SignalRef::FromLuaValue(lua_State* L, int idx) {
auto userdata = (std::weak_ptr<Signal>*)luaL_checkudata(L, 1, "__mt_signal");
auto userdata = (std::weak_ptr<Signal>**)luaL_checkudata(L, 1, "__mt_signal");
lua_pop(L, 1);
return Data::Variant(Data::SignalRef(*userdata));
return Data::Variant(Data::SignalRef(**userdata));
}
static int signal_gc(lua_State* L) {
printf("Elle!\n");
// Destroy the contained shared_ptr
auto userdata = (std::weak_ptr<Signal>*)luaL_checkudata(L, 1, "__mt_signal");
delete userdata;
auto userdata = (std::weak_ptr<Signal>**)luaL_checkudata(L, 1, "__mt_signal");
delete *userdata;
lua_pop(L, 1);
return 0;
@ -175,8 +180,8 @@ static int signal_gc(lua_State* L) {
// __index(t,k)
static int signal_index(lua_State* L) {
auto userdata = (std::weak_ptr<Signal>*)luaL_checkudata(L, 1, "__mt_signal");
std::weak_ptr<Signal> signal = *userdata;
auto userdata = (std::weak_ptr<Signal>**)luaL_checkudata(L, 1, "__mt_signal");
std::weak_ptr<Signal> signal = **userdata;
std::string key(lua_tostring(L, 2));
lua_pop(L, 2);
@ -195,8 +200,8 @@ static int signal_tostring(lua_State* L) {
}
static int signal_Connect(lua_State* L) {
auto userdata = (std::weak_ptr<Signal>*)luaL_checkudata(L, 1, "__mt_signal");
std::shared_ptr<Signal> signal = (*userdata).lock();
auto userdata = (std::weak_ptr<Signal>**)luaL_checkudata(L, 1, "__mt_signal");
std::shared_ptr<Signal> signal = (**userdata).lock();
luaL_checktype(L, 2, LUA_TFUNCTION);
SignalConnectionRef ref = signal->Connect(L);
@ -217,26 +222,6 @@ static const struct luaL_Reg signalconnection_metatable [] = {
{NULL, NULL} /* end of array */
};
static int signalconnection_gc(lua_State* L) {
// Destroy the contained shared_ptr
auto userdata = (std::weak_ptr<SignalConnection>*)luaL_checkudata(L, 1, "__mt_signalconnection");
delete userdata;
lua_pop(L, 1);
return 0;
}
// __index(t,k)
static int signalconnection_index(lua_State* L) {
auto userdata = (std::weak_ptr<SignalConnection>*)luaL_checkudata(L, 1, "__mt_signalconnection");
std::weak_ptr<SignalConnection> signalConnection = *userdata;
std::string key(lua_tostring(L, 2));
lua_pop(L, 2);
return luaL_error(L, "'%s' is not a valid member of %s", key.c_str(), "SignalConnection");
}
Data::SignalConnectionRef::SignalConnectionRef(std::weak_ptr<SignalConnection> ref) : signalConnection(ref) {}
Data::SignalConnectionRef::~SignalConnectionRef() = default;
@ -262,8 +247,8 @@ void Data::SignalConnectionRef::Serialize(pugi::xml_node node) const {
void Data::SignalConnectionRef::PushLuaValue(lua_State* L) const {
int n = lua_gettop(L);
auto userdata = (std::weak_ptr<SignalConnection>*)lua_newuserdata(L, sizeof(std::weak_ptr<SignalConnection>));
new(userdata) std::weak_ptr<SignalConnection>(signalConnection);
auto userdata = (std::weak_ptr<SignalConnection>**)lua_newuserdata(L, sizeof(std::weak_ptr<SignalConnection>));
*userdata = new std::weak_ptr<SignalConnection>(signalConnection);
// Create the instance's metatable
luaL_newmetatable(L, "__mt_signalconnection");
@ -273,13 +258,32 @@ void Data::SignalConnectionRef::PushLuaValue(lua_State* L) const {
}
result<Data::Variant, LuaCastError> Data::SignalConnectionRef::FromLuaValue(lua_State* L, int idx) {
auto userdata = (std::weak_ptr<SignalConnection>*)luaL_checkudata(L, 1, "__mt_signalconnection");
auto userdata = (std::weak_ptr<SignalConnection>**)luaL_checkudata(L, 1, "__mt_signalconnection");
lua_pop(L, 1);
return Data::Variant(Data::SignalConnectionRef(*userdata));
return Data::Variant(Data::SignalConnectionRef(**userdata));
}
static int signalconnection_tostring(lua_State* L) {
lua_pop(L, 1);
lua_pushstring(L, "SignalConnection");
return 1;
}
static int signalconnection_gc(lua_State* L) {
// Destroy the contained shared_ptr
auto userdata = (std::weak_ptr<SignalConnection>**)luaL_checkudata(L, 1, "__mt_signalconnection");
delete *userdata;
lua_pop(L, 1);
return 0;
}
// __index(t,k)
static int signalconnection_index(lua_State* L) {
auto userdata = (std::weak_ptr<SignalConnection>**)luaL_checkudata(L, 1, "__mt_signalconnection");
std::weak_ptr<SignalConnection> signalConnection = **userdata;
std::string key(lua_tostring(L, 2));
lua_pop(L, 2);
return luaL_error(L, "'%s' is not a valid member of %s", key.c_str(), "SignalConnection");
}

View file

@ -21,6 +21,8 @@ class SignalConnection : public std::enable_shared_from_this<SignalConnection> {
protected:
std::weak_ptr<Signal> parentSignal;
SignalConnection(std::weak_ptr<Signal> parent);
virtual void Call(std::vector<Data::Variant>) = 0;
friend Signal;
public:
@ -33,35 +35,40 @@ public:
class CSignalConnection : public SignalConnection {
std::function<void(std::vector<Data::Variant>)> function;
CSignalConnection(std::function<void(std::vector<Data::Variant>)>);
friend Signal;
protected:
void Call(std::vector<Data::Variant>) override;
public:
CSignalConnection(std::function<void(std::vector<Data::Variant>)>, std::weak_ptr<Signal> parent);
};
class LuaSignalConnection : public SignalConnection {
lua_State* thread;
LuaSignalConnection(lua_State*);
lua_State* state;
int function;
friend Signal;
protected:
void Call(std::vector<Data::Variant>) override;
public:
LuaSignalConnection(lua_State*, std::weak_ptr<Signal> parent);
LuaSignalConnection (const LuaSignalConnection&) = delete;
LuaSignalConnection& operator= (const LuaSignalConnection&) = delete;
~LuaSignalConnection();
};
class Signal {
class Signal : public std::enable_shared_from_this<Signal> {
std::vector<std::shared_ptr<SignalConnection>> connections;
friend SignalConnection;
public:
Signal();
virtual ~Signal();
Signal (const Signal&) = delete;
Signal& operator= (const Signal&) = delete;
void DisconnectAll();
void Fire(std::vector<Data::Variant> args);
void Fire();
Data::SignalConnectionRef Connect(std::function<void(std::vector<Data::Variant>)> callback);
Data::SignalConnectionRef Connect(lua_State*);
};

View file

@ -60,6 +60,8 @@ void Part::onUpdated(std::string property) {
// When position/rotation/size is manually edited, break all joints, they don't apply anymore
if (property != "Anchored")
BreakJoints();
OnParentUpdated->Fire();
}
// Expands provided extents to fit point