mirror of
https://github.com/ferdzo/iotDashboard.git
synced 2026-04-05 09:06:26 +00:00
Functioning device manager with renew,revoke, updated model for cert id
This commit is contained in:
@@ -0,0 +1,60 @@
|
|||||||
|
"""add_certificate_id_and_indices
|
||||||
|
|
||||||
|
Revision ID: 4f152b34e800
|
||||||
|
Revises: f94393f57c35
|
||||||
|
Create Date: 2025-10-30 21:29:43.843375+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '4f152b34e800'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = 'f94393f57c35'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# Step 1: Add id column as nullable first
|
||||||
|
op.add_column('device_certificates', sa.Column('id', sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
# Step 2: Generate IDs for existing records (use device_id as temporary ID)
|
||||||
|
op.execute("""
|
||||||
|
UPDATE device_certificates
|
||||||
|
SET id = device_id || '-' || EXTRACT(EPOCH FROM issued_at)::text
|
||||||
|
WHERE id IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Step 3: Drop old primary key constraint
|
||||||
|
op.drop_constraint('device_certificates_pkey', 'device_certificates', type_='primary')
|
||||||
|
|
||||||
|
# Step 4: Make id NOT NULL now that all rows have values
|
||||||
|
op.alter_column('device_certificates', 'id', nullable=False)
|
||||||
|
|
||||||
|
# Step 5: Create new primary key on id
|
||||||
|
op.create_primary_key('device_certificates_pkey', 'device_certificates', ['id'])
|
||||||
|
|
||||||
|
# Step 6: Create indices
|
||||||
|
op.create_index('idx_device_certificates_active', 'device_certificates', ['device_id', 'revoked_at'], unique=False)
|
||||||
|
op.create_index('idx_device_certificates_device_id', 'device_certificates', ['device_id'], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# Drop indices
|
||||||
|
op.drop_index('idx_device_certificates_device_id', table_name='device_certificates')
|
||||||
|
op.drop_index('idx_device_certificates_active', table_name='device_certificates')
|
||||||
|
|
||||||
|
# Drop new primary key
|
||||||
|
op.drop_constraint('device_certificates_pkey', 'device_certificates', type_='primary')
|
||||||
|
|
||||||
|
# Recreate old primary key on device_id
|
||||||
|
op.create_primary_key('device_certificates_pkey', 'device_certificates', ['device_id'])
|
||||||
|
|
||||||
|
# Drop id column
|
||||||
|
op.drop_column('device_certificates', 'id')
|
||||||
@@ -35,8 +35,9 @@ class DeviceCertificate(Base):
|
|||||||
|
|
||||||
__tablename__ = "device_certificates"
|
__tablename__ = "device_certificates"
|
||||||
|
|
||||||
|
id = Column(Text, primary_key=True)
|
||||||
device_id = Column(
|
device_id = Column(
|
||||||
Text, ForeignKey("devices.id", ondelete="CASCADE"), primary_key=True
|
Text, ForeignKey("devices.id", ondelete="CASCADE"), nullable=False
|
||||||
)
|
)
|
||||||
certificate_pem = Column(Text, nullable=False)
|
certificate_pem = Column(Text, nullable=False)
|
||||||
private_key_pem = Column(Text) # Optional: for backup/escrow
|
private_key_pem = Column(Text) # Optional: for backup/escrow
|
||||||
@@ -44,8 +45,13 @@ class DeviceCertificate(Base):
|
|||||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
revoked_at = Column(DateTime(timezone=True))
|
revoked_at = Column(DateTime(timezone=True))
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_device_certificates_device_id", "device_id"),
|
||||||
|
Index("idx_device_certificates_active", "device_id", "revoked_at"),
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<DeviceCertificate(device_id={self.device_id}, expires={self.expires_at})>"
|
return f"<DeviceCertificate(id={self.id}, device_id={self.device_id}, expires={self.expires_at})>"
|
||||||
|
|
||||||
|
|
||||||
class Telemetry(Base):
|
class Telemetry(Base):
|
||||||
|
|||||||
3
infrastructure/.gitignore
vendored
Normal file
3
infrastructure/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
mosquitto/certs/
|
||||||
|
mosquitto/data/
|
||||||
|
mosquitto/logs/
|
||||||
@@ -15,7 +15,7 @@ services:
|
|||||||
- "9001:9001"
|
- "9001:9001"
|
||||||
- "8883:8883"
|
- "8883:8883"
|
||||||
volumes:
|
volumes:
|
||||||
- ./mosquitto/:/mosquitto/config/
|
- ./mosquitto/:/mosquitto/
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
timescaledb:
|
timescaledb:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ persistence true
|
|||||||
persistence_location /mosquitto/data/
|
persistence_location /mosquitto/data/
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
log_dest file /mosquitto/log/mosquitto.log
|
#log_dest file /mosquitto/log/mosquitto.log
|
||||||
|
|
||||||
# Standard MQTT listener (for testing without certs)
|
# Standard MQTT listener (for testing without certs)
|
||||||
listener 1883
|
listener 1883
|
||||||
@@ -15,11 +15,14 @@ allow_anonymous true
|
|||||||
protocol mqtt
|
protocol mqtt
|
||||||
|
|
||||||
# Server certificates (mosquitto's identity)
|
# Server certificates (mosquitto's identity)
|
||||||
certfile /mosquitto/config/server.crt
|
certfile /mosquitto/certs/server.crt
|
||||||
keyfile /mosquitto/config/server.key
|
keyfile /mosquitto/certs/server.key
|
||||||
|
|
||||||
# CA certificate to verify client certificates
|
# CA certificate to verify client certificates
|
||||||
cafile /mosquitto/config/ca.crt
|
cafile /mosquitto/certs/ca.crt
|
||||||
|
|
||||||
|
# CRL file
|
||||||
|
crlfile /mosquitto/certs/ca.crl
|
||||||
|
|
||||||
# Certificate-based authentication
|
# Certificate-based authentication
|
||||||
require_certificate true
|
require_certificate true
|
||||||
270
services/device_manager/app/app.py
Normal file
270
services/device_manager/app/app.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from db_models import Device, DeviceCertificate # SQLAlchemy ORM models
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
|
||||||
|
from cert_manager import CertificateManager
|
||||||
|
from database import get_db_context
|
||||||
|
from models import DeviceRegistrationRequest, DeviceRegistrationResponse, DeviceResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
cert_manager = CertificateManager()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def hello():
|
||||||
|
return {"Hello": "World"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/devices/register")
|
||||||
|
async def register_device(
|
||||||
|
request: DeviceRegistrationRequest,
|
||||||
|
) -> DeviceRegistrationResponse:
|
||||||
|
"""
|
||||||
|
Register a new device and issue an X.509 certificate.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = cert_manager.register_device(
|
||||||
|
name=request.name,
|
||||||
|
location=request.location,
|
||||||
|
)
|
||||||
|
|
||||||
|
with get_db_context() as db:
|
||||||
|
device = Device(
|
||||||
|
id=response.device_id,
|
||||||
|
name=request.name,
|
||||||
|
location=request.location,
|
||||||
|
created_at=datetime.datetime.now(datetime.UTC),
|
||||||
|
)
|
||||||
|
db.add(device)
|
||||||
|
|
||||||
|
device_cert = DeviceCertificate(
|
||||||
|
id =response.certificate_id,
|
||||||
|
device_id=response.device_id,
|
||||||
|
certificate_pem=response.certificate_pem,
|
||||||
|
private_key_pem=response.private_key_pem,
|
||||||
|
issued_at=datetime.datetime.now(datetime.UTC),
|
||||||
|
expires_at=response.expires_at,
|
||||||
|
)
|
||||||
|
db.add(device_cert)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to register device {request.name}: {str(e)}", exc_info=True
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Failed to register device. Please try again."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/ca_certificate")
|
||||||
|
async def get_ca_certificate() -> str:
|
||||||
|
"""
|
||||||
|
Retrieve the CA certificate in PEM format.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ca_cert_pem = cert_manager.get_ca_certificate_pem()
|
||||||
|
return ca_cert_pem
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to retrieve CA certificate: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Failed to retrieve CA certificate."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
@app.get("/devices/{device_id}")
|
||||||
|
async def get_device(device_id: str) -> DeviceResponse:
|
||||||
|
"""
|
||||||
|
Retrieve device information by ID.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with get_db_context() as db:
|
||||||
|
device = db.query(Device).filter(Device.id == device_id).first()
|
||||||
|
if not device:
|
||||||
|
raise HTTPException(status_code=404, detail="Device not found")
|
||||||
|
|
||||||
|
return Device(
|
||||||
|
id=device.id,
|
||||||
|
name=device.name,
|
||||||
|
location=device.location,
|
||||||
|
created_at=device.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to retrieve device {device_id}: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Failed to retrieve device information."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
@app.get("/devices/")
|
||||||
|
async def list_devices() -> list[DeviceResponse]:
|
||||||
|
"""
|
||||||
|
List all registered devices.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with get_db_context() as db:
|
||||||
|
devices = db.query(Device).all()
|
||||||
|
return [
|
||||||
|
DeviceResponse(
|
||||||
|
id=device.id,
|
||||||
|
name=device.name,
|
||||||
|
location=device.location,
|
||||||
|
created_at=device.created_at,
|
||||||
|
)
|
||||||
|
for device in devices
|
||||||
|
]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list devices: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Failed to list devices."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
@app.post("/devices/{device_id}/revoke")
|
||||||
|
async def revoke_device_certificate(device_id: str):
|
||||||
|
"""
|
||||||
|
Revoke a device's certificate by:
|
||||||
|
1. Marking it as revoked in the database
|
||||||
|
2. Adding it to the Certificate Revocation List (CRL)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with get_db_context() as db:
|
||||||
|
device_cert = (
|
||||||
|
db.query(DeviceCertificate)
|
||||||
|
.filter(DeviceCertificate.device_id == device_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not device_cert:
|
||||||
|
raise HTTPException(status_code=404, detail="Device certificate not found")
|
||||||
|
|
||||||
|
if device_cert.revoked_at:
|
||||||
|
raise HTTPException(status_code=400, detail="Certificate already revoked")
|
||||||
|
|
||||||
|
cert_manager.revoke_certificate(device_cert.certificate_pem)
|
||||||
|
|
||||||
|
device_cert.revoked_at = datetime.datetime.now(datetime.UTC)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Successfully revoked certificate for device {device_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"device_id": device_id,
|
||||||
|
"revoked_at": device_cert.revoked_at.isoformat(),
|
||||||
|
"message": "Certificate revoked successfully"
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to revoke device {device_id}: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Failed to revoke device certificate."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/crl")
|
||||||
|
async def get_crl():
|
||||||
|
"""
|
||||||
|
Get the Certificate Revocation List (CRL) in PEM format.
|
||||||
|
Mosquitto and other MQTT clients can check this to validate certificates.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
crl_pem = cert_manager.get_crl_pem()
|
||||||
|
if not crl_pem:
|
||||||
|
return {"message": "No certificates have been revoked yet"}
|
||||||
|
return {"crl_pem": crl_pem}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to retrieve CRL: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Failed to retrieve CRL."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
@app.post("/devices/{device_id}/renew")
|
||||||
|
async def renew_certificate(device_id: str):
|
||||||
|
"""
|
||||||
|
Renew a device's certificate by issuing a new one and revoking the old one.
|
||||||
|
|
||||||
|
This endpoint:
|
||||||
|
1. Retrieves the current certificate from DB
|
||||||
|
2. Generates a new certificate with new keys
|
||||||
|
3. Revokes the old certificate (adds to CRL)
|
||||||
|
4. Updates the database with the new certificate
|
||||||
|
5. Returns the new credentials
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with get_db_context() as db:
|
||||||
|
# Get current certificate
|
||||||
|
device_cert = (
|
||||||
|
db.query(DeviceCertificate)
|
||||||
|
.filter(
|
||||||
|
DeviceCertificate.device_id == device_id,
|
||||||
|
# DeviceCertificate.revoked_at.is_(None)
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not device_cert:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="No active certificate found for device"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if certificate is about to expire (optional warning)
|
||||||
|
days_until_expiry = (device_cert.expires_at - datetime.datetime.now(datetime.UTC)).days
|
||||||
|
if days_until_expiry > 30:
|
||||||
|
logger.warning(
|
||||||
|
f"Certificate for device {device_id} renewed early "
|
||||||
|
f"({days_until_expiry} days remaining)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Revoke old certificate and add to CRL
|
||||||
|
cert_manager.revoke_certificate(device_cert.certificate_pem)
|
||||||
|
device_cert.revoked_at = datetime.datetime.now(datetime.UTC)
|
||||||
|
|
||||||
|
# Generate new certificate with new keys
|
||||||
|
new_cert_pem, new_key_pem = cert_manager.renew_certificate(
|
||||||
|
current_cert_pem=device_cert.certificate_pem,
|
||||||
|
validity_days=365,
|
||||||
|
key_size=4096
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract certificate ID (serial number) from the new certificate
|
||||||
|
from cryptography import x509
|
||||||
|
new_cert = x509.load_pem_x509_certificate(new_cert_pem)
|
||||||
|
new_cert_id = format(new_cert.serial_number, 'x')
|
||||||
|
|
||||||
|
# Create new certificate record in DB
|
||||||
|
now = datetime.datetime.now(datetime.UTC)
|
||||||
|
new_device_cert = DeviceCertificate(
|
||||||
|
id=new_cert_id,
|
||||||
|
device_id=device_id,
|
||||||
|
certificate_pem=new_cert_pem.decode("utf-8"),
|
||||||
|
private_key_pem=new_key_pem.decode("utf-8"),
|
||||||
|
issued_at=now,
|
||||||
|
expires_at=now + datetime.timedelta(days=365),
|
||||||
|
)
|
||||||
|
db.add(new_device_cert)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Successfully renewed certificate for device {device_id}")
|
||||||
|
|
||||||
|
return DeviceRegistrationResponse(
|
||||||
|
certificate_id=new_cert_id,
|
||||||
|
device_id=device_id,
|
||||||
|
ca_certificate_pem=cert_manager.get_ca_certificate_pem(),
|
||||||
|
certificate_pem=new_device_cert.certificate_pem,
|
||||||
|
private_key_pem=new_device_cert.private_key_pem,
|
||||||
|
expires_at=new_device_cert.expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to renew certificate for device {device_id}: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Failed to renew device certificate."
|
||||||
|
) from e
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from cryptography import x509
|
from cryptography import x509
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
@@ -55,12 +56,12 @@ class CertificateManager:
|
|||||||
) -> tuple[bytes, bytes]:
|
) -> tuple[bytes, bytes]:
|
||||||
"""Generate an X.509 certificate for a device signed by the CA."""
|
"""Generate an X.509 certificate for a device signed by the CA."""
|
||||||
|
|
||||||
# Build device certificate
|
|
||||||
subject = x509.Name(
|
subject = x509.Name(
|
||||||
[
|
[
|
||||||
x509.NameAttribute(NameOID.COMMON_NAME, device_id),
|
x509.NameAttribute(NameOID.COMMON_NAME, device_id),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
issuer = ca_cert.subject
|
issuer = ca_cert.subject
|
||||||
now = datetime.datetime.now(datetime.UTC)
|
now = datetime.datetime.now(datetime.UTC)
|
||||||
device_cert = (
|
device_cert = (
|
||||||
@@ -78,7 +79,6 @@ class CertificateManager:
|
|||||||
.sign(private_key=ca_key, algorithm=hashes.SHA256())
|
.sign(private_key=ca_key, algorithm=hashes.SHA256())
|
||||||
)
|
)
|
||||||
|
|
||||||
# Serialize certificate and key to PEM format
|
|
||||||
cert_pem = device_cert.public_bytes(serialization.Encoding.PEM)
|
cert_pem = device_cert.public_bytes(serialization.Encoding.PEM)
|
||||||
key_pem = device_key.private_bytes(
|
key_pem = device_key.private_bytes(
|
||||||
encoding=serialization.Encoding.PEM,
|
encoding=serialization.Encoding.PEM,
|
||||||
@@ -93,7 +93,7 @@ class CertificateManager:
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""Create device credentials: private key and signed certificate.
|
"""Create device credentials: private key and signed certificate.
|
||||||
Returns:
|
Returns:
|
||||||
dict with device_id, certificate_pem, private_key_pem, ca_certificate_pem, expires_at
|
dict with certificate_id, device_id, certificate_pem, private_key_pem, ca_certificate_pem, expires_at
|
||||||
"""
|
"""
|
||||||
device_key = self.generate_device_key(key_size=key_size)
|
device_key = self.generate_device_key(key_size=key_size)
|
||||||
|
|
||||||
@@ -106,11 +106,14 @@ class CertificateManager:
|
|||||||
key_size=key_size,
|
key_size=key_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
expires_at = datetime.datetime.now(datetime.UTC) + datetime.timedelta(
|
# Extract serial number from certificate to use as ID
|
||||||
days=validity_days
|
cert = x509.load_pem_x509_certificate(cert_pem)
|
||||||
)
|
cert_id = format(cert.serial_number, 'x') # Hex string of serial number
|
||||||
|
|
||||||
|
expires_at = datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=validity_days)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"certificate_id": cert_id,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
"certificate_pem": cert_pem,
|
"certificate_pem": cert_pem,
|
||||||
"private_key_pem": key_pem,
|
"private_key_pem": key_pem,
|
||||||
@@ -118,9 +121,7 @@ class CertificateManager:
|
|||||||
"expires_at": expires_at,
|
"expires_at": expires_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
def register_device(
|
def register_device(self, name: str, location: str | None = None) -> DeviceRegistrationResponse:
|
||||||
self, name: str, location: str | None = None
|
|
||||||
) -> DeviceRegistrationResponse:
|
|
||||||
"""Register a new device and generate its credentials.
|
"""Register a new device and generate its credentials.
|
||||||
Returns:
|
Returns:
|
||||||
DeviceRegistrationResponse
|
DeviceRegistrationResponse
|
||||||
@@ -129,6 +130,7 @@ class CertificateManager:
|
|||||||
credentials = self.create_device_credentials(device_id=device_id)
|
credentials = self.create_device_credentials(device_id=device_id)
|
||||||
|
|
||||||
return DeviceRegistrationResponse(
|
return DeviceRegistrationResponse(
|
||||||
|
certificate_id=credentials["certificate_id"],
|
||||||
device_id=credentials["device_id"],
|
device_id=credentials["device_id"],
|
||||||
ca_certificate_pem=credentials["ca_certificate_pem"].decode("utf-8"),
|
ca_certificate_pem=credentials["ca_certificate_pem"].decode("utf-8"),
|
||||||
certificate_pem=credentials["certificate_pem"].decode("utf-8"),
|
certificate_pem=credentials["certificate_pem"].decode("utf-8"),
|
||||||
@@ -140,3 +142,99 @@ class CertificateManager:
|
|||||||
"""Get the CA certificate in PEM format as a string."""
|
"""Get the CA certificate in PEM format as a string."""
|
||||||
return self.ca_cert_pem.decode("utf-8")
|
return self.ca_cert_pem.decode("utf-8")
|
||||||
|
|
||||||
|
def revoke_certificate(
|
||||||
|
self, certificate_pem: str, reason: x509.ReasonFlags = x509.ReasonFlags.unspecified
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Revoke a device certificate by adding it to the CRL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
certificate_pem: PEM-encoded certificate to revoke
|
||||||
|
reason: Revocation reason (default: unspecified)
|
||||||
|
"""
|
||||||
|
# Load the certificate to get serial number
|
||||||
|
cert = x509.load_pem_x509_certificate(certificate_pem.encode())
|
||||||
|
|
||||||
|
# Load existing CRL or create new one
|
||||||
|
crl_path = Path(config.CRL_PATH)
|
||||||
|
revoked_certs = []
|
||||||
|
|
||||||
|
if crl_path.exists():
|
||||||
|
with open(crl_path, "rb") as f:
|
||||||
|
existing_crl = x509.load_pem_x509_crl(f.read())
|
||||||
|
# Copy existing revoked certificates
|
||||||
|
revoked_certs = list(existing_crl)
|
||||||
|
|
||||||
|
# Add the new revoked certificate
|
||||||
|
revoked_cert = (
|
||||||
|
x509.RevokedCertificateBuilder()
|
||||||
|
.serial_number(cert.serial_number)
|
||||||
|
.revocation_date(datetime.datetime.now(datetime.UTC))
|
||||||
|
.add_extension(
|
||||||
|
x509.CRLReason(reason),
|
||||||
|
critical=False,
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
revoked_certs.append(revoked_cert)
|
||||||
|
|
||||||
|
# Build new CRL with all revoked certificates
|
||||||
|
crl_builder = x509.CertificateRevocationListBuilder()
|
||||||
|
crl_builder = crl_builder.issuer_name(self.ca_cert.subject)
|
||||||
|
crl_builder = crl_builder.last_update(datetime.datetime.now(datetime.UTC))
|
||||||
|
crl_builder = crl_builder.next_update(
|
||||||
|
datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=30)
|
||||||
|
)
|
||||||
|
|
||||||
|
for revoked in revoked_certs:
|
||||||
|
crl_builder = crl_builder.add_revoked_certificate(revoked)
|
||||||
|
|
||||||
|
# Sign the CRL with CA key
|
||||||
|
crl = crl_builder.sign(private_key=self.ca_key, algorithm=hashes.SHA256())
|
||||||
|
|
||||||
|
# Write CRL to file
|
||||||
|
crl_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(crl_path, "wb") as f:
|
||||||
|
f.write(crl.public_bytes(serialization.Encoding.PEM))
|
||||||
|
|
||||||
|
def get_crl_pem(self) -> str | None:
|
||||||
|
"""Get the current CRL in PEM format."""
|
||||||
|
crl_path = Path(config.CRL_PATH)
|
||||||
|
if not crl_path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(crl_path, "rb") as f:
|
||||||
|
return f.read().decode("utf-8")
|
||||||
|
|
||||||
|
def renew_certificate(
|
||||||
|
self,
|
||||||
|
current_cert_pem: str,
|
||||||
|
validity_days: int = 365,
|
||||||
|
key_size: int = 4096,
|
||||||
|
) -> tuple[bytes, bytes]:
|
||||||
|
"""Renew a device certificate before expiration.
|
||||||
|
Args:
|
||||||
|
current_cert_pem: PEM-encoded current certificate
|
||||||
|
validity_days: Validity period for new certificate
|
||||||
|
key_size: Key size for new device key
|
||||||
|
Returns:
|
||||||
|
tuple of (new_cert_pem, new_key_pem)
|
||||||
|
"""
|
||||||
|
# Load current certificate
|
||||||
|
current_cert = x509.load_pem_x509_certificate(current_cert_pem.encode())
|
||||||
|
device_id = current_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
|
||||||
|
|
||||||
|
# Generate new device key
|
||||||
|
new_device_key = self.generate_device_key(key_size=key_size)
|
||||||
|
|
||||||
|
# Generate new device certificate
|
||||||
|
new_cert_pem, new_key_pem = self.generate_device_certificate(
|
||||||
|
device_id=device_id,
|
||||||
|
ca_cert=self.ca_cert,
|
||||||
|
ca_key=self.ca_key,
|
||||||
|
device_key=new_device_key,
|
||||||
|
validity_days=validity_days,
|
||||||
|
key_size=key_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_cert_pem, new_key_pem
|
||||||
@@ -15,6 +15,7 @@ class Config:
|
|||||||
CERTS_DIR = SERVICE_DIR / "certs"
|
CERTS_DIR = SERVICE_DIR / "certs"
|
||||||
CA_CERT_PATH = os.getenv("CA_CERT_PATH", str(CERTS_DIR / "ca.crt"))
|
CA_CERT_PATH = os.getenv("CA_CERT_PATH", str(CERTS_DIR / "ca.crt"))
|
||||||
CA_KEY_PATH = os.getenv("CA_KEY_PATH", str(CERTS_DIR / "ca.key"))
|
CA_KEY_PATH = os.getenv("CA_KEY_PATH", str(CERTS_DIR / "ca.key"))
|
||||||
|
CRL_PATH = os.getenv("CRL_PATH", str(CERTS_DIR / "ca.crl"))
|
||||||
|
|
||||||
# Certificate settings
|
# Certificate settings
|
||||||
CERT_VALIDITY_DAYS = int(os.getenv("CERT_VALIDITY_DAYS", "365"))
|
CERT_VALIDITY_DAYS = int(os.getenv("CERT_VALIDITY_DAYS", "365"))
|
||||||
@@ -31,6 +31,7 @@ class DeviceCertificate(Base):
|
|||||||
|
|
||||||
__tablename__ = "device_certificates"
|
__tablename__ = "device_certificates"
|
||||||
|
|
||||||
|
id = Column(Text, primary_key=True)
|
||||||
device_id = Column(
|
device_id = Column(
|
||||||
Text, ForeignKey("devices.id", ondelete="CASCADE"), primary_key=True
|
Text, ForeignKey("devices.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
@@ -11,8 +11,17 @@ class DeviceRegistrationRequest(BaseModel):
|
|||||||
class DeviceRegistrationResponse(BaseModel):
|
class DeviceRegistrationResponse(BaseModel):
|
||||||
"""Response model after registering a new device."""
|
"""Response model after registering a new device."""
|
||||||
|
|
||||||
|
certificate_id: str
|
||||||
device_id: str
|
device_id: str
|
||||||
ca_certificate_pem: str
|
ca_certificate_pem: str
|
||||||
certificate_pem: str
|
certificate_pem: str
|
||||||
private_key_pem: str
|
private_key_pem: str
|
||||||
expires_at: datetime.datetime
|
expires_at: datetime.datetime
|
||||||
|
|
||||||
|
class DeviceResponse(BaseModel):
|
||||||
|
"""Response model for device information."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
location: str | None = None
|
||||||
|
created_at: datetime.datetime
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
|
||||||
|
|
||||||
from cert_manager import CertificateManager
|
|
||||||
from database import get_db_context
|
|
||||||
from db_models import Device, DeviceCertificate # SQLAlchemy ORM models
|
|
||||||
from models import DeviceRegistrationRequest, DeviceRegistrationResponse # Pydantic API models
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
cert_manager = CertificateManager()
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def hello():
|
|
||||||
return {"Hello": "World"}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/devices/register")
|
|
||||||
async def register_device(
|
|
||||||
request: DeviceRegistrationRequest,
|
|
||||||
) -> DeviceRegistrationResponse:
|
|
||||||
"""
|
|
||||||
Register a new device and issue an X.509 certificate.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = cert_manager.register_device(
|
|
||||||
name=request.name,
|
|
||||||
location=request.location,
|
|
||||||
)
|
|
||||||
|
|
||||||
with get_db_context() as db:
|
|
||||||
device = Device(
|
|
||||||
id=response.device_id,
|
|
||||||
name=request.name,
|
|
||||||
location=request.location,
|
|
||||||
created_at=datetime.datetime.now(datetime.UTC),
|
|
||||||
)
|
|
||||||
db.add(device)
|
|
||||||
|
|
||||||
device_cert = DeviceCertificate(
|
|
||||||
device_id=response.device_id,
|
|
||||||
certificate_pem=response.certificate_pem,
|
|
||||||
private_key_pem=response.private_key_pem,
|
|
||||||
issued_at=datetime.datetime.now(datetime.UTC),
|
|
||||||
expires_at=response.expires_at,
|
|
||||||
)
|
|
||||||
db.add(device_cert)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to register device {request.name}: {str(e)}", exc_info=True
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="Failed to register device. Please try again."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/ca_certificate")
|
|
||||||
async def get_ca_certificate() -> str:
|
|
||||||
"""
|
|
||||||
Retrieve the CA certificate in PEM format.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
ca_cert_pem = cert_manager.get_ca_certificate_pem()
|
|
||||||
return ca_cert_pem
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to retrieve CA certificate: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="Failed to retrieve CA certificate."
|
|
||||||
) from e
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "device-manager"
|
name = "device-manager"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Add your description here"
|
description = "Device Manager"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|||||||
Reference in New Issue
Block a user