Functioning device manager with renew,revoke, updated model for cert id

This commit is contained in:
2025-10-30 23:00:57 +01:00
parent 7446e9b4ac
commit 4df582b330
13 changed files with 468 additions and 94 deletions

View File

@@ -0,0 +1,270 @@
import datetime
import logging
from db_models import Device, DeviceCertificate # SQLAlchemy ORM models
from fastapi import FastAPI, HTTPException
from cert_manager import CertificateManager
from database import get_db_context
from models import DeviceRegistrationRequest, DeviceRegistrationResponse, DeviceResponse
logger = logging.getLogger(__name__)
app = FastAPI()
cert_manager = CertificateManager()
@app.get("/")
async def hello():
return {"Hello": "World"}
@app.post("/devices/register")
async def register_device(
request: DeviceRegistrationRequest,
) -> DeviceRegistrationResponse:
"""
Register a new device and issue an X.509 certificate.
"""
try:
response = cert_manager.register_device(
name=request.name,
location=request.location,
)
with get_db_context() as db:
device = Device(
id=response.device_id,
name=request.name,
location=request.location,
created_at=datetime.datetime.now(datetime.UTC),
)
db.add(device)
device_cert = DeviceCertificate(
id =response.certificate_id,
device_id=response.device_id,
certificate_pem=response.certificate_pem,
private_key_pem=response.private_key_pem,
issued_at=datetime.datetime.now(datetime.UTC),
expires_at=response.expires_at,
)
db.add(device_cert)
except Exception as e:
logger.error(
f"Failed to register device {request.name}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=500, detail="Failed to register device. Please try again."
) from e
return response
@app.get("/ca_certificate")
async def get_ca_certificate() -> str:
"""
Retrieve the CA certificate in PEM format.
"""
try:
ca_cert_pem = cert_manager.get_ca_certificate_pem()
return ca_cert_pem
except Exception as e:
logger.error(f"Failed to retrieve CA certificate: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to retrieve CA certificate."
) from e
@app.get("/devices/{device_id}")
async def get_device(device_id: str) -> DeviceResponse:
"""
Retrieve device information by ID.
"""
try:
with get_db_context() as db:
device = db.query(Device).filter(Device.id == device_id).first()
if not device:
raise HTTPException(status_code=404, detail="Device not found")
return Device(
id=device.id,
name=device.name,
location=device.location,
created_at=device.created_at,
)
except Exception as e:
logger.error(f"Failed to retrieve device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to retrieve device information."
) from e
@app.get("/devices/")
async def list_devices() -> list[DeviceResponse]:
"""
List all registered devices.
"""
try:
with get_db_context() as db:
devices = db.query(Device).all()
return [
DeviceResponse(
id=device.id,
name=device.name,
location=device.location,
created_at=device.created_at,
)
for device in devices
]
except Exception as e:
logger.error(f"Failed to list devices: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to list devices."
) from e
@app.post("/devices/{device_id}/revoke")
async def revoke_device_certificate(device_id: str):
"""
Revoke a device's certificate by:
1. Marking it as revoked in the database
2. Adding it to the Certificate Revocation List (CRL)
"""
try:
with get_db_context() as db:
device_cert = (
db.query(DeviceCertificate)
.filter(DeviceCertificate.device_id == device_id)
.first()
)
if not device_cert:
raise HTTPException(status_code=404, detail="Device certificate not found")
if device_cert.revoked_at:
raise HTTPException(status_code=400, detail="Certificate already revoked")
cert_manager.revoke_certificate(device_cert.certificate_pem)
device_cert.revoked_at = datetime.datetime.now(datetime.UTC)
db.commit()
logger.info(f"Successfully revoked certificate for device {device_id}")
return {
"device_id": device_id,
"revoked_at": device_cert.revoked_at.isoformat(),
"message": "Certificate revoked successfully"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to revoke device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to revoke device certificate."
) from e
@app.get("/crl")
async def get_crl():
"""
Get the Certificate Revocation List (CRL) in PEM format.
Mosquitto and other MQTT clients can check this to validate certificates.
"""
try:
crl_pem = cert_manager.get_crl_pem()
if not crl_pem:
return {"message": "No certificates have been revoked yet"}
return {"crl_pem": crl_pem}
except Exception as e:
logger.error(f"Failed to retrieve CRL: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to retrieve CRL."
) from e
@app.post("/devices/{device_id}/renew")
async def renew_certificate(device_id: str):
"""
Renew a device's certificate by issuing a new one and revoking the old one.
This endpoint:
1. Retrieves the current certificate from DB
2. Generates a new certificate with new keys
3. Revokes the old certificate (adds to CRL)
4. Updates the database with the new certificate
5. Returns the new credentials
"""
try:
with get_db_context() as db:
# Get current certificate
device_cert = (
db.query(DeviceCertificate)
.filter(
DeviceCertificate.device_id == device_id,
# DeviceCertificate.revoked_at.is_(None)
)
.first()
)
if not device_cert:
raise HTTPException(
status_code=404,
detail="No active certificate found for device"
)
# Check if certificate is about to expire (optional warning)
days_until_expiry = (device_cert.expires_at - datetime.datetime.now(datetime.UTC)).days
if days_until_expiry > 30:
logger.warning(
f"Certificate for device {device_id} renewed early "
f"({days_until_expiry} days remaining)"
)
# Revoke old certificate and add to CRL
cert_manager.revoke_certificate(device_cert.certificate_pem)
device_cert.revoked_at = datetime.datetime.now(datetime.UTC)
# Generate new certificate with new keys
new_cert_pem, new_key_pem = cert_manager.renew_certificate(
current_cert_pem=device_cert.certificate_pem,
validity_days=365,
key_size=4096
)
# Extract certificate ID (serial number) from the new certificate
from cryptography import x509
new_cert = x509.load_pem_x509_certificate(new_cert_pem)
new_cert_id = format(new_cert.serial_number, 'x')
# Create new certificate record in DB
now = datetime.datetime.now(datetime.UTC)
new_device_cert = DeviceCertificate(
id=new_cert_id,
device_id=device_id,
certificate_pem=new_cert_pem.decode("utf-8"),
private_key_pem=new_key_pem.decode("utf-8"),
issued_at=now,
expires_at=now + datetime.timedelta(days=365),
)
db.add(new_device_cert)
db.commit()
logger.info(f"Successfully renewed certificate for device {device_id}")
return DeviceRegistrationResponse(
certificate_id=new_cert_id,
device_id=device_id,
ca_certificate_pem=cert_manager.get_ca_certificate_pem(),
certificate_pem=new_device_cert.certificate_pem,
private_key_pem=new_device_cert.private_key_pem,
expires_at=new_device_cert.expires_at,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to renew certificate for device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to renew device certificate."
) from e

View File

@@ -0,0 +1,240 @@
import datetime
from pathlib import Path
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from nanoid import generate
from config import config
from models import DeviceRegistrationResponse
lowercase_numbers = "abcdefghijklmnopqrstuvwxyz0123456789"
class CertificateManager:
"""Manages device certificate generation and handling"""
def __init__(self):
self.ca_cert: x509.Certificate = self.load_ca_certificate(config.CA_CERT_PATH)
self.ca_key: rsa.RSAPrivateKey = self.load_ca_private_key(config.CA_KEY_PATH)
self.ca_cert_pem: bytes = self.ca_cert.public_bytes(serialization.Encoding.PEM)
def generate_device_id(self) -> str:
"""Generate a unique device ID using nanoid."""
return generate(alphabet=lowercase_numbers, size=config.DEVICE_ID_LENGTH)
def load_ca_certificate(self, ca_cert_path: str) -> x509.Certificate:
"""Load a CA certificate from file."""
with open(ca_cert_path, "rb") as f:
ca_data = f.read()
ca_cert = x509.load_pem_x509_certificate(ca_data)
return ca_cert
def load_ca_private_key(self, ca_key_path: str, password: bytes = None) -> rsa.RSAPrivateKey:
"""Load a CA private key from file."""
from cryptography.hazmat.primitives import serialization
with open(ca_key_path, "rb") as f:
key_data = f.read()
ca_key = serialization.load_pem_private_key(key_data, password=password)
return ca_key
def generate_device_key(self, key_size: int = 4096) -> rsa.RSAPrivateKey:
"""Generate an RSA private key for a device."""
return rsa.generate_private_key(public_exponent=65537, key_size=key_size)
def generate_device_certificate(
self,
device_id: str,
ca_cert: x509.Certificate,
ca_key: rsa.RSAPrivateKey,
device_key: rsa.RSAPrivateKey,
validity_days: int = 365,
key_size: int = 4096,
) -> tuple[bytes, bytes]:
"""Generate an X.509 certificate for a device signed by the CA."""
subject = x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, device_id),
]
)
issuer = ca_cert.subject
now = datetime.datetime.now(datetime.UTC)
device_cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(device_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + datetime.timedelta(days=validity_days))
.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
)
.sign(private_key=ca_key, algorithm=hashes.SHA256())
)
cert_pem = device_cert.public_bytes(serialization.Encoding.PEM)
key_pem = device_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
return cert_pem, key_pem
def create_device_credentials(
self, device_id: str, validity_days: int = 365, key_size: int = 4096
) -> dict:
"""Create device credentials: private key and signed certificate.
Returns:
dict with certificate_id, device_id, certificate_pem, private_key_pem, ca_certificate_pem, expires_at
"""
device_key = self.generate_device_key(key_size=key_size)
cert_pem, key_pem = self.generate_device_certificate(
device_id=device_id,
ca_cert=self.ca_cert,
ca_key=self.ca_key,
device_key=device_key,
validity_days=validity_days,
key_size=key_size,
)
# Extract serial number from certificate to use as ID
cert = x509.load_pem_x509_certificate(cert_pem)
cert_id = format(cert.serial_number, 'x') # Hex string of serial number
expires_at = datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=validity_days)
return {
"certificate_id": cert_id,
"device_id": device_id,
"certificate_pem": cert_pem,
"private_key_pem": key_pem,
"ca_certificate_pem": self.ca_cert_pem,
"expires_at": expires_at,
}
def register_device(self, name: str, location: str | None = None) -> DeviceRegistrationResponse:
"""Register a new device and generate its credentials.
Returns:
DeviceRegistrationResponse
"""
device_id = self.generate_device_id()
credentials = self.create_device_credentials(device_id=device_id)
return DeviceRegistrationResponse(
certificate_id=credentials["certificate_id"],
device_id=credentials["device_id"],
ca_certificate_pem=credentials["ca_certificate_pem"].decode("utf-8"),
certificate_pem=credentials["certificate_pem"].decode("utf-8"),
private_key_pem=credentials["private_key_pem"].decode("utf-8"),
expires_at=credentials["expires_at"],
)
def get_ca_certificate_pem(self) -> str:
"""Get the CA certificate in PEM format as a string."""
return self.ca_cert_pem.decode("utf-8")
def revoke_certificate(
self, certificate_pem: str, reason: x509.ReasonFlags = x509.ReasonFlags.unspecified
) -> None:
"""
Revoke a device certificate by adding it to the CRL.
Args:
certificate_pem: PEM-encoded certificate to revoke
reason: Revocation reason (default: unspecified)
"""
# Load the certificate to get serial number
cert = x509.load_pem_x509_certificate(certificate_pem.encode())
# Load existing CRL or create new one
crl_path = Path(config.CRL_PATH)
revoked_certs = []
if crl_path.exists():
with open(crl_path, "rb") as f:
existing_crl = x509.load_pem_x509_crl(f.read())
# Copy existing revoked certificates
revoked_certs = list(existing_crl)
# Add the new revoked certificate
revoked_cert = (
x509.RevokedCertificateBuilder()
.serial_number(cert.serial_number)
.revocation_date(datetime.datetime.now(datetime.UTC))
.add_extension(
x509.CRLReason(reason),
critical=False,
)
.build()
)
revoked_certs.append(revoked_cert)
# Build new CRL with all revoked certificates
crl_builder = x509.CertificateRevocationListBuilder()
crl_builder = crl_builder.issuer_name(self.ca_cert.subject)
crl_builder = crl_builder.last_update(datetime.datetime.now(datetime.UTC))
crl_builder = crl_builder.next_update(
datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=30)
)
for revoked in revoked_certs:
crl_builder = crl_builder.add_revoked_certificate(revoked)
# Sign the CRL with CA key
crl = crl_builder.sign(private_key=self.ca_key, algorithm=hashes.SHA256())
# Write CRL to file
crl_path.parent.mkdir(parents=True, exist_ok=True)
with open(crl_path, "wb") as f:
f.write(crl.public_bytes(serialization.Encoding.PEM))
def get_crl_pem(self) -> str | None:
"""Get the current CRL in PEM format."""
crl_path = Path(config.CRL_PATH)
if not crl_path.exists():
return None
with open(crl_path, "rb") as f:
return f.read().decode("utf-8")
def renew_certificate(
self,
current_cert_pem: str,
validity_days: int = 365,
key_size: int = 4096,
) -> tuple[bytes, bytes]:
"""Renew a device certificate before expiration.
Args:
current_cert_pem: PEM-encoded current certificate
validity_days: Validity period for new certificate
key_size: Key size for new device key
Returns:
tuple of (new_cert_pem, new_key_pem)
"""
# Load current certificate
current_cert = x509.load_pem_x509_certificate(current_cert_pem.encode())
device_id = current_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
# Generate new device key
new_device_key = self.generate_device_key(key_size=key_size)
# Generate new device certificate
new_cert_pem, new_key_pem = self.generate_device_certificate(
device_id=device_id,
ca_cert=self.ca_cert,
ca_key=self.ca_key,
device_key=new_device_key,
validity_days=validity_days,
key_size=key_size,
)
return new_cert_pem, new_key_pem

View File

@@ -0,0 +1,33 @@
import os
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
class Config:
"""Configuration settings for the Device Manager service."""
DATABASE_URL = os.getenv("DATABASE_URL")
SERVICE_DIR = Path(__file__).parent
CERTS_DIR = SERVICE_DIR / "certs"
CA_CERT_PATH = os.getenv("CA_CERT_PATH", str(CERTS_DIR / "ca.crt"))
CA_KEY_PATH = os.getenv("CA_KEY_PATH", str(CERTS_DIR / "ca.key"))
CRL_PATH = os.getenv("CRL_PATH", str(CERTS_DIR / "ca.crl"))
# Certificate settings
CERT_VALIDITY_DAYS = int(os.getenv("CERT_VALIDITY_DAYS", "365"))
CERT_KEY_SIZE = int(os.getenv("CERT_KEY_SIZE", "4096"))
# Device ID settings
DEVICE_ID_LENGTH = int(os.getenv("DEVICE_ID_LENGTH", "8"))
# Service settings
SERVICE_HOST = os.getenv("DEVICE_MANAGER_HOST", "0.0.0.0")
SERVICE_PORT = int(os.getenv("DEVICE_MANAGER_PORT", "8000"))
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
config = Config()

View File

@@ -0,0 +1,83 @@
"""
Database session management for FastAPI with SQLAlchemy.
Uses dependency injection pattern for database sessions.
"""
from collections.abc import Generator
from contextlib import contextmanager
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from config import config
# Create engine with connection pooling
engine = create_engine(
config.DATABASE_URL,
pool_pre_ping=True,
pool_size=5,
max_overflow=10,
echo=False,
)
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
)
def get_db() -> Generator[Session]:
"""
FastAPI dependency that provides a database session.
Usage in endpoints:
@app.post("/devices")
async def create_device(db: Session = Depends(get_db)):
device = Device(...)
db.add(device)
db.commit()
return device
The session is automatically closed after the request completes.
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
@contextmanager
def get_db_context():
"""
Context manager for database sessions outside of FastAPI endpoints.
Usage:
with get_db_context() as db:
device = db.query(Device).first()
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
def check_db_connection() -> bool:
"""
Check if database connection is working.
Returns:
True if connection successful, False otherwise
"""
try:
with engine.connect() as conn:
conn.execute("SELECT 1")
return True
except Exception as e:
print(f"Database connection failed: {e}")
return False

View File

@@ -0,0 +1,45 @@
"""
SQLAlchemy ORM models for device manager service.
These models mirror the database schema defined in db_migrations.
Kept separate to make the service independent.
"""
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
class Device(Base):
"""IoT devices registered in the system."""
__tablename__ = "devices"
id = Column(Text, primary_key=True)
name = Column(Text, nullable=False)
location = Column(Text)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
def __repr__(self):
return f"<Device(id={self.id}, name={self.name})>"
class DeviceCertificate(Base):
"""X.509 certificates issued to devices for mTLS authentication."""
__tablename__ = "device_certificates"
id = Column(Text, primary_key=True)
device_id = Column(
Text, ForeignKey("devices.id", ondelete="CASCADE"), primary_key=True
)
certificate_pem = Column(Text, nullable=False)
private_key_pem = Column(Text)
issued_at = Column(DateTime(timezone=True), nullable=False)
expires_at = Column(DateTime(timezone=True), nullable=False)
revoked_at = Column(DateTime(timezone=True))
def __repr__(self):
return f"<DeviceCertificate(device_id={self.device_id}, expires={self.expires_at})>"

View File

@@ -0,0 +1,27 @@
import datetime
from pydantic import BaseModel
class DeviceRegistrationRequest(BaseModel):
"""Request model for registering a new device."""
name: str
location: str | None = None
class DeviceRegistrationResponse(BaseModel):
"""Response model after registering a new device."""
certificate_id: str
device_id: str
ca_certificate_pem: str
certificate_pem: str
private_key_pem: str
expires_at: datetime.datetime
class DeviceResponse(BaseModel):
"""Response model for device information."""
id: str
name: str
location: str | None = None
created_at: datetime.datetime