Functioning device manager with renew,revoke, updated model for cert id

This commit is contained in:
2025-10-30 23:00:57 +01:00
parent 7446e9b4ac
commit 4df582b330
13 changed files with 468 additions and 94 deletions

View File

@@ -0,0 +1,270 @@
import datetime
import logging
from db_models import Device, DeviceCertificate # SQLAlchemy ORM models
from fastapi import FastAPI, HTTPException
from cert_manager import CertificateManager
from database import get_db_context
from models import DeviceRegistrationRequest, DeviceRegistrationResponse, DeviceResponse
logger = logging.getLogger(__name__)
app = FastAPI()
cert_manager = CertificateManager()
@app.get("/")
async def hello():
return {"Hello": "World"}
@app.post("/devices/register")
async def register_device(
request: DeviceRegistrationRequest,
) -> DeviceRegistrationResponse:
"""
Register a new device and issue an X.509 certificate.
"""
try:
response = cert_manager.register_device(
name=request.name,
location=request.location,
)
with get_db_context() as db:
device = Device(
id=response.device_id,
name=request.name,
location=request.location,
created_at=datetime.datetime.now(datetime.UTC),
)
db.add(device)
device_cert = DeviceCertificate(
id =response.certificate_id,
device_id=response.device_id,
certificate_pem=response.certificate_pem,
private_key_pem=response.private_key_pem,
issued_at=datetime.datetime.now(datetime.UTC),
expires_at=response.expires_at,
)
db.add(device_cert)
except Exception as e:
logger.error(
f"Failed to register device {request.name}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=500, detail="Failed to register device. Please try again."
) from e
return response
@app.get("/ca_certificate")
async def get_ca_certificate() -> str:
"""
Retrieve the CA certificate in PEM format.
"""
try:
ca_cert_pem = cert_manager.get_ca_certificate_pem()
return ca_cert_pem
except Exception as e:
logger.error(f"Failed to retrieve CA certificate: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to retrieve CA certificate."
) from e
@app.get("/devices/{device_id}")
async def get_device(device_id: str) -> DeviceResponse:
"""
Retrieve device information by ID.
"""
try:
with get_db_context() as db:
device = db.query(Device).filter(Device.id == device_id).first()
if not device:
raise HTTPException(status_code=404, detail="Device not found")
return Device(
id=device.id,
name=device.name,
location=device.location,
created_at=device.created_at,
)
except Exception as e:
logger.error(f"Failed to retrieve device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to retrieve device information."
) from e
@app.get("/devices/")
async def list_devices() -> list[DeviceResponse]:
"""
List all registered devices.
"""
try:
with get_db_context() as db:
devices = db.query(Device).all()
return [
DeviceResponse(
id=device.id,
name=device.name,
location=device.location,
created_at=device.created_at,
)
for device in devices
]
except Exception as e:
logger.error(f"Failed to list devices: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to list devices."
) from e
@app.post("/devices/{device_id}/revoke")
async def revoke_device_certificate(device_id: str):
"""
Revoke a device's certificate by:
1. Marking it as revoked in the database
2. Adding it to the Certificate Revocation List (CRL)
"""
try:
with get_db_context() as db:
device_cert = (
db.query(DeviceCertificate)
.filter(DeviceCertificate.device_id == device_id)
.first()
)
if not device_cert:
raise HTTPException(status_code=404, detail="Device certificate not found")
if device_cert.revoked_at:
raise HTTPException(status_code=400, detail="Certificate already revoked")
cert_manager.revoke_certificate(device_cert.certificate_pem)
device_cert.revoked_at = datetime.datetime.now(datetime.UTC)
db.commit()
logger.info(f"Successfully revoked certificate for device {device_id}")
return {
"device_id": device_id,
"revoked_at": device_cert.revoked_at.isoformat(),
"message": "Certificate revoked successfully"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to revoke device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to revoke device certificate."
) from e
@app.get("/crl")
async def get_crl():
"""
Get the Certificate Revocation List (CRL) in PEM format.
Mosquitto and other MQTT clients can check this to validate certificates.
"""
try:
crl_pem = cert_manager.get_crl_pem()
if not crl_pem:
return {"message": "No certificates have been revoked yet"}
return {"crl_pem": crl_pem}
except Exception as e:
logger.error(f"Failed to retrieve CRL: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to retrieve CRL."
) from e
@app.post("/devices/{device_id}/renew")
async def renew_certificate(device_id: str):
"""
Renew a device's certificate by issuing a new one and revoking the old one.
This endpoint:
1. Retrieves the current certificate from DB
2. Generates a new certificate with new keys
3. Revokes the old certificate (adds to CRL)
4. Updates the database with the new certificate
5. Returns the new credentials
"""
try:
with get_db_context() as db:
# Get current certificate
device_cert = (
db.query(DeviceCertificate)
.filter(
DeviceCertificate.device_id == device_id,
# DeviceCertificate.revoked_at.is_(None)
)
.first()
)
if not device_cert:
raise HTTPException(
status_code=404,
detail="No active certificate found for device"
)
# Check if certificate is about to expire (optional warning)
days_until_expiry = (device_cert.expires_at - datetime.datetime.now(datetime.UTC)).days
if days_until_expiry > 30:
logger.warning(
f"Certificate for device {device_id} renewed early "
f"({days_until_expiry} days remaining)"
)
# Revoke old certificate and add to CRL
cert_manager.revoke_certificate(device_cert.certificate_pem)
device_cert.revoked_at = datetime.datetime.now(datetime.UTC)
# Generate new certificate with new keys
new_cert_pem, new_key_pem = cert_manager.renew_certificate(
current_cert_pem=device_cert.certificate_pem,
validity_days=365,
key_size=4096
)
# Extract certificate ID (serial number) from the new certificate
from cryptography import x509
new_cert = x509.load_pem_x509_certificate(new_cert_pem)
new_cert_id = format(new_cert.serial_number, 'x')
# Create new certificate record in DB
now = datetime.datetime.now(datetime.UTC)
new_device_cert = DeviceCertificate(
id=new_cert_id,
device_id=device_id,
certificate_pem=new_cert_pem.decode("utf-8"),
private_key_pem=new_key_pem.decode("utf-8"),
issued_at=now,
expires_at=now + datetime.timedelta(days=365),
)
db.add(new_device_cert)
db.commit()
logger.info(f"Successfully renewed certificate for device {device_id}")
return DeviceRegistrationResponse(
certificate_id=new_cert_id,
device_id=device_id,
ca_certificate_pem=cert_manager.get_ca_certificate_pem(),
certificate_pem=new_device_cert.certificate_pem,
private_key_pem=new_device_cert.private_key_pem,
expires_at=new_device_cert.expires_at,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to renew certificate for device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to renew device certificate."
) from e

View File

@@ -1,4 +1,5 @@
import datetime
from pathlib import Path
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
@@ -55,12 +56,12 @@ class CertificateManager:
) -> tuple[bytes, bytes]:
"""Generate an X.509 certificate for a device signed by the CA."""
# Build device certificate
subject = x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, device_id),
]
)
issuer = ca_cert.subject
now = datetime.datetime.now(datetime.UTC)
device_cert = (
@@ -78,7 +79,6 @@ class CertificateManager:
.sign(private_key=ca_key, algorithm=hashes.SHA256())
)
# Serialize certificate and key to PEM format
cert_pem = device_cert.public_bytes(serialization.Encoding.PEM)
key_pem = device_key.private_bytes(
encoding=serialization.Encoding.PEM,
@@ -93,7 +93,7 @@ class CertificateManager:
) -> dict:
"""Create device credentials: private key and signed certificate.
Returns:
dict with device_id, certificate_pem, private_key_pem, ca_certificate_pem, expires_at
dict with certificate_id, device_id, certificate_pem, private_key_pem, ca_certificate_pem, expires_at
"""
device_key = self.generate_device_key(key_size=key_size)
@@ -106,11 +106,14 @@ class CertificateManager:
key_size=key_size,
)
expires_at = datetime.datetime.now(datetime.UTC) + datetime.timedelta(
days=validity_days
)
# Extract serial number from certificate to use as ID
cert = x509.load_pem_x509_certificate(cert_pem)
cert_id = format(cert.serial_number, 'x') # Hex string of serial number
expires_at = datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=validity_days)
return {
"certificate_id": cert_id,
"device_id": device_id,
"certificate_pem": cert_pem,
"private_key_pem": key_pem,
@@ -118,9 +121,7 @@ class CertificateManager:
"expires_at": expires_at,
}
def register_device(
self, name: str, location: str | None = None
) -> DeviceRegistrationResponse:
def register_device(self, name: str, location: str | None = None) -> DeviceRegistrationResponse:
"""Register a new device and generate its credentials.
Returns:
DeviceRegistrationResponse
@@ -129,6 +130,7 @@ class CertificateManager:
credentials = self.create_device_credentials(device_id=device_id)
return DeviceRegistrationResponse(
certificate_id=credentials["certificate_id"],
device_id=credentials["device_id"],
ca_certificate_pem=credentials["ca_certificate_pem"].decode("utf-8"),
certificate_pem=credentials["certificate_pem"].decode("utf-8"),
@@ -140,3 +142,99 @@ class CertificateManager:
"""Get the CA certificate in PEM format as a string."""
return self.ca_cert_pem.decode("utf-8")
def revoke_certificate(
self, certificate_pem: str, reason: x509.ReasonFlags = x509.ReasonFlags.unspecified
) -> None:
"""
Revoke a device certificate by adding it to the CRL.
Args:
certificate_pem: PEM-encoded certificate to revoke
reason: Revocation reason (default: unspecified)
"""
# Load the certificate to get serial number
cert = x509.load_pem_x509_certificate(certificate_pem.encode())
# Load existing CRL or create new one
crl_path = Path(config.CRL_PATH)
revoked_certs = []
if crl_path.exists():
with open(crl_path, "rb") as f:
existing_crl = x509.load_pem_x509_crl(f.read())
# Copy existing revoked certificates
revoked_certs = list(existing_crl)
# Add the new revoked certificate
revoked_cert = (
x509.RevokedCertificateBuilder()
.serial_number(cert.serial_number)
.revocation_date(datetime.datetime.now(datetime.UTC))
.add_extension(
x509.CRLReason(reason),
critical=False,
)
.build()
)
revoked_certs.append(revoked_cert)
# Build new CRL with all revoked certificates
crl_builder = x509.CertificateRevocationListBuilder()
crl_builder = crl_builder.issuer_name(self.ca_cert.subject)
crl_builder = crl_builder.last_update(datetime.datetime.now(datetime.UTC))
crl_builder = crl_builder.next_update(
datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=30)
)
for revoked in revoked_certs:
crl_builder = crl_builder.add_revoked_certificate(revoked)
# Sign the CRL with CA key
crl = crl_builder.sign(private_key=self.ca_key, algorithm=hashes.SHA256())
# Write CRL to file
crl_path.parent.mkdir(parents=True, exist_ok=True)
with open(crl_path, "wb") as f:
f.write(crl.public_bytes(serialization.Encoding.PEM))
def get_crl_pem(self) -> str | None:
"""Get the current CRL in PEM format."""
crl_path = Path(config.CRL_PATH)
if not crl_path.exists():
return None
with open(crl_path, "rb") as f:
return f.read().decode("utf-8")
def renew_certificate(
self,
current_cert_pem: str,
validity_days: int = 365,
key_size: int = 4096,
) -> tuple[bytes, bytes]:
"""Renew a device certificate before expiration.
Args:
current_cert_pem: PEM-encoded current certificate
validity_days: Validity period for new certificate
key_size: Key size for new device key
Returns:
tuple of (new_cert_pem, new_key_pem)
"""
# Load current certificate
current_cert = x509.load_pem_x509_certificate(current_cert_pem.encode())
device_id = current_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
# Generate new device key
new_device_key = self.generate_device_key(key_size=key_size)
# Generate new device certificate
new_cert_pem, new_key_pem = self.generate_device_certificate(
device_id=device_id,
ca_cert=self.ca_cert,
ca_key=self.ca_key,
device_key=new_device_key,
validity_days=validity_days,
key_size=key_size,
)
return new_cert_pem, new_key_pem

View File

@@ -15,6 +15,7 @@ class Config:
CERTS_DIR = SERVICE_DIR / "certs"
CA_CERT_PATH = os.getenv("CA_CERT_PATH", str(CERTS_DIR / "ca.crt"))
CA_KEY_PATH = os.getenv("CA_KEY_PATH", str(CERTS_DIR / "ca.key"))
CRL_PATH = os.getenv("CRL_PATH", str(CERTS_DIR / "ca.crl"))
# Certificate settings
CERT_VALIDITY_DAYS = int(os.getenv("CERT_VALIDITY_DAYS", "365"))

View File

@@ -31,6 +31,7 @@ class DeviceCertificate(Base):
__tablename__ = "device_certificates"
id = Column(Text, primary_key=True)
device_id = Column(
Text, ForeignKey("devices.id", ondelete="CASCADE"), primary_key=True
)

View File

@@ -11,8 +11,17 @@ class DeviceRegistrationRequest(BaseModel):
class DeviceRegistrationResponse(BaseModel):
"""Response model after registering a new device."""
certificate_id: str
device_id: str
ca_certificate_pem: str
certificate_pem: str
private_key_pem: str
expires_at: datetime.datetime
class DeviceResponse(BaseModel):
"""Response model for device information."""
id: str
name: str
location: str | None = None
created_at: datetime.datetime

View File

@@ -1,77 +0,0 @@
import datetime
import logging
from fastapi import FastAPI, HTTPException
from cert_manager import CertificateManager
from database import get_db_context
from db_models import Device, DeviceCertificate # SQLAlchemy ORM models
from models import DeviceRegistrationRequest, DeviceRegistrationResponse # Pydantic API models
logger = logging.getLogger(__name__)
app = FastAPI()
cert_manager = CertificateManager()
@app.get("/")
async def hello():
return {"Hello": "World"}
@app.post("/devices/register")
async def register_device(
request: DeviceRegistrationRequest,
) -> DeviceRegistrationResponse:
"""
Register a new device and issue an X.509 certificate.
"""
try:
response = cert_manager.register_device(
name=request.name,
location=request.location,
)
with get_db_context() as db:
device = Device(
id=response.device_id,
name=request.name,
location=request.location,
created_at=datetime.datetime.now(datetime.UTC),
)
db.add(device)
device_cert = DeviceCertificate(
device_id=response.device_id,
certificate_pem=response.certificate_pem,
private_key_pem=response.private_key_pem,
issued_at=datetime.datetime.now(datetime.UTC),
expires_at=response.expires_at,
)
db.add(device_cert)
except Exception as e:
logger.error(
f"Failed to register device {request.name}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=500, detail="Failed to register device. Please try again."
) from e
return response
@app.get("/ca_certificate")
async def get_ca_certificate() -> str:
"""
Retrieve the CA certificate in PEM format.
"""
try:
ca_cert_pem = cert_manager.get_ca_certificate_pem()
return ca_cert_pem
except Exception as e:
logger.error(f"Failed to retrieve CA certificate: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to retrieve CA certificate."
) from e

View File

@@ -1,7 +1,7 @@
[project]
name = "device-manager"
version = "0.1.0"
description = "Add your description here"
description = "Device Manager"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [