diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index e22d3b9007..c4bfe5db48 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -32,6 +32,20 @@ impl_type_checking!( sqlx::postgres::types::PgCube, + sqlx::postgres::types::PgPoint, + + sqlx::postgres::types::PgLine, + + sqlx::postgres::types::PgLSeg, + + sqlx::postgres::types::PgBox, + + sqlx::postgres::types::PgPath, + + sqlx::postgres::types::PgPolygon, + + sqlx::postgres::types::PgCircle, + #[cfg(feature = "uuid")] sqlx::types::Uuid, diff --git a/sqlx-postgres/src/types/geometry/box.rs b/sqlx-postgres/src/types/geometry/box.rs new file mode 100644 index 0000000000..988c028ed4 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/box.rs @@ -0,0 +1,321 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use std::str::FromStr; + +const ERROR: &str = "error decoding BOX"; + +/// ## Postgres Geometric Box type +/// +/// Description: Rectangular box +/// Representation: `((upper_right_x,upper_right_y),(lower_left_x,lower_left_y))` +/// +/// Boxes are represented by pairs of points that are opposite corners of the box. Values of type box are specified using any of the following syntaxes: +/// +/// ```text +/// ( ( upper_right_x , upper_right_y ) , ( lower_left_x , lower_left_y ) ) +/// ( upper_right_x , upper_right_y ) , ( lower_left_x , lower_left_y ) +/// upper_right_x , upper_right_y , lower_left_x , lower_left_y +/// ``` +/// where `(upper_right_x,upper_right_y) and (lower_left_x,lower_left_y)` are any two opposite corners of the box. +/// Any two opposite corners can be supplied on input, but the values will be reordered as needed to store the upper right and lower left corners, in that order. +/// +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-BOXES +#[derive(Debug, Clone, PartialEq)] +pub struct PgBox { + pub upper_right_x: f64, + pub upper_right_y: f64, + pub lower_left_x: f64, + pub lower_left_y: f64, +} + +impl Type for PgBox { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("box") + } +} + +impl PgHasArrayType for PgBox { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_box") + } +} + +impl<'r> Decode<'r, Postgres> for PgBox { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgBox::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgBox::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgBox { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("box")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgBox { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let mut parts = sanitised.split(','); + + let upper_right_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get upper_right_x from {}", ERROR, s))?; + + let upper_right_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get upper_right_y from {}", ERROR, s))?; + + let lower_left_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get lower_left_x from {}", ERROR, s))?; + + let lower_left_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get lower_left_y from {}", ERROR, s))?; + + if parts.next().is_some() { + return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into()); + } + + Ok(PgBox { + upper_right_x, + upper_right_y, + lower_left_x, + lower_left_y, + }) + } +} + +impl PgBox { + fn from_bytes(mut bytes: &[u8]) -> Result { + let upper_right_x = bytes.get_f64(); + let upper_right_y = bytes.get_f64(); + let lower_left_x = bytes.get_f64(); + let lower_left_y = bytes.get_f64(); + + Ok(PgBox { + upper_right_x, + upper_right_y, + lower_left_x, + lower_left_y, + }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let min_x = &self.upper_right_x.min(self.lower_left_x); + let min_y = &self.upper_right_y.min(self.lower_left_y); + let max_x = &self.upper_right_x.max(self.lower_left_x); + let max_y = &self.upper_right_y.max(self.lower_left_y); + + buff.extend_from_slice(&max_x.to_be_bytes()); + buff.extend_from_slice(&max_y.to_be_bytes()); + buff.extend_from_slice(&min_x.to_be_bytes()); + buff.extend_from_slice(&min_y.to_be_bytes()); + + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +#[cfg(test)] +mod box_tests { + + use std::str::FromStr; + + use super::PgBox; + + const BOX_BYTES: &[u8] = &[ + 64, 0, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, + 0, 0, 0, 0, + ]; + + #[test] + fn can_deserialise_box_type_bytes_in_order() { + let pg_box = PgBox::from_bytes(BOX_BYTES).unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 2., + upper_right_y: 2., + lower_left_x: -2., + lower_left_y: -2. + } + ) + } + + #[test] + fn can_deserialise_box_type_str_first_syntax() { + let pg_box = PgBox::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + #[test] + fn can_deserialise_box_type_str_second_syntax() { + let pg_box = PgBox::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + + #[test] + fn can_deserialise_box_type_str_third_syntax() { + let pg_box = PgBox::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + + #[test] + fn can_deserialise_box_type_str_fourth_syntax() { + let pg_box = PgBox::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + + #[test] + fn cannot_deserialise_too_many_numbers() { + let input_str = "1, 2, 3, 4, 5"; + let pg_box = PgBox::from_str(input_str); + assert!(pg_box.is_err()); + if let Err(err) = pg_box { + assert_eq!( + err.to_string(), + format!("error decoding BOX: too many numbers inputted in {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_too_few_numbers() { + let input_str = "1, 2, 3 "; + let pg_box = PgBox::from_str(input_str); + assert!(pg_box.is_err()); + if let Err(err) = pg_box { + assert_eq!( + err.to_string(), + format!("error decoding BOX: could not get lower_left_y from {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_invalid_numbers() { + let input_str = "1, 2, 3, FOUR"; + let pg_box = PgBox::from_str(input_str); + assert!(pg_box.is_err()); + if let Err(err) = pg_box { + assert_eq!( + err.to_string(), + format!("error decoding BOX: could not get lower_left_y from {input_str}") + ) + } + } + + #[test] + fn can_deserialise_box_type_str_float() { + let pg_box = PgBox::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1.1, + upper_right_y: 2.2, + lower_left_x: 3.3, + lower_left_y: 4.4 + } + ); + } + + #[test] + fn can_serialise_box_type_in_order() { + let pg_box = PgBox { + upper_right_x: 2., + lower_left_x: -2., + upper_right_y: -2., + lower_left_y: 2., + }; + assert_eq!(pg_box.serialize_to_vec(), BOX_BYTES,) + } + + #[test] + fn can_serialise_box_type_out_of_order() { + let pg_box = PgBox { + upper_right_x: -2., + lower_left_x: 2., + upper_right_y: 2., + lower_left_y: -2., + }; + assert_eq!(pg_box.serialize_to_vec(), BOX_BYTES,) + } + + #[test] + fn can_order_box() { + let pg_box = PgBox { + upper_right_x: -2., + lower_left_x: 2., + upper_right_y: 2., + lower_left_y: -2., + }; + let bytes = pg_box.serialize_to_vec(); + + let pg_box = PgBox::from_bytes(&bytes).unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 2., + upper_right_y: 2., + lower_left_x: -2., + lower_left_y: -2. + } + ) + } +} diff --git a/sqlx-postgres/src/types/geometry/circle.rs b/sqlx-postgres/src/types/geometry/circle.rs new file mode 100644 index 0000000000..839505c35a --- /dev/null +++ b/sqlx-postgres/src/types/geometry/circle.rs @@ -0,0 +1,247 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use sqlx_core::Error; +use std::str::FromStr; + +const ERROR: &str = "error decoding CIRCLE"; + +/// ## Postgres Geometric Circle type +/// +/// Description: Circle +/// Representation: `< (x, y), radius >` (center point and radius) +/// +/// ```text +/// < ( x , y ) , radius > +/// ( ( x , y ) , radius ) +/// ( x , y ) , radius +/// x , y , radius +/// ``` +/// where `(x,y)` is the center point. +/// +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-CIRCLE +#[derive(Debug, Clone, PartialEq)] +pub struct PgCircle { + pub x: f64, + pub y: f64, + pub radius: f64, +} + +impl Type for PgCircle { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("circle") + } +} + +impl PgHasArrayType for PgCircle { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_circle") + } +} + +impl<'r> Decode<'r, Postgres> for PgCircle { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgCircle::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgCircle::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgCircle { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("circle")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgCircle { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['<', '>', '(', ')', ' '], ""); + let mut parts = sanitised.split(','); + + let x = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get x from {}", ERROR, s))?; + + let y = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get y from {}", ERROR, s))?; + + let radius = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get radius from {}", ERROR, s))?; + + if parts.next().is_some() { + return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into()); + } + + if radius < 0. { + return Err(format!("{}: cannot have negative radius: {}", ERROR, s).into()); + } + + Ok(PgCircle { x, y, radius }) + } +} + +impl PgCircle { + fn from_bytes(mut bytes: &[u8]) -> Result { + let x = bytes.get_f64(); + let y = bytes.get_f64(); + let r = bytes.get_f64(); + Ok(PgCircle { x, y, radius: r }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), Error> { + buff.extend_from_slice(&self.x.to_be_bytes()); + buff.extend_from_slice(&self.y.to_be_bytes()); + buff.extend_from_slice(&self.radius.to_be_bytes()); + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +#[cfg(test)] +mod circle_tests { + + use std::str::FromStr; + + use super::PgCircle; + + const CIRCLE_BYTES: &[u8] = &[ + 63, 241, 153, 153, 153, 153, 153, 154, 64, 1, 153, 153, 153, 153, 153, 154, 64, 10, 102, + 102, 102, 102, 102, 102, + ]; + + #[test] + fn can_deserialise_circle_type_bytes() { + let circle = PgCircle::from_bytes(CIRCLE_BYTES).unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.1, + y: 2.2, + radius: 3.3 + } + ) + } + + #[test] + fn can_deserialise_circle_type_str() { + let circle = PgCircle::from_str("<(1, 2), 3 >").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn can_deserialise_circle_type_str_second_syntax() { + let circle = PgCircle::from_str("((1, 2), 3 )").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn can_deserialise_circle_type_str_third_syntax() { + let circle = PgCircle::from_str("(1, 2), 3 ").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn can_deserialise_circle_type_str_fourth_syntax() { + let circle = PgCircle::from_str("1, 2, 3 ").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn cannot_deserialise_circle_invalid_numbers() { + let input_str = "1, 2, Three"; + let circle = PgCircle::from_str(input_str); + assert!(circle.is_err()); + if let Err(err) = circle { + assert_eq!( + err.to_string(), + format!("error decoding CIRCLE: could not get radius from {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_circle_negative_radius() { + let input_str = "1, 2, -3"; + let circle = PgCircle::from_str(input_str); + assert!(circle.is_err()); + if let Err(err) = circle { + assert_eq!( + err.to_string(), + format!("error decoding CIRCLE: cannot have negative radius: {input_str}") + ) + } + } + + #[test] + fn can_deserialise_circle_type_str_float() { + let circle = PgCircle::from_str("<(1.1, 2.2), 3.3>").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.1, + y: 2.2, + radius: 3.3 + } + ); + } + + #[test] + fn can_serialise_circle_type() { + let circle = PgCircle { + x: 1.1, + y: 2.2, + radius: 3.3, + }; + assert_eq!(circle.serialize_to_vec(), CIRCLE_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/geometry/line.rs b/sqlx-postgres/src/types/geometry/line.rs new file mode 100644 index 0000000000..43f93c1c33 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/line.rs @@ -0,0 +1,211 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use std::str::FromStr; + +const ERROR: &str = "error decoding LINE"; + +/// ## Postgres Geometric Line type +/// +/// Description: Infinite line +/// Representation: `{A, B, C}` +/// +/// Lines are represented by the linear equation Ax + By + C = 0, where A and B are not both zero. +/// +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LINE +#[derive(Debug, Clone, PartialEq)] +pub struct PgLine { + pub a: f64, + pub b: f64, + pub c: f64, +} + +impl Type for PgLine { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("line") + } +} + +impl PgHasArrayType for PgLine { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_line") + } +} + +impl<'r> Decode<'r, Postgres> for PgLine { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgLine::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgLine::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgLine { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("line")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgLine { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let mut parts = s + .trim_matches(|c| c == '{' || c == '}' || c == ' ') + .split(','); + + let a = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get a from {}", ERROR, s))?; + + let b = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get b from {}", ERROR, s))?; + + let c = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get c from {}", ERROR, s))?; + + if parts.next().is_some() { + return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into()); + } + + Ok(PgLine { a, b, c }) + } +} + +impl PgLine { + fn from_bytes(mut bytes: &[u8]) -> Result { + let a = bytes.get_f64(); + let b = bytes.get_f64(); + let c = bytes.get_f64(); + Ok(PgLine { a, b, c }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { + buff.extend_from_slice(&self.a.to_be_bytes()); + buff.extend_from_slice(&self.b.to_be_bytes()); + buff.extend_from_slice(&self.c.to_be_bytes()); + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +#[cfg(test)] +mod line_tests { + + use std::str::FromStr; + + use super::PgLine; + + const LINE_BYTES: &[u8] = &[ + 63, 241, 153, 153, 153, 153, 153, 154, 64, 1, 153, 153, 153, 153, 153, 154, 64, 10, 102, + 102, 102, 102, 102, 102, + ]; + + #[test] + fn can_deserialise_line_type_bytes() { + let line = PgLine::from_bytes(LINE_BYTES).unwrap(); + assert_eq!( + line, + PgLine { + a: 1.1, + b: 2.2, + c: 3.3 + } + ) + } + + #[test] + fn can_deserialise_line_type_str() { + let line = PgLine::from_str("{ 1, 2, 3 }").unwrap(); + assert_eq!( + line, + PgLine { + a: 1.0, + b: 2.0, + c: 3.0 + } + ); + } + + #[test] + fn cannot_deserialise_line_too_few_numbers() { + let input_str = "{ 1, 2 }"; + let line = PgLine::from_str(input_str); + assert!(line.is_err()); + if let Err(err) = line { + assert_eq!( + err.to_string(), + format!("error decoding LINE: could not get c from {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_line_too_many_numbers() { + let input_str = "{ 1, 2, 3, 4 }"; + let line = PgLine::from_str(input_str); + assert!(line.is_err()); + if let Err(err) = line { + assert_eq!( + err.to_string(), + format!("error decoding LINE: too many numbers inputted in {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_line_invalid_numbers() { + let input_str = "{ 1, 2, three }"; + let line = PgLine::from_str(input_str); + assert!(line.is_err()); + if let Err(err) = line { + assert_eq!( + err.to_string(), + format!("error decoding LINE: could not get c from {input_str}") + ) + } + } + + #[test] + fn can_deserialise_line_type_str_float() { + let line = PgLine::from_str("{1.1, 2.2, 3.3}").unwrap(); + assert_eq!( + line, + PgLine { + a: 1.1, + b: 2.2, + c: 3.3 + } + ); + } + + #[test] + fn can_serialise_line_type() { + let line = PgLine { + a: 1.1, + b: 2.2, + c: 3.3, + }; + assert_eq!(line.serialize_to_vec(), LINE_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/geometry/line_segment.rs b/sqlx-postgres/src/types/geometry/line_segment.rs new file mode 100644 index 0000000000..ebe32d97d0 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/line_segment.rs @@ -0,0 +1,282 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use std::str::FromStr; + +const ERROR: &str = "error decoding LSEG"; + +/// ## Postgres Geometric Line Segment type +/// +/// Description: Finite line segment +/// Representation: `((start_x,start_y),(end_x,end_y))` +/// +/// +/// Line segments are represented by pairs of points that are the endpoints of the segment. Values of type lseg are specified using any of the following syntaxes: +/// ```text +/// [ ( start_x , start_y ) , ( end_x , end_y ) ] +/// ( ( start_x , start_y ) , ( end_x , end_y ) ) +/// ( start_x , start_y ) , ( end_x , end_y ) +/// start_x , start_y , end_x , end_y +/// ``` +/// where `(start_x,start_y) and (end_x,end_y)` are the end points of the line segment. +/// +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LSEG +#[derive(Debug, Clone, PartialEq)] +pub struct PgLSeg { + pub start_x: f64, + pub start_y: f64, + pub end_x: f64, + pub end_y: f64, +} + +impl Type for PgLSeg { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("lseg") + } +} + +impl PgHasArrayType for PgLSeg { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_lseg") + } +} + +impl<'r> Decode<'r, Postgres> for PgLSeg { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgLSeg::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgLSeg::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgLSeg { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("lseg")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgLSeg { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let mut parts = sanitised.split(','); + + let start_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get start_x from {}", ERROR, s))?; + + let start_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get start_y from {}", ERROR, s))?; + + let end_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get end_x from {}", ERROR, s))?; + + let end_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get end_y from {}", ERROR, s))?; + + if parts.next().is_some() { + return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into()); + } + + Ok(PgLSeg { + start_x, + start_y, + end_x, + end_y, + }) + } +} + +impl PgLSeg { + fn from_bytes(mut bytes: &[u8]) -> Result { + let start_x = bytes.get_f64(); + let start_y = bytes.get_f64(); + let end_x = bytes.get_f64(); + let end_y = bytes.get_f64(); + + Ok(PgLSeg { + start_x, + start_y, + end_x, + end_y, + }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { + buff.extend_from_slice(&self.start_x.to_be_bytes()); + buff.extend_from_slice(&self.start_y.to_be_bytes()); + buff.extend_from_slice(&self.end_x.to_be_bytes()); + buff.extend_from_slice(&self.end_y.to_be_bytes()); + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +#[cfg(test)] +mod lseg_tests { + + use std::str::FromStr; + + use super::PgLSeg; + + const LINE_SEGMENT_BYTES: &[u8] = &[ + 63, 241, 153, 153, 153, 153, 153, 154, 64, 1, 153, 153, 153, 153, 153, 154, 64, 10, 102, + 102, 102, 102, 102, 102, 64, 17, 153, 153, 153, 153, 153, 154, + ]; + + #[test] + fn can_deserialise_lseg_type_bytes() { + let lseg = PgLSeg::from_bytes(LINE_SEGMENT_BYTES).unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1.1, + start_y: 2.2, + end_x: 3.3, + end_y: 4.4 + } + ) + } + + #[test] + fn can_deserialise_lseg_type_str_first_syntax() { + let lseg = PgLSeg::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + #[test] + fn can_deserialise_lseg_type_str_second_syntax() { + let lseg = PgLSeg::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + + #[test] + fn can_deserialise_lseg_type_str_third_syntax() { + let lseg = PgLSeg::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + + #[test] + fn can_deserialise_lseg_type_str_fourth_syntax() { + let lseg = PgLSeg::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + + #[test] + fn can_deserialise_too_many_numbers() { + let input_str = "1, 2, 3, 4, 5"; + let lseg = PgLSeg::from_str(input_str); + assert!(lseg.is_err()); + if let Err(err) = lseg { + assert_eq!( + err.to_string(), + format!("error decoding LSEG: too many numbers inputted in {input_str}") + ) + } + } + + #[test] + fn can_deserialise_too_few_numbers() { + let input_str = "1, 2, 3"; + let lseg = PgLSeg::from_str(input_str); + assert!(lseg.is_err()); + if let Err(err) = lseg { + assert_eq!( + err.to_string(), + format!("error decoding LSEG: could not get end_y from {input_str}") + ) + } + } + + #[test] + fn can_deserialise_invalid_numbers() { + let input_str = "1, 2, 3, FOUR"; + let lseg = PgLSeg::from_str(input_str); + assert!(lseg.is_err()); + if let Err(err) = lseg { + assert_eq!( + err.to_string(), + format!("error decoding LSEG: could not get end_y from {input_str}") + ) + } + } + + #[test] + fn can_deserialise_lseg_type_str_float() { + let lseg = PgLSeg::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1.1, + start_y: 2.2, + end_x: 3.3, + end_y: 4.4 + } + ); + } + + #[test] + fn can_serialise_lseg_type() { + let lseg = PgLSeg { + start_x: 1.1, + start_y: 2.2, + end_x: 3.3, + end_y: 4.4, + }; + assert_eq!(lseg.serialize_to_vec(), LINE_SEGMENT_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/geometry/mod.rs b/sqlx-postgres/src/types/geometry/mod.rs new file mode 100644 index 0000000000..c3142145ee --- /dev/null +++ b/sqlx-postgres/src/types/geometry/mod.rs @@ -0,0 +1,7 @@ +pub mod r#box; +pub mod circle; +pub mod line; +pub mod line_segment; +pub mod path; +pub mod point; +pub mod polygon; diff --git a/sqlx-postgres/src/types/geometry/path.rs b/sqlx-postgres/src/types/geometry/path.rs new file mode 100644 index 0000000000..87a3b3e8d3 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/path.rs @@ -0,0 +1,372 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::{PgPoint, Type}; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use sqlx_core::Error; +use std::mem; +use std::str::FromStr; + +const BYTE_WIDTH: usize = mem::size_of::(); + +/// ## Postgres Geometric Path type +/// +/// Description: Open path or Closed path (similar to polygon) +/// Representation: Open `[(x1,y1),...]`, Closed `((x1,y1),...)` +/// +/// Paths are represented by lists of connected points. Paths can be open, where the first and last points in the list are considered not connected, or closed, where the first and last points are considered connected. +/// Values of type path are specified using any of the following syntaxes: +/// ```text +/// [ ( x1 , y1 ) , ... , ( xn , yn ) ] +/// ( ( x1 , y1 ) , ... , ( xn , yn ) ) +/// ( x1 , y1 ) , ... , ( xn , yn ) +/// ( x1 , y1 , ... , xn , yn ) +/// x1 , y1 , ... , xn , yn +/// ``` +/// where the points are the end points of the line segments comprising the path. Square brackets `([])` indicate an open path, while parentheses `(())` indicate a closed path. +/// When the outermost parentheses are omitted, as in the third through fifth syntaxes, a closed path is assumed. +/// +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-PATHS +#[derive(Debug, Clone, PartialEq)] +pub struct PgPath { + pub closed: bool, + pub points: Vec, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +struct Header { + is_closed: bool, + length: usize, +} + +impl Type for PgPath { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("path") + } +} + +impl PgHasArrayType for PgPath { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_path") + } +} + +impl<'r> Decode<'r, Postgres> for PgPath { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgPath::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgPath::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgPath { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("path")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgPath { + type Err = Error; + + fn from_str(s: &str) -> Result { + let closed = !s.contains('['); + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let parts = sanitised.split(',').collect::>(); + + let mut points = vec![]; + + if parts.len() % 2 != 0 { + return Err(Error::Decode( + format!("Unmatched pair in PATH: {}", s).into(), + )); + } + + for chunk in parts.chunks_exact(2) { + if let [x_str, y_str] = chunk { + let x = parse_float_from_str(x_str, "could not get x")?; + let y = parse_float_from_str(y_str, "could not get y")?; + + let point = PgPoint { x, y }; + points.push(point); + } + } + + if !points.is_empty() { + return Ok(PgPath { points, closed }); + } + + Err(Error::Decode( + format!("could not get path from {}", s).into(), + )) + } +} + +impl PgPath { + fn header(&self) -> Header { + Header { + is_closed: self.closed, + length: self.points.len(), + } + } + + fn from_bytes(mut bytes: &[u8]) -> Result { + let header = Header::try_read(&mut bytes)?; + + if bytes.len() != header.data_size() { + return Err(format!( + "expected {} bytes after header, got {}", + header.data_size(), + bytes.len() + ) + .into()); + } + + if bytes.len() % BYTE_WIDTH * 2 != 0 { + return Err(format!( + "data length not divisible by pairs of {BYTE_WIDTH}: {}", + bytes.len() + ) + .into()); + } + + let mut out_points = Vec::with_capacity(bytes.len() / (BYTE_WIDTH * 2)); + + while bytes.has_remaining() { + let point = PgPoint { + x: bytes.get_f64(), + y: bytes.get_f64(), + }; + out_points.push(point) + } + Ok(PgPath { + closed: header.is_closed, + points: out_points, + }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { + let header = self.header(); + buff.reserve(header.data_size()); + header.try_write(buff)?; + + for point in &self.points { + buff.extend_from_slice(&point.x.to_be_bytes()); + buff.extend_from_slice(&point.y.to_be_bytes()); + } + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +impl Header { + const HEADER_WIDTH: usize = mem::size_of::() + mem::size_of::(); + + fn data_size(&self) -> usize { + self.length * BYTE_WIDTH * 2 + } + + fn try_read(buf: &mut &[u8]) -> Result { + if buf.len() < Self::HEADER_WIDTH { + return Err(format!( + "expected PATH data to contain at least {} bytes, got {}", + Self::HEADER_WIDTH, + buf.len() + )); + } + + let is_closed = buf.get_i8(); + let length = buf.get_i32(); + + let length = usize::try_from(length).ok().ok_or_else(|| { + format!( + "received PATH data length: {length}. Expected length between 0 and {}", + usize::MAX + ) + })?; + + Ok(Self { + is_closed: is_closed != 0, + length, + }) + } + + fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let is_closed = self.is_closed as i8; + + let length = i32::try_from(self.length).map_err(|_| { + format!( + "PATH length exceeds allowed maximum ({} > {})", + self.length, + i32::MAX + ) + })?; + + buff.extend(is_closed.to_be_bytes()); + buff.extend(length.to_be_bytes()); + + Ok(()) + } +} + +fn parse_float_from_str(s: &str, error_msg: &str) -> Result { + s.parse().map_err(|_| Error::Decode(error_msg.into())) +} + +#[cfg(test)] +mod path_tests { + + use std::str::FromStr; + + use crate::types::PgPoint; + + use super::PgPath; + + const PATH_CLOSED_BYTES: &[u8] = &[ + 1, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, + 64, 16, 0, 0, 0, 0, 0, 0, + ]; + + const PATH_OPEN_BYTES: &[u8] = &[ + 0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, + 64, 16, 0, 0, 0, 0, 0, 0, + ]; + + const PATH_UNEVEN_POINTS: &[u8] = &[ + 0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, + 64, 16, 0, 0, + ]; + + #[test] + fn can_deserialise_path_type_bytes_closed() { + let path = PgPath::from_bytes(PATH_CLOSED_BYTES).unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }] + } + ) + } + + #[test] + fn cannot_deserialise_path_type_uneven_point_bytes() { + let path = PgPath::from_bytes(PATH_UNEVEN_POINTS); + assert!(path.is_err()); + + if let Err(err) = path { + assert_eq!( + err.to_string(), + format!("expected 32 bytes after header, got 28") + ) + } + } + + #[test] + fn can_deserialise_path_type_bytes_open() { + let path = PgPath::from_bytes(PATH_OPEN_BYTES).unwrap(); + assert_eq!( + path, + PgPath { + closed: false, + points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }] + } + ) + } + + #[test] + fn can_deserialise_path_type_str_first_syntax() { + let path = PgPath::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + path, + PgPath { + closed: false, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn cannot_deserialise_path_type_str_uneven_points_first_syntax() { + let input_str = "[( 1, 2), (3)]"; + let path = PgPath::from_str(input_str); + + assert!(path.is_err()); + + if let Err(err) = path { + assert_eq!( + err.to_string(), + format!("error occurred while decoding: Unmatched pair in PATH: {input_str}") + ) + } + } + + #[test] + fn can_deserialise_path_type_str_second_syntax() { + let path = PgPath::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_path_type_str_third_syntax() { + let path = PgPath::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_path_type_str_fourth_syntax() { + let path = PgPath::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_path_type_str_float() { + let path = PgPath::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1.1, y: 2.2 }, PgPoint { x: 3.3, y: 4.4 }] + } + ); + } + + #[test] + fn can_serialise_path_type() { + let path = PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }], + }; + assert_eq!(path.serialize_to_vec(), PATH_CLOSED_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/geometry/point.rs b/sqlx-postgres/src/types/geometry/point.rs new file mode 100644 index 0000000000..cc10672950 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/point.rs @@ -0,0 +1,138 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use sqlx_core::Error; +use std::str::FromStr; + +/// ## Postgres Geometric Point type +/// +/// Description: Point on a plane +/// Representation: `(x, y)` +/// +/// Points are the fundamental two-dimensional building block for geometric types. Values of type point are specified using either of the following syntaxes: +/// ```text +/// ( x , y ) +/// x , y +/// ```` +/// where x and y are the respective coordinates, as floating-point numbers. +/// +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-POINTS +#[derive(Debug, Clone, PartialEq)] +pub struct PgPoint { + pub x: f64, + pub y: f64, +} + +impl Type for PgPoint { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("point") + } +} + +impl PgHasArrayType for PgPoint { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_point") + } +} + +impl<'r> Decode<'r, Postgres> for PgPoint { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgPoint::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgPoint::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgPoint { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("point")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +fn parse_float_from_str(s: &str, error_msg: &str) -> Result { + s.trim() + .parse() + .map_err(|_| Error::Decode(error_msg.into())) +} + +impl FromStr for PgPoint { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let (x_str, y_str) = s + .trim_matches(|c| c == '(' || c == ')' || c == ' ') + .split_once(',') + .ok_or_else(|| format!("error decoding POINT: could not get x and y from {}", s))?; + + let x = parse_float_from_str(x_str, "error decoding POINT: could not get x")?; + let y = parse_float_from_str(y_str, "error decoding POINT: could not get x")?; + + Ok(PgPoint { x, y }) + } +} + +impl PgPoint { + fn from_bytes(mut bytes: &[u8]) -> Result { + let x = bytes.get_f64(); + let y = bytes.get_f64(); + Ok(PgPoint { x, y }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { + buff.extend_from_slice(&self.x.to_be_bytes()); + buff.extend_from_slice(&self.y.to_be_bytes()); + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +#[cfg(test)] +mod point_tests { + + use std::str::FromStr; + + use super::PgPoint; + + const POINT_BYTES: &[u8] = &[ + 64, 0, 204, 204, 204, 204, 204, 205, 64, 20, 204, 204, 204, 204, 204, 205, + ]; + + #[test] + fn can_deserialise_point_type_bytes() { + let point = PgPoint::from_bytes(POINT_BYTES).unwrap(); + assert_eq!(point, PgPoint { x: 2.1, y: 5.2 }) + } + + #[test] + fn can_deserialise_point_type_str() { + let point = PgPoint::from_str("(2, 3)").unwrap(); + assert_eq!(point, PgPoint { x: 2., y: 3. }); + } + + #[test] + fn can_deserialise_point_type_str_float() { + let point = PgPoint::from_str("(2.5, 3.4)").unwrap(); + assert_eq!(point, PgPoint { x: 2.5, y: 3.4 }); + } + + #[test] + fn can_serialise_point_type() { + let point = PgPoint { x: 2.1, y: 5.2 }; + assert_eq!(point.serialize_to_vec(), POINT_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/geometry/polygon.rs b/sqlx-postgres/src/types/geometry/polygon.rs new file mode 100644 index 0000000000..500c9933e9 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/polygon.rs @@ -0,0 +1,363 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::{PgPoint, Type}; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use sqlx_core::Error; +use std::mem; +use std::str::FromStr; + +const BYTE_WIDTH: usize = mem::size_of::(); + +/// ## Postgres Geometric Polygon type +/// +/// Description: Polygon (similar to closed polygon) +/// Representation: `((x1,y1),...)` +/// +/// Polygons are represented by lists of points (the vertexes of the polygon). Polygons are very similar to closed paths; the essential semantic difference is that a polygon is considered to include the area within it, while a path is not. +/// An important implementation difference between polygons and paths is that the stored representation of a polygon includes its smallest bounding box. This speeds up certain search operations, although computing the bounding box adds overhead while constructing new polygons. +/// Values of type polygon are specified using any of the following syntaxes: +/// +/// ```text +/// ( ( x1 , y1 ) , ... , ( xn , yn ) ) +/// ( x1 , y1 ) , ... , ( xn , yn ) +/// ( x1 , y1 , ... , xn , yn ) +/// x1 , y1 , ... , xn , yn +/// ``` +/// +/// where the points are the end points of the line segments comprising the boundary of the polygon. +/// +/// Seeh ttps://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-POLYGON +#[derive(Debug, Clone, PartialEq)] +pub struct PgPolygon { + pub points: Vec, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +struct Header { + length: usize, +} + +impl Type for PgPolygon { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("polygon") + } +} + +impl PgHasArrayType for PgPolygon { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_polygon") + } +} + +impl<'r> Decode<'r, Postgres> for PgPolygon { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgPolygon::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgPolygon::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgPolygon { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("polygon")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgPolygon { + type Err = Error; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let parts = sanitised.split(',').collect::>(); + + let mut points = vec![]; + + if parts.len() % 2 != 0 { + return Err(Error::Decode( + format!("Unmatched pair in POLYGON: {}", s).into(), + )); + } + + for chunk in parts.chunks_exact(2) { + if let [x_str, y_str] = chunk { + let x = parse_float_from_str(x_str, "could not get x")?; + let y = parse_float_from_str(y_str, "could not get y")?; + + let point = PgPoint { x, y }; + points.push(point); + } + } + + if !points.is_empty() { + return Ok(PgPolygon { points }); + } + + Err(Error::Decode( + format!("could not get polygon from {}", s).into(), + )) + } +} + +impl PgPolygon { + fn header(&self) -> Header { + Header { + length: self.points.len(), + } + } + + fn from_bytes(mut bytes: &[u8]) -> Result { + let header = Header::try_read(&mut bytes)?; + + if bytes.len() != header.data_size() { + return Err(format!( + "expected {} bytes after header, got {}", + header.data_size(), + bytes.len() + ) + .into()); + } + + if bytes.len() % BYTE_WIDTH * 2 != 0 { + return Err(format!( + "data length not divisible by pairs of {BYTE_WIDTH}: {}", + bytes.len() + ) + .into()); + } + + let mut out_points = Vec::with_capacity(bytes.len() / (BYTE_WIDTH * 2)); + while bytes.has_remaining() { + let point = PgPoint { + x: bytes.get_f64(), + y: bytes.get_f64(), + }; + out_points.push(point) + } + Ok(PgPolygon { points: out_points }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { + let header = self.header(); + buff.reserve(header.data_size()); + header.try_write(buff)?; + + for point in &self.points { + buff.extend_from_slice(&point.x.to_be_bytes()); + buff.extend_from_slice(&point.y.to_be_bytes()); + } + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +impl Header { + const HEADER_WIDTH: usize = mem::size_of::() + mem::size_of::(); + + fn data_size(&self) -> usize { + self.length * BYTE_WIDTH * 2 + } + + fn try_read(buf: &mut &[u8]) -> Result { + if buf.len() < Self::HEADER_WIDTH { + return Err(format!( + "expected polygon data to contain at least {} bytes, got {}", + Self::HEADER_WIDTH, + buf.len() + )); + } + + let length = buf.get_i32(); + + let length = usize::try_from(length).ok().ok_or_else(|| { + format!( + "received polygon with length: {length}. Expected length between 0 and {}", + usize::MAX + ) + })?; + + Ok(Self { length }) + } + + fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let length = i32::try_from(self.length).map_err(|_| { + format!( + "polygon length exceeds allowed maximum ({} > {})", + self.length, + i32::MAX + ) + })?; + + buff.extend(length.to_be_bytes()); + + Ok(()) + } +} + +fn parse_float_from_str(s: &str, error_msg: &str) -> Result { + s.parse().map_err(|_| Error::Decode(error_msg.into())) +} + +#[cfg(test)] +mod polygon_tests { + + use std::str::FromStr; + + use crate::types::PgPoint; + + use super::PgPolygon; + + const POLYGON_BYTES: &[u8] = &[ + 0, 0, 0, 12, 192, 0, 0, 0, 0, 0, 0, 0, 192, 8, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, + 0, 192, 8, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 63, + 240, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, + 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 192, + 8, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 192, 8, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 191, + 240, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, + 0, 0, 0, + ]; + + #[test] + fn can_deserialise_polygon_type_bytes() { + let polygon = PgPolygon::from_bytes(POLYGON_BYTES).unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![ + PgPoint { x: -2., y: -3. }, + PgPoint { x: -1., y: -3. }, + PgPoint { x: -1., y: -1. }, + PgPoint { x: 1., y: 1. }, + PgPoint { x: 1., y: 3. }, + PgPoint { x: 2., y: 3. }, + PgPoint { x: 2., y: -3. }, + PgPoint { x: 1., y: -3. }, + PgPoint { x: 1., y: 0. }, + PgPoint { x: -1., y: 0. }, + PgPoint { x: -1., y: -2. }, + PgPoint { x: -2., y: -2. } + ] + } + ) + } + + #[test] + fn can_deserialise_polygon_type_str_first_syntax() { + let polygon = PgPolygon::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_polygon_type_str_second_syntax() { + let polygon = PgPolygon::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn cannot_deserialise_polygon_type_str_uneven_points_first_syntax() { + let input_str = "[( 1, 2), (3)]"; + let polygon = PgPolygon::from_str(input_str); + + assert!(polygon.is_err()); + + if let Err(err) = polygon { + assert_eq!( + err.to_string(), + format!("error occurred while decoding: Unmatched pair in POLYGON: {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_polygon_type_str_invalid_numbers() { + let input_str = "[( 1, 2), (2, three)]"; + let polygon = PgPolygon::from_str(input_str); + + assert!(polygon.is_err()); + + if let Err(err) = polygon { + assert_eq!( + err.to_string(), + format!("error occurred while decoding: could not get y") + ) + } + } + + #[test] + fn can_deserialise_polygon_type_str_third_syntax() { + let polygon = PgPolygon::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_polygon_type_str_fourth_syntax() { + let polygon = PgPolygon::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_polygon_type_str_float() { + let polygon = PgPolygon::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1.1, y: 2.2 }, PgPoint { x: 3.3, y: 4.4 }] + } + ); + } + + #[test] + fn can_serialise_polygon_type() { + let polygon = PgPolygon { + points: vec![ + PgPoint { x: -2., y: -3. }, + PgPoint { x: -1., y: -3. }, + PgPoint { x: -1., y: -1. }, + PgPoint { x: 1., y: 1. }, + PgPoint { x: 1., y: 3. }, + PgPoint { x: 2., y: 3. }, + PgPoint { x: 2., y: -3. }, + PgPoint { x: 1., y: -3. }, + PgPoint { x: 1., y: 0. }, + PgPoint { x: -1., y: 0. }, + PgPoint { x: -1., y: -2. }, + PgPoint { x: -2., y: -2. }, + ], + }; + assert_eq!(polygon.serialize_to_vec(), POLYGON_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index 846f1b731d..7dd25cb272 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -21,6 +21,13 @@ //! | [`PgLQuery`] | LQUERY | //! | [`PgCiText`] | CITEXT1 | //! | [`PgCube`] | CUBE | +//! | [`PgPoint] | POINT | +//! | [`PgLine`] | LINE | +//! | [`PgLSeg`] | LSEG | +//! | [`PgBox`] | BOX | +//! | [`PgPath`] | PATH | +//! | [`PgPolygon`] | POLYGON | +//! | [`PgCircle`] | CIRCLE | //! | [`PgHstore`] | HSTORE | //! //! 1 SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc., @@ -212,6 +219,8 @@ mod bigdecimal; mod cube; +mod geometry; + #[cfg(any(feature = "bigdecimal", feature = "rust_decimal"))] mod numeric; @@ -242,6 +251,13 @@ mod bit_vec; pub use array::PgHasArrayType; pub use citext::PgCiText; pub use cube::PgCube; +pub use geometry::circle::PgCircle; +pub use geometry::line::PgLine; +pub use geometry::line_segment::PgLSeg; +pub use geometry::path::PgPath; +pub use geometry::point::PgPoint; +pub use geometry::polygon::PgPolygon; +pub use geometry::r#box::PgBox; pub use hstore::PgHstore; pub use interval::PgInterval; pub use lquery::PgLQuery; diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index 5ba0f6323f..304492e695 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -51,6 +51,12 @@ macro_rules! test_type { } }; + ($name:ident<$ty:ty>($db:ident, $($text:literal @= $value:expr),+ $(,)?)) => { + paste::item! { + $crate::__test_prepared_type!($name<$ty>($db, $crate::[< $db _query_for_test_prepared_geometric_type >]!(), $($text == $value),+)); + } + }; + ($name:ident($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { $crate::test_type!($name<$name>($db, $($text == $value),+)); }; @@ -82,6 +88,7 @@ macro_rules! test_prepared_type { } }; + ($name:ident($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { $crate::__test_prepared_type!($name<$name>($db, $($text == $value),+)); }; @@ -223,3 +230,10 @@ macro_rules! Postgres_query_for_test_prepared_type { "SELECT ({0} is not distinct from $1)::int4, {0}, $2" }; } + +#[macro_export] +macro_rules! Postgres_query_for_test_prepared_geometric_type { + () => { + "SELECT ({0}::text is not distinct from $1::text)::int4, {0}, $2" + }; +} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 4912339dc2..fa1ab08d5e 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -3,7 +3,7 @@ extern crate time_ as time; use std::net::SocketAddr; use std::ops::Bound; -use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange}; +use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgPoint, PgRange}; use sqlx::postgres::Postgres; use sqlx_test::{new, test_decode_type, test_prepared_type, test_type}; @@ -492,6 +492,81 @@ test_type!(_cube>(Postgres, "array[cube(2.2,-3.4)]" == vec![sqlx::postgres::types::PgCube::OneDimensionInterval(2.2, -3.4)], )); +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(point(Postgres, + "point(2.2,-3.4)" @= sqlx::postgres::types::PgPoint { x: 2.2, y:-3.4 }, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(_point>(Postgres, + "array[point(2,3),point(2.1,3.4)]" @= vec![sqlx::postgres::types::PgPoint { x:2., y: 3. }, sqlx::postgres::types::PgPoint { x:2.1, y: 3.4 }], + "array[point(2.2,-3.4)]" @= vec![sqlx::postgres::types::PgPoint { x: 2.2, y: -3.4 }], +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(line(Postgres, + "line('{1.1, -2.2, 3.3}')" @= sqlx::postgres::types::PgLine { a: 1.1, b:-2.2, c: 3.3 }, + "line('((0.0, 0.0), (1.0,1.0))')" @= sqlx::postgres::types::PgLine { a: 1., b: -1., c: 0. }, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(_line>(Postgres, + "array[line('{1,2,3}'),line('{1.1, 2.2, 3.3}')]" @= vec![sqlx::postgres::types::PgLine { a:1., b: 2., c: 3. }, sqlx::postgres::types::PgLine { a:1.1, b: 2.2, c: 3.3 }], +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(lseg(Postgres, + "lseg('((1.0, 2.0), (3.0,4.0))')" @= sqlx::postgres::types::PgLSeg { start_x: 1., start_y: 2., end_x: 3. , end_y: 4.}, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(_lseg>(Postgres, + "array[lseg('(1,2,3,4)'),lseg('[(1.1, 2.2), (3.3, 4.4)]')]" @= vec![sqlx::postgres::types::PgLSeg { start_x: 1., start_y: 2., end_x: 3., end_y: 4. }, sqlx::postgres::types::PgLSeg { start_x: 1.1, start_y: 2.2, end_x: 3.3, end_y: 4.4 }], +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(box(Postgres, + "box('((1.0, 2.0), (3.0,4.0))')" @= sqlx::postgres::types::PgBox { upper_right_x: 3., upper_right_y: 4., lower_left_x: 1. , lower_left_y: 2.}, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(_box>(Postgres, + "array[box('1,2,3,4'),box('((1.1, 2.2), (3.3, 4.4))')]" @= vec![sqlx::postgres::types::PgBox { upper_right_x: 3., upper_right_y: 4., lower_left_x: 1., lower_left_y: 2. }, sqlx::postgres::types::PgBox { upper_right_x: 3.3, upper_right_y: 4.4, lower_left_x: 1.1, lower_left_y: 2.2 }], +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(path(Postgres, + "path('((1.0, 2.0), (3.0,4.0))')" @= sqlx::postgres::types::PgPath { closed: true, points: vec![ PgPoint { x: 1., y: 2. }, PgPoint { x: 3. , y: 4. } ]}, + "path('[(1.0, 2.0), (3.0,4.0)]')" @= sqlx::postgres::types::PgPath { closed: false, points: vec![ PgPoint { x: 1., y: 2. }, PgPoint { x: 3. , y: 4. } ]}, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(_path>(Postgres, + "array[path('(1,2),(3,4)'),path('[(1.1, 2.2), (3.3, 4.4)]')]" @= vec![sqlx::postgres::types::PgPath { closed: true, points: vec![ PgPoint { x: 1., y: 2. }, PgPoint { x: 3. , y: 4. } ]}, sqlx::postgres::types::PgPath { closed: false, points: vec![ PgPoint { x: 1.1, y: 2.2 }, PgPoint { x: 3.3 , y: 4.4 } ]},], +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(polygon(Postgres, + "polygon('((-2,-3),(-1,-3),(-1,-1),(1,1),(1,3),(2,3),(2,-3),(1,-3),(1,0),(-1,0),(-1,-2),(-2,-2))')" @= sqlx::postgres::types::PgPolygon { points: vec![ + PgPoint { x: -2., y: -3. }, PgPoint { x: -1., y: -3. }, PgPoint { x: -1., y: -1. }, PgPoint { x: 1., y: 1. }, + PgPoint { x: 1., y: 3. }, PgPoint { x: 2., y: 3. }, PgPoint { x: 2., y: -3. }, PgPoint { x: 1., y: -3. }, + PgPoint { x: 1., y: 0. }, PgPoint { x: -1., y: 0. }, PgPoint { x: -1., y: -2. }, PgPoint { x: -2., y: -2. }, + ]}, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(circle(Postgres, + "circle('<(1.1, -2.2), 3.3>')" @= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, + "circle('((1.1, -2.2), 3.3)')" @= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, + "circle('(1.1, -2.2), 3.3')" @= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, + "circle('1.1, -2.2, 3.3')" @= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(_circle>(Postgres, + "array[circle('<(1,2),3>'),circle('(1.1, 2.2), 3.3')]" @= vec![sqlx::postgres::types::PgCircle { x: 1., y: 2., radius: 3. }, sqlx::postgres::types::PgCircle { x: 1.1, y: 2.2, radius: 3.3 }], +)); + #[cfg(feature = "rust_decimal")] test_type!(decimal(Postgres, "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(),