mirror of
https://github.com/ferdzo/iotDashboard.git
synced 2026-04-05 17:16:26 +00:00
Functioning device manager with renew,revoke, updated model for cert id
This commit is contained in:
270
services/device_manager/app/app.py
Normal file
270
services/device_manager/app/app.py
Normal file
@@ -0,0 +1,270 @@
|
||||
import datetime
|
||||
import logging
|
||||
|
||||
from db_models import Device, DeviceCertificate # SQLAlchemy ORM models
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
from cert_manager import CertificateManager
|
||||
from database import get_db_context
|
||||
from models import DeviceRegistrationRequest, DeviceRegistrationResponse, DeviceResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
cert_manager = CertificateManager()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def hello():
|
||||
return {"Hello": "World"}
|
||||
|
||||
|
||||
@app.post("/devices/register")
|
||||
async def register_device(
|
||||
request: DeviceRegistrationRequest,
|
||||
) -> DeviceRegistrationResponse:
|
||||
"""
|
||||
Register a new device and issue an X.509 certificate.
|
||||
"""
|
||||
try:
|
||||
response = cert_manager.register_device(
|
||||
name=request.name,
|
||||
location=request.location,
|
||||
)
|
||||
|
||||
with get_db_context() as db:
|
||||
device = Device(
|
||||
id=response.device_id,
|
||||
name=request.name,
|
||||
location=request.location,
|
||||
created_at=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
db.add(device)
|
||||
|
||||
device_cert = DeviceCertificate(
|
||||
id =response.certificate_id,
|
||||
device_id=response.device_id,
|
||||
certificate_pem=response.certificate_pem,
|
||||
private_key_pem=response.private_key_pem,
|
||||
issued_at=datetime.datetime.now(datetime.UTC),
|
||||
expires_at=response.expires_at,
|
||||
)
|
||||
db.add(device_cert)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to register device {request.name}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to register device. Please try again."
|
||||
) from e
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/ca_certificate")
|
||||
async def get_ca_certificate() -> str:
|
||||
"""
|
||||
Retrieve the CA certificate in PEM format.
|
||||
"""
|
||||
try:
|
||||
ca_cert_pem = cert_manager.get_ca_certificate_pem()
|
||||
return ca_cert_pem
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve CA certificate: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to retrieve CA certificate."
|
||||
) from e
|
||||
|
||||
@app.get("/devices/{device_id}")
|
||||
async def get_device(device_id: str) -> DeviceResponse:
|
||||
"""
|
||||
Retrieve device information by ID.
|
||||
"""
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
device = db.query(Device).filter(Device.id == device_id).first()
|
||||
if not device:
|
||||
raise HTTPException(status_code=404, detail="Device not found")
|
||||
|
||||
return Device(
|
||||
id=device.id,
|
||||
name=device.name,
|
||||
location=device.location,
|
||||
created_at=device.created_at,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve device {device_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to retrieve device information."
|
||||
) from e
|
||||
|
||||
@app.get("/devices/")
|
||||
async def list_devices() -> list[DeviceResponse]:
|
||||
"""
|
||||
List all registered devices.
|
||||
"""
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
devices = db.query(Device).all()
|
||||
return [
|
||||
DeviceResponse(
|
||||
id=device.id,
|
||||
name=device.name,
|
||||
location=device.location,
|
||||
created_at=device.created_at,
|
||||
)
|
||||
for device in devices
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list devices: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to list devices."
|
||||
) from e
|
||||
|
||||
@app.post("/devices/{device_id}/revoke")
|
||||
async def revoke_device_certificate(device_id: str):
|
||||
"""
|
||||
Revoke a device's certificate by:
|
||||
1. Marking it as revoked in the database
|
||||
2. Adding it to the Certificate Revocation List (CRL)
|
||||
"""
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
device_cert = (
|
||||
db.query(DeviceCertificate)
|
||||
.filter(DeviceCertificate.device_id == device_id)
|
||||
.first()
|
||||
)
|
||||
if not device_cert:
|
||||
raise HTTPException(status_code=404, detail="Device certificate not found")
|
||||
|
||||
if device_cert.revoked_at:
|
||||
raise HTTPException(status_code=400, detail="Certificate already revoked")
|
||||
|
||||
cert_manager.revoke_certificate(device_cert.certificate_pem)
|
||||
|
||||
device_cert.revoked_at = datetime.datetime.now(datetime.UTC)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Successfully revoked certificate for device {device_id}")
|
||||
|
||||
return {
|
||||
"device_id": device_id,
|
||||
"revoked_at": device_cert.revoked_at.isoformat(),
|
||||
"message": "Certificate revoked successfully"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke device {device_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to revoke device certificate."
|
||||
) from e
|
||||
|
||||
|
||||
@app.get("/crl")
|
||||
async def get_crl():
|
||||
"""
|
||||
Get the Certificate Revocation List (CRL) in PEM format.
|
||||
Mosquitto and other MQTT clients can check this to validate certificates.
|
||||
"""
|
||||
try:
|
||||
crl_pem = cert_manager.get_crl_pem()
|
||||
if not crl_pem:
|
||||
return {"message": "No certificates have been revoked yet"}
|
||||
return {"crl_pem": crl_pem}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve CRL: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to retrieve CRL."
|
||||
) from e
|
||||
|
||||
@app.post("/devices/{device_id}/renew")
|
||||
async def renew_certificate(device_id: str):
|
||||
"""
|
||||
Renew a device's certificate by issuing a new one and revoking the old one.
|
||||
|
||||
This endpoint:
|
||||
1. Retrieves the current certificate from DB
|
||||
2. Generates a new certificate with new keys
|
||||
3. Revokes the old certificate (adds to CRL)
|
||||
4. Updates the database with the new certificate
|
||||
5. Returns the new credentials
|
||||
"""
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
# Get current certificate
|
||||
device_cert = (
|
||||
db.query(DeviceCertificate)
|
||||
.filter(
|
||||
DeviceCertificate.device_id == device_id,
|
||||
# DeviceCertificate.revoked_at.is_(None)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not device_cert:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No active certificate found for device"
|
||||
)
|
||||
|
||||
# Check if certificate is about to expire (optional warning)
|
||||
days_until_expiry = (device_cert.expires_at - datetime.datetime.now(datetime.UTC)).days
|
||||
if days_until_expiry > 30:
|
||||
logger.warning(
|
||||
f"Certificate for device {device_id} renewed early "
|
||||
f"({days_until_expiry} days remaining)"
|
||||
)
|
||||
|
||||
# Revoke old certificate and add to CRL
|
||||
cert_manager.revoke_certificate(device_cert.certificate_pem)
|
||||
device_cert.revoked_at = datetime.datetime.now(datetime.UTC)
|
||||
|
||||
# Generate new certificate with new keys
|
||||
new_cert_pem, new_key_pem = cert_manager.renew_certificate(
|
||||
current_cert_pem=device_cert.certificate_pem,
|
||||
validity_days=365,
|
||||
key_size=4096
|
||||
)
|
||||
|
||||
# Extract certificate ID (serial number) from the new certificate
|
||||
from cryptography import x509
|
||||
new_cert = x509.load_pem_x509_certificate(new_cert_pem)
|
||||
new_cert_id = format(new_cert.serial_number, 'x')
|
||||
|
||||
# Create new certificate record in DB
|
||||
now = datetime.datetime.now(datetime.UTC)
|
||||
new_device_cert = DeviceCertificate(
|
||||
id=new_cert_id,
|
||||
device_id=device_id,
|
||||
certificate_pem=new_cert_pem.decode("utf-8"),
|
||||
private_key_pem=new_key_pem.decode("utf-8"),
|
||||
issued_at=now,
|
||||
expires_at=now + datetime.timedelta(days=365),
|
||||
)
|
||||
db.add(new_device_cert)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Successfully renewed certificate for device {device_id}")
|
||||
|
||||
return DeviceRegistrationResponse(
|
||||
certificate_id=new_cert_id,
|
||||
device_id=device_id,
|
||||
ca_certificate_pem=cert_manager.get_ca_certificate_pem(),
|
||||
certificate_pem=new_device_cert.certificate_pem,
|
||||
private_key_pem=new_device_cert.private_key_pem,
|
||||
expires_at=new_device_cert.expires_at,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to renew certificate for device {device_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to renew device certificate."
|
||||
) from e
|
||||
240
services/device_manager/app/cert_manager.py
Normal file
240
services/device_manager/app/cert_manager.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.x509.oid import NameOID
|
||||
from nanoid import generate
|
||||
|
||||
from config import config
|
||||
from models import DeviceRegistrationResponse
|
||||
|
||||
lowercase_numbers = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
|
||||
class CertificateManager:
|
||||
"""Manages device certificate generation and handling"""
|
||||
|
||||
def __init__(self):
|
||||
self.ca_cert: x509.Certificate = self.load_ca_certificate(config.CA_CERT_PATH)
|
||||
self.ca_key: rsa.RSAPrivateKey = self.load_ca_private_key(config.CA_KEY_PATH)
|
||||
self.ca_cert_pem: bytes = self.ca_cert.public_bytes(serialization.Encoding.PEM)
|
||||
|
||||
def generate_device_id(self) -> str:
|
||||
"""Generate a unique device ID using nanoid."""
|
||||
return generate(alphabet=lowercase_numbers, size=config.DEVICE_ID_LENGTH)
|
||||
|
||||
def load_ca_certificate(self, ca_cert_path: str) -> x509.Certificate:
|
||||
"""Load a CA certificate from file."""
|
||||
with open(ca_cert_path, "rb") as f:
|
||||
ca_data = f.read()
|
||||
ca_cert = x509.load_pem_x509_certificate(ca_data)
|
||||
return ca_cert
|
||||
|
||||
def load_ca_private_key(self, ca_key_path: str, password: bytes = None) -> rsa.RSAPrivateKey:
|
||||
"""Load a CA private key from file."""
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
with open(ca_key_path, "rb") as f:
|
||||
key_data = f.read()
|
||||
ca_key = serialization.load_pem_private_key(key_data, password=password)
|
||||
return ca_key
|
||||
|
||||
def generate_device_key(self, key_size: int = 4096) -> rsa.RSAPrivateKey:
|
||||
"""Generate an RSA private key for a device."""
|
||||
return rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
||||
|
||||
def generate_device_certificate(
|
||||
self,
|
||||
device_id: str,
|
||||
ca_cert: x509.Certificate,
|
||||
ca_key: rsa.RSAPrivateKey,
|
||||
device_key: rsa.RSAPrivateKey,
|
||||
validity_days: int = 365,
|
||||
key_size: int = 4096,
|
||||
) -> tuple[bytes, bytes]:
|
||||
"""Generate an X.509 certificate for a device signed by the CA."""
|
||||
|
||||
subject = x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, device_id),
|
||||
]
|
||||
)
|
||||
|
||||
issuer = ca_cert.subject
|
||||
now = datetime.datetime.now(datetime.UTC)
|
||||
device_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(issuer)
|
||||
.public_key(device_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + datetime.timedelta(days=validity_days))
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=False, path_length=None),
|
||||
critical=True,
|
||||
)
|
||||
.sign(private_key=ca_key, algorithm=hashes.SHA256())
|
||||
)
|
||||
|
||||
cert_pem = device_cert.public_bytes(serialization.Encoding.PEM)
|
||||
key_pem = device_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
return cert_pem, key_pem
|
||||
|
||||
def create_device_credentials(
|
||||
self, device_id: str, validity_days: int = 365, key_size: int = 4096
|
||||
) -> dict:
|
||||
"""Create device credentials: private key and signed certificate.
|
||||
Returns:
|
||||
dict with certificate_id, device_id, certificate_pem, private_key_pem, ca_certificate_pem, expires_at
|
||||
"""
|
||||
device_key = self.generate_device_key(key_size=key_size)
|
||||
|
||||
cert_pem, key_pem = self.generate_device_certificate(
|
||||
device_id=device_id,
|
||||
ca_cert=self.ca_cert,
|
||||
ca_key=self.ca_key,
|
||||
device_key=device_key,
|
||||
validity_days=validity_days,
|
||||
key_size=key_size,
|
||||
)
|
||||
|
||||
# Extract serial number from certificate to use as ID
|
||||
cert = x509.load_pem_x509_certificate(cert_pem)
|
||||
cert_id = format(cert.serial_number, 'x') # Hex string of serial number
|
||||
|
||||
expires_at = datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=validity_days)
|
||||
|
||||
return {
|
||||
"certificate_id": cert_id,
|
||||
"device_id": device_id,
|
||||
"certificate_pem": cert_pem,
|
||||
"private_key_pem": key_pem,
|
||||
"ca_certificate_pem": self.ca_cert_pem,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
def register_device(self, name: str, location: str | None = None) -> DeviceRegistrationResponse:
|
||||
"""Register a new device and generate its credentials.
|
||||
Returns:
|
||||
DeviceRegistrationResponse
|
||||
"""
|
||||
device_id = self.generate_device_id()
|
||||
credentials = self.create_device_credentials(device_id=device_id)
|
||||
|
||||
return DeviceRegistrationResponse(
|
||||
certificate_id=credentials["certificate_id"],
|
||||
device_id=credentials["device_id"],
|
||||
ca_certificate_pem=credentials["ca_certificate_pem"].decode("utf-8"),
|
||||
certificate_pem=credentials["certificate_pem"].decode("utf-8"),
|
||||
private_key_pem=credentials["private_key_pem"].decode("utf-8"),
|
||||
expires_at=credentials["expires_at"],
|
||||
)
|
||||
|
||||
def get_ca_certificate_pem(self) -> str:
|
||||
"""Get the CA certificate in PEM format as a string."""
|
||||
return self.ca_cert_pem.decode("utf-8")
|
||||
|
||||
def revoke_certificate(
|
||||
self, certificate_pem: str, reason: x509.ReasonFlags = x509.ReasonFlags.unspecified
|
||||
) -> None:
|
||||
"""
|
||||
Revoke a device certificate by adding it to the CRL.
|
||||
|
||||
Args:
|
||||
certificate_pem: PEM-encoded certificate to revoke
|
||||
reason: Revocation reason (default: unspecified)
|
||||
"""
|
||||
# Load the certificate to get serial number
|
||||
cert = x509.load_pem_x509_certificate(certificate_pem.encode())
|
||||
|
||||
# Load existing CRL or create new one
|
||||
crl_path = Path(config.CRL_PATH)
|
||||
revoked_certs = []
|
||||
|
||||
if crl_path.exists():
|
||||
with open(crl_path, "rb") as f:
|
||||
existing_crl = x509.load_pem_x509_crl(f.read())
|
||||
# Copy existing revoked certificates
|
||||
revoked_certs = list(existing_crl)
|
||||
|
||||
# Add the new revoked certificate
|
||||
revoked_cert = (
|
||||
x509.RevokedCertificateBuilder()
|
||||
.serial_number(cert.serial_number)
|
||||
.revocation_date(datetime.datetime.now(datetime.UTC))
|
||||
.add_extension(
|
||||
x509.CRLReason(reason),
|
||||
critical=False,
|
||||
)
|
||||
.build()
|
||||
)
|
||||
revoked_certs.append(revoked_cert)
|
||||
|
||||
# Build new CRL with all revoked certificates
|
||||
crl_builder = x509.CertificateRevocationListBuilder()
|
||||
crl_builder = crl_builder.issuer_name(self.ca_cert.subject)
|
||||
crl_builder = crl_builder.last_update(datetime.datetime.now(datetime.UTC))
|
||||
crl_builder = crl_builder.next_update(
|
||||
datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=30)
|
||||
)
|
||||
|
||||
for revoked in revoked_certs:
|
||||
crl_builder = crl_builder.add_revoked_certificate(revoked)
|
||||
|
||||
# Sign the CRL with CA key
|
||||
crl = crl_builder.sign(private_key=self.ca_key, algorithm=hashes.SHA256())
|
||||
|
||||
# Write CRL to file
|
||||
crl_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(crl_path, "wb") as f:
|
||||
f.write(crl.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
def get_crl_pem(self) -> str | None:
|
||||
"""Get the current CRL in PEM format."""
|
||||
crl_path = Path(config.CRL_PATH)
|
||||
if not crl_path.exists():
|
||||
return None
|
||||
|
||||
with open(crl_path, "rb") as f:
|
||||
return f.read().decode("utf-8")
|
||||
|
||||
def renew_certificate(
|
||||
self,
|
||||
current_cert_pem: str,
|
||||
validity_days: int = 365,
|
||||
key_size: int = 4096,
|
||||
) -> tuple[bytes, bytes]:
|
||||
"""Renew a device certificate before expiration.
|
||||
Args:
|
||||
current_cert_pem: PEM-encoded current certificate
|
||||
validity_days: Validity period for new certificate
|
||||
key_size: Key size for new device key
|
||||
Returns:
|
||||
tuple of (new_cert_pem, new_key_pem)
|
||||
"""
|
||||
# Load current certificate
|
||||
current_cert = x509.load_pem_x509_certificate(current_cert_pem.encode())
|
||||
device_id = current_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
|
||||
|
||||
# Generate new device key
|
||||
new_device_key = self.generate_device_key(key_size=key_size)
|
||||
|
||||
# Generate new device certificate
|
||||
new_cert_pem, new_key_pem = self.generate_device_certificate(
|
||||
device_id=device_id,
|
||||
ca_cert=self.ca_cert,
|
||||
ca_key=self.ca_key,
|
||||
device_key=new_device_key,
|
||||
validity_days=validity_days,
|
||||
key_size=key_size,
|
||||
)
|
||||
|
||||
return new_cert_pem, new_key_pem
|
||||
33
services/device_manager/app/config.py
Normal file
33
services/device_manager/app/config.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Config:
|
||||
"""Configuration settings for the Device Manager service."""
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
|
||||
SERVICE_DIR = Path(__file__).parent
|
||||
CERTS_DIR = SERVICE_DIR / "certs"
|
||||
CA_CERT_PATH = os.getenv("CA_CERT_PATH", str(CERTS_DIR / "ca.crt"))
|
||||
CA_KEY_PATH = os.getenv("CA_KEY_PATH", str(CERTS_DIR / "ca.key"))
|
||||
CRL_PATH = os.getenv("CRL_PATH", str(CERTS_DIR / "ca.crl"))
|
||||
|
||||
# Certificate settings
|
||||
CERT_VALIDITY_DAYS = int(os.getenv("CERT_VALIDITY_DAYS", "365"))
|
||||
CERT_KEY_SIZE = int(os.getenv("CERT_KEY_SIZE", "4096"))
|
||||
|
||||
# Device ID settings
|
||||
DEVICE_ID_LENGTH = int(os.getenv("DEVICE_ID_LENGTH", "8"))
|
||||
|
||||
# Service settings
|
||||
SERVICE_HOST = os.getenv("DEVICE_MANAGER_HOST", "0.0.0.0")
|
||||
SERVICE_PORT = int(os.getenv("DEVICE_MANAGER_PORT", "8000"))
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||
|
||||
|
||||
config = Config()
|
||||
83
services/device_manager/app/database.py
Normal file
83
services/device_manager/app/database.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Database session management for FastAPI with SQLAlchemy.
|
||||
|
||||
Uses dependency injection pattern for database sessions.
|
||||
"""
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from config import config
|
||||
|
||||
# Create engine with connection pooling
|
||||
engine = create_engine(
|
||||
config.DATABASE_URL,
|
||||
pool_pre_ping=True,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
)
|
||||
|
||||
|
||||
def get_db() -> Generator[Session]:
|
||||
"""
|
||||
FastAPI dependency that provides a database session.
|
||||
|
||||
Usage in endpoints:
|
||||
@app.post("/devices")
|
||||
async def create_device(db: Session = Depends(get_db)):
|
||||
device = Device(...)
|
||||
db.add(device)
|
||||
db.commit()
|
||||
return device
|
||||
|
||||
The session is automatically closed after the request completes.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_context():
|
||||
"""
|
||||
Context manager for database sessions outside of FastAPI endpoints.
|
||||
Usage:
|
||||
with get_db_context() as db:
|
||||
device = db.query(Device).first()
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def check_db_connection() -> bool:
|
||||
"""
|
||||
Check if database connection is working.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Database connection failed: {e}")
|
||||
return False
|
||||
45
services/device_manager/app/db_models.py
Normal file
45
services/device_manager/app/db_models.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
SQLAlchemy ORM models for device manager service.
|
||||
|
||||
These models mirror the database schema defined in db_migrations.
|
||||
Kept separate to make the service independent.
|
||||
"""
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class Device(Base):
|
||||
"""IoT devices registered in the system."""
|
||||
|
||||
__tablename__ = "devices"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
name = Column(Text, nullable=False)
|
||||
location = Column(Text)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Device(id={self.id}, name={self.name})>"
|
||||
|
||||
|
||||
class DeviceCertificate(Base):
|
||||
"""X.509 certificates issued to devices for mTLS authentication."""
|
||||
|
||||
__tablename__ = "device_certificates"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
device_id = Column(
|
||||
Text, ForeignKey("devices.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
certificate_pem = Column(Text, nullable=False)
|
||||
private_key_pem = Column(Text)
|
||||
issued_at = Column(DateTime(timezone=True), nullable=False)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
revoked_at = Column(DateTime(timezone=True))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DeviceCertificate(device_id={self.device_id}, expires={self.expires_at})>"
|
||||
27
services/device_manager/app/models.py
Normal file
27
services/device_manager/app/models.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
class DeviceRegistrationRequest(BaseModel):
|
||||
"""Request model for registering a new device."""
|
||||
|
||||
name: str
|
||||
location: str | None = None
|
||||
|
||||
class DeviceRegistrationResponse(BaseModel):
|
||||
"""Response model after registering a new device."""
|
||||
|
||||
certificate_id: str
|
||||
device_id: str
|
||||
ca_certificate_pem: str
|
||||
certificate_pem: str
|
||||
private_key_pem: str
|
||||
expires_at: datetime.datetime
|
||||
|
||||
class DeviceResponse(BaseModel):
|
||||
"""Response model for device information."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
location: str | None = None
|
||||
created_at: datetime.datetime
|
||||
Reference in New Issue
Block a user