Skip to content

Commit

Permalink
clone fix + add equal
Browse files Browse the repository at this point in the history
  • Loading branch information
Danielkonge committed Dec 1, 2023
1 parent 8567fc6 commit cfe4705
Showing 1 changed file with 82 additions and 25 deletions.
107 changes: 82 additions & 25 deletions lua-api-crates/table-funcs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub fn register(lua: &Lua) -> anyhow::Result<()> {
table.set("length", lua.create_function(length)?)?;
table.set("has_key", lua.create_function(has_key)?)?;
table.set("has_value", lua.create_function(has_value)?)?;
table.set("equal", lua.create_function(equal)?)?;
table.set("to_string", lua.create_function(to_string)?)?;
table.set(
"to_string_fallback",
Expand All @@ -22,27 +23,30 @@ pub fn register(lua: &Lua) -> anyhow::Result<()> {

#[derive(Debug, Clone, PartialEq, Eq)]
enum ConflictMode {
Keep,
Force,
Error,
Keep,
Force,
Error,
}

impl<'lua> mlua::FromLua<'lua> for ConflictMode {
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> mlua::Result<Self> {
match value {
LuaValue::String(s) => {
match s.to_str() {
Ok("Keep") => Ok(ConflictMode::Keep),
Ok("keep") => Ok(ConflictMode::Keep),
Ok("Force") => Ok(ConflictMode::Force),
Ok("force") => Ok(ConflictMode::Force),
Ok("Error") => Ok(ConflictMode::Error),
Ok("error") => Ok(ConflictMode::Error),
_ => Err(mlua::Error::runtime("Unknown string. Expected 'Keep', 'Force' or 'Error'".to_string()))
}
}
LuaValue::String(s) => match s.to_str() {
Ok("Keep") => Ok(ConflictMode::Keep),
Ok("keep") => Ok(ConflictMode::Keep),
Ok("Force") => Ok(ConflictMode::Force),
Ok("force") => Ok(ConflictMode::Force),
Ok("Error") => Ok(ConflictMode::Error),
Ok("error") => Ok(ConflictMode::Error),
_ => Err(mlua::Error::runtime(
"Unknown string. Expected 'Keep', 'Force' or 'Error'".to_string(),
)),
},
LuaValue::Error(err) => Err(err),
other => Err(mlua::Error::runtime(format!("Expected a Lua string. Got something of type: {}", other.type_name()))),
other => Err(mlua::Error::runtime(format!(
"Expected a Lua string. Got something of type: {}",
other.type_name()
))),
}
}
}
Expand All @@ -68,7 +72,7 @@ fn extend<'lua>(

match behavior {
Some(ConflictMode::Keep) => {
for (key,value) in tbl_vec {
for (key, value) in tbl_vec {
if !tbl.contains_key(key.clone())? {
tbl.set(key, value)?;
}
Expand All @@ -83,7 +87,10 @@ fn extend<'lua>(
Some(ConflictMode::Error) => {
for (key, value) in tbl_vec {
if tbl.contains_key(key.clone())? {
return Err(mlua::Error::runtime(format!("The key {} is in more than one of the tables.", key.to_string()?)));
return Err(mlua::Error::runtime(format!(
"The key {} is in more than one of the tables.",
key.to_string()?
)));
}
tbl.set(key, value)?;
}
Expand All @@ -93,7 +100,6 @@ fn extend<'lua>(
Ok(tbl)
}


// merge tables entrywise recursively
// (in case of overlap of the tables, we default to taking the key-value pair from the last table)
// Note that we don't use a HashMap since we want to keep the order of the tables, which
Expand All @@ -115,11 +121,14 @@ fn deep_extend<'lua>(

match behavior {
Some(ConflictMode::Keep) => {
for (key,value) in tbl_vec {
for (key, value) in tbl_vec {
if !tbl.contains_key(key.clone())? {
tbl.set(key, value)?;
} else if let LuaValue::Table(t) = value {
let inner_tbl = deep_extend(lua, (vec![tbl.get(key.clone())?, t], Some(ConflictMode::Keep)))?;
let inner_tbl = deep_extend(
lua,
(vec![tbl.get(key.clone())?, t], Some(ConflictMode::Keep)),
)?;
tbl.set(key, inner_tbl)?;
}
}
Expand All @@ -130,7 +139,10 @@ fn deep_extend<'lua>(
if !tbl.contains_key(key.clone())? {
tbl.set(key, value)?;
} else if let LuaValue::Table(t) = value {
let inner_tbl = deep_extend(lua, (vec![tbl.get(key.clone())?, t], Some(ConflictMode::Force)))?;
let inner_tbl = deep_extend(
lua,
(vec![tbl.get(key.clone())?, t], Some(ConflictMode::Force)),
)?;
tbl.set(key, inner_tbl)?;
} else {
tbl.set(key, value)?;
Expand All @@ -142,10 +154,16 @@ fn deep_extend<'lua>(
if !tbl.contains_key(key.clone())? {
tbl.set(key, value)?;
} else if let LuaValue::Table(t) = value {
let inner_tbl = deep_extend(lua, (vec![tbl.get(key.clone())?, t], Some(ConflictMode::Keep)))?;
let inner_tbl = deep_extend(
lua,
(vec![tbl.get(key.clone())?, t], Some(ConflictMode::Keep)),
)?;
tbl.set(key, inner_tbl)?;
} else {
return Err(mlua::Error::runtime(format!("The key {} is in more than one of the tables.", key.to_string()?)));
return Err(mlua::Error::runtime(format!(
"The key {} is in more than one of the tables.",
key.to_string()?
)));
}
}
}
Expand All @@ -154,9 +172,20 @@ fn deep_extend<'lua>(
Ok(tbl)
}

fn clone<'lua>(lua: &'lua Lua, table: Table<'lua>) -> mlua::Result<Table<'lua>> {
let table_len = table.clone().pairs::<LuaValue, LuaValue>().count();
let res: Table<'lua> = lua.create_table_with_capacity(0, table_len)?;

fn clone<'lua>(_: &'lua Lua, table: Table<'lua>) -> mlua::Result<Table<'lua>> {
Ok(table.clone())
for pair in table.pairs::<LuaValue, LuaValue>() {
let (key, value) = pair?;
if let LuaValue::Table(tbl) = value {
let inner_res = clone(lua, tbl)?;
res.set(key, inner_res)?;
} else {
res.set(key, value)?;
}
}
Ok(res)
}

fn flatten<'lua>(lua: &'lua Lua, arrays: Vec<LuaValue<'lua>>) -> mlua::Result<Vec<LuaValue<'lua>>> {
Expand Down Expand Up @@ -203,6 +232,34 @@ fn has_value<'lua>(_: &'lua Lua, (table, value): (Table<'lua>, LuaValue)) -> mlu
Ok(false)
}

fn equal<'lua>(lua: &'lua Lua, (table1, table2): (Table<'lua>, Table<'lua>)) -> mlua::Result<bool> {
let mut res = true;

// check if the tables are the same length to ensure we don't miss anything in table2
// when we only loop through table1
let table1_len = table2.clone().pairs::<LuaValue, LuaValue>().count();
let table2_len = table2.clone().pairs::<LuaValue, LuaValue>().count();
if table1_len != table2_len {
return Ok(false);
}

for pair in table1.pairs::<LuaValue, LuaValue>() {
let (key, value1) = pair?;
let value2 = table2.get(key.clone())?;
if let LuaValue::Table(tbl1) = value1.clone() {
if let LuaValue::Table(tbl2) = value2 {
res = equal(lua, (tbl1, tbl2))?;
} else {
return Ok(false);
}
} else {
res = value1.eq(&value2);
}
}

Ok(res)
}

fn to_string_fallback<'lua>(_: &'lua Lua, table: Table<'lua>) -> mlua::Result<String> {
Ok(format!("{:#?}", table))
}
Expand Down

0 comments on commit cfe4705

Please sign in to comment.