Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(awc): allow to retrieve request head in client response #3535

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions actix-http/src/h1/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,16 @@ impl Decoder for ClientPayloadCodec {
}
}

impl Encoder<Message<(RequestHeadType, BodySize)>> for ClientCodec {
impl Encoder<Message<(&mut RequestHeadType, BodySize)>> for ClientCodec {
type Error = io::Error;

fn encode(
&mut self,
item: Message<(RequestHeadType, BodySize)>,
item: Message<(&mut RequestHeadType, BodySize)>,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
match item {
Message::Item((mut head, length)) => {
Message::Item((head, length)) => {
let inner = &mut self.inner;
inner.version = head.as_ref().version;
inner
Expand All @@ -219,7 +219,7 @@ impl Encoder<Message<(RequestHeadType, BodySize)>> for ClientCodec {

inner.encoder.encode(
dst,
&mut head,
head,
false,
false,
inner.version,
Expand Down
1 change: 1 addition & 0 deletions awc/CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Update `brotli` dependency to `7`.
- Prevent panics on connection pool drop when Tokio runtime is shutdown early.
- Minimum supported Rust version (MSRV) is now 1.75.
- Allow to retrieve request head used to send the http request on `ClientResponse`

## 3.5.1

Expand Down
19 changes: 13 additions & 6 deletions awc/src/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ where
self,
head: H,
body: RB,
) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>
) -> LocalBoxFuture<'static, Result<(RequestHeadType, ResponseHead, Payload), SendRequestError>>
where
H: Into<RequestHeadType> + 'static,
RB: MessageBody + 'static,
Expand Down Expand Up @@ -273,17 +273,24 @@ where
head: H,
) -> LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<Connection<A, B>, ClientCodec>), SendRequestError>,
Result<
(
RequestHeadType,
ResponseHead,
Framed<Connection<A, B>, ClientCodec>,
),
SendRequestError,
>,
> {
Box::pin(async move {
match self {
Connection::Tcp(ConnectionType::H1(ref _conn)) => {
let (head, framed) = h1proto::open_tunnel(self, head.into()).await?;
Ok((head, framed))
let (head, res_head, framed) = h1proto::open_tunnel(self, head.into()).await?;
Ok((head, res_head, framed))
}
Connection::Tls(ConnectionType::H1(ref _conn)) => {
let (head, framed) = h1proto::open_tunnel(self, head.into()).await?;
Ok((head, framed))
let (head, res_head, framed) = h1proto::open_tunnel(self, head.into()).await?;
Ok((head, res_head, framed))
}
Connection::Tls(ConnectionType::H2(mut conn)) => {
conn.release();
Expand Down
21 changes: 10 additions & 11 deletions awc/src/client/h1proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub(crate) async fn send_request<Io, B>(
io: H1Connection<Io>,
mut head: RequestHeadType,
body: B,
) -> Result<(ResponseHead, Payload), SendRequestError>
) -> Result<(RequestHeadType, ResponseHead, Payload), SendRequestError>
where
Io: ConnectionIo,
B: MessageBody,
Expand Down Expand Up @@ -86,7 +86,7 @@ where

// special handle for EXPECT request.
let (do_send, mut res_head) = if is_expect {
pin_framed.send((head, body.size()).into()).await?;
pin_framed.send((&mut head, body.size()).into()).await?;

let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx))
.await
Expand All @@ -96,7 +96,7 @@ where
// and current head would be used as final response head.
(head.status == StatusCode::CONTINUE, Some(head))
} else {
pin_framed.feed((head, body.size()).into()).await?;
pin_framed.feed((&mut head, body.size()).into()).await?;

(true, None)
};
Expand All @@ -118,17 +118,16 @@ where
res_head = Some(head);
}

let head = res_head.unwrap();

match pin_framed.codec_ref().message_type() {
h1::MessageType::None => {
let keep_alive = pin_framed.codec_ref().keep_alive();
pin_framed.io_mut().on_release(keep_alive);

Ok((head, Payload::None))
Ok((head, res_head.unwrap(), Payload::None))
}
_ => Ok((
head,
res_head.unwrap(),
Payload::Stream {
payload: Box::pin(PlStream::new(framed)),
},
Expand All @@ -138,21 +137,21 @@ where

pub(crate) async fn open_tunnel<Io>(
io: Io,
head: RequestHeadType,
) -> Result<(ResponseHead, Framed<Io, h1::ClientCodec>), SendRequestError>
mut head: RequestHeadType,
) -> Result<(RequestHeadType, ResponseHead, Framed<Io, h1::ClientCodec>), SendRequestError>
where
Io: ConnectionIo,
{
// create Framed and send request.
let mut framed = Framed::new(io, h1::ClientCodec::default());
framed.send((head, BodySize::None).into()).await?;
framed.send((&mut head, BodySize::None).into()).await?;

// read response head.
let head = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx))
let res_head = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx))
.await
.ok_or(ConnectError::Disconnected)??;

Ok((head, framed))
Ok((head, res_head, framed))
}

/// send request body to the peer
Expand Down
10 changes: 5 additions & 5 deletions awc/src/client/h2proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub(crate) async fn send_request<Io, B>(
mut io: H2Connection<Io>,
head: RequestHeadType,
body: B,
) -> Result<(ResponseHead, Payload), SendRequestError>
) -> Result<(RequestHeadType, ResponseHead, Payload), SendRequestError>
where
Io: ConnectionIo,
B: MessageBody,
Expand Down Expand Up @@ -129,10 +129,10 @@ where
let (parts, body) = resp.into_parts();
let payload = if head_req { Payload::None } else { body.into() };

let mut head = ResponseHead::new(parts.status);
head.version = parts.version;
head.headers = parts.headers.into();
Ok((head, payload))
let mut res_head = ResponseHead::new(parts.status);
res_head.version = parts.version;
res_head.headers = parts.headers.into();
Ok((head, res_head, payload))
}

async fn send_body<B>(body: B, mut send: SendStream<Bytes>) -> Result<(), SendRequestError>
Expand Down
28 changes: 19 additions & 9 deletions awc/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ pub enum ConnectResponse {
/// Tunnel used for WebSocket communication.
///
/// Contains response head and framed HTTP/1.1 codec.
Tunnel(ResponseHead, Framed<BoxedSocket, ClientCodec>),
Tunnel(
RequestHeadType,
ResponseHead,
Framed<BoxedSocket, ClientCodec>,
),
}

impl ConnectResponse {
Expand All @@ -70,9 +74,15 @@ impl ConnectResponse {
///
/// # Panics
/// Panics if enum variant is not `Tunnel`.
pub fn into_tunnel_response(self) -> (ResponseHead, Framed<BoxedSocket, ClientCodec>) {
pub fn into_tunnel_response(
self,
) -> (
RequestHeadType,
ResponseHead,
Framed<BoxedSocket, ClientCodec>,
) {
match self {
ConnectResponse::Tunnel(head, framed) => (head, framed),
ConnectResponse::Tunnel(req, head, framed) => (req, head, framed),
_ => {
panic!("TunnelResponse only reachable with ConnectResponse::TunnelResponse variant")
}
Expand Down Expand Up @@ -133,12 +143,12 @@ pin_project_lite::pin_project! {
req: Option<ConnectRequest>
},
Client {
fut: LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>
fut: LocalBoxFuture<'static, Result<(RequestHeadType, ResponseHead, Payload), SendRequestError>>
},
Tunnel {
fut: LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<Connection<Io>, ClientCodec>), SendRequestError>,
Result<(RequestHeadType, ResponseHead, Framed<Connection<Io>, ClientCodec>), SendRequestError>,
>,
}
}
Expand Down Expand Up @@ -181,16 +191,16 @@ where
}

ConnectRequestProj::Client { fut } => {
let (head, payload) = ready!(fut.as_mut().poll(cx))?;
let (req, head, payload) = ready!(fut.as_mut().poll(cx))?;
Poll::Ready(Ok(ConnectResponse::Client(ClientResponse::new(
head, payload,
req, head, payload,
))))
}

ConnectRequestProj::Tunnel { fut } => {
let (head, framed) = ready!(fut.as_mut().poll(cx))?;
let (req, head, framed) = ready!(fut.as_mut().poll(cx))?;
let framed = framed.into_map_io(|io| Box::new(io) as _);
Poll::Ready(Ok(ConnectResponse::Tunnel(head, framed)))
Poll::Ready(Ok(ConnectResponse::Tunnel(req, head, framed)))
}
}
}
Expand Down
1 change: 1 addition & 0 deletions awc/src/middleware/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ mod tests {
let res = client.get(srv.url("/")).send().await.unwrap();

assert_eq!(res.status().as_u16(), 400);
assert_eq!(res.req_head().uri.path(), "/test");
}

#[actix_rt::test]
Expand Down
14 changes: 12 additions & 2 deletions awc/src/responses/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{

use actix_http::{
error::PayloadError, header::HeaderMap, BoxedPayloadStream, Extensions, HttpMessage, Payload,
ResponseHead, StatusCode, Version,
RequestHead, RequestHeadType, ResponseHead, StatusCode, Version,
};
use actix_rt::time::{sleep, Sleep};
use bytes::Bytes;
Expand All @@ -23,6 +23,7 @@ use crate::cookie::{Cookie, ParseError as CookieParseError};
pin_project! {
/// Client Response
pub struct ClientResponse<S = BoxedPayloadStream> {
pub(crate) req_head: RequestHeadType,
pub(crate) head: ResponseHead,
#[pin]
pub(crate) payload: Payload<S>,
Expand All @@ -34,15 +35,22 @@ pin_project! {

impl<S> ClientResponse<S> {
/// Create new Request instance
pub(crate) fn new(head: ResponseHead, payload: Payload<S>) -> Self {
pub(crate) fn new(req_head: RequestHeadType, head: ResponseHead, payload: Payload<S>) -> Self {
ClientResponse {
req_head,
head,
payload,
timeout: ResponseTimeout::default(),
extensions: RefCell::new(Extensions::new()),
}
}

/// Returns the request head used to send the request.
#[inline]
pub fn req_head(&self) -> &RequestHead {
self.req_head.as_ref()
}

#[inline]
pub(crate) fn head(&self) -> &ResponseHead {
&self.head
Expand Down Expand Up @@ -77,6 +85,7 @@ impl<S> ClientResponse<S> {

ClientResponse {
payload,
req_head: self.req_head,
head: self.head,
timeout: self.timeout,
extensions: self.extensions,
Expand Down Expand Up @@ -105,6 +114,7 @@ impl<S> ClientResponse<S> {
Self {
payload: self.payload,
head: self.head,
req_head: self.req_head,
timeout,
extensions: self.extensions,
}
Expand Down
10 changes: 7 additions & 3 deletions awc/src/test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! Test helpers for actix http client to use during testing.

use actix_http::{h1, header::TryIntoHeaderPair, Payload, ResponseHead, StatusCode, Version};
use actix_http::{
h1, header::TryIntoHeaderPair, Payload, RequestHead, ResponseHead, StatusCode, Version,
};
use bytes::Bytes;

#[cfg(feature = "cookies")]
Expand All @@ -9,6 +11,7 @@ use crate::ClientResponse;

/// Test `ClientResponse` builder
pub struct TestResponse {
req_head: RequestHead,
head: ResponseHead,
#[cfg(feature = "cookies")]
cookies: CookieJar,
Expand All @@ -18,6 +21,7 @@ pub struct TestResponse {
impl Default for TestResponse {
fn default() -> TestResponse {
TestResponse {
req_head: RequestHead::default(),
head: ResponseHead::new(StatusCode::OK),
#[cfg(feature = "cookies")]
cookies: CookieJar::new(),
Expand Down Expand Up @@ -88,10 +92,10 @@ impl TestResponse {
}

if let Some(pl) = self.payload {
ClientResponse::new(head, pl)
ClientResponse::new(self.req_head.into(), head, pl)
} else {
let (_, payload) = h1::Payload::create(true);
ClientResponse::new(head, payload.into())
ClientResponse::new(self.req_head.into(), head, payload.into())
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions awc/src/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ impl WebsocketsRequest {
fut.await?
};

let (head, framed) = res.into_tunnel_response();
let (req_head, head, framed) = res.into_tunnel_response();

// verify response
if head.status != StatusCode::SWITCHING_PROTOCOLS {
Expand Down Expand Up @@ -411,7 +411,7 @@ impl WebsocketsRequest {

// response and ws framed
Ok((
ClientResponse::new(head, Payload::None),
ClientResponse::new(req_head, head, Payload::None),
framed.into_map_codec(|_| {
if server_mode {
ws::Codec::new().max_size(max_size)
Expand Down
Loading