Skip to content

Commit b33e3a5

Browse files
authored
Merge pull request #153 from github/jorendorff/error-param
Make the error type configurable
2 parents e087e60 + 84c892b commit b33e3a5

File tree

10 files changed

+268
-27
lines changed

10 files changed

+268
-27
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Add the `twirp-build` crate as a build dependency in your `Cargo.toml` (you'll n
2828
```toml
2929
# Cargo.toml
3030
[build-dependencies]
31-
twirp-build = "0.3"
31+
twirp-build = "0.7"
3232
prost-build = "0.13"
3333
```
3434

@@ -83,6 +83,8 @@ struct HaberdasherApiServer;
8383

8484
#[async_trait]
8585
impl haberdash::HaberdasherApi for HaberdasherApiServer {
86+
type Error = TwirpErrorResponse;
87+
8688
async fn make_hat(&self, ctx: twirp::Context, req: MakeHatRequest) -> Result<MakeHatResponse, TwirpErrorResponse> {
8789
todo!()
8890
}

crates/twirp-build/src/lib.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
2828
//
2929
writeln!(buf, "#[twirp::async_trait::async_trait]").unwrap();
3030
writeln!(buf, "pub trait {} {{", service_name).unwrap();
31+
writeln!(buf, " type Error;").unwrap();
3132
for m in &service.methods {
3233
writeln!(
3334
buf,
34-
" async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, twirp::TwirpErrorResponse>;",
35+
" async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, Self::Error>;",
3536
m.name, m.input_type, m.output_type,
3637
)
3738
.unwrap();
@@ -43,10 +44,11 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
4344
writeln!(buf, "where").unwrap();
4445
writeln!(buf, " T: {service_name} + Sync + Send").unwrap();
4546
writeln!(buf, "{{").unwrap();
47+
writeln!(buf, " type Error = T::Error;\n").unwrap();
4648
for m in &service.methods {
4749
writeln!(
4850
buf,
49-
" async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, twirp::TwirpErrorResponse> {{",
51+
" async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, Self::Error> {{",
5052
m.name, m.input_type, m.output_type,
5153
)
5254
.unwrap();
@@ -61,6 +63,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
6163
r#"pub fn router<T>(api: T) -> twirp::Router
6264
where
6365
T: {service_name} + Clone + Send + Sync + 'static,
66+
<T as {service_name}>::Error: twirp::IntoTwirpResponse,
6467
{{
6568
twirp::details::TwirpRouterBuilder::new(api)"#,
6669
)

crates/twirp/src/client.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,9 @@ mod tests {
303303
}
304304

305305
#[tokio::test]
306-
#[ignore = "integration"]
307306
async fn test_standard_client() {
308-
let h = run_test_server(3001).await;
309-
let base_url = Url::parse("http://localhost:3001/twirp/").unwrap();
307+
let h = run_test_server(3002).await;
308+
let base_url = Url::parse("http://localhost:3002/twirp/").unwrap();
310309
let client = Client::from_base_url(base_url).unwrap();
311310
let resp = client
312311
.ping(PingRequest {

crates/twirp/src/details.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::future::Future;
55
use axum::extract::{Request, State};
66
use axum::Router;
77

8-
use crate::{server, Context, TwirpErrorResponse};
8+
use crate::{server, Context, IntoTwirpResponse};
99

1010
/// Builder object used by generated code to build a Twirp service.
1111
///
@@ -31,12 +31,13 @@ where
3131
///
3232
/// The generated code passes a closure that calls the method, like
3333
/// `|api: Arc<HaberdasherApiServer>, req: MakeHatRequest| async move { api.make_hat(req) }`.
34-
pub fn route<F, Fut, Req, Res>(self, url: &str, f: F) -> Self
34+
pub fn route<F, Fut, Req, Res, Err>(self, url: &str, f: F) -> Self
3535
where
3636
F: Fn(S, Context, Req) -> Fut + Clone + Sync + Send + 'static,
37-
Fut: Future<Output = Result<Res, TwirpErrorResponse>> + Send,
37+
Fut: Future<Output = Result<Res, Err>> + Send,
3838
Req: prost::Message + Default + serde::de::DeserializeOwned,
3939
Res: prost::Message + serde::Serialize,
40+
Err: IntoTwirpResponse,
4041
{
4142
TwirpRouterBuilder {
4243
service: self.service,

crates/twirp/src/error.rs

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,35 @@ use http::header::{self, HeaderMap, HeaderValue};
88
use hyper::{Response, StatusCode};
99
use serde::{Deserialize, Serialize, Serializer};
1010

11+
/// Trait for user-defined error types that can be converted to Twirp responses.
12+
pub trait IntoTwirpResponse {
13+
/// Generate a Twirp response. The return type is the `http::Response` type, with a
14+
/// [`TwirpErrorResponse`] as the body. The simplest way to implement this is:
15+
///
16+
/// ```
17+
/// use axum::body::Body;
18+
/// use http::Response;
19+
/// use twirp::{TwirpErrorResponse, IntoTwirpResponse};
20+
/// # struct MyError { message: String }
21+
///
22+
/// impl IntoTwirpResponse for MyError {
23+
/// fn into_twirp_response(self) -> Response<TwirpErrorResponse> {
24+
/// // Use TwirpErrorResponse to generate a valid starting point
25+
/// let mut response = twirp::invalid_argument(&self.message)
26+
/// .into_twirp_response();
27+
///
28+
/// // Customize the response as desired.
29+
/// response.headers_mut().insert("X-Server-Pid", std::process::id().into());
30+
/// response
31+
/// }
32+
/// }
33+
/// ```
34+
///
35+
/// The `Response` that `TwirpErrorResponse` generates can be used as a starting point,
36+
/// adding headers and extensions to it.
37+
fn into_twirp_response(self) -> Response<TwirpErrorResponse>;
38+
}
39+
1140
/// Alias for a generic error
1241
pub type GenericError = Box<dyn std::error::Error + Send + Sync>;
1342

@@ -152,20 +181,30 @@ impl TwirpErrorResponse {
152181
pub fn insert_meta(&mut self, key: String, value: String) -> Option<String> {
153182
self.meta.insert(key, value)
154183
}
184+
185+
pub fn into_axum_body(self) -> Body {
186+
let json =
187+
serde_json::to_string(&self).expect("JSON serialization of an error should not fail");
188+
Body::new(json)
189+
}
155190
}
156191

157-
impl IntoResponse for TwirpErrorResponse {
158-
fn into_response(self) -> Response<Body> {
192+
impl IntoTwirpResponse for TwirpErrorResponse {
193+
fn into_twirp_response(self) -> Response<TwirpErrorResponse> {
159194
let mut headers = HeaderMap::new();
160195
headers.insert(
161196
header::CONTENT_TYPE,
162197
HeaderValue::from_static("application/json"),
163198
);
164199

165-
let json =
166-
serde_json::to_string(&self).expect("JSON serialization of an error should not fail");
200+
let code = self.code.http_status_code();
201+
(code, headers).into_response().map(|_| self)
202+
}
203+
}
167204

168-
(self.code.http_status_code(), headers, json).into_response()
205+
impl IntoResponse for TwirpErrorResponse {
206+
fn into_response(self) -> Response<Body> {
207+
self.into_twirp_response().map(|err| err.into_axum_body())
169208
}
170209
}
171210

crates/twirp/src/server.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use serde::Serialize;
1717
use tokio::time::{Duration, Instant};
1818

1919
use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
20-
use crate::{error, serialize_proto_message, Context, GenericError, TwirpErrorResponse};
20+
use crate::{error, serialize_proto_message, Context, GenericError, IntoTwirpResponse};
2121

2222
// TODO: Properly implement JsonPb (de)serialization as it is slightly different
2323
// than standard JSON.
@@ -42,16 +42,17 @@ impl BodyFormat {
4242
}
4343

4444
/// Entry point used in code generated by `twirp-build`.
45-
pub(crate) async fn handle_request<S, F, Fut, Req, Resp>(
45+
pub(crate) async fn handle_request<S, F, Fut, Req, Resp, Err>(
4646
service: S,
4747
req: Request<Body>,
4848
f: F,
4949
) -> Response<Body>
5050
where
5151
F: FnOnce(S, Context, Req) -> Fut + Clone + Sync + Send + 'static,
52-
Fut: Future<Output = Result<Resp, TwirpErrorResponse>> + Send,
52+
Fut: Future<Output = Result<Resp, Err>> + Send,
5353
Req: prost::Message + Default + serde::de::DeserializeOwned,
5454
Resp: prost::Message + serde::Serialize,
55+
Err: IntoTwirpResponse,
5556
{
5657
let mut timings = req
5758
.extensions()
@@ -114,12 +115,13 @@ where
114115
Ok((request, parts.extensions, format))
115116
}
116117

117-
fn write_response<T>(
118-
response: Result<T, TwirpErrorResponse>,
118+
fn write_response<T, Err>(
119+
response: Result<T, Err>,
119120
response_format: BodyFormat,
120121
) -> Result<Response<Body>, GenericError>
121122
where
122123
T: prost::Message + Serialize,
124+
Err: IntoTwirpResponse,
123125
{
124126
let res = match response {
125127
Ok(response) => match response_format {
@@ -133,7 +135,7 @@ where
133135
.body(Body::from(data))?
134136
}
135137
},
136-
Err(err) => err.into_response(),
138+
Err(err) => err.into_twirp_response().map(|err| err.into_axum_body()),
137139
};
138140
Ok(res)
139141
}

example/Cargo.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,13 @@ prost-build = "0.13"
2121
prost-wkt-build = "0.6"
2222

2323
[[bin]]
24-
name = "example-client"
25-
path = "src/bin/example-client.rs"
24+
name = "client"
25+
path = "src/bin/client.rs"
26+
27+
[[bin]]
28+
name = "simple-server"
29+
path = "src/bin/simple-server.rs"
30+
31+
[[bin]]
32+
name = "advanced-server"
33+
path = "src/bin/advanced-server.rs"

example/src/main.rs renamed to example/src/bin/advanced-server.rs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//! This example is like simple-server but uses middleware and a custom error type.
2+
13
use std::net::SocketAddr;
24
use std::time::UNIX_EPOCH;
35

@@ -6,7 +8,7 @@ use twirp::axum::body::Body;
68
use twirp::axum::http;
79
use twirp::axum::middleware::{self, Next};
810
use twirp::axum::routing::get;
9-
use twirp::{invalid_argument, Context, Router, TwirpErrorResponse};
11+
use twirp::{invalid_argument, Context, IntoTwirpResponse, Router, TwirpErrorResponse};
1012

1113
pub mod service {
1214
pub mod haberdash {
@@ -48,15 +50,30 @@ pub async fn main() {
4850
#[derive(Clone)]
4951
struct HaberdasherApiServer;
5052

53+
#[derive(Debug, PartialEq)]
54+
enum HatError {
55+
InvalidSize,
56+
}
57+
58+
impl IntoTwirpResponse for HatError {
59+
fn into_twirp_response(self) -> http::Response<TwirpErrorResponse> {
60+
match self {
61+
HatError::InvalidSize => invalid_argument("inches").into_twirp_response(),
62+
}
63+
}
64+
}
65+
5166
#[async_trait]
5267
impl haberdash::HaberdasherApi for HaberdasherApiServer {
68+
type Error = HatError;
69+
5370
async fn make_hat(
5471
&self,
5572
ctx: Context,
5673
req: MakeHatRequest,
57-
) -> Result<MakeHatResponse, TwirpErrorResponse> {
74+
) -> Result<MakeHatResponse, HatError> {
5875
if req.inches == 0 {
59-
return Err(invalid_argument("inches"));
76+
return Err(HatError::InvalidSize);
6077
}
6178

6279
if let Some(id) = ctx.get::<RequestId>() {
@@ -118,7 +135,6 @@ mod test {
118135
use service::haberdash::v1::HaberdasherApiClient;
119136
use twirp::client::Client;
120137
use twirp::url::Url;
121-
use twirp::TwirpErrorCode;
122138

123139
use crate::service::haberdash::v1::HaberdasherApi;
124140

@@ -141,7 +157,7 @@ mod test {
141157
let res = api.make_hat(ctx, MakeHatRequest { inches: 0 }).await;
142158
assert!(res.is_err());
143159
let err = res.unwrap_err();
144-
assert_eq!(err.code, TwirpErrorCode::InvalidArgument);
160+
assert_eq!(err, HatError::InvalidSize);
145161
}
146162

147163
/// A running network server task, bound to an arbitrary port on localhost, chosen by the OS
File renamed without changes.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy