diff --git a/Cargo.lock b/Cargo.lock index b1894de..8f7a016 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1291,7 +1291,6 @@ name = "twirp-build" version = "0.8.0" dependencies = [ "prettyplease", - "proc-macro2", "prost-build", "quote", "syn", diff --git a/crates/twirp-build/Cargo.toml b/crates/twirp-build/Cargo.toml index 908c318..900843e 100644 --- a/crates/twirp-build/Cargo.toml +++ b/crates/twirp-build/Cargo.toml @@ -18,4 +18,3 @@ prost-build = "0.13" prettyplease = { version = "0.2" } quote = "1.0" syn = "2.0" -proc-macro2 = "1.0" diff --git a/crates/twirp-build/src/lib.rs b/crates/twirp-build/src/lib.rs index 0540c67..78216aa 100644 --- a/crates/twirp-build/src/lib.rs +++ b/crates/twirp-build/src/lib.rs @@ -166,19 +166,25 @@ impl prost_build::ServiceGenerator for ServiceGenerator { let mut client_methods = Vec::with_capacity(service.methods.len()); for m in &service.methods { let name = &m.name; + let name_request = format_ident!("{}_request", name); let input_type = &m.input_type; let output_type = &m.output_type; let request_path = format!("{}/{}", service.fqn, m.proto_name); client_trait_methods.push(quote! { - async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError>; + async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError> { + self.#name_request(req)?.send().await + } + }); + client_trait_methods.push(quote! { + fn #name_request(&self, req: #input_type) -> Result, twirp::ClientError>; }); client_methods.push(quote! { - async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError> { - self.request(#request_path, req).await + fn #name_request(&self, req: #input_type) -> Result, twirp::ClientError> { + self.request(#request_path, req) } - }) + }); } let client_trait = quote! { #[twirp::async_trait::async_trait] diff --git a/crates/twirp/src/client.rs b/crates/twirp/src/client.rs index 5f8ac5b..c38176b 100644 --- a/crates/twirp/src/client.rs +++ b/crates/twirp/src/client.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use std::vec; use async_trait::async_trait; +use http::{HeaderName, HeaderValue}; use reqwest::header::{InvalidHeaderValue, CONTENT_TYPE}; use reqwest::StatusCode; use thiserror::Error; @@ -155,23 +156,12 @@ impl Client { } } - /// Make an HTTP twirp request. - pub async fn request(&self, path: &str, body: I) -> Result + /// Executes a `Request`. + pub(super) async fn execute(&self, req: reqwest::Request) -> Result where - I: prost::Message, O: prost::Message + Default, { - let mut url = self.inner.base_url.join(path)?; - if let Some(host) = &self.host { - url.set_host(Some(host))? - }; - let path = url.path().to_string(); - let req = self - .http_client - .post(url) - .header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF) - .body(serialize_proto_message(body)) - .build()?; + let path = req.url().path().to_string(); // Create and execute the middleware handlers let next = Next::new(&self.http_client, &self.inner.middlewares); @@ -204,6 +194,68 @@ impl Client { }), } } + + /// Start building a `Request` with a path and a request body. + /// + /// Returns a `RequestBuilder`, which will allow setting headers before sending. + pub fn request(&self, path: &str, body: I) -> Result> + where + I: prost::Message, + O: prost::Message + Default, + { + let mut url = self.inner.base_url.join(path)?; + if let Some(host) = &self.host { + url.set_host(Some(host))? + }; + + let req = self + .http_client + .post(url) + .header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF) + .body(serialize_proto_message(body)); + Ok(RequestBuilder::new(self.clone(), req)) + } +} + +pub struct RequestBuilder +where + O: prost::Message + Default, +{ + client: Client, + inner: reqwest::RequestBuilder, + _input: std::marker::PhantomData, + _output: std::marker::PhantomData, +} + +impl RequestBuilder +where + O: prost::Message + Default, +{ + pub fn new(client: Client, inner: reqwest::RequestBuilder) -> Self { + Self { + client, + inner, + _input: std::marker::PhantomData, + _output: std::marker::PhantomData, + } + } + + /// Add a `Header` to this Request. + pub fn header(mut self, key: K, value: V) -> RequestBuilder + where + HeaderName: TryFrom, + >::Error: Into, + HeaderValue: TryFrom, + >::Error: Into, + { + self.inner = self.inner.header(key, value); + self + } + + pub async fn send(self) -> Result { + let req = self.inner.build()?; + self.client.execute(req).await + } } // This concept of reqwest middleware is taken pretty much directly from: diff --git a/crates/twirp/src/lib.rs b/crates/twirp/src/lib.rs index 5b66b2b..6cbbb52 100644 --- a/crates/twirp/src/lib.rs +++ b/crates/twirp/src/lib.rs @@ -10,7 +10,7 @@ pub mod test; #[doc(hidden)] pub mod details; -pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result}; +pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, RequestBuilder, Result}; pub use context::Context; pub use error::*; // many constructors like `invalid_argument()` pub use http::Extensions; diff --git a/crates/twirp/src/test.rs b/crates/twirp/src/test.rs index e80effd..7489a76 100644 --- a/crates/twirp/src/test.rs +++ b/crates/twirp/src/test.rs @@ -121,7 +121,7 @@ pub trait TestApiClient { #[async_trait] impl TestApiClient for Client { async fn ping(&self, req: PingRequest) -> Result { - self.request("test.TestAPI/Ping", req).await + self.request("test.TestAPI/Ping", req)?.send().await } async fn boom(&self, _req: PingRequest) -> Result { diff --git a/example/src/bin/client.rs b/example/src/bin/client.rs index 89c6e71..51b132a 100644 --- a/example/src/bin/client.rs +++ b/example/src/bin/client.rs @@ -38,6 +38,14 @@ pub async fn main() -> Result<(), GenericError> { .await; eprintln!("{:?}", resp); + let resp = client + .with_host("localhost") + .make_hat_request(MakeHatRequest { inches: 1 })? + .header("x-custom-header", "a") + .send() + .await?; + eprintln!("{:?}", resp); + Ok(()) } @@ -69,23 +77,39 @@ impl Middleware for PrintResponseHeaders { } } +// NOTE: This is just to demonstrate manually implementing the client trait. You don't need to do this as this code will +// be generated for you by twirp-build. +// +// This is here so that we can visualize changes to the generated client code #[allow(dead_code)] #[derive(Debug)] struct MockHaberdasherApiClient; #[async_trait] impl HaberdasherApiClient for MockHaberdasherApiClient { - async fn make_hat( + fn make_hat_request( &self, _req: MakeHatRequest, - ) -> Result { + ) -> Result, twirp::ClientError> { + todo!() + } + // implementing this one is optional + async fn make_hat(&self, _req: MakeHatRequest) -> Result { todo!() } + fn get_status_request( + &self, + _req: GetStatusRequest, + ) -> Result, twirp::ClientError> + { + todo!() + } + // implementing this one is optional async fn get_status( &self, _req: GetStatusRequest, - ) -> Result { + ) -> Result { todo!() } } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy