mirror of
https://github.com/ferdzo/iotDashboard.git
synced 2026-04-05 09:06:26 +00:00
Added multi-protocol support for devices, improved models and updated readme.md and instructions
This commit is contained in:
@@ -5,8 +5,12 @@ from fastapi import FastAPI, HTTPException
|
||||
|
||||
from app.cert_manager import CertificateManager
|
||||
from app.database import get_db_context
|
||||
from app.db_models import Device, DeviceCertificate # SQLAlchemy ORM models
|
||||
from app.models import DeviceRegistrationRequest, DeviceRegistrationResponse, DeviceResponse
|
||||
from app.db_models import Device, DeviceCertificate
|
||||
from app.models import (
|
||||
DeviceRegistrationRequest,
|
||||
DeviceRegistrationResponse,
|
||||
DeviceResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,43 +29,62 @@ async def register_device(
|
||||
request: DeviceRegistrationRequest,
|
||||
) -> DeviceRegistrationResponse:
|
||||
"""
|
||||
Register a new device and issue an X.509 certificate.
|
||||
Register a new device.
|
||||
- MQTT devices: issues X.509 certificate for mTLS
|
||||
- HTTP/webhook devices: generates API key or HMAC secret
|
||||
"""
|
||||
try:
|
||||
response = cert_manager.register_device(
|
||||
name=request.name,
|
||||
location=request.location,
|
||||
)
|
||||
|
||||
with get_db_context() as db:
|
||||
device = Device(
|
||||
id=response.device_id,
|
||||
if request.protocol == "mqtt":
|
||||
cert_response = cert_manager.register_device(
|
||||
name=request.name,
|
||||
location=request.location,
|
||||
created_at=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
db.add(device)
|
||||
|
||||
device_cert = DeviceCertificate(
|
||||
id =response.certificate_id,
|
||||
device_id=response.device_id,
|
||||
certificate_pem=response.certificate_pem,
|
||||
private_key_pem=response.private_key_pem,
|
||||
issued_at=datetime.datetime.now(datetime.UTC),
|
||||
expires_at=response.expires_at,
|
||||
with get_db_context() as db:
|
||||
device = Device(
|
||||
id=cert_response.device_id,
|
||||
name=request.name,
|
||||
location=request.location,
|
||||
protocol=request.protocol,
|
||||
connection_config=request.connection_config,
|
||||
created_at=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
db.add(device)
|
||||
|
||||
device_cert = DeviceCertificate(
|
||||
id=cert_response.certificate_id,
|
||||
device_id=cert_response.device_id,
|
||||
certificate_pem=cert_response.certificate_pem,
|
||||
private_key_pem=cert_response.private_key_pem,
|
||||
issued_at=datetime.datetime.now(datetime.UTC),
|
||||
expires_at=cert_response.expires_at,
|
||||
)
|
||||
db.add(device_cert)
|
||||
|
||||
return DeviceRegistrationResponse(
|
||||
device_id=cert_response.device_id,
|
||||
protocol=request.protocol,
|
||||
certificate_id=cert_response.certificate_id,
|
||||
ca_certificate_pem=cert_response.ca_certificate_pem,
|
||||
certificate_pem=cert_response.certificate_pem,
|
||||
private_key_pem=cert_response.private_key_pem,
|
||||
expires_at=cert_response.expires_at,
|
||||
)
|
||||
db.add(device_cert)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Protocol '{request.protocol}' not yet implemented. Only 'mqtt' is supported.",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to register device {request.name}: {str(e)}", exc_info=True
|
||||
)
|
||||
logger.error(f"Failed to register device {request.name}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to register device. Please try again."
|
||||
) from e
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/ca_certificate")
|
||||
async def get_ca_certificate() -> str:
|
||||
@@ -73,9 +96,8 @@ async def get_ca_certificate() -> str:
|
||||
return ca_cert_pem
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve CA certificate: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to retrieve CA certificate."
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve CA certificate.") from e
|
||||
|
||||
|
||||
@app.get("/devices/{device_id}")
|
||||
async def get_device(device_id: str) -> DeviceResponse:
|
||||
@@ -88,18 +110,19 @@ async def get_device(device_id: str) -> DeviceResponse:
|
||||
if not device:
|
||||
raise HTTPException(status_code=404, detail="Device not found")
|
||||
|
||||
return Device(
|
||||
return DeviceResponse(
|
||||
id=device.id,
|
||||
name=device.name,
|
||||
location=device.location,
|
||||
protocol=device.protocol,
|
||||
connection_config=device.connection_config,
|
||||
created_at=device.created_at,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve device {device_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to retrieve device information."
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve device information.") from e
|
||||
|
||||
|
||||
@app.get("/devices/")
|
||||
async def list_devices() -> list[DeviceResponse]:
|
||||
@@ -114,6 +137,8 @@ async def list_devices() -> list[DeviceResponse]:
|
||||
id=device.id,
|
||||
name=device.name,
|
||||
location=device.location,
|
||||
protocol=device.protocol,
|
||||
connection_config=device.connection_config,
|
||||
created_at=device.created_at,
|
||||
)
|
||||
for device in devices
|
||||
@@ -121,9 +146,8 @@ async def list_devices() -> list[DeviceResponse]:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list devices: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to list devices."
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail="Failed to list devices.") from e
|
||||
|
||||
|
||||
@app.post("/devices/{device_id}/revoke")
|
||||
async def revoke_device_certificate(device_id: str):
|
||||
@@ -135,9 +159,7 @@ async def revoke_device_certificate(device_id: str):
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
device_cert = (
|
||||
db.query(DeviceCertificate)
|
||||
.filter(DeviceCertificate.device_id == device_id)
|
||||
.first()
|
||||
db.query(DeviceCertificate).filter(DeviceCertificate.device_id == device_id).first()
|
||||
)
|
||||
if not device_cert:
|
||||
raise HTTPException(status_code=404, detail="Device certificate not found")
|
||||
@@ -155,16 +177,14 @@ async def revoke_device_certificate(device_id: str):
|
||||
return {
|
||||
"device_id": device_id,
|
||||
"revoked_at": device_cert.revoked_at.isoformat(),
|
||||
"message": "Certificate revoked successfully"
|
||||
"message": "Certificate revoked successfully",
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke device {device_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to revoke device certificate."
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail="Failed to revoke device certificate.") from e
|
||||
|
||||
|
||||
@app.get("/crl")
|
||||
@@ -180,15 +200,14 @@ async def get_crl():
|
||||
return {"crl_pem": crl_pem}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve CRL: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to retrieve CRL."
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve CRL.") from e
|
||||
|
||||
|
||||
@app.post("/devices/{device_id}/renew")
|
||||
async def renew_certificate(device_id: str):
|
||||
"""
|
||||
Renew a device's certificate by issuing a new one and revoking the old one.
|
||||
|
||||
|
||||
This endpoint:
|
||||
1. Retrieves the current certificate from DB
|
||||
2. Generates a new certificate with new keys
|
||||
@@ -209,8 +228,7 @@ async def renew_certificate(device_id: str):
|
||||
)
|
||||
if not device_cert:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No active certificate found for device"
|
||||
status_code=404, detail="No active certificate found for device"
|
||||
)
|
||||
|
||||
# Check if certificate is about to expire (optional warning)
|
||||
@@ -227,15 +245,14 @@ async def renew_certificate(device_id: str):
|
||||
|
||||
# Generate new certificate with new keys
|
||||
new_cert_pem, new_key_pem = cert_manager.renew_certificate(
|
||||
current_cert_pem=device_cert.certificate_pem,
|
||||
validity_days=365,
|
||||
key_size=4096
|
||||
current_cert_pem=device_cert.certificate_pem, validity_days=365, key_size=4096
|
||||
)
|
||||
|
||||
# Extract certificate ID (serial number) from the new certificate
|
||||
from cryptography import x509
|
||||
|
||||
new_cert = x509.load_pem_x509_certificate(new_cert_pem)
|
||||
new_cert_id = format(new_cert.serial_number, 'x')
|
||||
new_cert_id = format(new_cert.serial_number, "x")
|
||||
|
||||
# Create new certificate record in DB
|
||||
now = datetime.datetime.now(datetime.UTC)
|
||||
@@ -252,9 +269,12 @@ async def renew_certificate(device_id: str):
|
||||
|
||||
logger.info(f"Successfully renewed certificate for device {device_id}")
|
||||
|
||||
device = db.query(Device).filter(Device.id == device_id).first()
|
||||
|
||||
return DeviceRegistrationResponse(
|
||||
certificate_id=new_cert_id,
|
||||
device_id=device_id,
|
||||
protocol=device.protocol if device else "mqtt",
|
||||
certificate_id=new_cert_id,
|
||||
ca_certificate_pem=cert_manager.get_ca_certificate_pem(),
|
||||
certificate_pem=new_device_cert.certificate_pem,
|
||||
private_key_pem=new_device_cert.private_key_pem,
|
||||
@@ -265,6 +285,4 @@ async def renew_certificate(device_id: str):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to renew certificate for device {device_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to renew device certificate."
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail="Failed to renew device certificate.") from e
|
||||
|
||||
Reference in New Issue
Block a user