Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adapter/storage: Cast MySQL bit columns to uint8, add convience functions #31097

Merged
merged 8 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/environmentd/tests/testdata/http/ws

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/expr/src/scalar.proto
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ message ProtoUnaryFunc {
ProtoToCharTimestamp to_char_timestamp = 331;
ProtoToCharTimestamp to_char_timestamp_tz = 332;
google.protobuf.Empty cast_date_to_mz_timestamp = 333;
google.protobuf.Empty bit_count_bytes = 334;
}
}

Expand Down Expand Up @@ -668,6 +669,7 @@ message ProtoBinaryFunc {
bool list_contains_list = 193;
bool array_contains_array = 194;
google.protobuf.Empty starts_with = 195;
google.protobuf.Empty get_bit = 196;
}
}

Expand Down
36 changes: 35 additions & 1 deletion src/expr/src/scalar/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,27 @@ fn power_numeric<'a>(a: Datum<'a>, b: Datum<'a>) -> Result<Datum<'a>, EvalError>
}
}

fn get_bit<'a>(a: Datum<'a>, b: Datum<'a>) -> Result<Datum<'a>, EvalError> {
let bytes = a.unwrap_bytes();
let index = b.unwrap_int32();
let err = EvalError::IndexOutOfRange {
provided: index,
valid_end: i32::try_from(bytes.len().saturating_mul(8)).unwrap() - 1,
};

let index = usize::try_from(index).map_err(|_| err.clone())?;

let byte_index = index / 8;
let bit_index = index % 8;

let i = bytes
.get(byte_index)
.map(|b| (*b >> bit_index) & 1)
.ok_or(err)?;
assert!(i == 0 || i == 1);
Ok(Datum::from(i32::from(i)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding correctly, the following might be useful: assert!(i == 0 || i == 1)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call!

}

fn get_byte<'a>(a: Datum<'a>, b: Datum<'a>) -> Result<Datum<'a>, EvalError> {
let bytes = a.unwrap_bytes();
let index = b.unwrap_int32();
Expand Down Expand Up @@ -2344,6 +2365,7 @@ pub enum BinaryFunc {
LogNumeric,
Power,
PowerNumeric,
GetBit,
GetByte,
ConstantTimeEqBytes,
ConstantTimeEqString,
Expand Down Expand Up @@ -2607,6 +2629,7 @@ impl BinaryFunc {
BinaryFunc::Power => power(a, b),
BinaryFunc::PowerNumeric => power_numeric(a, b),
BinaryFunc::RepeatString => repeat_string(a, b, temp_storage),
BinaryFunc::GetBit => get_bit(a, b),
BinaryFunc::GetByte => get_byte(a, b),
BinaryFunc::ConstantTimeEqBytes => constant_time_eq_bytes(a, b),
BinaryFunc::ConstantTimeEqString => constant_time_eq_string(a, b),
Expand Down Expand Up @@ -2804,6 +2827,7 @@ impl BinaryFunc {
ScalarType::Numeric { max_scale: None }.nullable(in_nullable)
}

GetBit => ScalarType::Int32.nullable(in_nullable),
GetByte => ScalarType::Int32.nullable(in_nullable),

ConstantTimeEqBytes | ConstantTimeEqString => {
Expand Down Expand Up @@ -3023,6 +3047,7 @@ impl BinaryFunc {
| LogNumeric
| Power
| PowerNumeric
| GetBit
| GetByte
| RangeContainsElem { .. }
| RangeContainsRange { .. }
Expand Down Expand Up @@ -3241,6 +3266,7 @@ impl BinaryFunc {
| ListRemove
| LikeEscape
| UuidGenerateV5
| GetBit
| GetByte
| MzAclItemContainsPrivilege
| ConstantTimeEqBytes
Expand Down Expand Up @@ -3508,7 +3534,8 @@ impl BinaryFunc {
| BinaryFunc::Decode => (false, false),
// TODO: it may be safe to treat these as monotone.
BinaryFunc::LogNumeric | BinaryFunc::Power | BinaryFunc::PowerNumeric => (false, false),
BinaryFunc::GetByte
BinaryFunc::GetBit
| BinaryFunc::GetByte
| BinaryFunc::RangeContainsElem { .. }
| BinaryFunc::RangeContainsRange { .. }
| BinaryFunc::RangeOverlaps
Expand Down Expand Up @@ -3716,6 +3743,7 @@ impl fmt::Display for BinaryFunc {
BinaryFunc::Power => f.write_str("power"),
BinaryFunc::PowerNumeric => f.write_str("power_numeric"),
BinaryFunc::RepeatString => f.write_str("repeat"),
BinaryFunc::GetBit => f.write_str("get_bit"),
BinaryFunc::GetByte => f.write_str("get_byte"),
BinaryFunc::ConstantTimeEqBytes => f.write_str("constant_time_compare_bytes"),
BinaryFunc::ConstantTimeEqString => f.write_str("constant_time_compare_strings"),
Expand Down Expand Up @@ -4140,6 +4168,7 @@ impl RustType<ProtoBinaryFunc> for BinaryFunc {
BinaryFunc::LogNumeric => LogNumeric(()),
BinaryFunc::Power => Power(()),
BinaryFunc::PowerNumeric => PowerNumeric(()),
BinaryFunc::GetBit => GetBit(()),
BinaryFunc::GetByte => GetByte(()),
BinaryFunc::RangeContainsElem { elem_type, rev } => {
RangeContainsElem(crate::scalar::proto_binary_func::ProtoRangeContainsInner {
Expand Down Expand Up @@ -4360,6 +4389,7 @@ impl RustType<ProtoBinaryFunc> for BinaryFunc {
LogNumeric(()) => Ok(BinaryFunc::LogNumeric),
Power(()) => Ok(BinaryFunc::Power),
PowerNumeric(()) => Ok(BinaryFunc::PowerNumeric),
GetBit(()) => Ok(BinaryFunc::GetBit),
GetByte(()) => Ok(BinaryFunc::GetByte),
RangeContainsElem(inner) => Ok(BinaryFunc::RangeContainsElem {
elem_type: inner
Expand Down Expand Up @@ -4799,6 +4829,7 @@ derive_unary!(
FloorFloat64,
FloorNumeric,
Ascii,
BitCountBytes,
BitLengthBytes,
BitLengthString,
ByteLengthBytes,
Expand Down Expand Up @@ -5209,6 +5240,7 @@ impl Arbitrary for UnaryFunc {
FloorFloat64::arbitrary().prop_map_into().boxed(),
FloorNumeric::arbitrary().prop_map_into().boxed(),
Ascii::arbitrary().prop_map_into().boxed(),
BitCountBytes::arbitrary().prop_map_into().boxed(),
BitLengthBytes::arbitrary().prop_map_into().boxed(),
BitLengthString::arbitrary().prop_map_into().boxed(),
ByteLengthBytes::arbitrary().prop_map_into().boxed(),
Expand Down Expand Up @@ -5597,6 +5629,7 @@ impl RustType<ProtoUnaryFunc> for UnaryFunc {
UnaryFunc::FloorFloat64(_) => FloorFloat64(()),
UnaryFunc::FloorNumeric(_) => FloorNumeric(()),
UnaryFunc::Ascii(_) => Ascii(()),
UnaryFunc::BitCountBytes(_) => BitCountBytes(()),
UnaryFunc::BitLengthBytes(_) => BitLengthBytes(()),
UnaryFunc::BitLengthString(_) => BitLengthString(()),
UnaryFunc::ByteLengthBytes(_) => ByteLengthBytes(()),
Expand Down Expand Up @@ -6071,6 +6104,7 @@ impl RustType<ProtoUnaryFunc> for UnaryFunc {
FloorFloat64(_) => Ok(impls::FloorFloat64.into()),
FloorNumeric(_) => Ok(impls::FloorNumeric.into()),
Ascii(_) => Ok(impls::Ascii.into()),
BitCountBytes(_) => Ok(impls::BitCountBytes.into()),
BitLengthBytes(_) => Ok(impls::BitLengthBytes.into()),
BitLengthString(_) => Ok(impls::BitLengthString.into()),
ByteLengthBytes(_) => Ok(impls::ByteLengthBytes.into()),
Expand Down
9 changes: 9 additions & 0 deletions src/expr/src/scalar/func/impls/byte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use mz_ore::cast::CastFrom;
use mz_repr::strconv;

use crate::EvalError;
Expand Down Expand Up @@ -64,6 +65,14 @@ sqlfunc!(
}
);

sqlfunc!(
#[sqlname = "bit_count"]
fn bit_count_bytes<'a>(a: &'a [u8]) -> Result<i64, EvalError> {
let count: u64 = a.iter().map(|b| u64::cast_from(b.count_ones())).sum();
i64::try_from(count).or(Err(EvalError::Int64OutOfRange(count.to_string().into())))
}
);

sqlfunc!(
#[sqlname = "bit_length"]
fn bit_length_bytes<'a>(a: &'a [u8]) -> Result<i32, EvalError> {
Expand Down
31 changes: 30 additions & 1 deletion src/mysql-util/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,35 @@ fn pack_val_as_datum(
ScalarType::Int16 => packer.push(Datum::from(from_value_opt::<i16>(value)?)),
ScalarType::UInt32 => packer.push(Datum::from(from_value_opt::<u32>(value)?)),
ScalarType::Int32 => packer.push(Datum::from(from_value_opt::<i32>(value)?)),
ScalarType::UInt64 => packer.push(Datum::from(from_value_opt::<u64>(value)?)),
ScalarType::UInt64 => {
if let Some(MySqlColumnMeta::Bit(precision)) = &col_desc.meta {
let mut value = from_value_opt::<Vec<u8>>(value)?;

// Ensure we have the correct number of bytes.
let precision_bytes = (precision + 7) / 8;
if value.len() != usize::cast_from(precision_bytes) {
return Err(anyhow::anyhow!("'bit' column out of range!"));
}
// Be defensive and prune any bits that come over the wire and are
// greater than our precision.
let bit_index = precision % 8;
if bit_index != 0 {
let mask = !(u8::MAX << bit_index);
if value.len() > 0 {
value[0] &= mask;
}
}

// Based on experimentation the value coming across the wire is
// encoded in big-endian.
let mut buf = [0u8; 8];
buf[(8 - value.len())..].copy_from_slice(value.as_slice());
let value = u64::from_be_bytes(buf);
packer.push(Datum::from(value))
} else {
packer.push(Datum::from(from_value_opt::<u64>(value)?))
}
}
ScalarType::Int64 => packer.push(Datum::from(from_value_opt::<i64>(value)?)),
ScalarType::Float32 => packer.push(Datum::from(from_value_opt::<f32>(value)?)),
ScalarType::Float64 => packer.push(Datum::from(from_value_opt::<f64>(value)?)),
Expand Down Expand Up @@ -198,6 +226,7 @@ fn pack_val_as_datum(
))?;
}
}
Some(MySqlColumnMeta::Bit(_)) => unreachable!("parsed as a u64"),
None => {
packer.push(Datum::String(&from_value_opt::<String>(value)?));
}
Expand Down
5 changes: 5 additions & 0 deletions src/mysql-util/src/desc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ message ProtoMySqlColumnMetaTimestamp {
uint32 precision = 1;
}

message ProtoMySqlColumnMetaBit {
uint32 precision = 1;
}

message ProtoMySqlColumnDesc {
string name = 1;
optional mz_repr.relation_and_scalar.ProtoColumnType column_type = 2;
Expand All @@ -44,6 +48,7 @@ message ProtoMySqlColumnDesc {
ProtoMySqlColumnMetaYear year = 5;
ProtoMySqlColumnMetaDate date = 6;
ProtoMySqlColumnMetaTimestamp timestamp = 7;
ProtoMySqlColumnMetaBit bit = 8;
}
}

Expand Down
9 changes: 9 additions & 0 deletions src/mysql-util/src/desc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ pub enum MySqlColumnMeta {
Date,
/// The described column is a timestamp value with a set precision.
Timestamp(u32),
/// The described column is a `bit` column, with the given possibly precision.
Bit(u32),
}

impl IsCompatible for Option<MySqlColumnMeta> {
Expand All @@ -195,6 +197,9 @@ impl IsCompatible for Option<MySqlColumnMeta> {
Some(MySqlColumnMeta::Timestamp(precision)),
Some(MySqlColumnMeta::Timestamp(other_precision)),
) => precision <= other_precision,
// We always cast bit columns to u64's and the max precision of a bit column
// is 64 bits, so any bit column is always compatible with another.
(Some(MySqlColumnMeta::Bit(_)), Some(MySqlColumnMeta::Bit(_))) => true,
_ => false,
}
}
Expand Down Expand Up @@ -226,6 +231,9 @@ impl RustType<ProtoMySqlColumnDesc> for MySqlColumnDesc {
precision: *precision,
}))
}
MySqlColumnMeta::Bit(precision) => Some(Meta::Bit(ProtoMySqlColumnMetaBit {
precision: *precision,
})),
}),
}
}
Expand All @@ -245,6 +253,7 @@ impl RustType<ProtoMySqlColumnDesc> for MySqlColumnDesc {
Meta::Year(_) => Some(Ok(MySqlColumnMeta::Year)),
Meta::Date(_) => Some(Ok(MySqlColumnMeta::Date)),
Meta::Timestamp(e) => Some(Ok(MySqlColumnMeta::Timestamp(e.precision))),
Meta::Bit(e) => Some(Ok(MySqlColumnMeta::Bit(e.precision))),
})
.transpose()?,
})
Expand Down
Loading
Loading