Added multi-protocol support for devices, improved models and updated readme.md and instructions

This commit is contained in:
2025-11-02 14:09:29 +01:00
parent ddbc588c77
commit 96e2377073
13 changed files with 730 additions and 375 deletions

View File

@@ -0,0 +1,32 @@
FROM ghcr.io/astral-sh/uv:python3.13-alpine AS builder
WORKDIR /app
ENV UV_COMPILE_BYTECODE=1
COPY pyproject.toml uv.lock ./
RUN uv sync --frozen --no-dev --no-install-project
COPY ./src/ ./src/
COPY main.py ./
RUN uv sync --frozen --no-dev
FROM python:3.13-alpine
WORKDIR /app
COPY --from=builder /app/.venv /app/.venv
COPY --from=builder /app/*.py /app/
RUN adduser -D -u 1000 appuser && \
chown -R appuser:appuser /app
USER appuser
ENV PATH="/app/.venv/bin:$PATH"
CMD ["python", "main.py"]

View File

@@ -5,8 +5,12 @@ from fastapi import FastAPI, HTTPException
from app.cert_manager import CertificateManager
from app.database import get_db_context
from app.db_models import Device, DeviceCertificate # SQLAlchemy ORM models
from app.models import DeviceRegistrationRequest, DeviceRegistrationResponse, DeviceResponse
from app.db_models import Device, DeviceCertificate
from app.models import (
DeviceRegistrationRequest,
DeviceRegistrationResponse,
DeviceResponse,
)
logger = logging.getLogger(__name__)
@@ -25,43 +29,62 @@ async def register_device(
request: DeviceRegistrationRequest,
) -> DeviceRegistrationResponse:
"""
Register a new device and issue an X.509 certificate.
Register a new device.
- MQTT devices: issues X.509 certificate for mTLS
- HTTP/webhook devices: generates API key or HMAC secret
"""
try:
response = cert_manager.register_device(
name=request.name,
location=request.location,
)
with get_db_context() as db:
device = Device(
id=response.device_id,
if request.protocol == "mqtt":
cert_response = cert_manager.register_device(
name=request.name,
location=request.location,
created_at=datetime.datetime.now(datetime.UTC),
)
db.add(device)
device_cert = DeviceCertificate(
id =response.certificate_id,
device_id=response.device_id,
certificate_pem=response.certificate_pem,
private_key_pem=response.private_key_pem,
issued_at=datetime.datetime.now(datetime.UTC),
expires_at=response.expires_at,
with get_db_context() as db:
device = Device(
id=cert_response.device_id,
name=request.name,
location=request.location,
protocol=request.protocol,
connection_config=request.connection_config,
created_at=datetime.datetime.now(datetime.UTC),
)
db.add(device)
device_cert = DeviceCertificate(
id=cert_response.certificate_id,
device_id=cert_response.device_id,
certificate_pem=cert_response.certificate_pem,
private_key_pem=cert_response.private_key_pem,
issued_at=datetime.datetime.now(datetime.UTC),
expires_at=cert_response.expires_at,
)
db.add(device_cert)
return DeviceRegistrationResponse(
device_id=cert_response.device_id,
protocol=request.protocol,
certificate_id=cert_response.certificate_id,
ca_certificate_pem=cert_response.ca_certificate_pem,
certificate_pem=cert_response.certificate_pem,
private_key_pem=cert_response.private_key_pem,
expires_at=cert_response.expires_at,
)
db.add(device_cert)
else:
raise HTTPException(
status_code=400,
detail=f"Protocol '{request.protocol}' not yet implemented. Only 'mqtt' is supported.",
)
except HTTPException:
raise
except Exception as e:
logger.error(
f"Failed to register device {request.name}: {str(e)}", exc_info=True
)
logger.error(f"Failed to register device {request.name}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to register device. Please try again."
) from e
return response
@app.get("/ca_certificate")
async def get_ca_certificate() -> str:
@@ -73,9 +96,8 @@ async def get_ca_certificate() -> str:
return ca_cert_pem
except Exception as e:
logger.error(f"Failed to retrieve CA certificate: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to retrieve CA certificate."
) from e
raise HTTPException(status_code=500, detail="Failed to retrieve CA certificate.") from e
@app.get("/devices/{device_id}")
async def get_device(device_id: str) -> DeviceResponse:
@@ -88,18 +110,19 @@ async def get_device(device_id: str) -> DeviceResponse:
if not device:
raise HTTPException(status_code=404, detail="Device not found")
return Device(
return DeviceResponse(
id=device.id,
name=device.name,
location=device.location,
protocol=device.protocol,
connection_config=device.connection_config,
created_at=device.created_at,
)
except Exception as e:
logger.error(f"Failed to retrieve device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to retrieve device information."
) from e
raise HTTPException(status_code=500, detail="Failed to retrieve device information.") from e
@app.get("/devices/")
async def list_devices() -> list[DeviceResponse]:
@@ -114,6 +137,8 @@ async def list_devices() -> list[DeviceResponse]:
id=device.id,
name=device.name,
location=device.location,
protocol=device.protocol,
connection_config=device.connection_config,
created_at=device.created_at,
)
for device in devices
@@ -121,9 +146,8 @@ async def list_devices() -> list[DeviceResponse]:
except Exception as e:
logger.error(f"Failed to list devices: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to list devices."
) from e
raise HTTPException(status_code=500, detail="Failed to list devices.") from e
@app.post("/devices/{device_id}/revoke")
async def revoke_device_certificate(device_id: str):
@@ -135,9 +159,7 @@ async def revoke_device_certificate(device_id: str):
try:
with get_db_context() as db:
device_cert = (
db.query(DeviceCertificate)
.filter(DeviceCertificate.device_id == device_id)
.first()
db.query(DeviceCertificate).filter(DeviceCertificate.device_id == device_id).first()
)
if not device_cert:
raise HTTPException(status_code=404, detail="Device certificate not found")
@@ -155,16 +177,14 @@ async def revoke_device_certificate(device_id: str):
return {
"device_id": device_id,
"revoked_at": device_cert.revoked_at.isoformat(),
"message": "Certificate revoked successfully"
"message": "Certificate revoked successfully",
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to revoke device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to revoke device certificate."
) from e
raise HTTPException(status_code=500, detail="Failed to revoke device certificate.") from e
@app.get("/crl")
@@ -180,15 +200,14 @@ async def get_crl():
return {"crl_pem": crl_pem}
except Exception as e:
logger.error(f"Failed to retrieve CRL: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to retrieve CRL."
) from e
raise HTTPException(status_code=500, detail="Failed to retrieve CRL.") from e
@app.post("/devices/{device_id}/renew")
async def renew_certificate(device_id: str):
"""
Renew a device's certificate by issuing a new one and revoking the old one.
This endpoint:
1. Retrieves the current certificate from DB
2. Generates a new certificate with new keys
@@ -209,8 +228,7 @@ async def renew_certificate(device_id: str):
)
if not device_cert:
raise HTTPException(
status_code=404,
detail="No active certificate found for device"
status_code=404, detail="No active certificate found for device"
)
# Check if certificate is about to expire (optional warning)
@@ -227,15 +245,14 @@ async def renew_certificate(device_id: str):
# Generate new certificate with new keys
new_cert_pem, new_key_pem = cert_manager.renew_certificate(
current_cert_pem=device_cert.certificate_pem,
validity_days=365,
key_size=4096
current_cert_pem=device_cert.certificate_pem, validity_days=365, key_size=4096
)
# Extract certificate ID (serial number) from the new certificate
from cryptography import x509
new_cert = x509.load_pem_x509_certificate(new_cert_pem)
new_cert_id = format(new_cert.serial_number, 'x')
new_cert_id = format(new_cert.serial_number, "x")
# Create new certificate record in DB
now = datetime.datetime.now(datetime.UTC)
@@ -252,9 +269,12 @@ async def renew_certificate(device_id: str):
logger.info(f"Successfully renewed certificate for device {device_id}")
device = db.query(Device).filter(Device.id == device_id).first()
return DeviceRegistrationResponse(
certificate_id=new_cert_id,
device_id=device_id,
protocol=device.protocol if device else "mqtt",
certificate_id=new_cert_id,
ca_certificate_pem=cert_manager.get_ca_certificate_pem(),
certificate_pem=new_device_cert.certificate_pem,
private_key_pem=new_device_cert.private_key_pem,
@@ -265,6 +285,4 @@ async def renew_certificate(device_id: str):
raise
except Exception as e:
logger.error(f"Failed to renew certificate for device {device_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail="Failed to renew device certificate."
) from e
raise HTTPException(status_code=500, detail="Failed to renew device certificate.") from e

View File

@@ -8,7 +8,7 @@ from cryptography.x509.oid import NameOID
from nanoid import generate
from app.config import config
from app.models import DeviceRegistrationResponse
from app.models import DeviceCertificateResponse, DeviceCredentials
lowercase_numbers = "abcdefghijklmnopqrstuvwxyz0123456789"
@@ -21,10 +21,12 @@ class CertificateManager:
self.ca_key: rsa.RSAPrivateKey = self.load_ca_private_key(config.CA_KEY_PATH)
self.ca_cert_pem: bytes = self.ca_cert.public_bytes(serialization.Encoding.PEM)
def generate_device_id(self) -> str:
"""Generate a unique device ID using nanoid."""
return generate(alphabet=lowercase_numbers, size=config.DEVICE_ID_LENGTH)
def load_ca_certificate(self, ca_cert_path: str) -> x509.Certificate:
"""Load a CA certificate from file."""
with open(ca_cert_path, "rb") as f:
@@ -32,6 +34,7 @@ class CertificateManager:
ca_cert = x509.load_pem_x509_certificate(ca_data)
return ca_cert
def load_ca_private_key(self, ca_key_path: str, password: bytes = None) -> rsa.RSAPrivateKey:
"""Load a CA private key from file."""
from cryptography.hazmat.primitives import serialization
@@ -41,10 +44,12 @@ class CertificateManager:
ca_key = serialization.load_pem_private_key(key_data, password=password)
return ca_key
def generate_device_key(self, key_size: int = 4096) -> rsa.RSAPrivateKey:
"""Generate an RSA private key for a device."""
return rsa.generate_private_key(public_exponent=65537, key_size=key_size)
def generate_device_certificate(
self,
device_id: str,
@@ -88,12 +93,13 @@ class CertificateManager:
return cert_pem, key_pem
def create_device_credentials(
self, device_id: str, validity_days: int = 365, key_size: int = 4096
) -> dict:
) -> DeviceCredentials:
"""Create device credentials: private key and signed certificate.
Returns:
dict with certificate_id, device_id, certificate_pem, private_key_pem, ca_certificate_pem, expires_at
DeviceCredentials model with certificate_id, device_id, certificate_pem, private_key_pem, expires_at
"""
device_key = self.generate_device_key(key_size=key_size)
@@ -108,40 +114,42 @@ class CertificateManager:
# Extract serial number from certificate to use as ID
cert = x509.load_pem_x509_certificate(cert_pem)
cert_id = format(cert.serial_number, 'x') # Hex string of serial number
cert_id = format(cert.serial_number, 'x')
expires_at = datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=validity_days)
return {
"certificate_id": cert_id,
"device_id": device_id,
"certificate_pem": cert_pem,
"private_key_pem": key_pem,
"ca_certificate_pem": self.ca_cert_pem,
"expires_at": expires_at,
}
return DeviceCredentials(
certificate_id=cert_id,
device_id=device_id,
certificate_pem=cert_pem,
private_key_pem=key_pem,
expires_at=expires_at,
)
def register_device(self, name: str, location: str | None = None) -> DeviceRegistrationResponse:
def register_device(self, name: str, location: str | None = None) -> DeviceCertificateResponse:
"""Register a new device and generate its credentials.
Returns:
DeviceRegistrationResponse
DeviceCertificateResponse
"""
device_id = self.generate_device_id()
credentials = self.create_device_credentials(device_id=device_id)
return DeviceRegistrationResponse(
certificate_id=credentials["certificate_id"],
device_id=credentials["device_id"],
ca_certificate_pem=credentials["ca_certificate_pem"].decode("utf-8"),
certificate_pem=credentials["certificate_pem"].decode("utf-8"),
private_key_pem=credentials["private_key_pem"].decode("utf-8"),
expires_at=credentials["expires_at"],
return DeviceCertificateResponse(
certificate_id=credentials.certificate_id,
device_id=credentials.device_id,
ca_certificate_pem=self.ca_cert_pem.decode("utf-8"),
certificate_pem=credentials.certificate_pem.decode("utf-8"),
private_key_pem=credentials.private_key_pem.decode("utf-8"),
expires_at=credentials.expires_at,
)
def get_ca_certificate_pem(self) -> str:
"""Get the CA certificate in PEM format as a string."""
return self.ca_cert_pem.decode("utf-8")
def revoke_certificate(
self, certificate_pem: str, reason: x509.ReasonFlags = x509.ReasonFlags.unspecified
) -> None:
@@ -197,6 +205,7 @@ class CertificateManager:
with open(crl_path, "wb") as f:
f.write(crl.public_bytes(serialization.Encoding.PEM))
def get_crl_pem(self) -> str | None:
"""Get the current CRL in PEM format."""
crl_path = Path(config.CRL_PATH)
@@ -206,6 +215,7 @@ class CertificateManager:
with open(crl_path, "rb") as f:
return f.read().decode("utf-8")
def renew_certificate(
self,
current_cert_pem: str,

View File

@@ -4,7 +4,7 @@ SQLAlchemy ORM models for device manager service.
These models mirror the database schema defined in db_migrations.
Kept separate to make the service independent.
"""
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Text
from sqlalchemy import JSON, Boolean, Column, DateTime, ForeignKey, Index, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import func
@@ -19,11 +19,13 @@ class Device(Base):
id = Column(Text, primary_key=True)
name = Column(Text, nullable=False)
location = Column(Text)
protocol = Column(Text, nullable=False, default="mqtt")
connection_config = Column(JSON)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
def __repr__(self):
return f"<Device(id={self.id}, name={self.name})>"
return f"<Device(id={self.id}, name={self.name}, protocol={self.protocol})>"
class DeviceCertificate(Base):
@@ -33,7 +35,7 @@ class DeviceCertificate(Base):
id = Column(Text, primary_key=True)
device_id = Column(
Text, ForeignKey("devices.id", ondelete="CASCADE"), primary_key=True
Text, ForeignKey("devices.id", ondelete="CASCADE"), nullable=False
)
certificate_pem = Column(Text, nullable=False)
private_key_pem = Column(Text)
@@ -41,5 +43,34 @@ class DeviceCertificate(Base):
expires_at = Column(DateTime(timezone=True), nullable=False)
revoked_at = Column(DateTime(timezone=True))
__table_args__ = (
Index("idx_device_certificates_device_id", "device_id"),
Index("idx_device_certificates_active", "device_id", "revoked_at"),
)
def __repr__(self):
return f"<DeviceCertificate(device_id={self.device_id}, expires={self.expires_at})>"
return f"<DeviceCertificate(id={self.id}, device_id={self.device_id}, expires={self.expires_at})>"
class DeviceCredential(Base):
"""Authentication credentials for non-mTLS protocols (HTTP, webhook, etc)."""
__tablename__ = "device_credentials"
id = Column(Text, primary_key=True)
device_id = Column(
Text, ForeignKey("devices.id", ondelete="CASCADE"), nullable=False
)
credential_type = Column(Text, nullable=False)
credential_hash = Column(Text, nullable=False)
created_at = Column(DateTime(timezone=True), nullable=False)
expires_at = Column(DateTime(timezone=True))
revoked_at = Column(DateTime(timezone=True))
__table_args__ = (
Index("idx_device_credentials_device_id", "device_id"),
Index("idx_device_credentials_active", "device_id", "revoked_at"),
)
def __repr__(self):
return f"<DeviceCredential(id={self.id}, device_id={self.device_id}, type={self.credential_type})>"

View File

@@ -1,17 +1,38 @@
import datetime
from typing import Any
from pydantic import BaseModel
class DeviceRegistrationRequest(BaseModel):
"""Request model for registering a new device."""
name: str
location: str | None = None
protocol: str = "mqtt"
connection_config: dict[str, Any] | None = None
class DeviceRegistrationResponse(BaseModel):
"""Response model after registering a new device."""
device_id: str
protocol: str
certificate_id: str | None = None
ca_certificate_pem: str | None = None
certificate_pem: str | None = None
private_key_pem: str | None = None
expires_at: datetime.datetime | None = None
credential_id: str | None = None
api_key: str | None = None
webhook_secret: str | None = None
class DeviceResponse(BaseModel):
id: str
name: str
location: str | None = None
protocol: str
connection_config: dict[str, Any] | None = None
created_at: datetime.datetime
class DeviceCertificateResponse(BaseModel):
certificate_id: str
device_id: str
ca_certificate_pem: str
@@ -19,10 +40,10 @@ class DeviceRegistrationResponse(BaseModel):
private_key_pem: str
expires_at: datetime.datetime
class DeviceResponse(BaseModel):
"""Response model for device information."""
id: str
name: str
location: str | None = None
created_at: datetime.datetime
class DeviceCredentials(BaseModel):
certificate_id: str
device_id: str
certificate_pem: bytes
private_key_pem: bytes
expires_at: datetime.datetime