Skip to content

JWTConfig

JWT extends for JSON Web Token and it can be used with any middleware at your desire that implements the BaseAuthMiddleware.

Tip

More information about JWT here.

Requirements

Esmerald uses pyjwt and passlib for this JWT integration. You can install by running:

$ pip install esmerald[jwt]

JWTConfig and application

To use the JWTConfig with a middleware.

from myapp.models import User

from esmerald import Esmerald, settings
from esmerald.config.jwt import JWTConfig
from esmerald.contrib.auth.edgy.middleware import JWTAuthMiddleware
from lilya.middleware import DefineMiddleware as LilyaMiddleware

jwt_config = JWTConfig(
    signing_key=settings.secret_key,
)

auth_middleware = LilyaMiddleware(JWTAuthMiddleware, config=jwt_config, user_model=User)

app = Esmerald(middleware=[auth_middleware])

Info

The example uses a supported JWTAuthMiddleware from Esmerald with Edgy ORM.

Parameters

All the parameters and defaults are available in the JWTConfig Reference.

JWTConfig and application settings

The JWTConfig can be done directly via application instantiation but also via settings.

from typing import TYPE_CHECKING, List

from esmerald import EsmeraldAPISettings
from esmerald.config.jwt import JWTConfig
from esmerald.contrib.auth.edgy.middleware import JWTAuthMiddleware
from lilya._internal._module_loading import import_string
from lilya.middleware import DefineMiddleware as LilyaMiddleware

if TYPE_CHECKING:
    from esmerald.types import Middleware


class CustomSettings(EsmeraldAPISettings):
    @property
    def jwt_config(self) -> JWTConfig:
        """
        A JWT object configuration to be passed to the application middleware
        """
        return JWTConfig(signing_key=self.secret_key, auth_header_types=["Bearer", "Token"])

    @property
    def middleware(self) -> List["Middleware"]:
        """
        Initial middlewares to be loaded on startup of the application.
        """
        return [
            LilyaMiddleware(
                JWTAuthMiddleware,
                config=self.jwt_config,
                user_model=import_string("myapp.models.User"),
            )
        ]

This will make sure you keep the settings clean, separated and without a bloated Esmerald instance.

Token model

Esmerald offers a pretty standard Token object that allows you to generate and decode tokens at ease.

from esmerald.security.jwt.token import Token

token = Token(exp=..., iat=..., sub=...)

The parameters are pretty standard from Python JOSE so you can feel comfortable with.

Generate a Token (encode)

The token offers simple and standard operations to interact with pyjwt.

from esmerald.security.jwt.token import Token
from esmerald.conf import settings

# Create the token model
token = Token(exp=..., iat=..., sub=...)

# Generate the JWT token
jwt_token = Token.encode(key=settings.secret_key, algorithm="HS256", **claims)

Decode a Token (decode)

The same decoding functionality is also provided.

from esmerald.security.jwt.token import Token
from esmerald.conf import settings

# Decodes the JWT token
jwt_token = Token.decode(token=..., key=settings.secret_key, algorithms=["HS256"])

The Token.decode returns a Token object.

Note

This functionality relies heavily on pyjwt but it is not mandatory to use it in any way. You are free to use any library that suits your unique needs. Esmerald only offers some examples and alternatives.

The claims

The **claims can be very useful mostly if you want to generate tokens for access and refresh. When using the claims you can simply pass any extra parameter that once decoded it will be available to you to manipulate.

The database integrations shows an example how to do this simple operations as examples but let us run a quick example.

We will be using a middleware and we will be generating a access_token and a refresh_token for a given API.

Let us assume a few things.

  • There is a User model inside an accounts/models.py.
  • The controllers are placed inside accounts/controllers.py.
  • We will be subclassing an exising middleware to make it easier.
  • The middleware is placed inside an accounts/middleware.py
  • The JWTConfig is inside a settings file already configured.
  • The Token class will be subclassed to allow extra parameters like token_type.
  • There is an accounts/backends.py file containing operations for the authentication and refreshing of the token.

The token class

You can and should subclass the Token class if you want to add extra parameters for your own purposes, for instance to have an extra token_type that indicates the if the token is access or refresh or whatever you need to have that can be used in and for your claims.

Something like this:

from typing import Union

from esmerald.security.jwt.token import Token as EsmeraldToken


class Token(EsmeraldToken):
    token_type: Union[str, None] = None

This will be particularly useful in the next steps as we will be using the token_type to distinguish between the access_token and the refresh_token.

The middleware

Let us use an existing middleware from the contrib to make it easier. This middleware will serve only for access the APIs and not for refreshing the token.

Tip

Feel free to build your own, this is for explanation purposes.

from jose import JWSError, JWTError

from esmerald.conf import settings
from esmerald.contrib.auth.edgy.middleware import JWTAuthMiddleware as EsmeraldMiddleware
from esmerald.exceptions import AuthenticationError, NotAuthorized
from esmerald.middleware.authentication import AuthResult
from esmerald.security.jwt.token import Token
from lilya._internal._connection import Connection
from lilya._internal._module_loading import import_string
from lilya.middleware import DefineMiddleware as LilyaMiddleware


class JWTAuthMiddleware(EsmeraldMiddleware):
    def get_token(self, request: Connection) -> Token:
        """
        Gets the token from the headers.
        """
        token = request.headers.get(self.config.authorization_header, None)

        if not token or token is None:
            raise NotAuthorized(detail="Token not found in the request header")

        token_partition = token.partition(" ")
        token_type = token_partition[0]
        auth_token = token_partition[-1]

        if token_type not in self.config.auth_header_types:
            raise NotAuthorized(detail=f"'{token_type}' is not an authorized header.")

        try:
            token = Token.decode(
                token=auth_token,
                key=self.config.signing_key,
                algorithms=[self.config.algorithm],
            )
        except (JWSError, JWTError) as e:
            raise AuthenticationError(str(e)) from e
        return token

    async def authenticate(self, request: Connection) -> AuthResult:
        """
        Retrieves the header default of the config, validates
        and returns the AuthResult.

        Raises Authentication error if invalid.
        """
        token: Token = self.get_token(request)

        if token.token_type == settings.jwt_config.refresh_token_name:
            raise NotAuthorized(detail="Refresh tokens cannot be used for operations.")

        user = await self.retrieve_user(token.sub)
        if not user:
            raise AuthenticationError("User not found.")
        return AuthResult(user=user)


# Middleware responsible from user accesses.
# This can be imported in any level of the application
AuthMiddleware = LilyaMiddleware(
    JWTAuthMiddleware,
    config=settings.jwt_config,
    user_model=import_string("accounts.models.User"),
)

There is a lot here happening but basically what are we doing?

  • Checking for token in the header.
  • Checking if the token_type is of access_token (default name from the JWTConfig and can be whatever you want) and raises an exception if it's not access_token.
  • Returns the AuthResult object with the details of the retrieved user.

The middleware also contains a wrapper called AuthMiddleware. This will be used later on in the views of the user.

Backend

This is where we will place the logic that handles the authentication and refreshing of the token.

Warning

The example below uses Edgy from the contrib to make it simpler to explain and query.

from datetime import datetime
from typing import Any, Dict

from accounts.models import User
from edgy.exceptions import ObjectNotFound
from jose import JWSError, JWTError
from pydantic import BaseModel

from esmerald.conf import settings
from esmerald.exceptions import AuthenticationError, NotAuthorized
from esmerald.security.jwt.token import Token


class AccessToken(BaseModel):
    access_token: str


class RefreshToken(BaseModel):
    """
    Model used only to refresh
    """

    refresh_token: str


class TokenAccess(AccessToken, RefreshToken):
    """
    Model representation of an access token.
    """

    ...


class LoginIn(BaseModel):
    email: str
    password: str


class BackendAuthentication(BaseModel):
    """
    Utility class that helps with the authentication process.
    """

    email: str
    password: str

    async def authenticate(self) -> Dict[str, str]:
        """Authenticates a user and returns
        a dictionary containing the `access_token` and `refresh_token`
        in the format of:

        {
            "access_token": ...,
            "refresh_token": ...
        }
        """
        try:
            user: User = await User.query.get(email=self.email)
        except ObjectNotFound:
            # Run the default password hasher once to reduce the timing
            # difference between an existing and a nonexistent user.
            await User().set_password(self.password)
        else:
            is_password_valid = await user.check_password(self.password)
            if is_password_valid and self.is_user_able_to_authenticate(user):
                # The lifetime of a token should be short, let us make 5 minutes.
                # You can use also the access_token_lifetime from the JWT config directly
                access_time = datetime.now() + settings.jwt_config.access_token_lifetime
                refresh_time = datetime.now() + settings.jwt_config.refresh_token_lifetime
                access_token = TokenAccess(
                    access_token=self.generate_user_token(
                        user,
                        time=access_time,
                        token_type=settings.jwt_config.access_token_name,  # 'access_token'
                    ),
                    refresh_token=self.generate_user_token(
                        user,
                        time=refresh_time,
                        token_type=settings.jwt_config.refresh_token_name,  # 'refresh_token'
                    ),
                )
                return access_token.model_dump()
            else:
                raise NotAuthorized(detail="Invalid credentials.")

    def is_user_able_to_authenticate(self, user: Any):
        """
        Reject users with is_active=False. Custom user models that don't have
        that attribute are allowed.
        """
        return getattr(user, "is_active", True)

    def generate_user_token(self, user: User, token_type: str, time: datetime = None):
        """
        Generates the JWT token for the authenticated user.
        """
        if not time:
            later = datetime.now() + settings.jwt_config.access_token_lifetime
        else:
            later = time

        token = Token(sub=str(user.id), exp=later)
        return token.encode(
            key=settings.jwt_config.signing_key,
            algorithm=settings.jwt_config.algorithm,
            token_type=token_type,
        )


class RefreshAuthentication(BaseModel):
    """
    Refreshes the access token given a refresh token of a given user.

    This object does not perform any DB action, instead, uses the existing refresh
    token to generate a new access.
    """

    token: RefreshToken

    async def refresh(self) -> AccessToken:
        """
        Retrieves the header default of the config and validates against the decoding.
        Raises Authentication error if invalid.
        """
        token = self.token.refresh_token

        try:
            token = Token.decode(
                token=token,
                key=settings.jwt_config.signing_key,
                algorithms=[settings.jwt_config.algorithm],
            )
        except (JWSError, JWTError) as e:
            raise AuthenticationError(str(e)) from e

        if token.token_type != settings.jwt_config.refresh_token_name:
            raise NotAuthorized(detail="Only refresh tokens are allowed.")

        # Apply the maximum living time
        expiry_date = datetime.now() + settings.jwt_config.access_token_lifetime

        # New token object
        new_token = Token(sub=token.sub, exp=expiry_date)

        # Encode the token
        access_token = new_token.encode(
            key=settings.jwt_config.signing_key,
            algorithm=settings.jwt_config.algorithm,
            token_type=settings.jwt_config.access_token_name,
        )

        return AccessToken(access_token=access_token)

Quite a lot of code, right? Well yes but it is mostly logic used for authenticate and refresh the existing token.

Did you see the BackendAuthentication and the RefreshAuthentication? Now this will be very useful.

The RefreshAuthentication is where we validate the refresh_token. Remember the middleware only allowing access_token? Well this is the reason why. The middleware will be used only for APIs that require authentication and the refresh_token, usually by design, should only do that, refresh and nothing else.

Since the refresh token already contains all the infomation needed to generate the new access token, there is no need to query the user again and do the whole process.

The way the refresh token was designed and passed in the claims also allows us to directly use it and generate the new access_token.

Remember the Token we subclassed? This is where the token_type plays the role in dictating which type of token is being validated and sent.

The access_token is sent via headers as it should and the refresh_token is sent via POST.

The controllers

Now it is time to assemble everything in the controllers where we will have:

  • /auth/create - Endpoint to create the users.
  • /auth/signin - Login endpoint for the user.
  • /auth/users - Endpoint that returns a list of all users.
  • /auth/refresh-access - The endpoint responsible only for refreshing the access_token.

In the end, something like this:

from typing import Any, Dict, List, Union

from accounts.backends import (
    AccessToken,
    BackendAuthentication,
    RefreshAuthentication,
    RefreshToken,
    TokenAccess,
)
from accounts.middleware import AuthMiddleware
from accounts.models import User
from accounts.v1.schemas import LoginIn, UserIn, UserOut
from pydantic import BaseModel, EmailStr

from esmerald import APIView, JSONResponse, get, post, status
from esmerald.openapi.datastructures import OpenAPIResponse
from esmerald.openapi.security.http import Bearer


class UserIn(BaseModel):
    """
    Model responsible for the creation of a User.
    """

    first_name: str
    last_name: str
    email: str
    password: str
    username: str


class UserOut(BaseModel):
    """
    Representation of the list of users.
    """

    id: int
    first_name: str
    last_name: str
    email: str
    username: str
    is_staff: bool
    is_active: bool
    is_superuser: bool
    is_verified: bool


class LoginIn(BaseModel):
    """
    Details needed for a login of a user in the system.
    """

    email: EmailStr
    password: str


class ErrorDetail(BaseModel):
    """
    Used by the OpenAPI to describe the error
    exposing the details.
    """

    detail: str


class UserAPIView(APIView):
    tags: List[str] = ["User and Access"]
    security: List[Any] = [Bearer]

    @get(
        "/users",
        summary="Gets all the users",
        responses={201: OpenAPIResponse(model=[UserOut])},
        middleware=[AuthMiddleware],
    )
    async def get_all(self) -> List[UserOut]:
        return await User.query.all()

    @post(
        path="/create",
        summary="Creates a user in the system",
        responses={400: OpenAPIResponse(model=ErrorDetail)},
    )
    async def create_user(self, data: UserIn) -> None:
        """
        Creates a user in the system and returns the default 201
        status code.
        """
        user_data = data.model_dump()
        user_data.update({"is_verified": False})
        await User.query.create(**user_data)

    @post(
        path="/signin",
        summary="Login API and returns a JWT Token.",
        status_code=status.HTTP_200_OK,
        responses={
            200: OpenAPIResponse(model=TokenAccess),
            401: OpenAPIResponse(model=ErrorDetail),
        },
    )
    async def signin(self, data: LoginIn) -> JSONResponse:
        """
        Login a user and returns a JWT token, else raises ValueError
        """
        auth = BackendAuthentication(email=data.email, password=data.password)
        access_tokens: Dict[str, str] = await auth.authenticate()
        return JSONResponse(access_tokens)

    @post(
        path="/refresh-access",
        summary="Refreshes the access token",
        description="When a token expires, a new access token must be generated from the refresh token previously provided. The refresh token must be just that, a refresh and it should only return a new access token and nothing else",
        status_code=status.HTTP_200_OK,
        responses={
            200: OpenAPIResponse(model=AccessToken),
            401: OpenAPIResponse(model=ErrorDetail),
        },
    )
    async def refresh_token(self, payload: RefreshToken) -> AccessToken:
        authentication = RefreshAuthentication(token=payload)
        access_token: AccessToken = await authentication.refresh()
        return access_token

As you can see, we now assembled everything. The /auth/users requires authentication to have access and the /auth/refresh-access will make sure that will return only the new access_token.