Introduction

more-di-axum is a crate which provides dependency injection (DI) extensions for the axum web framework. Any trait or struct can be used as an injected service.

axum provides a dependency injection example; however, it is very limited. axum does not have nor provide a fully-fledged DI framework. State in axum must support Clone and is copied many times within the pipeline. In particular, the native State model does not intrinsically support a scoped (e.g. per-request) lifetime. This is a limitation of Clone. A state can be wrapped in Arc as a singleton; otherwise, it is transient. more-di-axum brings full support for various lifetimes by layering over the more-di library and makes them ergonomic to consume within axum. Since more-di is a complete DI framework, swapping out dependency registration in different contexts, such as testing, is trivial.

Contributing

more-di-axum is free and open source. You can find the source code on GitHub and issues and feature requests can be posted on the GitHub issue tracker. more-di-axum relies on the community to fix bugs and add features: if you'd like to contribute, please read the CONTRIBUTING guide and consider opening a pull request.

License

This project is licensed under the MIT license.

Getting Started

The simplest way to get started is to install the crate using the default features.

cargo add more-di-axum

Example

The following example reworks the axum dependency injection example with full dependency injection support using more-di.

use axum::{
    async_trait,
    extract::Path,
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, Router,
};
use di::*;
use di_axum::*;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::net::TcpListener;
use uuid::Uuid;

#[tokio::main]
async fn main() {
    let provider = ServiceCollection::new()
        .add(ExampleUserRepo::singleton())
        .build_provider()
        .unwrap();

    let app = Router::new()
        .route("/users/:id", get(one_user))
        .route("/users", post(new_user))
        .with_provider(provider);

    let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap();
    
    println!("listening on: {}", listener.local_addr().unwrap());

    axum::serve(listener, app).await.unwrap();
}

async fn one_user(
    Path(id): Path<Uuid>,
    Inject(repo): Inject<dyn UserRepo + Send + Sync>,
) -> Result<Json<User>, AppError> {
    let user = repo.find(user_id).await?;
    Ok(user.into())
}

async fn new_user(
    Inject(repo): Inject<dyn UserRepo + Send + Sync>,
    Json(params): Json<CreateUser>,
) -> Result<Json<User>, AppError> {
    let user = repo.create(params).await?;
    Ok(user.into())
}

#[derive(Debug)]
enum UserRepoError {
    #[allow(dead_code)]
    NotFound,
    #[allow(dead_code)]
    InvalidUserName,
}

enum AppError {
    UserRepo(UserRepoError),
}

#[async_trait]
trait UserRepo {
    async fn find(&self, user_id: Uuid) -> Result<User, UserRepoError>;
    async fn create(&self, params: CreateUser) -> Result<User, UserRepoError>;
}

#[derive(Debug, Serialize)]
struct User {
    id: Uuid,
    username: String,
}

#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct CreateUser {
    username: String,
}

#[injectable(UserRepo + Send + Sync)]
struct ExampleUserRepo;

#[async_trait]
impl UserRepo for ExampleUserRepo {
    async fn find(&self, _user_id: Uuid) -> Result<User, UserRepoError> {
        unimplemented!()
    }

    async fn create(&self, _params: CreateUser) -> Result<User, UserRepoError> {
        unimplemented!()
    }
}

impl From<UserRepoError> for AppError {
    fn from(inner: UserRepoError) -> Self {
        AppError::UserRepo(inner)
    }
}

impl IntoResponse for AppError {
    fn into_response(self) -> Response {
        let (status, error_message) = match self {
            AppError::UserRepo(UserRepoError::NotFound) => {
                (StatusCode::NOT_FOUND, "User not found")
            }
            AppError::UserRepo(UserRepoError::InvalidUserName) => {
                (StatusCode::UNPROCESSABLE_ENTITY, "Invalid user name")
            }
        };

        let body = Json(json!({
            "error": error_message,
        }));

        (status, body).into_response()
    }
}

Service Registration

axum handlers execute in an asynchronous context. This requires that an injected service must be thread-safe. axum imposes that an such service must implement Send and Sync. Most structures will already satisfy this requirement and is generated by the compiler. If it doesn't, then the struct will have to be wrapped by another struct that satisfies this requirement. A trait, on the other hand, has several options and depends on the trait definition itself. The method you chose to use is largely based on your preference.

Thread-Safe Trait

A trait declares it is thread-safe if it requires implementing Send and Sync.

trait Service: Send + Sync {}

#[injectable(Service)]
struct ServiceImpl;

impl Service for ServiceImpl {}

async fn handler(Inject(service): Inject<dyn Service>) {}

Multiple Trait Implementation

If the original trait does not declare thread safety with Send and Sync, then a struct implementation can directly specify that it is thread-safe.

trait Service {}

#[injectable(Service + Send + Sync)]
struct ServiceImpl;

impl Service for ServiceImpl {}

async fn handler(Inject(service): Inject<dyn Service + Send + Sync>) {}

Trait Unification

If the original trait does not declare thread safety with Send and Sync, another alterative is to unify the trait with Send and Sync in a new trait. You might chose this approach for better usage ergonomics.

trait Service {}

trait ServiceAsync: Service + Send + Sync {}

#[injectable(ServiceAsync)]
struct ServiceImpl;

impl Service for ServiceImpl {}
impl ServiceAsync for ServiceImpl {}

async fn handler(Inject(service): Inject<dyn ServiceAsync>) {}

Service Resolution

Services are resolved and injected using the functions provided by ServiceProvider. A new scope is created during each HTTP request before the handler is executed.

ExtactorFunction
TryInjectget
TryInjectMutget_mut
TryInjectWithKeyget_by_key
InjectWithKeyMutget_by_key_mut
InjectAllget_all
InjectAllMutget_all_mut
InjectAllWithKeyget_all_by_key
InjectWithKeyMutget_all_by_key_mut
Injectget_required
InjectMutget_required_mut
InjectWithKeyget_required_by_key
InjectWithKeyMutget_required_by_key_mut

If resolution fails, the HTTP request will short-circuit with HTTP status code 500 - Internal Server Error.

Best Practices

To make it easy to test an application, it is recommended that you expose a function that configures the default set of services. This will make it simple to use the same default configuration as the application and replace only the parts that are necessary for testing.

If a service can be replaced, then it should be registered using try_add. A service can still be replaced after it has been registered, but try_add will skip the process altogether if the service has already been registered.

use crate::*;
use axum::{
    async_trait,
    extract::Path,
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, Router,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::net::TcpListener;
use uuid::Uuid;

// provide a function that can be called with the expected set of services
fn add_default_services(services: &mut ServiceCollection) {
    services.try_add(ExampleUserRepo::scoped());
}

// provide a function that can build a router representing the application
fn build_app(services: ServiceCollection) -> Router {
    Router::new()
        .route("/users/:id", get(one_user))
        .route("/users", post(new_user))
        .with_provider(services.build_provider().unwrap())
}

#[tokio::main]
async fn main() {
    let mut services = ServiceCollection::new();

    add_default_services(&mut services);

    let app = build_app(services);
    let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap();
    
    println!("listening on: {}", listener.local_addr().unwrap());

    axum::serve(listener, app).await.unwrap();
}

#[async_trait]
trait UserRepo {
    async fn find(&self, user_id: Uuid) -> Result<User, UserRepoError>;
    async fn create(&self, params: CreateUser) -> Result<User, UserRepoError>;
}

#[injectable(UserRepo + Send + Sync)]
struct ExampleUserRepo;

#[async_trait]
impl UserRepo for ExampleUserRepo {
    async fn find(&self, _user_id: Uuid) -> Result<User, UserRepoError> {
        unimplemented!()
    }

    async fn create(&self, _params: CreateUser) -> Result<User, UserRepoError> {
        unimplemented!()
    }
}

async fn one_user(
    Path(id): Path<Uuid>,
    Inject(repo): Inject<dyn UserRepo + Send + Sync>,
) -> Result<Json<User>, AppError> {
    let user = repo.find(user_id).await?;
    Ok(user.into())
}

async fn new_user(
    Inject(repo): Inject<dyn UserRepo + Send + Sync>,
    Json(params): Json<CreateUser>,
) -> Result<Json<User>, AppError> {
    let user = repo.create(params).await?;
    Ok(user.into())
}

You can now easily test your application by replacing on the only necessary services. In the following test, we:

  1. Create a TestUserRepo to simulate the behavior of a dyn UserRepo
  2. Register TestUserRepo in a new ServiceCollection
  3. Register all other default services
    • Since dyn UserRepo has been registered as TestUserRepo and try_add was used, the default registration is skipped
  4. Create a Router representing the application
  5. Run the application with a test client
  6. Invoke the HTTP GET method to return a single User
use super::*;
use crate::*;
use di::*;

#[tokio::test]
async fn get_should_return_user() {
    // arrange
    #[injectable(UserRepo + Send + Sync)]
    struct TestUserRepo;

    #[async_trait]
    impl UserRepo for TestUserRepo {
        async fn find(&self, _user_id: Uuid) -> Result<User, UserRepoError> {
            Ok(User::default())
        }

        async fn create(&self, _params: CreateUser) -> Result<User, UserRepoError> {
            unimplemented!()
        }
    }

    let mut services = ServiceCollection::new();

    services.add(TestUserRepo::scoped());
    add_default_services(services);

    let app = build_app(services);
    let client = TestClient::new(app);

    // act
    let response = client.get("/user/b51565c273c04bb4ac179232c90b20af").send().await;

    // assert
    assert_eq!(response.status(), StatusCode::OK);
}