diff --git a/Cargo.lock b/Cargo.lock index 2d52e3d..b7589f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1256,9 +1256,13 @@ dependencies = [ [[package]] name = "twirp-build" -version = "0.8.0" +version = "0.7.0" dependencies = [ + "prettyplease", + "proc-macro2", "prost-build", + "quote", + "syn", ] [[package]] diff --git a/crates/twirp-build/Cargo.toml b/crates/twirp-build/Cargo.toml index 9a84091..941d535 100644 --- a/crates/twirp-build/Cargo.toml +++ b/crates/twirp-build/Cargo.toml @@ -16,3 +16,7 @@ license-file = "./LICENSE" [dependencies] prost-build = "0.13" +prettyplease = { version = "0.2" } +quote = "1.0" +syn = "2.0" +proc-macro2 = "1.0" diff --git a/crates/twirp-build/src/lib.rs b/crates/twirp-build/src/lib.rs index 62dd6a9..14ea9e4 100644 --- a/crates/twirp-build/src/lib.rs +++ b/crates/twirp-build/src/lib.rs @@ -1,4 +1,5 @@ -use std::fmt::Write; +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; /// Generates twirp services for protobuf rpc service definitions. /// @@ -11,123 +12,165 @@ pub fn service_generator() -> Box { Box::new(ServiceGenerator {}) } +struct MethodTypes { + input_type: TokenStream, + output_type: TokenStream, +} + +impl MethodTypes { + fn from_prost(m: &prost_build::Method) -> Self { + let as_type = |s| -> TokenStream { + let Ok(typ) = syn::parse_str::(s) else { + panic!( + "twirp-build generated invalid Rust. this is a bug in twirp-build, please file an issue:\n + method={name} + input_type={input_type} + output_type={output_type} + ", + name = m.name, + input_type = m.input_type, + output_type = m.output_type, + ); + }; + typ.to_token_stream() + }; + + let input_type = as_type(&m.input_type); + let output_type = as_type(&m.output_type); + + Self { + input_type, + output_type, + } + } +} + pub struct ServiceGenerator; impl prost_build::ServiceGenerator for ServiceGenerator { fn generate(&mut self, service: prost_build::Service, buf: &mut String) { - let service_name = service.name; + let service_name = format_ident!("{}", &service.name); let service_fqn = format!("{}.{}", service.package, service.proto_name); - writeln!(buf).unwrap(); - writeln!(buf, "pub use twirp;").unwrap(); - writeln!(buf).unwrap(); - writeln!(buf, "pub const SERVICE_FQN: &str = \"/{service_fqn}\";").unwrap(); - - // // generate the twirp server - // - writeln!(buf, "#[twirp::async_trait::async_trait]").unwrap(); - writeln!(buf, "pub trait {} {{", service_name).unwrap(); - writeln!(buf, " type Error;").unwrap(); - for m in &service.methods { - writeln!( - buf, - " async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, Self::Error>;", - m.name, m.input_type, m.output_type, - ) - .unwrap(); - } - writeln!(buf, "}}").unwrap(); - - writeln!(buf, "#[twirp::async_trait::async_trait]").unwrap(); - writeln!(buf, "impl {service_name} for std::sync::Arc").unwrap(); - writeln!(buf, "where").unwrap(); - writeln!(buf, " T: {service_name} + Sync + Send").unwrap(); - writeln!(buf, "{{").unwrap(); - writeln!(buf, " type Error = T::Error;\n").unwrap(); + let mut trait_methods = Vec::with_capacity(service.methods.len()); + let mut proxy_methods = Vec::with_capacity(service.methods.len()); for m in &service.methods { - writeln!( - buf, - " async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, Self::Error> {{", - m.name, m.input_type, m.output_type, - ) - .unwrap(); - writeln!(buf, " T::{}(&*self, ctx, req).await", m.name).unwrap(); - writeln!(buf, " }}").unwrap(); + let name = format_ident!("{}", &m.name); + let MethodTypes { + input_type, + output_type, + } = MethodTypes::from_prost(m); + + trait_methods.push(quote! { + async fn #name(&self, ctx: twirp::Context, req: #input_type) -> Result<#output_type, Self::Error>; + }); + + proxy_methods.push(quote! { + async fn #name(&self, ctx: twirp::Context, req: #input_type) -> Result<#output_type, Self::Error> { + T::#name(&*self, ctx, req).await + } + }); } - writeln!(buf, "}}").unwrap(); - - // add_service - writeln!( - buf, - r#"pub fn router(api: T) -> twirp::Router -where - T: {service_name} + Clone + Send + Sync + 'static, - ::Error: twirp::IntoTwirpResponse, -{{ - twirp::details::TwirpRouterBuilder::new(api)"#, - ) - .unwrap(); + + let server_trait = quote! { + #[twirp::async_trait::async_trait] + pub trait #service_name { + type Error; + + #(#trait_methods)* + } + + #[twirp::async_trait::async_trait] + impl #service_name for std::sync::Arc + where + T: #service_name + Sync + Send + { + type Error = T::Error; + + #(#proxy_methods)* + } + }; + + // generate the router + let mut route_calls = Vec::with_capacity(service.methods.len()); for m in &service.methods { - let uri = &m.proto_name; - let req_type = &m.input_type; - let rust_method_name = &m.name; - writeln!( - buf, - r#" .route("/{uri}", |api: T, ctx: twirp::Context, req: {req_type}| async move {{ - api.{rust_method_name}(ctx, req).await - }})"#, - ) - .unwrap(); + let name = format_ident!("{}", &m.name); + let uri = format!("/{}", &m.proto_name); + let MethodTypes { input_type, .. } = MethodTypes::from_prost(&m); + route_calls.push(quote! { + .route(#uri, |api: T, ctx: twirp::Context, req: #input_type| async move { + api.#name(ctx, req).await + }) + }); } - writeln!( - buf, - r#" - .build() -}}"# - ) - .unwrap(); + let router = quote! { + pub fn router(api: T) -> twirp::Router + where + T: #service_name + Clone + Send + Sync + 'static, + ::Error: twirp::IntoTwirpResponse + { + twirp::details::TwirpRouterBuilder::new(api) + #(#route_calls)* + .build() + } + }; // // generate the twirp client // - writeln!(buf).unwrap(); - writeln!(buf, "#[twirp::async_trait::async_trait]").unwrap(); - writeln!(buf, "pub trait {service_name}Client: Send + Sync {{",).unwrap(); - for m in &service.methods { - // Define: - writeln!( - buf, - " async fn {}(&self, req: {}) -> Result<{}, twirp::ClientError>;", - m.name, m.input_type, m.output_type, - ) - .unwrap(); - } - writeln!(buf, "}}").unwrap(); - - // Implement the rpc traits for: `twirp::client::Client` - writeln!(buf, "#[twirp::async_trait::async_trait]").unwrap(); - writeln!( - buf, - "impl {service_name}Client for twirp::client::Client {{", - ) - .unwrap(); + let client_name = format_ident!("{}Client", service_name); + + let mut client_trait_methods = Vec::with_capacity(service.methods.len()); + let mut client_methods = Vec::with_capacity(service.methods.len()); for m in &service.methods { - // Define the rpc `` - writeln!( - buf, - " async fn {}(&self, req: {}) -> Result<{}, twirp::ClientError> {{", - m.name, m.input_type, m.output_type, - ) - .unwrap(); - writeln!( - buf, - r#" self.request("{}/{}", req).await"#, - service_fqn, m.proto_name - ) - .unwrap(); - writeln!(buf, " }}").unwrap(); + let name = format_ident!("{}", &m.name); + let MethodTypes { + input_type, + output_type, + } = MethodTypes::from_prost(&m); + + client_trait_methods.push(quote! { + async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError>; + }); + + let url = format!("{}/{}", service_fqn, m.proto_name); + client_methods.push(quote! { + async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError> { + self.request(#url, req).await + } + }) } - writeln!(buf, "}}").unwrap(); + let client_trait = quote! { + #[twirp::async_trait::async_trait] + pub trait #client_name: Send + Sync { + #(#client_trait_methods)* + } + + #[twirp::async_trait::async_trait] + impl #client_name for twirp::client::Client { + #(#client_methods)* + } + }; + + // generate the service and client as a single file. run it through + // prettyplease before outputting it. + let service_fqn_path = format!("/{}", service_fqn); + let generated = quote! { + pub use twirp; + + pub const SERVICE_FQN: &str = #service_fqn_path; + + #server_trait + + #router + + #client_trait + }; + + let ast: syn::File = syn::parse2(generated) + .expect("twirp-build generated invalid Rust. this is a bug in twirp-build, please file an issue"); + let code = prettyplease::unparse(&ast); + buf.push_str(&code); } } diff --git a/example/proto/haberdash/v1/haberdash_api.proto b/example/proto/haberdash/v1/haberdash_api.proto index 6515ba9..e385d4a 100644 --- a/example/proto/haberdash/v1/haberdash_api.proto +++ b/example/proto/haberdash/v1/haberdash_api.proto @@ -9,6 +9,7 @@ option go_package = "haberdash.v1"; service HaberdasherAPI { // MakeHat produces a hat of mysterious, randomly-selected color! rpc MakeHat(MakeHatRequest) returns (MakeHatResponse); + rpc GetStatus(GetStatusRequest) returns (GetStatusResponse); } // Size is passed when requesting a new hat to be made. It's always @@ -32,3 +33,9 @@ message MakeHatResponse { // Demonstrate importing an external message. google.protobuf.Timestamp timestamp = 4; } + +message GetStatusRequest {} + +message GetStatusResponse { + string status = 1; +} diff --git a/example/src/bin/advanced-server.rs b/example/src/bin/advanced-server.rs index 486c22b..cd24fa3 100644 --- a/example/src/bin/advanced-server.rs +++ b/example/src/bin/advanced-server.rs @@ -17,7 +17,9 @@ pub mod service { } } } -use service::haberdash::v1::{self as haberdash, MakeHatRequest, MakeHatResponse}; +use service::haberdash::v1::{ + self as haberdash, GetStatusRequest, GetStatusResponse, MakeHatRequest, MakeHatResponse, +}; async fn ping() -> &'static str { "Pong\n" @@ -95,6 +97,16 @@ impl haberdash::HaberdasherApi for HaberdasherApiServer { }), }) } + + async fn get_status( + &self, + _ctx: Context, + _req: GetStatusRequest, + ) -> Result { + Ok(GetStatusResponse { + status: "making hats".to_string(), + }) + } } // Demonstrate sending back custom extensions from the handlers. diff --git a/example/src/bin/client.rs b/example/src/bin/client.rs index 9f388ed..89c6e71 100644 --- a/example/src/bin/client.rs +++ b/example/src/bin/client.rs @@ -12,7 +12,9 @@ pub mod service { } } -use service::haberdash::v1::{HaberdasherApiClient, MakeHatRequest, MakeHatResponse}; +use service::haberdash::v1::{ + GetStatusRequest, GetStatusResponse, HaberdasherApiClient, MakeHatRequest, MakeHatResponse, +}; #[tokio::main] pub async fn main() -> Result<(), GenericError> { @@ -79,4 +81,11 @@ impl HaberdasherApiClient for MockHaberdasherApiClient { ) -> Result { todo!() } + + async fn get_status( + &self, + _req: GetStatusRequest, + ) -> Result { + todo!() + } } diff --git a/example/src/bin/simple-server.rs b/example/src/bin/simple-server.rs index 852a543..12eb18b 100644 --- a/example/src/bin/simple-server.rs +++ b/example/src/bin/simple-server.rs @@ -12,7 +12,9 @@ pub mod service { } } } -use service::haberdash::v1::{self as haberdash, MakeHatRequest, MakeHatResponse}; +use service::haberdash::v1::{ + self as haberdash, GetStatusRequest, GetStatusResponse, MakeHatRequest, MakeHatResponse, +}; async fn ping() -> &'static str { "Pong\n" @@ -69,6 +71,16 @@ impl haberdash::HaberdasherApi for HaberdasherApiServer { }), }) } + + async fn get_status( + &self, + _ctx: Context, + _req: GetStatusRequest, + ) -> Result { + Ok(GetStatusResponse { + status: "making hats".to_string(), + }) + } } // Demonstrate sending back custom extensions from the handlers. pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy