diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 7092b65f03c2..25e8813625d8 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -44,10 +44,13 @@ use super::json::pg_text_row_to_json; use super::SERVERLESS_DRIVER_SNI; #[derive(serde::Deserialize)] +#[serde(rename_all = "camelCase")] struct QueryData { query: String, #[serde(deserialize_with = "bytes_to_pg_text")] params: Vec>, + #[serde(default)] + array_mode: Option, } #[derive(serde::Deserialize)] @@ -330,7 +333,7 @@ async fn handle_inner( // Determine the output options. Default behaviour is 'false'. Anything that is not // strictly 'true' assumed to be false. let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE); - let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE); + let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE); // Allow connection pooling only if explicitly requested // or if we have decided that http pool is no longer opt-in @@ -402,83 +405,87 @@ async fn handle_inner( // Now execute the query and return the result // let mut size = 0; - let result = - match payload { - Payload::Single(stmt) => { - let (status, results) = - query_to_json(&*client, stmt, &mut 0, raw_output, array_mode) - .await - .map_err(|e| { - client.discard(); - e - })?; - client.check_idle(status); - results + let result = match payload { + Payload::Single(stmt) => { + let (status, results) = + query_to_json(&*client, stmt, &mut 0, raw_output, default_array_mode) + .await + .map_err(|e| { + client.discard(); + e + })?; + client.check_idle(status); + results + } + Payload::Batch(statements) => { + let (inner, mut discard) = client.inner(); + let mut builder = inner.build_transaction(); + if let Some(isolation_level) = txn_isolation_level { + builder = builder.isolation_level(isolation_level); + } + if txn_read_only { + builder = builder.read_only(true); + } + if txn_deferrable { + builder = builder.deferrable(true); } - Payload::Batch(statements) => { - let (inner, mut discard) = client.inner(); - let mut builder = inner.build_transaction(); - if let Some(isolation_level) = txn_isolation_level { - builder = builder.isolation_level(isolation_level); - } - if txn_read_only { - builder = builder.read_only(true); - } - if txn_deferrable { - builder = builder.deferrable(true); - } - let transaction = builder.start().await.map_err(|e| { - // if we cannot start a transaction, we should return immediately - // and not return to the pool. connection is clearly broken - discard.discard(); - e - })?; - - let results = - match query_batch(&transaction, statements, &mut size, raw_output, array_mode) - .await - { - Ok(results) => { - let status = transaction.commit().await.map_err(|e| { - // if we cannot commit - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - e - })?; - discard.check_idle(status); - results - } - Err(err) => { - let status = transaction.rollback().await.map_err(|e| { - // if we cannot rollback - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - e - })?; - discard.check_idle(status); - return Err(err); - } - }; - - if txn_read_only { - response = response.header( - TXN_READ_ONLY.clone(), - HeaderValue::try_from(txn_read_only.to_string())?, - ); - } - if txn_deferrable { - response = response.header( - TXN_DEFERRABLE.clone(), - HeaderValue::try_from(txn_deferrable.to_string())?, - ); + let transaction = builder.start().await.map_err(|e| { + // if we cannot start a transaction, we should return immediately + // and not return to the pool. connection is clearly broken + discard.discard(); + e + })?; + + let results = match query_batch( + &transaction, + statements, + &mut size, + raw_output, + default_array_mode, + ) + .await + { + Ok(results) => { + let status = transaction.commit().await.map_err(|e| { + // if we cannot commit - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + e + })?; + discard.check_idle(status); + results } - if let Some(txn_isolation_level) = txn_isolation_level_raw { - response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); + Err(err) => { + let status = transaction.rollback().await.map_err(|e| { + // if we cannot rollback - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + e + })?; + discard.check_idle(status); + return Err(err); } - json!({ "results": results }) + }; + + if txn_read_only { + response = response.header( + TXN_READ_ONLY.clone(), + HeaderValue::try_from(txn_read_only.to_string())?, + ); + } + if txn_deferrable { + response = response.header( + TXN_DEFERRABLE.clone(), + HeaderValue::try_from(txn_deferrable.to_string())?, + ); + } + if let Some(txn_isolation_level) = txn_isolation_level_raw { + response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); } - }; + json!({ "results": results }) + } + }; ctx.set_success(); ctx.log(); @@ -524,7 +531,7 @@ async fn query_to_json( data: QueryData, current_size: &mut usize, raw_output: bool, - array_mode: bool, + default_array_mode: bool, ) -> anyhow::Result<(ReadyForQueryStatus, Value)> { let query_params = data.params; let row_stream = client.query_raw_txt(&data.query, query_params).await?; @@ -578,6 +585,8 @@ async fn query_to_json( columns.push(client.get_type(c.type_oid()).await?); } + let array_mode = data.array_mode.unwrap_or(default_array_mode); + // convert rows to JSON let rows = rows .iter() diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index b3b35e446d6e..49a0450f0cd0 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -390,6 +390,39 @@ def qq( assert result[0]["rows"] == [{"answer": 42}] +def test_sql_over_http_batch_output_options(static_proxy: NeonProxy): + static_proxy.safe_psql("create role http with login password 'http' superuser") + + connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres" + response = requests.post( + f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql", + data=json.dumps( + { + "queries": [ + {"query": "select $1 as answer", "params": [42], "arrayMode": True}, + {"query": "select $1 as answer", "params": [42], "arrayMode": False}, + ] + } + ), + headers={ + "Content-Type": "application/sql", + "Neon-Connection-String": connstr, + "Neon-Batch-Isolation-Level": "Serializable", + "Neon-Batch-Read-Only": "false", + "Neon-Batch-Deferrable": "false", + }, + verify=str(static_proxy.test_output_dir / "proxy.crt"), + ) + assert response.status_code == 200 + results = response.json()["results"] + + assert results[0]["rowAsArray"] + assert results[0]["rows"] == [["42"]] + + assert not results[1]["rowAsArray"] + assert results[1]["rows"] == [{"answer": "42"}] + + def test_sql_over_http_pool(static_proxy: NeonProxy): static_proxy.safe_psql("create user http_auth with password 'http' superuser")