Skip to content

Instantly share code, notes, and snippets.

@mrorigo
Created February 19, 2026 08:35
Show Gist options
  • Select an option

  • Save mrorigo/a604ff8a42fafdc57be00e204a8c73f8 to your computer and use it in GitHub Desktop.

Select an option

Save mrorigo/a604ff8a42fafdc57be00e204a8c73f8 to your computer and use it in GitHub Desktop.
GLiNER2 REST API
from __future__ import annotations
import logging
import os
import sys
from functools import lru_cache
from typing import Any, Dict, List, Literal, Union
from fastapi import Depends, FastAPI, Header, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, field_validator
from gliner2 import GLiNER2
# -----------------------------
# Logging (stderr only)
# -----------------------------
logging.basicConfig(
level=os.getenv("GLINER2_API_LOG_LEVEL", "INFO").upper(),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
stream=sys.stderr,
force=True,
)
logger = logging.getLogger("gliner2_api")
# -----------------------------
# Config
# -----------------------------
MODEL_ID = os.getenv("GLINER2_MODEL_ID", "fastino/gliner2-base-v1")
DEVICE = os.getenv("GLINER2_DEVICE")
MAX_TEXT_BYTES = int(os.getenv("GLINER2_API_MAX_TEXT_BYTES", "8192"))
API_TOKEN = os.getenv("GLINER2_API_TOKEN", "")
# -----------------------------
# Pydantic models
# -----------------------------
TaskType = Literal["extract_entities", "classify_text", "extract_json"]
SchemaType = Union[Dict[str, Any], List[Any]]
class InferenceRequest(BaseModel):
task: TaskType
text: str = Field(min_length=1)
schema: SchemaType
threshold: float = Field(default=0.5, ge=0.0, le=1.0)
@field_validator("text")
@classmethod
def validate_text_size(cls, value: str) -> str:
if len(value.encode("utf-8")) > MAX_TEXT_BYTES:
raise ValueError(
f"text exceeds max size of {MAX_TEXT_BYTES} bytes; reduce payload size"
)
return value
class InferenceResponse(BaseModel):
result: Any
def error_payload(code: str, message: str) -> Dict[str, Dict[str, str]]:
return {"error": {"code": code, "message": message}}
# -----------------------------
# Model lifecycle
# -----------------------------
@lru_cache(maxsize=1)
def get_model() -> GLiNER2:
logger.info("Loading GLiNER2 model: %s", MODEL_ID)
if DEVICE:
try:
return GLiNER2.from_pretrained(MODEL_ID, device=DEVICE)
except TypeError:
logger.warning("device= not supported by this GLiNER2 build; fallback")
return GLiNER2.from_pretrained(MODEL_ID)
# -----------------------------
# Auth
# -----------------------------
def require_auth(authorization: str | None = Header(default=None)) -> None:
# If token not configured, auth is effectively disabled.
if not API_TOKEN:
return
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing bearer token")
token = authorization.removeprefix("Bearer ").strip()
if token != API_TOKEN:
raise HTTPException(status_code=401, detail="Invalid token")
# -----------------------------
# App + handlers
# -----------------------------
app = FastAPI(title="GLiNER2 API", version="1.0.0")
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
_request: Request, exc: RequestValidationError
) -> JSONResponse:
return JSONResponse(
status_code=400,
content=error_payload("INVALID_REQUEST", str(exc)),
)
@app.exception_handler(HTTPException)
async def http_exception_handler(_request: Request, exc: HTTPException) -> JSONResponse:
if exc.status_code == 401:
return JSONResponse(
status_code=401,
content=error_payload("UNAUTHORIZED", str(exc.detail)),
)
if exc.status_code == 400:
return JSONResponse(
status_code=400,
content=error_payload("INVALID_REQUEST", str(exc.detail)),
)
return JSONResponse(
status_code=exc.status_code,
content=error_payload("SERVER_ERROR", str(exc.detail)),
)
@app.exception_handler(Exception)
async def generic_exception_handler(_request: Request, exc: Exception) -> JSONResponse:
logger.exception("Unhandled inference error")
return JSONResponse(
status_code=500,
content=error_payload("SERVER_ERROR", str(exc)),
)
def validate_schema_for_task(task: TaskType, schema: SchemaType) -> None:
if task == "extract_entities":
if not isinstance(schema, list) or not all(isinstance(x, str) for x in schema):
raise HTTPException(
status_code=400,
detail="For task=extract_entities, schema must be list[str]",
)
return
if task in {"classify_text", "extract_json"}:
if not isinstance(schema, dict):
raise HTTPException(
status_code=400,
detail=f"For task={task}, schema must be an object",
)
return
@app.post("/gliner-2", response_model=InferenceResponse)
async def run_gliner2_inference(
payload: InferenceRequest,
_auth: None = Depends(require_auth),
) -> InferenceResponse:
model = get_model()
validate_schema_for_task(payload.task, payload.schema)
if payload.task == "extract_entities":
result = model.extract_entities(
payload.text,
payload.schema, # list[str]
threshold=payload.threshold,
)
# Keep output compact and consistent with docs expectation.
if isinstance(result, dict) and "entities" in result:
result = result["entities"]
return InferenceResponse(result=result)
if payload.task == "classify_text":
result = model.classify_text(
payload.text,
payload.schema, # dict task schema
threshold=payload.threshold,
)
return InferenceResponse(result=result)
# task == extract_json
result = model.extract_json(
payload.text,
payload.schema, # dict structures
threshold=payload.threshold,
)
return InferenceResponse(result=result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment