Created
February 19, 2026 08:35
-
-
Save mrorigo/a604ff8a42fafdc57be00e204a8c73f8 to your computer and use it in GitHub Desktop.
GLiNER2 REST API
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import logging | |
| import os | |
| import sys | |
| from functools import lru_cache | |
| from typing import Any, Dict, List, Literal, Union | |
| from fastapi import Depends, FastAPI, Header, HTTPException, Request | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field, field_validator | |
| from gliner2 import GLiNER2 | |
| # ----------------------------- | |
| # Logging (stderr only) | |
| # ----------------------------- | |
| logging.basicConfig( | |
| level=os.getenv("GLINER2_API_LOG_LEVEL", "INFO").upper(), | |
| format="%(asctime)s %(levelname)s %(name)s %(message)s", | |
| stream=sys.stderr, | |
| force=True, | |
| ) | |
| logger = logging.getLogger("gliner2_api") | |
| # ----------------------------- | |
| # Config | |
| # ----------------------------- | |
| MODEL_ID = os.getenv("GLINER2_MODEL_ID", "fastino/gliner2-base-v1") | |
| DEVICE = os.getenv("GLINER2_DEVICE") | |
| MAX_TEXT_BYTES = int(os.getenv("GLINER2_API_MAX_TEXT_BYTES", "8192")) | |
| API_TOKEN = os.getenv("GLINER2_API_TOKEN", "") | |
| # ----------------------------- | |
| # Pydantic models | |
| # ----------------------------- | |
| TaskType = Literal["extract_entities", "classify_text", "extract_json"] | |
| SchemaType = Union[Dict[str, Any], List[Any]] | |
| class InferenceRequest(BaseModel): | |
| task: TaskType | |
| text: str = Field(min_length=1) | |
| schema: SchemaType | |
| threshold: float = Field(default=0.5, ge=0.0, le=1.0) | |
| @field_validator("text") | |
| @classmethod | |
| def validate_text_size(cls, value: str) -> str: | |
| if len(value.encode("utf-8")) > MAX_TEXT_BYTES: | |
| raise ValueError( | |
| f"text exceeds max size of {MAX_TEXT_BYTES} bytes; reduce payload size" | |
| ) | |
| return value | |
| class InferenceResponse(BaseModel): | |
| result: Any | |
| def error_payload(code: str, message: str) -> Dict[str, Dict[str, str]]: | |
| return {"error": {"code": code, "message": message}} | |
| # ----------------------------- | |
| # Model lifecycle | |
| # ----------------------------- | |
| @lru_cache(maxsize=1) | |
| def get_model() -> GLiNER2: | |
| logger.info("Loading GLiNER2 model: %s", MODEL_ID) | |
| if DEVICE: | |
| try: | |
| return GLiNER2.from_pretrained(MODEL_ID, device=DEVICE) | |
| except TypeError: | |
| logger.warning("device= not supported by this GLiNER2 build; fallback") | |
| return GLiNER2.from_pretrained(MODEL_ID) | |
| # ----------------------------- | |
| # Auth | |
| # ----------------------------- | |
| def require_auth(authorization: str | None = Header(default=None)) -> None: | |
| # If token not configured, auth is effectively disabled. | |
| if not API_TOKEN: | |
| return | |
| if not authorization or not authorization.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="Missing bearer token") | |
| token = authorization.removeprefix("Bearer ").strip() | |
| if token != API_TOKEN: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| # ----------------------------- | |
| # App + handlers | |
| # ----------------------------- | |
| app = FastAPI(title="GLiNER2 API", version="1.0.0") | |
| @app.exception_handler(RequestValidationError) | |
| async def validation_exception_handler( | |
| _request: Request, exc: RequestValidationError | |
| ) -> JSONResponse: | |
| return JSONResponse( | |
| status_code=400, | |
| content=error_payload("INVALID_REQUEST", str(exc)), | |
| ) | |
| @app.exception_handler(HTTPException) | |
| async def http_exception_handler(_request: Request, exc: HTTPException) -> JSONResponse: | |
| if exc.status_code == 401: | |
| return JSONResponse( | |
| status_code=401, | |
| content=error_payload("UNAUTHORIZED", str(exc.detail)), | |
| ) | |
| if exc.status_code == 400: | |
| return JSONResponse( | |
| status_code=400, | |
| content=error_payload("INVALID_REQUEST", str(exc.detail)), | |
| ) | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content=error_payload("SERVER_ERROR", str(exc.detail)), | |
| ) | |
| @app.exception_handler(Exception) | |
| async def generic_exception_handler(_request: Request, exc: Exception) -> JSONResponse: | |
| logger.exception("Unhandled inference error") | |
| return JSONResponse( | |
| status_code=500, | |
| content=error_payload("SERVER_ERROR", str(exc)), | |
| ) | |
| def validate_schema_for_task(task: TaskType, schema: SchemaType) -> None: | |
| if task == "extract_entities": | |
| if not isinstance(schema, list) or not all(isinstance(x, str) for x in schema): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="For task=extract_entities, schema must be list[str]", | |
| ) | |
| return | |
| if task in {"classify_text", "extract_json"}: | |
| if not isinstance(schema, dict): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"For task={task}, schema must be an object", | |
| ) | |
| return | |
| @app.post("/gliner-2", response_model=InferenceResponse) | |
| async def run_gliner2_inference( | |
| payload: InferenceRequest, | |
| _auth: None = Depends(require_auth), | |
| ) -> InferenceResponse: | |
| model = get_model() | |
| validate_schema_for_task(payload.task, payload.schema) | |
| if payload.task == "extract_entities": | |
| result = model.extract_entities( | |
| payload.text, | |
| payload.schema, # list[str] | |
| threshold=payload.threshold, | |
| ) | |
| # Keep output compact and consistent with docs expectation. | |
| if isinstance(result, dict) and "entities" in result: | |
| result = result["entities"] | |
| return InferenceResponse(result=result) | |
| if payload.task == "classify_text": | |
| result = model.classify_text( | |
| payload.text, | |
| payload.schema, # dict task schema | |
| threshold=payload.threshold, | |
| ) | |
| return InferenceResponse(result=result) | |
| # task == extract_json | |
| result = model.extract_json( | |
| payload.text, | |
| payload.schema, # dict structures | |
| threshold=payload.threshold, | |
| ) | |
| return InferenceResponse(result=result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment