Tower 库完全指南:Service 与 Layer 架构
Tower 是 Rust 生态中用于构建健壮网络客户端和服务器的核心库。它提供了两个关键抽象:Service(异步请求/响应处理)和 Layer(中间件组合)。Axum、Hyper 等流行框架都基于 Tower 构建。理解 Tower 的架构对于深入掌握现代 Rust Web 开发至关重要。
Tower 概述
什么是 Tower?
Tower 的核心设计理念是:
// async fn(Request) -> Result<Response, Error>
这是一个简洁而强大的抽象,代表了一个异步函数,接受请求并返回响应或错误。
Tower 生态系统
Tower 生态系统由以下 crate 组成:
- tower:核心库,提供 Service 和 Layer trait
- tower-service:Service trait 定义
- tower-layer:Layer trait 定义
- tower-test:测试工具
- tower-http:HTTP 特定的中间件
为什么使用 Tower?
- 模块化:每个中间件都是独立的可组合组件
- 协议无关:Service trait 可以建模各种网络协议
- 可复用:社区提供的丰富中间件库
- 类型安全:充分利用 Rust 的类型系统
核心抽象:Service Trait
Service Trait 定义
use std::future::Future;
use std::task::{Context, Poll};
pub trait Service<Request> {
/// 响应的未来类型
type Response;
/// 可能发生的错误类型
type Error;
/// 响应 Future 类型
type Future: Future<Output = Result<Self::Response, Self::Error>>;
/// 检查服务是否准备好处理请求
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
/// 处理请求并返回响应
fn call(&mut self, req: Request) -> Self::Future;
}实现第一个 Service
use std::future::{Ready, ready};
use std::task::{Context, Poll};
use tower::{Layer, Service};
use http::{Request, Response};
use http_body::Body;
pub struct HelloService;
impl Service<Request<hyper::Body>> for HelloService {
type Response = Response<String>;
type Error = std::convert::Infallible;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<hyper::Body>) -> Self::Future {
let response = Response::builder()
.status(200)
.body("Hello, Tower!".to_string())
.unwrap();
ready(Ok(response))
}
}Service 方法
impl<S, Request> Service<Request> for S
where
S: tower::Service<Request>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
self.call(req)
}
}Service 的状态管理
use std::sync::{Arc, Mutex};
use std::future::{Ready, ready};
use std::task::{Context, Poll};
use tower::{Service, Layer};
use http::{Request, Response};
pub struct CounterService {
counter: Arc<Mutex<usize>>,
}
impl CounterService {
pub fn new() -> Self {
CounterService {
counter: Arc::new(Mutex::new(0)),
}
}
}
impl Service<Request<()>> for CounterService {
type Response = Response<String>;
type Error = std::convert::Infallible;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<()>) -> Self::Future {
let mut counter = self.counter.lock().unwrap();
*counter += 1;
let count = *counter;
let response = Response::builder()
.status(200)
.body(format!("Request count: {}", count))
.unwrap();
ready(Ok(response))
}
}核心抽象:Layer Trait
Layer Trait 定义
pub trait Layer<S> {
/// 应用此层后产生的 Service 类型
type Service;
/// 将 Layer 应用到 Service 上
fn layer(&self, inner: S) -> Self::Service;
}实现第一个 Layer
use std::future::{Future, ready};
use std::task::{Context, Poll};
use std::pin::Pin;
use tower::{Service, Layer};
use http::{Request, Response};
#[derive(Clone)]
pub struct TimingLayer;
impl<S> Layer<S> for TimingLayer {
type Service = TimingService<S>;
fn layer(&self, inner: S) -> Self::Service {
TimingService { inner }
}
}
#[derive(Clone)]
pub struct TimingService<S> {
inner: S,
}
impl<S, B> Service<Request<B>> for TimingService<S>
where
S: Service<Request<B>, Response = Response<String>> + Clone + Send + 'static,
S::Future: Send + 'static,
B: Send + 'static,
{
type Response = Response<String>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let start = std::time::Instant::now();
let mut inner = self.inner.clone();
Box::pin(async move {
let mut res = inner.call(req).await?;
let elapsed = start.elapsed();
println!("Request took: {:?}", elapsed);
Ok(res)
})
}
}组合多个 Layer
use tower::ServiceBuilder;
let service = ServiceBuilder::new()
.layer(TimingLayer)
.layer(LoggingLayer)
.layer(AuthLayer)
.service(HelloService)
.clone();内置中间件
Timeout 中间件
use tower::timeout::Timeout;
use std::time::Duration;
let service = Timeout::new(
my_service,
Duration::from_secs(30)
);Retry 中间件
use tower::retry::Retry;
use tower::retry::Policy;
use std::sync::Arc;
// 定义重试策略
#[derive(Clone)]
struct RetryPolicy;
impl<B, E> Policy<http::Request<B>, Response<String>, E> for RetryPolicy
where
B: Send + 'static,
E: std::fmt::Debug + Send + 'static,
{
// ... 实现 Policy trait
}
let service = Retry::new(
RetryPolicy,
my_service,
);Rate Limit 中间件
use tower::limit::RateLimit;
use std::num::NonZeroU64;
// 每秒最多 100 个请求
let service = RateLimit::new(
my_service,
NonZeroU64::new(100).unwrap(),
);Load Shedding 中间件
use tower::load_shed::LoadShed;
let service = LoadShed::new(my_service);Buffer 中间件
use tower::buffer::Buffer;
use std::sync::Arc;
let service = Buffer::new(
my_service,
256, // 缓冲区大小
).unwrap();实际应用示例
认证中间件
use std::future::{Future, ready};
use std::pin::Pin;
use std::task::{Context, Poll};
use tower::{Layer, Service};
use http::{Request, Response, StatusCode};
use http_body::Body;
#[derive(Clone)]
pub struct AuthLayer {
expected_token: String,
}
impl AuthLayer {
pub fn new(expected_token: impl Into<String>) -> Self {
Self {
expected_token: expected_token.into(),
}
}
}
impl<S> Layer<S> for AuthLayer {
type Service = AuthMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
AuthMiddleware {
inner,
expected_token: self.expected_token.clone(),
}
}
}
#[derive(Clone)]
pub struct AuthMiddleware<S> {
inner: S,
expected_token: String,
}
impl<S, B, ResB, ResE> Service<Request<B>> for AuthMiddleware<S>
where
S: Service<Request<B>, Response = Response<ResB>, Error = ResE> + Clone + Send + 'static,
S::Future: Send + 'static,
B: Send + 'static,
ResB: Body::Data + 'static,
ResE: Send + 'static,
{
type Response = Response<ResB>;
type Error = ResE;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
// 检查 Authorization 头
let auth_header = req
.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let is_authorized = auth_header
.as_ref()
.map(|token| token.strip_prefix("Bearer ").unwrap_or(token))
.map(|token| token == &self.expected_token)
.unwrap_or(false);
let mut inner = self.inner.clone();
if is_authorized {
Box::pin(async move {
inner.call(req).await
})
} else {
let response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body("Unauthorized".into())
.unwrap();
Box::pin(ready(Ok(response)))
}
}
}日志中间件
use std::future::{Future, ready};
use std::pin::Pin;
use std::task::{Context, Poll};
use tower::{Layer, Service};
use http::Request;
#[derive(Clone)]
pub struct LoggingLayer;
impl<S> Layer<S> for LoggingLayer {
type Service = LoggingService<S>;
fn layer(&self, inner: S) -> Self::Service {
LoggingService { inner }
}
}
#[derive(Clone)]
pub struct LoggingService<S> {
inner: S,
}
impl<S, B, ResB, ResE> Service<Request<B>> for LoggingService<S>
where
S: Service<Request<B>, Response = Response<ResB>, Error = ResE> + Clone + Send + 'static,
S::Future: Send + 'static,
B: Send + 'static,
ResB: Default + Send + 'static,
ResE: Send + 'static,
{
type Response = Response<ResB>;
type Error = ResE;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let method = req.method().clone();
let uri = req.uri().clone();
let start = std::time::Instant::now();
let mut inner = self.inner.clone();
Box::pin(async move {
println!("--> {} {}", method, uri);
let result = inner.call(req).await;
let elapsed = start.elapsed();
let status = match &result {
Ok(res) => res.status().as_u16(),
Err(_) => 500,
};
println!("<-- {} {} ({:?})", status, uri, elapsed);
result
})
}
}请求ID中间件
use std::future::{Future, ready};
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use tower::{Layer, Service};
use http::{Request, header};
use http_body::Body;
static REQUEST_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Clone)]
pub struct RequestIdLayer;
impl<S> Layer<S> for RequestIdLayer {
type Service = RequestIdService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequestIdService { inner }
}
}
pub struct RequestIdService<S> {
inner: S,
}
impl<S, B, ResB, ResE> Service<Request<B>> for RequestIdService<S>
where
S: Service<Request<B>, Response = Response<ResB>, Error = ResE> + Clone + Send + 'static,
S::Future: Send + 'static,
B: Send + 'static,
ResB: Default + Send + 'static,
ResE: Send + 'static,
{
type Response = Response<ResB>;
type Error = ResE;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
// 生成请求 ID
let request_id = REQUEST_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
// 添加到请求头
req.headers_mut().insert(
header::HeaderName::from_static("x-request-id"),
header::HeaderValue::from(format!("{}", request_id).parse().unwrap()),
);
let mut inner = self.inner.clone();
Box::pin(async move {
inner.call(req).await
})
}
}ServiceBuilder
基本用法
use tower::ServiceBuilder;
let service = ServiceBuilder::new()
.rate_limit(100, std::time::Duration::from_secs(1))
.concurrency_limit(256)
.timeout(std::time::Duration::from_secs(30))
.service(my_service);组合多个中间件
use tower::ServiceBuilder;
use tower::limit::RateLimit;
use tower::timeout::Timeout;
let service = ServiceBuilder::new()
// 添加认证
.layer(AuthLayer::new("secret-token"))
// 添加日志
.layer(LoggingLayer)
// 添加请求 ID
.layer(RequestIdLayer)
// 限制速率
.rate_limit(100, Duration::from_secs(1))
// 添加超时
.timeout(Duration::from_secs(30))
// 应用目标服务
.service(MyService);自定义中间件组合
use tower::ServiceBuilder;
pub fn with_middleware<S>(service: S) -> impl Service<Request<()>>
where
S: Service<Request<()>> + Clone,
{
ServiceBuilder::new()
.layer(TimingLayer)
.layer(LoggingLayer)
.layer(AuthLayer::new("token"))
.service(service)
}与 Axum 集成
在 Axum 中使用 Tower 中间件
use axum::{
Router,
routing::get,
body::Body,
http::{Request, StatusCode},
};
use tower::{Layer, ServiceExt};
use tower_http::trace::TraceLayer;
use std::net::SocketAddr;
#[tokio::main]
async fn main() {
let app = Router::new()
.route("/", get(handler))
.layer(TraceLayer::new_for_http());
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("Listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
async fn handler() -> &'static str {
"Hello, Tower!"
}自定义 Tower 中间件在 Axum 中
use axum::{
Router,
routing::get,
extract::Request,
response::Response,
body::Body,
};
use tower::{Service, Layer};
use std::future::{Ready, ready};
use std::task::{Context, Poll};
#[tokio::main]
async fn main() {
let app = Router::new()
.route("/", get(handler))
.layer(MyTowerLayer::new());
// 启动服务...
}
async fn handler() -> &'static str {
"Hello, Tower Middleware!"
}
#[derive(Clone)]
pub struct MyTowerLayer;
impl MyTowerLayer {
pub fn new() -> Self {
Self
}
}
impl<S> Layer<S> for MyTowerLayer {
type Service = MyTowerMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
MyTowerMiddleware { inner }
}
}
#[derive(Clone)]
pub struct MyTowerMiddleware<S> {
inner: S,
}
impl<S> Service<Request<Body>> for MyTowerMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
println!("Processing request: {} {}", req.method(), req.uri());
self.inner.call(req)
}
}Tower HTTP
Trace 中间件
use tower_http::trace::TraceLayer;
use tracing_subscriber;
let service = TraceLayer::new_for_http()
.on_request(|request, _| {
tracing::info!("Started {} {}", request.method(), request.uri());
})
.on_response(|response, latency| {
tracing::info!("Completed with {} in {:?}", response.status(), latency);
})
.on_failure(|error, _| {
tracing::error!("Request failed: {}", error);
});CORS 中间件
use tower_http::cors::{CorsLayer, Any};
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let service = my_service.layer(cors);Compression 中间件
use tower_http::compression::CompressionLayer;
let service = CompressionLayer::new()
.gzip(true, 12)
.deflate(true);Sensitive Headers
use tower_http::set_header::SetResponseHeaderLayer;
use http::{header::AUTHORIZATION, header::HeaderValue};
let service = my_service
.layer(SetResponseHeaderLayer::if_present(
AUTHORIZATION,
HeaderValue::from_static("****"),
));测试
单元测试
#[cfg(test)]
mod tests {
use super::*;
use http::Request;
use std::convert::Infallible;
#[tokio::test]
async fn test_auth_middleware_authorized() {
let mut service = AuthMiddleware {
inner: PassThroughService,
expected_token: "secret".to_string(),
};
let req = Request::builder()
.header("Authorization", "Bearer secret")
.body(())
.unwrap();
let res = service.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_auth_middleware_unauthorized() {
let mut service = AuthMiddleware {
inner: PassThroughService,
expected_token: "secret".to_string(),
};
let req = Request::builder()
.header("Authorization", "Bearer wrong")
.body(())
.unwrap();
let res = service.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
// 用于测试的穿透服务
struct PassThroughService;
impl Service<Request<()>> for PassThroughService {
type Response = Response<String>;
type Error = Infallible;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: Request<()>) -> Self::Future {
std::future::ready(Ok(Response::builder()
.status(200)
.body("OK".to_string())
.unwrap()))
}
}
}使用 Mock 进行测试
use tower_test::{mock, Mock};
#[tokio::test]
async fn test_with_mock() {
let (mock_service, handle) = Mock::new();
// 设置期望
handle.allow(0); // 允许 0 个请求
let mut layered = MyLayer.layer(mock_service);
// 测试逻辑...
}最佳实践
1. 保持中间件简单
// 好:单一职责
struct RateLimitLayer { /* ... */ }
struct AuthLayer { /* ... */ }
struct LoggingLayer { /* ... */ }
// 不好:承担太多职责
struct MegaLayer { /* ... */ } // 做了所有事情2. 使用 Clone 实现可复制
#[derive(Clone)]
pub struct MyLayer;
impl<S> Layer<S> for MyLayer {
type Service = MyMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
MyMiddleware { inner }
}
}3. 正确实现 poll_ready
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// 确保内部服务准备好
self.inner.poll_ready(cx)
}4. 处理错误
fn call(&mut self, req: Request) -> Self::Future {
let mut inner = self.inner.clone();
Box::pin(async move {
inner.call(req).await
})
}5. 中间件顺序
ServiceBuilder::new()
// 1. 先处理请求 ID(最先,记录所有请求)
.layer(RequestIdLayer)
// 2. 日志记录
.layer(LoggingLayer)
// 3. 认证(可能拒绝请求)
.layer(AuthLayer)
// 4. 限流(在认证之后)
.rate_limit(100, Duration::from_secs(1))
// 5. 目标服务
.service(MyService)总结
Tower 是 Rust 网络编程的核心基础设施:
核心抽象:
- Service:
async fn(Request) -> Result<Response, Error> - Layer:装饰 Service 产生新的 Service
关键特性:
- 模块化:每个中间件独立可组合
- 协议无关:适用于各种网络协议
- 类型安全:充分利用 Rust 类型系统
- 丰富的生态:tower-http、tower-go、tower-grpc 等
实际应用:
- 认证和授权
- 日志和追踪
- 限流和熔断
- 请求 ID 和调试
- CORS 和压缩
最佳实践:
- 保持中间件简单和单一职责
- 正确实现 poll_ready
- 注意中间件顺序
- 编写测试确保正确性
掌握 Tower,你将能够构建模块化、可维护的现代网络应用!
快乐编程,大家来 Rust! 🦀