Last active
January 9, 2026 07:54
-
-
Save sluipmoord/05bde6f54875283b4dfce040120ee2d6 to your computer and use it in GitHub Desktop.
Custom PGP Encryption TypeDecorator for sqlalchemy and flask_sqlalchemy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import sys | |
| from typing import Any, ClassVar | |
| from flask import current_app | |
| from flask_sqlalchemy import SQLAlchemy | |
| from sqlalchemy import ColumnElement, Dialect, FunctionElement, Integer, String, TypeDecorator, func, text, type_coerce | |
| from sqlalchemy.dialects.postgresql import BYTEA, JSON, JSONB | |
| from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column | |
| from sqlalchemy.sql.operators import OperatorType | |
| class Base(DeclarativeBase): | |
| def __repr__(self: "Base") -> str: | |
| """Return a string representation of the object.""" | |
| vals: str = ",".join(f"{k}={v!r}" for k, v in self.__dict__.items() if k != "_sa_instance_state") | |
| return f"<{self.__class__.__name__}({vals})>" | |
| class PGPEncryptString(TypeDecorator): | |
| """A type for storing encrypted strings in the database. | |
| This type encrypts the data server side using the pgcrypto extension's pgp_sym_encrypt function when storing it in | |
| the database. The data is decrypted using the pgp_sym_decrypt function when it is read from the database. | |
| Typical usage:: | |
| class MyModel(Base): | |
| __tablename__ = "my_model" | |
| id = Column(Integer, primary_key=True) | |
| encrypted_string: Mapped[str] = mapped_column(PGPEncryptString()) | |
| # Inserting data | |
| encrypted_string = "my secret data" | |
| my_model = MyModel(encrypted_string=encrypted_string) | |
| db.session.add(my_model) | |
| db.session.commit() | |
| # Filtering data | |
| my_model = db.session.query(MyModel).filter( | |
| MyModel.encrypted_string == "my secret data" | |
| ).first() | |
| """ | |
| impl = String | |
| cache_ok = True | |
| class Comparator(TypeDecorator.Comparator): | |
| """Custom comparator for PGPEncryptString.""" | |
| def operate( | |
| self: "PGPEncryptString.Comparator", op: OperatorType, other: Any, **kwargs: Any | |
| ) -> ColumnElement[Any]: | |
| return op(pgp_sym_decrypt(self), other, **kwargs) | |
| @property | |
| def comparator_factory(self: "PGPEncryptString") -> Any: | |
| return self.Comparator | |
| def bind_expression(self: "PGPEncryptString", bindvalue: Any): | |
| assert current_app.config["PGP_PASSPHRASE"], "PGP_PASSPHRASE must be set before using PGPString" | |
| bindvalue = type_coerce(bindvalue, String) | |
| return func.pgp_sym_encrypt(bindvalue, current_app.config["PGP_PASSPHRASE"], "cipher-algo=aes256, s2k-mode=1") | |
| def column_expression(self: "PGPEncryptString", col: Any): | |
| assert current_app.config["PGP_PASSPHRASE"], "PGP_PASSPHRASE must be set before using PGPString" | |
| return pgp_sym_decrypt(col) | |
| class PGPEncryptJSONB(TypeDecorator): | |
| """A type for storing encrypted JSONB data in the database. | |
| This type encrypts the data server side using the pgcrypto extension's pgp_sym_encrypt function before storing it in | |
| the database. It loops through the JSONB data and encrypts each value individually. This allows for searching and | |
| filtering on the encrypted data. | |
| The data is decrypted using the pgp_sym_decrypt function when it is read from the database. The type attempts to | |
| infer the primitive type of the decrypted data and return it as a Python primitive type. | |
| Typical usage:: | |
| class MyModel(Base): | |
| __tablename__ = "my_model" | |
| id = Column(Integer, primary_key=True) | |
| encrypted_data: Mapped[str] = mapped_column(PGPEncryptJSONB()) | |
| # Inserting data | |
| data = {"key": "value", "count": 1} | |
| my_model = MyModel(encrypted_data=data) | |
| db.session.add(my_model) | |
| db.session.commit() | |
| # Filtering data | |
| my_model = db.session.query(MyModel).filter( | |
| MyModel.encrypted_data["key"].astext == "value" | |
| ).first() | |
| my_model = db.session.query(MyModel).filter( | |
| MyModel.encrypted_data["count"].astext.cast(Integer) == 1 | |
| ).first() | |
| """ | |
| impl = JSONB | |
| cache_ok = True | |
| class Comparator(JSON.Comparator): | |
| """Custom comparator for PGPEncryptJSONB.""" | |
| # add custom operators here | |
| @property | |
| def astext(self: "PGPEncryptJSONB.Comparator") -> Any: | |
| res = super().astext | |
| return pgp_sym_decrypt(res) | |
| @property | |
| def comparator_factory(self: "PGPEncryptJSONB") -> Any: | |
| return self.Comparator | |
| def infer_primitive_type(self: "PGPEncryptJSONB", value: str) -> bool | int | float | str: | |
| if value.lower() in ["true", "false"]: | |
| return bool(value) | |
| try: | |
| return int(value) | |
| except ValueError: | |
| pass | |
| try: | |
| return float(value) | |
| except ValueError: | |
| pass | |
| return value | |
| def process_bind_param(self: "PGPEncryptJSONB", value: dict, dialect: Dialect) -> dict[Any, Any] | list[Any] | Any: # noqa: ARG002 | |
| assert current_app.config["PGP_PASSPHRASE"], "PGP_PASSPHRASE must be set before using PGPEncryptJSONB" | |
| if value is None: | |
| return value | |
| def encrypt_value(val: Any) -> dict[Any, Any] | list[Any] | Any: | |
| if isinstance(val, dict): | |
| return {k: encrypt_value(v) for k, v in val.items()} | |
| if isinstance(val, list): | |
| return [encrypt_value(v) for v in val] | |
| return db.session.query( | |
| func.cast( | |
| func.pgp_sym_encrypt( | |
| func.cast(val, String), current_app.config["PGP_PASSPHRASE"], "cipher-algo=aes256, s2k-mode=1" | |
| ), | |
| String, | |
| ) | |
| ).scalar() | |
| return encrypt_value(value) | |
| def process_result_value( | |
| self: "PGPEncryptJSONB", | |
| value: dict, | |
| dialect: Dialect, # noqa: ARG002 | |
| ) -> dict[Any, Any] | list[Any] | Any: | |
| assert current_app.config["PGP_PASSPHRASE"], "PGP_PASSPHRASE must be set before using PGPEncryptJSONB" | |
| if value is None: | |
| return value | |
| def decrypt_value(val: Any) -> dict[Any, Any] | list[Any] | Any: | |
| if isinstance(val, dict): | |
| return {k: decrypt_value(v) for k, v in val.items()} | |
| if isinstance(val, list): | |
| return [decrypt_value(v) for v in val] | |
| result = db.session.query( | |
| func.pgp_sym_decrypt(val, current_app.config["PGP_PASSPHRASE"], "cipher-algo=aes256, s2k-mode=1") | |
| ).scalar() | |
| return self.infer_primitive_type(result) | |
| return decrypt_value(value) | |
| def pgp_sym_decrypt(col: Any) -> FunctionElement: | |
| assert current_app.config["PGP_PASSPHRASE"], "PGP_PASSPHRASE must be set" | |
| return func.pgp_sym_decrypt(func.cast(col, BYTEA), current_app.config["PGP_PASSPHRASE"]) | |
| db = SQLAlchemy(model_class=Base) | |
| def create_app(): | |
| from flask import Flask | |
| app = Flask(__name__) | |
| app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("SQLALCHEMY_DATABASE_URI") | |
| app.config["PGP_PASSPHRASE"] = os.getenv("PGP_PASSPHRASE") | |
| app.logger.setLevel("DEBUG") | |
| db.init_app(app) | |
| return app | |
| class EncryptionTest(Base): | |
| __tablename__ = "encryption_test" | |
| __table_args__: ClassVar = {"schema": "test_schema"} | |
| id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) | |
| name: Mapped[str] = mapped_column(String) | |
| encrypted_string: Mapped[str] = mapped_column(PGPEncryptString()) | |
| encrypted_json: Mapped[dict] = mapped_column(PGPEncryptJSONB()) | |
| if __name__ == "__main__": | |
| app = create_app() | |
| with app.app_context(), app.test_request_context(): | |
| # assumes you have the pgcrypto extension installed | |
| db.session.execute( | |
| text( | |
| """ | |
| DROP SCHEMA IF EXISTS test_schema CASCADE; | |
| CREATE SCHEMA IF NOT EXISTS test_schema; | |
| SET search_path TO test_schema, public; | |
| """ | |
| ) | |
| ) | |
| db.session.commit() | |
| db.create_all() | |
| test_model = EncryptionTest( | |
| name="John Cena", | |
| encrypted_string="You can't see me", | |
| encrypted_json={ | |
| "movies": ["The Marine", "12 Rounds"], | |
| "age": 47, | |
| "nested": { | |
| "key": "value", | |
| "list": [1, 2, 3], | |
| }, | |
| }, | |
| ) | |
| db.session.add(test_model) | |
| db.session.commit() | |
| # filter on the encrypted string data | |
| result = db.session.query(EncryptionTest).filter(EncryptionTest.encrypted_string == "You can't see me").one() | |
| current_app.logger.info(result) | |
| assert result.name == "John Cena" | |
| assert result.encrypted_json["age"] == 47 | |
| assert result.encrypted_string == "You can't see me" | |
| # filer on the encrypted json data | |
| result = ( | |
| db.session.query(EncryptionTest) | |
| .filter(EncryptionTest.encrypted_json["age"].astext.cast(Integer) < 50) | |
| .one() | |
| ) | |
| current_app.logger.info(result) | |
| assert result.name == "John Cena" | |
| assert result.encrypted_string == "You can't see me" | |
| assert result.encrypted_json["age"] == 47 | |
| db.session.execute(text("DROP SCHEMA test_schema CASCADE")) | |
| db.session.commit() | |
| sys.exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment