Files
iotDashboard/services/device_manager/app/app.py

394 lines
14 KiB
Python

import datetime
import logging
import secrets
from cryptography import x509
from fastapi import FastAPI, HTTPException, Query
from app.cert_manager import CertificateManager
from app.database import get_db_context
from app.db_models import Device, DeviceCertificate, DeviceOnboardingToken
from app.models import (
DeviceCertificateResponse,
DeviceRegistrationRequest,
DeviceRegistrationResponse,
DeviceResponse,
)
logger = logging.getLogger(__name__)
app = FastAPI()
cert_manager = CertificateManager()
@app.get("/")
async def hello():
return {"Hello": "World"}
@app.post("/devices/register")
async def register_device(
request: DeviceRegistrationRequest,
) -> DeviceRegistrationResponse:
"""
Register a new device.
- MQTT devices: issues X.509 certificate for mTLS
- HTTP/webhook devices: generates API key or HMAC secret
"""
try:
if request.protocol == "mqtt":
cert_response = cert_manager.register_device(
name=request.name,
location=request.location,
)
with get_db_context() as db:
device = Device(
id=cert_response.device_id,
name=request.name,
location=request.location,
protocol=request.protocol,
connection_config=request.connection_config,
created_at=datetime.datetime.now(datetime.UTC),
)
db.add(device)
device_cert = DeviceCertificate(
id=cert_response.certificate_id,
device_id=cert_response.device_id,
certificate_pem=cert_response.certificate_pem,
private_key_pem=cert_response.private_key_pem,
issued_at=datetime.datetime.now(datetime.UTC),
expires_at=cert_response.expires_at,
)
db.add(device_cert)
onboarding_token = secrets.token_urlsafe(32)
token_expires = datetime.datetime.now(datetime.UTC) + datetime.timedelta(minutes=15)
db_token = DeviceOnboardingToken(
token=onboarding_token,
device_id=cert_response.device_id,
certificate_id=cert_response.certificate_id,
created_at=datetime.datetime.now(datetime.UTC),
expires_at=token_expires,
)
db.add(db_token)
db.commit()
return DeviceRegistrationResponse(
device_id=cert_response.device_id,
protocol=request.protocol,
certificate_id=cert_response.certificate_id,
ca_certificate_pem=cert_response.ca_certificate_pem,
certificate_pem=cert_response.certificate_pem,
private_key_pem=cert_response.private_key_pem,
expires_at=cert_response.expires_at,
onboarding_token=onboarding_token,
)
else:
raise HTTPException(
status_code=400,
detail=f"Protocol '{request.protocol}' not yet implemented. Only 'mqtt' is supported.",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to register device {request.name}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to register device. Please try again."
) from e
@app.get("/ca_certificate")
async def get_ca_certificate() -> str:
"""
Retrieve the CA certificate in PEM format.
"""
try:
ca_cert_pem = cert_manager.get_ca_certificate_pem()
return ca_cert_pem
except Exception as e:
logger.error(f"Failed to retrieve CA certificate: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to retrieve CA certificate.") from e
@app.get("/devices/{device_id}/credentials")
async def get_device_credentials(device_id: str, token: str = Query(...)) -> DeviceCertificateResponse:
"""
Securely fetch device credentials using one-time token from QR code.
Token is valid for 15 minutes and can only be used once.
"""
try:
with get_db_context() as db:
db_token = db.query(DeviceOnboardingToken).filter(
DeviceOnboardingToken.token == token,
DeviceOnboardingToken.device_id == device_id
).first()
if not db_token:
raise HTTPException(status_code=404, detail="Invalid or expired onboarding token")
if db_token.used_at is not None:
raise HTTPException(status_code=403, detail="Token has already been used")
# Check if token expired
if datetime.datetime.now(datetime.UTC) > db_token.expires_at:
raise HTTPException(status_code=403, detail="Token has expired")
device_cert = db.query(DeviceCertificate).filter(
DeviceCertificate.id == db_token.certificate_id,
DeviceCertificate.device_id == device_id,
DeviceCertificate.revoked_at.is_(None)
).first()
if not device_cert:
raise HTTPException(status_code=404, detail="Certificate not found or revoked")
db_token.used_at = datetime.datetime.now(datetime.UTC)
db.commit()
ca_cert_pem = cert_manager.get_ca_certificate_pem()
logger.info(f"Device {device_id} fetched credentials using onboarding token")
return DeviceCertificateResponse(
certificate_id=device_cert.id,
device_id=device_cert.device_id,
ca_certificate_pem=ca_cert_pem,
certificate_pem=device_cert.certificate_pem,
private_key_pem=device_cert.private_key_pem,
expires_at=device_cert.expires_at,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to fetch credentials for device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to fetch device credentials"
) from e
@app.get("/devices/{device_id}")
async def get_device(device_id: str) -> DeviceResponse:
"""
Retrieve device information by ID.
"""
try:
with get_db_context() as db:
device = db.query(Device).filter(Device.id == device_id).first()
if not device:
raise HTTPException(status_code=404, detail="Device not found")
return DeviceResponse(
id=device.id,
name=device.name,
location=device.location,
protocol=device.protocol,
connection_config=device.connection_config,
created_at=device.created_at,
)
except Exception as e:
logger.error(f"Failed to retrieve device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to retrieve device information.") from e
@app.get("/devices/")
async def list_devices() -> list[DeviceResponse]:
"""
List all registered devices.
"""
try:
with get_db_context() as db:
devices = db.query(Device).all()
return [
DeviceResponse(
id=device.id,
name=device.name,
location=device.location,
protocol=device.protocol,
connection_config=device.connection_config,
created_at=device.created_at,
)
for device in devices
]
except Exception as e:
logger.error(f"Failed to list devices: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to list devices.") from e
@app.post("/devices/{device_id}/revoke")
async def revoke_device_certificate(device_id: str):
"""
Revoke a device's certificate by:
1. Marking it as revoked in the database
2. Adding it to the Certificate Revocation List (CRL)
"""
try:
with get_db_context() as db:
# Get the active (non-revoked) certificate for the device
device_cert = (
db.query(DeviceCertificate)
.filter(
DeviceCertificate.device_id == device_id,
DeviceCertificate.revoked_at.is_(None)
)
.first()
)
if not device_cert:
raise HTTPException(
status_code=404,
detail="No active certificate found for this device"
)
cert_manager.revoke_certificate(device_cert.certificate_pem)
device_cert.revoked_at = datetime.datetime.now(datetime.UTC)
db.commit()
logger.info(f"Successfully revoked certificate for device {device_id}")
return {
"device_id": device_id,
"revoked_at": device_cert.revoked_at.isoformat(),
"message": "Certificate revoked successfully",
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to revoke device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to revoke device certificate.") from e
@app.get("/crl")
async def get_crl():
"""
Get the Certificate Revocation List (CRL) in PEM format.
Mosquitto and other MQTT clients can check this to validate certificates.
"""
try:
crl_pem = cert_manager.get_crl_pem()
if not crl_pem:
return {"message": "No certificates have been revoked yet"}
return {"crl_pem": crl_pem}
except Exception as e:
logger.error(f"Failed to retrieve CRL: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to retrieve CRL.") from e
@app.post("/devices/{device_id}/delete")
async def delete_device(device_id: str):
"""
Delete a device and its associated certificates from the database.
"""
try:
with get_db_context() as db:
device = db.query(Device).filter(Device.id == device_id).first()
if not device:
raise HTTPException(status_code=404, detail="Device not found")
# Delete associated certificates
db.query(DeviceCertificate).filter(DeviceCertificate.device_id == device_id).delete()
# Delete the device
db.delete(device)
db.commit()
logger.info(f"Successfully deleted device {device_id} and its certificates")
return {"message": f"Device {device_id} and its certificates have been deleted."}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to delete device.") from e
@app.post("/devices/{device_id}/renew")
async def renew_certificate(device_id: str):
"""
Renew a device's certificate by issuing a new one and revoking the old one.
This endpoint:
1. Retrieves the current certificate from DB
2. Generates a new certificate with new keys
3. Revokes the old certificate (adds to CRL)
4. Updates the database with the new certificate
5. Returns the new credentials
"""
try:
with get_db_context() as db:
# Get current certificate
device_cert = (
db.query(DeviceCertificate)
.filter(
DeviceCertificate.device_id == device_id,
)
.first()
)
if not device_cert:
raise HTTPException(
status_code=404, detail="No active certificate found for device"
)
# Check if certificate is about to expire (optional warning)
days_until_expiry = (device_cert.expires_at - datetime.datetime.now(datetime.UTC)).days
if days_until_expiry > 30:
logger.warning(
f"Certificate for device {device_id} renewed early "
f"({days_until_expiry} days remaining)"
)
# Revoke old certificate and add to CRL
cert_manager.revoke_certificate(device_cert.certificate_pem)
device_cert.revoked_at = datetime.datetime.now(datetime.UTC)
# Generate new certificate with new keys
new_cert_pem, new_key_pem = cert_manager.renew_certificate(
current_cert_pem=device_cert.certificate_pem, validity_days=365, key_size=4096
)
# Extract certificate ID (serial number) from the new certificate
new_cert = x509.load_pem_x509_certificate(new_cert_pem)
new_cert_id = format(new_cert.serial_number, "x")
# Create new certificate record in DB
now = datetime.datetime.now(datetime.UTC)
new_device_cert = DeviceCertificate(
id=new_cert_id,
device_id=device_id,
certificate_pem=new_cert_pem.decode("utf-8"),
private_key_pem=new_key_pem.decode("utf-8"),
issued_at=now,
expires_at=now + datetime.timedelta(days=365),
)
db.add(new_device_cert)
db.commit()
logger.info(f"Successfully renewed certificate for device {device_id}")
device = db.query(Device).filter(Device.id == device_id).first()
return DeviceRegistrationResponse(
device_id=device_id,
protocol=device.protocol if device else "mqtt",
certificate_id=new_cert_id,
ca_certificate_pem=cert_manager.get_ca_certificate_pem(),
certificate_pem=new_device_cert.certificate_pem,
private_key_pem=new_device_cert.private_key_pem,
expires_at=new_device_cert.expires_at,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to renew certificate for device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to renew device certificate.") from e