Skip to content

Instantly share code, notes, and snippets.

@sluipmoord
Last active January 9, 2026 07:54
Show Gist options
  • Select an option

  • Save sluipmoord/05bde6f54875283b4dfce040120ee2d6 to your computer and use it in GitHub Desktop.

Select an option

Save sluipmoord/05bde6f54875283b4dfce040120ee2d6 to your computer and use it in GitHub Desktop.
Custom PGP Encryption TypeDecorator for sqlalchemy and flask_sqlalchemy
import os
import sys
from typing import Any, ClassVar
from flask import current_app
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import ColumnElement, Dialect, FunctionElement, Integer, String, TypeDecorator, func, text, type_coerce
from sqlalchemy.dialects.postgresql import BYTEA, JSON, JSONB
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.sql.operators import OperatorType
class Base(DeclarativeBase):
def __repr__(self: "Base") -> str:
"""Return a string representation of the object."""
vals: str = ",".join(f"{k}={v!r}" for k, v in self.__dict__.items() if k != "_sa_instance_state")
return f"<{self.__class__.__name__}({vals})>"
class PGPEncryptString(TypeDecorator):
"""A type for storing encrypted strings in the database.
This type encrypts the data server side using the pgcrypto extension's pgp_sym_encrypt function when storing it in
the database. The data is decrypted using the pgp_sym_decrypt function when it is read from the database.
Typical usage::
class MyModel(Base):
__tablename__ = "my_model"
id = Column(Integer, primary_key=True)
encrypted_string: Mapped[str] = mapped_column(PGPEncryptString())
# Inserting data
encrypted_string = "my secret data"
my_model = MyModel(encrypted_string=encrypted_string)
db.session.add(my_model)
db.session.commit()
# Filtering data
my_model = db.session.query(MyModel).filter(
MyModel.encrypted_string == "my secret data"
).first()
"""
impl = String
cache_ok = True
class Comparator(TypeDecorator.Comparator):
"""Custom comparator for PGPEncryptString."""
def operate(
self: "PGPEncryptString.Comparator", op: OperatorType, other: Any, **kwargs: Any
) -> ColumnElement[Any]:
return op(pgp_sym_decrypt(self), other, **kwargs)
@property
def comparator_factory(self: "PGPEncryptString") -> Any:
return self.Comparator
def bind_expression(self: "PGPEncryptString", bindvalue: Any):
assert current_app.config["PGP_PASSPHRASE"], "PGP_PASSPHRASE must be set before using PGPString"
bindvalue = type_coerce(bindvalue, String)
return func.pgp_sym_encrypt(bindvalue, current_app.config["PGP_PASSPHRASE"], "cipher-algo=aes256, s2k-mode=1")
def column_expression(self: "PGPEncryptString", col: Any):
assert current_app.config["PGP_PASSPHRASE"], "PGP_PASSPHRASE must be set before using PGPString"
return pgp_sym_decrypt(col)
class PGPEncryptJSONB(TypeDecorator):
"""A type for storing encrypted JSONB data in the database.
This type encrypts the data server side using the pgcrypto extension's pgp_sym_encrypt function before storing it in
the database. It loops through the JSONB data and encrypts each value individually. This allows for searching and
filtering on the encrypted data.
The data is decrypted using the pgp_sym_decrypt function when it is read from the database. The type attempts to
infer the primitive type of the decrypted data and return it as a Python primitive type.
Typical usage::
class MyModel(Base):
__tablename__ = "my_model"
id = Column(Integer, primary_key=True)
encrypted_data: Mapped[str] = mapped_column(PGPEncryptJSONB())
# Inserting data
data = {"key": "value", "count": 1}
my_model = MyModel(encrypted_data=data)
db.session.add(my_model)
db.session.commit()
# Filtering data
my_model = db.session.query(MyModel).filter(
MyModel.encrypted_data["key"].astext == "value"
).first()
my_model = db.session.query(MyModel).filter(
MyModel.encrypted_data["count"].astext.cast(Integer) == 1
).first()
"""
impl = JSONB
cache_ok = True
class Comparator(JSON.Comparator):
"""Custom comparator for PGPEncryptJSONB."""
# add custom operators here
@property
def astext(self: "PGPEncryptJSONB.Comparator") -> Any:
res = super().astext
return pgp_sym_decrypt(res)
@property
def comparator_factory(self: "PGPEncryptJSONB") -> Any:
return self.Comparator
def infer_primitive_type(self: "PGPEncryptJSONB", value: str) -> bool | int | float | str:
if value.lower() in ["true", "false"]:
return bool(value)
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
def process_bind_param(self: "PGPEncryptJSONB", value: dict, dialect: Dialect) -> dict[Any, Any] | list[Any] | Any: # noqa: ARG002
assert current_app.config["PGP_PASSPHRASE"], "PGP_PASSPHRASE must be set before using PGPEncryptJSONB"
if value is None:
return value
def encrypt_value(val: Any) -> dict[Any, Any] | list[Any] | Any:
if isinstance(val, dict):
return {k: encrypt_value(v) for k, v in val.items()}
if isinstance(val, list):
return [encrypt_value(v) for v in val]
return db.session.query(
func.cast(
func.pgp_sym_encrypt(
func.cast(val, String), current_app.config["PGP_PASSPHRASE"], "cipher-algo=aes256, s2k-mode=1"
),
String,
)
).scalar()
return encrypt_value(value)
def process_result_value(
self: "PGPEncryptJSONB",
value: dict,
dialect: Dialect, # noqa: ARG002
) -> dict[Any, Any] | list[Any] | Any:
assert current_app.config["PGP_PASSPHRASE"], "PGP_PASSPHRASE must be set before using PGPEncryptJSONB"
if value is None:
return value
def decrypt_value(val: Any) -> dict[Any, Any] | list[Any] | Any:
if isinstance(val, dict):
return {k: decrypt_value(v) for k, v in val.items()}
if isinstance(val, list):
return [decrypt_value(v) for v in val]
result = db.session.query(
func.pgp_sym_decrypt(val, current_app.config["PGP_PASSPHRASE"], "cipher-algo=aes256, s2k-mode=1")
).scalar()
return self.infer_primitive_type(result)
return decrypt_value(value)
def pgp_sym_decrypt(col: Any) -> FunctionElement:
assert current_app.config["PGP_PASSPHRASE"], "PGP_PASSPHRASE must be set"
return func.pgp_sym_decrypt(func.cast(col, BYTEA), current_app.config["PGP_PASSPHRASE"])
db = SQLAlchemy(model_class=Base)
def create_app():
from flask import Flask
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("SQLALCHEMY_DATABASE_URI")
app.config["PGP_PASSPHRASE"] = os.getenv("PGP_PASSPHRASE")
app.logger.setLevel("DEBUG")
db.init_app(app)
return app
class EncryptionTest(Base):
__tablename__ = "encryption_test"
__table_args__: ClassVar = {"schema": "test_schema"}
id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False)
name: Mapped[str] = mapped_column(String)
encrypted_string: Mapped[str] = mapped_column(PGPEncryptString())
encrypted_json: Mapped[dict] = mapped_column(PGPEncryptJSONB())
if __name__ == "__main__":
app = create_app()
with app.app_context(), app.test_request_context():
# assumes you have the pgcrypto extension installed
db.session.execute(
text(
"""
DROP SCHEMA IF EXISTS test_schema CASCADE;
CREATE SCHEMA IF NOT EXISTS test_schema;
SET search_path TO test_schema, public;
"""
)
)
db.session.commit()
db.create_all()
test_model = EncryptionTest(
name="John Cena",
encrypted_string="You can't see me",
encrypted_json={
"movies": ["The Marine", "12 Rounds"],
"age": 47,
"nested": {
"key": "value",
"list": [1, 2, 3],
},
},
)
db.session.add(test_model)
db.session.commit()
# filter on the encrypted string data
result = db.session.query(EncryptionTest).filter(EncryptionTest.encrypted_string == "You can't see me").one()
current_app.logger.info(result)
assert result.name == "John Cena"
assert result.encrypted_json["age"] == 47
assert result.encrypted_string == "You can't see me"
# filer on the encrypted json data
result = (
db.session.query(EncryptionTest)
.filter(EncryptionTest.encrypted_json["age"].astext.cast(Integer) < 50)
.one()
)
current_app.logger.info(result)
assert result.name == "John Cena"
assert result.encrypted_string == "You can't see me"
assert result.encrypted_json["age"] == 47
db.session.execute(text("DROP SCHEMA test_schema CASCADE"))
db.session.commit()
sys.exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment