diff --git a/db_migrations/alembic/versions/20251030_2129_4f152b34e800_add_certificate_id_and_indices.py b/db_migrations/alembic/versions/20251030_2129_4f152b34e800_add_certificate_id_and_indices.py new file mode 100644 index 0000000..1024221 --- /dev/null +++ b/db_migrations/alembic/versions/20251030_2129_4f152b34e800_add_certificate_id_and_indices.py @@ -0,0 +1,60 @@ +"""add_certificate_id_and_indices + +Revision ID: 4f152b34e800 +Revises: f94393f57c35 +Create Date: 2025-10-30 21:29:43.843375+00:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '4f152b34e800' +down_revision: Union[str, Sequence[str], None] = 'f94393f57c35' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # Step 1: Add id column as nullable first + op.add_column('device_certificates', sa.Column('id', sa.Text(), nullable=True)) + + # Step 2: Generate IDs for existing records (use device_id as temporary ID) + op.execute(""" + UPDATE device_certificates + SET id = device_id || '-' || EXTRACT(EPOCH FROM issued_at)::text + WHERE id IS NULL + """) + + # Step 3: Drop old primary key constraint + op.drop_constraint('device_certificates_pkey', 'device_certificates', type_='primary') + + # Step 4: Make id NOT NULL now that all rows have values + op.alter_column('device_certificates', 'id', nullable=False) + + # Step 5: Create new primary key on id + op.create_primary_key('device_certificates_pkey', 'device_certificates', ['id']) + + # Step 6: Create indices + op.create_index('idx_device_certificates_active', 'device_certificates', ['device_id', 'revoked_at'], unique=False) + op.create_index('idx_device_certificates_device_id', 'device_certificates', ['device_id'], unique=False) + + +def downgrade() -> None: + """Downgrade schema.""" + # Drop indices + op.drop_index('idx_device_certificates_device_id', table_name='device_certificates') + op.drop_index('idx_device_certificates_active', table_name='device_certificates') + + # Drop new primary key + op.drop_constraint('device_certificates_pkey', 'device_certificates', type_='primary') + + # Recreate old primary key on device_id + op.create_primary_key('device_certificates_pkey', 'device_certificates', ['device_id']) + + # Drop id column + op.drop_column('device_certificates', 'id') diff --git a/db_migrations/models.py b/db_migrations/models.py index de6692a..973d85d 100644 --- a/db_migrations/models.py +++ b/db_migrations/models.py @@ -35,8 +35,9 @@ class DeviceCertificate(Base): __tablename__ = "device_certificates" + id = Column(Text, primary_key=True) device_id = Column( - Text, ForeignKey("devices.id", ondelete="CASCADE"), primary_key=True + Text, ForeignKey("devices.id", ondelete="CASCADE"), nullable=False ) certificate_pem = Column(Text, nullable=False) private_key_pem = Column(Text) # Optional: for backup/escrow @@ -44,8 +45,13 @@ class DeviceCertificate(Base): expires_at = Column(DateTime(timezone=True), nullable=False) revoked_at = Column(DateTime(timezone=True)) + __table_args__ = ( + Index("idx_device_certificates_device_id", "device_id"), + Index("idx_device_certificates_active", "device_id", "revoked_at"), + ) + def __repr__(self): - return f"" + return f"" class Telemetry(Base): diff --git a/infrastructure/.gitignore b/infrastructure/.gitignore new file mode 100644 index 0000000..fe44ad6 --- /dev/null +++ b/infrastructure/.gitignore @@ -0,0 +1,3 @@ +mosquitto/certs/ +mosquitto/data/ +mosquitto/logs/ \ No newline at end of file diff --git a/infrastructure/compose.yml b/infrastructure/compose.yml index c076c29..7b9c76a 100644 --- a/infrastructure/compose.yml +++ b/infrastructure/compose.yml @@ -15,7 +15,7 @@ services: - "9001:9001" - "8883:8883" volumes: - - ./mosquitto/:/mosquitto/config/ + - ./mosquitto/:/mosquitto/ restart: unless-stopped timescaledb: diff --git a/infrastructure/mosquitto/mosquitto.conf b/infrastructure/mosquitto/config/mosquitto.conf similarity index 72% rename from infrastructure/mosquitto/mosquitto.conf rename to infrastructure/mosquitto/config/mosquitto.conf index 7b3c566..4e8efcb 100644 --- a/infrastructure/mosquitto/mosquitto.conf +++ b/infrastructure/mosquitto/config/mosquitto.conf @@ -3,7 +3,7 @@ persistence true persistence_location /mosquitto/data/ # Logging -log_dest file /mosquitto/log/mosquitto.log +#log_dest file /mosquitto/log/mosquitto.log # Standard MQTT listener (for testing without certs) listener 1883 @@ -15,11 +15,14 @@ allow_anonymous true protocol mqtt # Server certificates (mosquitto's identity) -certfile /mosquitto/config/server.crt -keyfile /mosquitto/config/server.key +certfile /mosquitto/certs/server.crt +keyfile /mosquitto/certs/server.key # CA certificate to verify client certificates -cafile /mosquitto/config/ca.crt +cafile /mosquitto/certs/ca.crt + +# CRL file +crlfile /mosquitto/certs/ca.crl # Certificate-based authentication require_certificate true diff --git a/services/device_manager/app/app.py b/services/device_manager/app/app.py new file mode 100644 index 0000000..f38a3ae --- /dev/null +++ b/services/device_manager/app/app.py @@ -0,0 +1,270 @@ +import datetime +import logging + +from db_models import Device, DeviceCertificate # SQLAlchemy ORM models +from fastapi import FastAPI, HTTPException + +from cert_manager import CertificateManager +from database import get_db_context +from models import DeviceRegistrationRequest, DeviceRegistrationResponse, DeviceResponse + +logger = logging.getLogger(__name__) + +app = FastAPI() + +cert_manager = CertificateManager() + + +@app.get("/") +async def hello(): + return {"Hello": "World"} + + +@app.post("/devices/register") +async def register_device( + request: DeviceRegistrationRequest, +) -> DeviceRegistrationResponse: + """ + Register a new device and issue an X.509 certificate. + """ + try: + response = cert_manager.register_device( + name=request.name, + location=request.location, + ) + + with get_db_context() as db: + device = Device( + id=response.device_id, + name=request.name, + location=request.location, + created_at=datetime.datetime.now(datetime.UTC), + ) + db.add(device) + + device_cert = DeviceCertificate( + id =response.certificate_id, + device_id=response.device_id, + certificate_pem=response.certificate_pem, + private_key_pem=response.private_key_pem, + issued_at=datetime.datetime.now(datetime.UTC), + expires_at=response.expires_at, + ) + db.add(device_cert) + + except Exception as e: + logger.error( + f"Failed to register device {request.name}: {str(e)}", exc_info=True + ) + raise HTTPException( + status_code=500, detail="Failed to register device. Please try again." + ) from e + + return response + + +@app.get("/ca_certificate") +async def get_ca_certificate() -> str: + """ + Retrieve the CA certificate in PEM format. + """ + try: + ca_cert_pem = cert_manager.get_ca_certificate_pem() + return ca_cert_pem + except Exception as e: + logger.error(f"Failed to retrieve CA certificate: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, detail="Failed to retrieve CA certificate." + ) from e + +@app.get("/devices/{device_id}") +async def get_device(device_id: str) -> DeviceResponse: + """ + Retrieve device information by ID. + """ + try: + with get_db_context() as db: + device = db.query(Device).filter(Device.id == device_id).first() + if not device: + raise HTTPException(status_code=404, detail="Device not found") + + return Device( + id=device.id, + name=device.name, + location=device.location, + created_at=device.created_at, + ) + + except Exception as e: + logger.error(f"Failed to retrieve device {device_id}: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, detail="Failed to retrieve device information." + ) from e + +@app.get("/devices/") +async def list_devices() -> list[DeviceResponse]: + """ + List all registered devices. + """ + try: + with get_db_context() as db: + devices = db.query(Device).all() + return [ + DeviceResponse( + id=device.id, + name=device.name, + location=device.location, + created_at=device.created_at, + ) + for device in devices + ] + + except Exception as e: + logger.error(f"Failed to list devices: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, detail="Failed to list devices." + ) from e + +@app.post("/devices/{device_id}/revoke") +async def revoke_device_certificate(device_id: str): + """ + Revoke a device's certificate by: + 1. Marking it as revoked in the database + 2. Adding it to the Certificate Revocation List (CRL) + """ + try: + with get_db_context() as db: + device_cert = ( + db.query(DeviceCertificate) + .filter(DeviceCertificate.device_id == device_id) + .first() + ) + if not device_cert: + raise HTTPException(status_code=404, detail="Device certificate not found") + + if device_cert.revoked_at: + raise HTTPException(status_code=400, detail="Certificate already revoked") + + cert_manager.revoke_certificate(device_cert.certificate_pem) + + device_cert.revoked_at = datetime.datetime.now(datetime.UTC) + db.commit() + + logger.info(f"Successfully revoked certificate for device {device_id}") + + return { + "device_id": device_id, + "revoked_at": device_cert.revoked_at.isoformat(), + "message": "Certificate revoked successfully" + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to revoke device {device_id}: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, detail="Failed to revoke device certificate." + ) from e + + +@app.get("/crl") +async def get_crl(): + """ + Get the Certificate Revocation List (CRL) in PEM format. + Mosquitto and other MQTT clients can check this to validate certificates. + """ + try: + crl_pem = cert_manager.get_crl_pem() + if not crl_pem: + return {"message": "No certificates have been revoked yet"} + return {"crl_pem": crl_pem} + except Exception as e: + logger.error(f"Failed to retrieve CRL: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, detail="Failed to retrieve CRL." + ) from e + +@app.post("/devices/{device_id}/renew") +async def renew_certificate(device_id: str): + """ + Renew a device's certificate by issuing a new one and revoking the old one. + + This endpoint: + 1. Retrieves the current certificate from DB + 2. Generates a new certificate with new keys + 3. Revokes the old certificate (adds to CRL) + 4. Updates the database with the new certificate + 5. Returns the new credentials + """ + try: + with get_db_context() as db: + # Get current certificate + device_cert = ( + db.query(DeviceCertificate) + .filter( + DeviceCertificate.device_id == device_id, + # DeviceCertificate.revoked_at.is_(None) + ) + .first() + ) + if not device_cert: + raise HTTPException( + status_code=404, + detail="No active certificate found for device" + ) + + # Check if certificate is about to expire (optional warning) + days_until_expiry = (device_cert.expires_at - datetime.datetime.now(datetime.UTC)).days + if days_until_expiry > 30: + logger.warning( + f"Certificate for device {device_id} renewed early " + f"({days_until_expiry} days remaining)" + ) + + # Revoke old certificate and add to CRL + cert_manager.revoke_certificate(device_cert.certificate_pem) + device_cert.revoked_at = datetime.datetime.now(datetime.UTC) + + # Generate new certificate with new keys + new_cert_pem, new_key_pem = cert_manager.renew_certificate( + current_cert_pem=device_cert.certificate_pem, + validity_days=365, + key_size=4096 + ) + + # Extract certificate ID (serial number) from the new certificate + from cryptography import x509 + new_cert = x509.load_pem_x509_certificate(new_cert_pem) + new_cert_id = format(new_cert.serial_number, 'x') + + # Create new certificate record in DB + now = datetime.datetime.now(datetime.UTC) + new_device_cert = DeviceCertificate( + id=new_cert_id, + device_id=device_id, + certificate_pem=new_cert_pem.decode("utf-8"), + private_key_pem=new_key_pem.decode("utf-8"), + issued_at=now, + expires_at=now + datetime.timedelta(days=365), + ) + db.add(new_device_cert) + db.commit() + + logger.info(f"Successfully renewed certificate for device {device_id}") + + return DeviceRegistrationResponse( + certificate_id=new_cert_id, + device_id=device_id, + ca_certificate_pem=cert_manager.get_ca_certificate_pem(), + certificate_pem=new_device_cert.certificate_pem, + private_key_pem=new_device_cert.private_key_pem, + expires_at=new_device_cert.expires_at, + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to renew certificate for device {device_id}: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, detail="Failed to renew device certificate." + ) from e diff --git a/services/device_manager/cert_manager.py b/services/device_manager/app/cert_manager.py similarity index 54% rename from services/device_manager/cert_manager.py rename to services/device_manager/app/cert_manager.py index 0b624bb..0968cc9 100644 --- a/services/device_manager/cert_manager.py +++ b/services/device_manager/app/cert_manager.py @@ -1,4 +1,5 @@ import datetime +from pathlib import Path from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization @@ -55,12 +56,12 @@ class CertificateManager: ) -> tuple[bytes, bytes]: """Generate an X.509 certificate for a device signed by the CA.""" - # Build device certificate subject = x509.Name( [ x509.NameAttribute(NameOID.COMMON_NAME, device_id), ] ) + issuer = ca_cert.subject now = datetime.datetime.now(datetime.UTC) device_cert = ( @@ -78,7 +79,6 @@ class CertificateManager: .sign(private_key=ca_key, algorithm=hashes.SHA256()) ) - # Serialize certificate and key to PEM format cert_pem = device_cert.public_bytes(serialization.Encoding.PEM) key_pem = device_key.private_bytes( encoding=serialization.Encoding.PEM, @@ -93,7 +93,7 @@ class CertificateManager: ) -> dict: """Create device credentials: private key and signed certificate. Returns: - dict with device_id, certificate_pem, private_key_pem, ca_certificate_pem, expires_at + dict with certificate_id, device_id, certificate_pem, private_key_pem, ca_certificate_pem, expires_at """ device_key = self.generate_device_key(key_size=key_size) @@ -106,11 +106,14 @@ class CertificateManager: key_size=key_size, ) - expires_at = datetime.datetime.now(datetime.UTC) + datetime.timedelta( - days=validity_days - ) + # Extract serial number from certificate to use as ID + cert = x509.load_pem_x509_certificate(cert_pem) + cert_id = format(cert.serial_number, 'x') # Hex string of serial number + + expires_at = datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=validity_days) return { + "certificate_id": cert_id, "device_id": device_id, "certificate_pem": cert_pem, "private_key_pem": key_pem, @@ -118,9 +121,7 @@ class CertificateManager: "expires_at": expires_at, } - def register_device( - self, name: str, location: str | None = None - ) -> DeviceRegistrationResponse: + def register_device(self, name: str, location: str | None = None) -> DeviceRegistrationResponse: """Register a new device and generate its credentials. Returns: DeviceRegistrationResponse @@ -129,6 +130,7 @@ class CertificateManager: credentials = self.create_device_credentials(device_id=device_id) return DeviceRegistrationResponse( + certificate_id=credentials["certificate_id"], device_id=credentials["device_id"], ca_certificate_pem=credentials["ca_certificate_pem"].decode("utf-8"), certificate_pem=credentials["certificate_pem"].decode("utf-8"), @@ -140,3 +142,99 @@ class CertificateManager: """Get the CA certificate in PEM format as a string.""" return self.ca_cert_pem.decode("utf-8") + def revoke_certificate( + self, certificate_pem: str, reason: x509.ReasonFlags = x509.ReasonFlags.unspecified + ) -> None: + """ + Revoke a device certificate by adding it to the CRL. + + Args: + certificate_pem: PEM-encoded certificate to revoke + reason: Revocation reason (default: unspecified) + """ + # Load the certificate to get serial number + cert = x509.load_pem_x509_certificate(certificate_pem.encode()) + + # Load existing CRL or create new one + crl_path = Path(config.CRL_PATH) + revoked_certs = [] + + if crl_path.exists(): + with open(crl_path, "rb") as f: + existing_crl = x509.load_pem_x509_crl(f.read()) + # Copy existing revoked certificates + revoked_certs = list(existing_crl) + + # Add the new revoked certificate + revoked_cert = ( + x509.RevokedCertificateBuilder() + .serial_number(cert.serial_number) + .revocation_date(datetime.datetime.now(datetime.UTC)) + .add_extension( + x509.CRLReason(reason), + critical=False, + ) + .build() + ) + revoked_certs.append(revoked_cert) + + # Build new CRL with all revoked certificates + crl_builder = x509.CertificateRevocationListBuilder() + crl_builder = crl_builder.issuer_name(self.ca_cert.subject) + crl_builder = crl_builder.last_update(datetime.datetime.now(datetime.UTC)) + crl_builder = crl_builder.next_update( + datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=30) + ) + + for revoked in revoked_certs: + crl_builder = crl_builder.add_revoked_certificate(revoked) + + # Sign the CRL with CA key + crl = crl_builder.sign(private_key=self.ca_key, algorithm=hashes.SHA256()) + + # Write CRL to file + crl_path.parent.mkdir(parents=True, exist_ok=True) + with open(crl_path, "wb") as f: + f.write(crl.public_bytes(serialization.Encoding.PEM)) + + def get_crl_pem(self) -> str | None: + """Get the current CRL in PEM format.""" + crl_path = Path(config.CRL_PATH) + if not crl_path.exists(): + return None + + with open(crl_path, "rb") as f: + return f.read().decode("utf-8") + + def renew_certificate( + self, + current_cert_pem: str, + validity_days: int = 365, + key_size: int = 4096, + ) -> tuple[bytes, bytes]: + """Renew a device certificate before expiration. + Args: + current_cert_pem: PEM-encoded current certificate + validity_days: Validity period for new certificate + key_size: Key size for new device key + Returns: + tuple of (new_cert_pem, new_key_pem) + """ + # Load current certificate + current_cert = x509.load_pem_x509_certificate(current_cert_pem.encode()) + device_id = current_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value + + # Generate new device key + new_device_key = self.generate_device_key(key_size=key_size) + + # Generate new device certificate + new_cert_pem, new_key_pem = self.generate_device_certificate( + device_id=device_id, + ca_cert=self.ca_cert, + ca_key=self.ca_key, + device_key=new_device_key, + validity_days=validity_days, + key_size=key_size, + ) + + return new_cert_pem, new_key_pem diff --git a/services/device_manager/config.py b/services/device_manager/app/config.py similarity index 93% rename from services/device_manager/config.py rename to services/device_manager/app/config.py index e299b5b..0223718 100644 --- a/services/device_manager/config.py +++ b/services/device_manager/app/config.py @@ -15,6 +15,7 @@ class Config: CERTS_DIR = SERVICE_DIR / "certs" CA_CERT_PATH = os.getenv("CA_CERT_PATH", str(CERTS_DIR / "ca.crt")) CA_KEY_PATH = os.getenv("CA_KEY_PATH", str(CERTS_DIR / "ca.key")) + CRL_PATH = os.getenv("CRL_PATH", str(CERTS_DIR / "ca.crl")) # Certificate settings CERT_VALIDITY_DAYS = int(os.getenv("CERT_VALIDITY_DAYS", "365")) diff --git a/services/device_manager/database.py b/services/device_manager/app/database.py similarity index 100% rename from services/device_manager/database.py rename to services/device_manager/app/database.py diff --git a/services/device_manager/db_models.py b/services/device_manager/app/db_models.py similarity index 97% rename from services/device_manager/db_models.py rename to services/device_manager/app/db_models.py index 4bec069..e20e2f4 100644 --- a/services/device_manager/db_models.py +++ b/services/device_manager/app/db_models.py @@ -31,6 +31,7 @@ class DeviceCertificate(Base): __tablename__ = "device_certificates" + id = Column(Text, primary_key=True) device_id = Column( Text, ForeignKey("devices.id", ondelete="CASCADE"), primary_key=True ) diff --git a/services/device_manager/models.py b/services/device_manager/app/models.py similarity index 68% rename from services/device_manager/models.py rename to services/device_manager/app/models.py index ea82836..ec997a5 100644 --- a/services/device_manager/models.py +++ b/services/device_manager/app/models.py @@ -11,8 +11,17 @@ class DeviceRegistrationRequest(BaseModel): class DeviceRegistrationResponse(BaseModel): """Response model after registering a new device.""" + certificate_id: str device_id: str ca_certificate_pem: str certificate_pem: str private_key_pem: str expires_at: datetime.datetime + +class DeviceResponse(BaseModel): + """Response model for device information.""" + + id: str + name: str + location: str | None = None + created_at: datetime.datetime diff --git a/services/device_manager/main.py b/services/device_manager/main.py deleted file mode 100644 index 57b591d..0000000 --- a/services/device_manager/main.py +++ /dev/null @@ -1,77 +0,0 @@ -import datetime -import logging - -from fastapi import FastAPI, HTTPException - -from cert_manager import CertificateManager -from database import get_db_context -from db_models import Device, DeviceCertificate # SQLAlchemy ORM models -from models import DeviceRegistrationRequest, DeviceRegistrationResponse # Pydantic API models - -logger = logging.getLogger(__name__) - -app = FastAPI() - -cert_manager = CertificateManager() - - -@app.get("/") -async def hello(): - return {"Hello": "World"} - - -@app.post("/devices/register") -async def register_device( - request: DeviceRegistrationRequest, -) -> DeviceRegistrationResponse: - """ - Register a new device and issue an X.509 certificate. - """ - try: - response = cert_manager.register_device( - name=request.name, - location=request.location, - ) - - with get_db_context() as db: - device = Device( - id=response.device_id, - name=request.name, - location=request.location, - created_at=datetime.datetime.now(datetime.UTC), - ) - db.add(device) - - device_cert = DeviceCertificate( - device_id=response.device_id, - certificate_pem=response.certificate_pem, - private_key_pem=response.private_key_pem, - issued_at=datetime.datetime.now(datetime.UTC), - expires_at=response.expires_at, - ) - db.add(device_cert) - - except Exception as e: - logger.error( - f"Failed to register device {request.name}: {str(e)}", exc_info=True - ) - raise HTTPException( - status_code=500, detail="Failed to register device. Please try again." - ) from e - - return response - - -@app.get("/ca_certificate") -async def get_ca_certificate() -> str: - """ - Retrieve the CA certificate in PEM format. - """ - try: - ca_cert_pem = cert_manager.get_ca_certificate_pem() - return ca_cert_pem - except Exception as e: - logger.error(f"Failed to retrieve CA certificate: {str(e)}", exc_info=True) - raise HTTPException( - status_code=500, detail="Failed to retrieve CA certificate." - ) from e \ No newline at end of file diff --git a/services/device_manager/pyproject.toml b/services/device_manager/pyproject.toml index 3ce8a5a..cc454e6 100644 --- a/services/device_manager/pyproject.toml +++ b/services/device_manager/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "device-manager" version = "0.1.0" -description = "Add your description here" +description = "Device Manager" readme = "README.md" requires-python = ">=3.13" dependencies = [