项目地址 https://github.com/poem-web/poem
什么是web框架 web框架是一种用来简化web开发的软件框架,它提供了一系列的功能,比如路由、中间件、模板引擎、数据库连接等等。这些功能可以让开发者更加专注于业务逻辑的实现,而不用花费大量的精力在底层的实现上。
1 func handle (request Request) -> (Response)
但是java 的servlet却是这样的
1 void service(ServletRequest req, ServletResponse res) throws ServletException, IOException
为什么servlet这么设计 这种设计是为了方便开发者,因为开发者可以直接使用servlet提供的response对象,而不用自己构造response对象。还有个问题就是如果使用Response放在响应中的方案那么servlet的运行时,即Web服务器本身是能够访问到container内部的buffer的(这种情况下必须由servlet容器创建buffer)。而创建buffer本就是web服务器的工作。这种写法在servlet这种嵌入的模式下将不太解耦。
poem的设计 那么,poem这里是如何定义的呢?
1 2 3 4 5 6 7 8 9 10 11 #[async_trait::async_trait] pub trait Endpoint : Send + Sync { type Output : IntoResponse; async fn call (&self , req: Request) -> Result <Self ::Output>; async fn get_response (&self , req: Request) -> Response { self .call (req) .await .map (IntoResponse::into_response) .unwrap_or_else (|err| err.into_response ()) } }
这里基本就是我们前面说的return 模式了,但是为什么有两个函数呢?
1 2 3 4 #[handler] fn hello (Path (name): Path<String >) -> String { format! ("hello: {name}" ) }
这段代码从Http path中解析name参数,返回String。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 #[allow(non_camel_case_types)] struct hello ;impl poem ::Endpoint for hello { type Output = poem::Response; #[allow(unused_mut)] async fn call (&self , mut req: poem::Request) -> poem::Result <Self ::Output> { let (req, mut body) = req.split (); let p0 = <Path<String > as poem::FromRequest>::from_request (&req, &mut body).await ?; fn hello (Path (name): Path<String >) -> String { format! ("hello: {name}" ) } let res = hello (p0); let res = poem::error::IntoResult::into_result (res); std::result::Result ::map (res, poem::IntoResponse::into_response) } }
我们可以看到,这里hello函数被过程宏转换为struct hello
结构体,然后实现了Endpoint trait,在call函数中,参数是框架从TcpStream中解析好的标准HttpRequest
1 2 3 4 5 6 7 8 9 10 11 #[derive(Default)] pub struct Request { method: Method, uri: Uri, version: Version, headers: HeaderMap, extensions: Extensions, body: Body, state: RequestState, }
1 2 3 4 5 pub trait IntoResponse : Send { fn into_response (self ) -> Response; }
解析 那么我们的TcpSteam是如何解析为HttpRequest呢?
1 2 3 4 5 6 let app = Route::new ().at ("/hello/:name" , get (hello)).with (Tracing);Server::new (TcpListener::bind ("" )) .name ("hello-world" ) .run (app) .await
1 2 3 4 5 6 7 8 9 10 impl <L: Listener> Server<L, Infallible> { pub fn new (listener: L) -> Self { Self { listener: Either::Listener (listener), name: None , idle_timeout: None , } } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 pub async fn run <E>(self , ep: E) -> IoResult<()> where E: IntoEndpoint, E::Endpoint: 'static , { self .run_with_graceful_shutdown (ep, futures_util::future::pending (), None ) .await } ```` 这里则传入我们的Endpoint,Endpoint可以不是某个具体的,处理请求的函数,也可以是一个路由,通常我们在服务器启动是,代码中已经全局静态写好了路由: ```rust let app = Route::new ().nest ("/api" , api ()); ```` 在run_with_graceful_shutdown中: ```rust pub async fn run_with_graceful_shutdown <E>( self , ep: E, signal: impl Future <Output = ()>, timeout: Option <Duration>, ) -> IoResult<()> where E: IntoEndpoint, E::Endpoint: 'static , { let ep = Arc::new (ep.into_endpoint ().map_to_response ()); let Server { listener, name, idle_timeout, } = self ; let name = name.as_deref (); let alive_connections = Arc::new (AtomicUsize::new (0 )); let notify = Arc::new (Notify::new ()); let timeout_token = CancellationToken::new (); let server_graceful_shutdown_token = CancellationToken::new (); let mut acceptor = match listener { Either::Listener (listener) => listener.into_acceptor ().await ?.boxed (), Either::Acceptor (acceptor) => acceptor.boxed (), }; tokio::pin!(signal); for addr in acceptor.local_addr () { tracing::info!(name = name, addr = %addr, "listening" ); } tracing::info!(name = name, "server started" ); loop { tokio::select! { _ = &mut signal => { server_graceful_shutdown_token.cancel (); if let Some (timeout) = timeout { tracing::info!( name = name, timeout_in_seconds = timeout.as_secs_f32 (), "initiate graceful shutdown" , ); let timeout_token = timeout_token.clone (); tokio::spawn (async move { tokio::time::sleep (timeout).await ; timeout_token.cancel (); }); } else { tracing::info!(name = name, "initiate graceful shutdown" ); } break ; }, res = acceptor.accept () => { if let Ok ((socket, local_addr, remote_addr, scheme)) = res { alive_connections.fetch_add (1 , Ordering::Release); let ep = ep.clone (); let alive_connections = alive_connections.clone (); let notify = notify.clone (); let timeout_token = timeout_token.clone (); let server_graceful_shutdown_token = server_graceful_shutdown_token.clone (); tokio::spawn (async move { let serve_connection = serve_connection (socket, local_addr, remote_addr, scheme, ep, server_graceful_shutdown_token, idle_timeout); if timeout.is_some () { tokio::select! { _ = serve_connection => {} _ = timeout_token.cancelled () => {} } } else { serve_connection.await ; } if alive_connections.fetch_sub (1 , Ordering::Acquire) == 1 { notify.notify_waiters (); } }); } } } } drop (acceptor); if alive_connections.load (Ordering::Acquire) > 0 { tracing::info!(name = name, "wait for all connections to close." ); notify.notified ().await ; } tracing::info!(name = name, "server stopped" ); Ok (()) }
,notify用于做优雅关闭的功能,在accept socket的循环中,第一层嵌套是在accept
假如外界添加了ctrl + c的signal,循环将立刻终止,我们的进程将无法接受新的socket。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 async fn serve_connection ( socket: impl AsyncRead + AsyncWrite + Send + Unpin + 'static , local_addr: LocalAddr, remote_addr: RemoteAddr, scheme: Scheme, ep: Arc<dyn Endpoint<Output = Response>>, server_graceful_shutdown_token: CancellationToken, idle_connection_close_timeout: Option <Duration>, ) { let connection_shutdown_token = CancellationToken::new (); let service = hyper::service::service_fn ({ let remote_addr = remote_addr.clone (); move |req: http::Request<Incoming>| { let ep = ep.clone (); let local_addr = local_addr.clone (); let remote_addr = remote_addr.clone (); let scheme = scheme.clone (); async move { Ok::<http::Response<_>, Infallible>( ep.get_response ((req, local_addr, remote_addr, scheme).into ()) .await .into (), ) } } }); let socket = match idle_connection_close_timeout { Some (timeout) => { tokio_util::either::Either::Left (ClosingInactiveConnection::new (socket, timeout, { let connection_shutdown_token = connection_shutdown_token.clone (); move || { let connection_shutdown_token = connection_shutdown_token.clone (); async move { connection_shutdown_token.cancel (); } } })) } None => tokio_util::either::Either::Right (socket), }; let builder = auto::Builder::new (hyper_util::rt::TokioExecutor::new ()); let conn = builder.serve_connection_with_upgrades (hyper_util::rt::TokioIo::new (socket), service); futures_util::pin_mut!(conn); tokio::select! { _ = conn => { }, _ = connection_shutdown_token.cancelled () => { tracing::info!(remote_addr=%remote_addr, "closing connection due to inactivity" ); } _ = server_graceful_shutdown_token.cancelled () => {} }
1 let service = hyper::service::service_fn ({
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 pub trait HttpService <ReqBody>: sealed::Sealed<ReqBody> { type ResBody : Body; type Error : Into <Box <dyn StdError + Send + Sync >>; type Future : Future<Output = Result <Response<Self ::ResBody>, Self ::Error>>; fn call (&mut self , req: Request<ReqBody>) -> Self ::Future; } ```` 这里跟poem定义的非常类似。 最后异步运行时真正在poll的是这个玩意: ```rust let conn = builder.serve_connection_with_upgrades (hyper_util::rt::TokioIo::new (socket), service);
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 pub fn serve_connection_with_upgrades <I, S, B>( &self , io: I, service: S, ) -> UpgradeableConnection<'_ , I, S, E> where S: Service<Request<Incoming>, Response = Response<B>>, S::Future: 'static , S::Error: Into <Box <dyn StdError + Send + Sync >>, B: Body + 'static , B::Error: Into <Box <dyn StdError + Send + Sync >>, I: Read + Write + Unpin + Send + 'static , E: Http2ServerConnExec<S::Future, B>, { UpgradeableConnection { state: UpgradeableConnState::ReadVersion { read_version: read_version (io), builder: self , service: Some (service), }, } } ```` 代码这里就跳转到hyper中去了,这里就是分具体版本Http协议的解析部分了,我们直接看Future状态机怎么写的: ```rust impl <I, S, E, B> Future for UpgradeableConnection <'_ , I, S, E>where S: Service<Request<Incoming>, Response = Response<B>>, S::Future: 'static , S::Error: Into <Box <dyn StdError + Send + Sync >>, B: Body + 'static , B::Error: Into <Box <dyn StdError + Send + Sync >>, I: Read + Write + Unpin + Send + 'static , E: Http2ServerConnExec<S::Future, B>, { type Output = Result <()>; fn poll (mut self : Pin<&mut Self >, cx: &mut Context<'_ >) -> Poll<Self ::Output> { loop { let mut this = self .as_mut ().project (); match this.state.as_mut ().project () { UpgradeableConnStateProj::ReadVersion { read_version, builder, service, } => { let (version, io) = ready!(read_version.poll (cx))?; let service = service.take ().unwrap (); match version { Version::H1 => { let conn = builder.http1.serve_connection (io, service).with_upgrades (); this.state.set (UpgradeableConnState::H1 { conn }); } Version::H2 => { let conn = builder.http2.serve_connection (io, service); this.state.set (UpgradeableConnState::H2 { conn }); } } } UpgradeableConnStateProj::H1 { conn } => { return conn.poll (cx).map_err (Into ::into); } UpgradeableConnStateProj::H2 { conn } => { return conn.poll (cx).map_err (Into ::into); } } } } }
这个状态机就是首先从header中读出http版本,然后创建对应协议解析的状态机Future ,然后直接poll。
1 2 3 4 5 pub (crate ) struct Conn <I, B, T> { io: Buffered<I, EncodedBuf<B>>, state: State, _marker: PhantomData<fn (T)>, }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 pub fn serve_connection <I, S>(&self , io: I, service: S) -> Connection<I, S> where S: HttpService<IncomingBody>, S::Error: Into <Box <dyn StdError + Send + Sync >>, S::ResBody: 'static , <S::ResBody as Body>::Error: Into <Box <dyn StdError + Send + Sync >>, I: Read + Write + Unpin, { let mut conn = proto::Conn::new (io); conn.set_timer (self .timer.clone ()); if !self .h1_keep_alive { conn.disable_keep_alive (); } if self .h1_half_close { conn.set_allow_half_close (); } if self .h1_title_case_headers { conn.set_title_case_headers (); } if self .h1_preserve_header_case { conn.set_preserve_header_case (); } if let Some (dur) = self .timer .check (self .h1_header_read_timeout, "header_read_timeout" ) { conn.set_http1_header_read_timeout (dur); }; if let Some (writev) = self .h1_writev { if writev { conn.set_write_strategy_queue (); } else { conn.set_write_strategy_flatten (); } } conn.set_flush_pipeline (self .pipeline_flush); if let Some (max) = self .max_buf_size { conn.set_max_buf_size (max); } let sd = proto::h1::dispatch::Server::new (service); let proto = proto::h1::Dispatcher::new (sd, conn); Connection { conn: proto } } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 impl <D, Bs, I, T> Future for Dispatcher <D, Bs, I, T>where D: Dispatch< PollItem = MessageHead<T::Outgoing>, PollBody = Bs, RecvItem = MessageHead<T::Incoming>, > + Unpin, D::PollError: Into <Box <dyn StdError + Send + Sync >>, I: Read + Write + Unpin, T: Http1Transaction + Unpin, Bs: Body + 'static , Bs::Error: Into <Box <dyn StdError + Send + Sync >>, { type Output = crate::Result <Dispatched>; #[inline] fn poll (mut self : Pin<&mut Self >, cx: &mut Context<'_ >) -> Poll<Self ::Output> { self .poll_catch (cx, true ) } }
然后就是poll_catch -> poll_inner -> poll_loop :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 let _ = self .poll_read (cx)?; let _ = self .poll_write (cx)?; let _ = self .poll_flush (cx)?; ```` 这里执行到实际的读写 先看读: ```rust fn poll_read (&mut self , cx: &mut Context<'_ >) -> Poll<crate::Result <()>> { loop { if self .is_closing { return Poll::Ready (Ok (())); } else if self .conn.can_read_head () { ready!(self .poll_read_head (cx))?; } else if let Some (mut body) = self .body_tx.take () { if self .conn.can_read_body () { match body.poll_ready (cx) { Poll::Ready (Ok (())) => (), Poll::Pending => { self .body_tx = Some (body); return Poll::Pending; } Poll::Ready (Err (_canceled)) => { trace!("body receiver dropped before eof, draining or closing" ); self .conn.poll_drain_or_close_read (cx); continue ; } } match self .conn.poll_read_body (cx) { Poll::Ready (Some (Ok (chunk))) => match body.try_send_data (chunk) { Ok (()) => { self .body_tx = Some (body); } Err (_canceled) => { if self .conn.can_read_body () { trace!("body receiver dropped before eof, closing" ); self .conn.close_read (); } } }, Poll::Ready (None ) => { } Poll::Pending => { self .body_tx = Some (body); return Poll::Pending; } Poll::Ready (Some (Err (e))) => { body.send_error (crate::Error::new_body (e)); } } } else { } } else { return self .conn.poll_read_keep_alive (cx); } } }
终于看到了真正的解析Http协议的地方了。这里面poll_read_head 用于poll http请求头。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 fn poll_read_head (&mut self , cx: &mut Context<'_ >) -> Poll<crate::Result <()>> { match ready!(self .dispatch.poll_ready (cx)) { Ok (()) => (), Err (()) => { trace!("dispatch no longer receiving messages" ); self .close (); return Poll::Ready (Ok (())); } } match ready!(self .conn.poll_read_head (cx)) { Some (Ok ((mut head, body_len, wants))) => { let body = match body_len { DecodedLength::ZERO => IncomingBody::empty (), other => { let (tx, rx) = IncomingBody::new_channel (other, wants.contains (Wants::EXPECT)); self .body_tx = Some (tx); rx } }; if wants.contains (Wants::UPGRADE) { let upgrade = self .conn.on_upgrade (); debug_assert! (!upgrade.is_none (), "empty upgrade" ); debug_assert! ( head.extensions.get::<OnUpgrade>().is_none (), "OnUpgrade already set" ); head.extensions.insert (upgrade); } self .dispatch.recv_msg (Ok ((head, body)))?; Poll::Ready (Ok (())) } Some (Err (err)) => { debug!("read_head error: {}" , err); self .dispatch.recv_msg (Err (err))?; self .close (); Poll::Ready (Ok (())) } None => { debug_assert! (self .conn.is_read_closed ()); if self .conn.is_write_closed () { self .close (); } Poll::Ready (Ok (())) } } } ```` poll_read_head这里调用conn的poll_read_head后能得到我们需要的header及body长度。 最终在self .dispatch.recv_msg中我们将最终调用用户代码: ```rust fn recv_msg (&mut self , msg: crate::Result <(Self ::RecvItem, IncomingBody)>) -> crate::Result <()> { let (msg, body) = msg?; let mut req = Request::new (body); *req.method_mut () = msg.subject.0 ; *req.uri_mut () = msg.subject.1 ; *req.headers_mut () = msg.headers; *req.version_mut () = msg.version; *req.extensions_mut () = msg.extensions; let fut = self .service.call (req); self .in_flight.set (Some (fut)); Ok (()) }
Middleware 中间件是web框架中的重要概念,poem中的中间件是这样定义的:
1 2 3 4 5 6 pub trait Middleware <E: Endpoint> { type Output : Endpoint; fn transform (&self , ep: E) -> Self ::Output; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 struct Log ;impl <E: Endpoint> Middleware<E> for Log { type Output = LogImpl<E>; fn transform (&self , ep: E) -> Self ::Output { LogImpl (ep) } } struct LogImpl <E>(E);#[async_trait] impl <E: Endpoint> Endpoint for LogImpl <E> { type Output = Response; async fn call (&self , req: Request) -> Result <Self ::Output> { println! ("request: {}" , req.uri ().path ()); let res = self .0 .call (req).await ; match res { Ok (resp) => { let resp = resp.into_response (); println! ("response: {}" , resp.status ()); Ok (resp) } Err (err) => { println! ("error: {err}" ); Err (err) } } } }
1 2 3 4 5 6 7 8 9 10 pub trait EndpointExt : IntoEndpoint { fn with <T>(self , middleware: T) -> T::Output where T: Middleware<Self ::Endpoint>, Self : Sized , { middleware.transform (self .into_endpoint ()) } }
1 let app = Route::new ().at ("/hello/:name" , get (hello)).with (Tracing);
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 pub trait IntoEndpoint { type Endpoint : Endpoint; fn into_endpoint (self ) -> Self ::Endpoint; } impl <T: Endpoint> IntoEndpoint for T { type Endpoint = T; fn into_endpoint (self ) -> Self ::Endpoint { self } } ```` 还可以用around方法,把Fn 封装为中间件: ```rust fn around <F, Fut, R>(self , f: F) -> Around<Self ::Endpoint, F> where F: Fn (Arc<Self ::Endpoint>, Request) -> Fut + Send + Sync + 'static , Fut: Future<Output = Result <R>> + Send + 'static , R: IntoResponse, Self : Sized , { Around::new (self .into_endpoint (), f) }
1 2 3 4 pub struct Around <E, F> { inner: Arc<E>, f: F, }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 #[async_trait::async_trait] impl <E, F, Fut, T> Endpoint for Around <E, F>where E: Endpoint, F: Fn (Arc<E>, Request) -> Fut + Send + Sync + 'static , Fut: Future<Output = Result <T>> + Send , T: IntoResponse, { type Output = T; async fn call (&self , req: Request) -> Result <Self ::Output> { (self .f)(self .inner.clone (), req).await } }
1 2 3 4 5 6 7 8 9 impl <T, E> Endpoint for ServerSessionEndpoint <T, E>where T: SessionStorage, E: Endpoint, { ... }