Skip to content

Instantly share code, notes, and snippets.

@hyunsik
Last active December 10, 2025 08:15
Show Gist options
  • Select an option

  • Save hyunsik/d7af65c18dfc7456753a497d8a79f4ed to your computer and use it in GitHub Desktop.

Select an option

Save hyunsik/d7af65c18dfc7456753a497d8a79f4ed to your computer and use it in GitHub Desktop.
vLLM Model Conversion (Pooling Model Adapter) 분석

vLLM Model Conversion (Pooling Model Adapter) 분석

기준 커밋: 83319b44c

개요

vLLM은 --convert <type> 옵션을 통해 기존 Text Generation 모델(예: *ForCausalLM)을 Pooling 모델(Embedding, Classification, Reward)로 변환하는 기능을 제공합니다. 이는 Adapter 패턴을 사용하여 원본 모델 코드 수정 없이 동적으로 클래스를 확장합니다.


1. 핵심 개념: 3계층 구조

vLLM은 모델 실행을 위해 3가지 독립적인 설정을 사용합니다:

┌─────────────────────────────────────────────────────────────────┐
│ RunnerType (실행 엔진 선택)                                      │
│   - "generate": 토큰 생성 (GPUModelRunner + Sampler)             │
│   - "pooling": 임베딩/분류 (GPUModelRunner + Pooler)             │
│   - "draft": Speculative Decoding용 드래프트 모델                 │
└─────────────────────────────────────────────────────────────────┘
         │
         ▼
┌─────────────────────────────────────────────────────────────────┐
│ ConvertType (모델 아키텍처 변환)                                  │
│   - "none": 변환 없음                                            │
│   - "embed": Embedding 모델로 변환                               │
│   - "classify": Classification 모델로 변환                       │
│   - "reward": Reward 모델 (네이티브만 지원)                       │
└─────────────────────────────────────────────────────────────────┘
         │
         ▼
┌─────────────────────────────────────────────────────────────────┐
│ Task (요청별 처리 방식)                                          │
│   - "embed", "token_embed": 임베딩 추출                          │
│   - "classify", "token_classify": 분류                          │
│   - "score": Cross-encoder 점수                                  │
│   - "generate", "transcription": 텍스트 생성                     │
└─────────────────────────────────────────────────────────────────┘

1.1 계층별 역할과 시점

계층 시점 역할 영향 범위
RunnerType 엔진 초기화 실행 파이프라인 선택 Worker, Scheduler, Output 형식
ConvertType 모델 로딩 모델 클래스 구조 변환 모델 아키텍처 (head, pooler)
Task 요청 처리 Pooler/처리 로직 선택 개별 추론 요청

1.2 허용되는 조합

_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
    "generate": [],                              # 변환 불가
    "pooling": ["embed", "classify", "reward"],  # 변환 가능
    "draft": [],                                 # 변환 불가
}

_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {
    "generate": ["generate", "transcription"],
    "pooling": ["embedding", "embed", "classify", "score", "reward"],
    "draft": ["draft"],
}

1.3 ConvertType과 Task의 1:N 관계

하나의 convert_type이 여러 task를 지원합니다:

convert_type="embed"
    └── tasks: ["embed", "token_embed"]

convert_type="classify"
    └── tasks: ["classify", "score", "token_classify"]

실용적 예시:

# 모델 로딩 (한 번) - convert_type 결정
llm = LLM(model="llama", convert="embed")

# 추론 요청 1: 문장 임베딩
llm.embed(["Hello world"])  # task="embed" → LastPool + Normalize

# 추론 요청 2: 토큰별 임베딩
llm.encode(["Hello world"], task="token_embed")  # task="token_embed" → AllPool

1.4 3계층 분리의 설계 동기

  1. 명확한 책임 분리:

    • RunnerType: 어떤 출력을 생성할지 (토큰 vs 벡터)
    • ConvertType: 모델 구조를 어떻게 수정할지
    • Task: 요청을 어떻게 처리할지
  2. 효율적인 리소스 관리:

    • Pooling 모델은 KV 캐시 재사용 불필요 → 메모리 최적화 가능
    • Generate 모델은 Sampler, Pooling 모델은 Pooler만 초기화
  3. 확장성:

    • 새로운 RunnerType 추가 시 기존 로직 영향 최소화
    • Speculative Decoding(draft)처럼 특수 실행 모드 지원

2. Convert Type 상세

2.1 기대 원본 모델

Adapter를 통한 변환은 특정 원본 모델 아키텍처를 전제로 합니다.

_GENERATE_SUFFIXES (adapters.py#L33-L39):

_GENERATE_SUFFIXES = ("ForCausalLM", "ForConditionalGeneration", "ChatModel", "LMHeadModel")
Convert Type 기대 원본 모델 변환 결과 비고
embed *ForCausalLM, *ForConditionalGeneration, *ChatModel, *LMHeadModel *Embedding Text Generation → Embedding
classify *ForCausalLM, *ForConditionalGeneration, *ChatModel, *LMHeadModel *Classification Text Generation → Classification
reward ❌ 없음 (변환 불가) - 네이티브 모델만 지원
none 모든 모델 변환 없음 원본 그대로 사용

변환 예시:

# convert="embed"
LlamaForCausalLM       → LlamaForCausalLMEmbedding
Qwen2ForCausalLM       → Qwen2ForCausalLMEmbedding

# convert="classify"
LlamaForCausalLM       → LlamaForCausalLMClassification

# convert="reward" (변환 불가 - 네이티브만)
Qwen2ForRewardModel    → (변환 없이 그대로 사용)

2.2 convert="embed": Embedding 모델 변환

as_embedding_model() (adapters.py#L239-L271)

def as_embedding_model(cls: _T) -> _T:
    if is_pooling_model(cls):  # 이미 pooling 모델이면 그대로 반환
        return cls

    class ModelForEmbedding(_create_pooling_model_cls(cls)):
        def _init_pooler(self, vllm_config, prefix=""):
            self.pooler = DispatchPooler({
                "token_embed": Pooler.for_token_embed(pooler_config),
                "embed": Pooler.for_embed(pooler_config),
            })

    return ModelForEmbedding

지원 Task: embed, token_embed

2.3 convert="classify": Classification 모델 변환

as_seq_cls_model() (adapters.py#L274-L351)

def as_seq_cls_model(cls: _T) -> _T:
    class ModelForSequenceClassification(_create_pooling_model_cls(cls), SupportsCrossEncoding):
        def _init_pooler(self, vllm_config, prefix=""):
            # Classification head 추가
            self.score = ReplicatedLinear(hidden_size, num_labels, ...)

            self.pooler = DispatchPooler({
                "token_classify": Pooler.for_token_classify(pooler_config, classifier=self.score),
                "classify": Pooler.for_classify(pooler_config, classifier=self.score, act_fn="classify"),
                "score": Pooler.for_classify(pooler_config, classifier=self.score, act_fn="score"),
            })

지원 Task: classify, score, token_classify

2.4 convert="reward": Reward 모델 (네이티브만)

⚠️ 주의: convert="reward"는 타입으로 정의되어 있지만, 별도의 adapter 함수(as_reward_model)가 존재하지 않습니다. CausalLM → Reward 변환은 지원되지 않으며, 네이티브 Reward 모델만 지원됩니다.

Convert Type Adapter 함수 CausalLM 변환 네이티브 모델
embed as_embedding_model() ✅ 지원 ✅ 지원
classify as_seq_cls_model() ✅ 지원 ✅ 지원
reward ❌ 없음 ❌ 미지원 ✅ 지원

네이티브 Reward 모델 예시:

Qwen2ForRewardModel (qwen2_rm.py#L99-L109)

@default_pooling_type("ALL")
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
    is_pooling_model = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        vllm_config.model_config.hf_config.num_labels = 1
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        self.pooler = DispatchPooler(
            {"token_classify": Pooler.for_token_classify(pooler_config)}
        )

InternLM2ForRewardModel (internlm2.py#L407-L440)

@default_pooling_type("ALL")
class InternLM2ForRewardModel(InternLM2ForCausalLM):
    is_pooling_model = True

    def __init__(self, *, vllm_config, prefix="", ...):
        super().__init__(...)

        # lm_head 대신 v_head (reward head) 사용
        for attr in ("output", "logits_processor"):
            delattr(self, attr)

        self.v_head = RowParallelLinear(config.hidden_size, 1, ...)

        self.pooler = DispatchPooler(
            {"token_classify": Pooler.for_token_classify(pooler_config)}
        )

Reward 모델의 특징:

특성 설명
Pooling Type @default_pooling_type("ALL") - 모든 토큰에 대한 reward 계산
Task "token_classify" - 토큰별 분류/점수
Head v_head 또는 score - 단일 값(reward) 출력
num_labels 보통 1 (스칼라 reward) 또는 2 (Process Reward Model)

3. 변환 흐름 (Conversion Flow)

사용자 요청: LLM(model="llama", convert="embed")
         │
         ▼
┌─────────────────────────────────────────────────────┐
│ 1. ModelConfig 초기화                                │
│    - _get_convert_type(): convert="auto"면 자동 결정  │
│    - _SUFFIX_TO_DEFAULTS로 아키텍처 suffix 매칭       │
└─────────────────────────────────────────────────────┘
         │
         ▼
┌─────────────────────────────────────────────────────┐
│ 2. _get_model_architecture() 호출                   │
│    (vllm/model_executor/model_loader/utils.py:167)  │
│    - registry에서 원본 모델 클래스 조회               │
│    - convert_type에 따라 adapter 함수 호출           │
└─────────────────────────────────────────────────────┘
         │
         ├── convert="embed"  ──▶ as_embedding_model(cls)
         │
         └── convert="classify" ──▶ as_seq_cls_model(cls)
                  │
                  ▼
┌─────────────────────────────────────────────────────┐
│ 3. Adapter 함수 실행                                 │
│    (vllm/model_executor/models/adapters.py)         │
│                                                     │
│    _create_pooling_model_cls(orig_cls):             │
│    - 새 클래스 ModelForPooling 동적 생성             │
│    - orig_cls + VllmModelForPooling 상속            │
│    - lm_head, logits_processor 제거                 │
│    - pooler 초기화                                  │
└─────────────────────────────────────────────────────┘
         │
         ▼
┌─────────────────────────────────────────────────────┐
│ 4. 변환된 모델 클래스 반환                            │
│    예: LlamaForCausalLM → LlamaForEmbedding         │
└─────────────────────────────────────────────────────┘

3.1 아키텍처 → Convert Type 자동 매핑

convert="auto" (기본값)일 때, vLLM은 HuggingFace 모델의 아키텍처 클래스 이름 suffix를 기반으로 RunnerTypeConvertType을 자동 결정합니다.

예를 들어:

  • meta-llama/Llama-3.1-8B → 아키텍처: LlamaForCausalLM → suffix ForCausalLM 매칭 → runner_type="generate", convert_type="none"
  • BAAI/bge-base-en-v1.5 → 아키텍처: BertModel → suffix Model 매칭 → runner_type="pooling", convert_type="embed"

_SUFFIX_TO_DEFAULTS (config/model.py#L1894-L1910):

모델 아키텍처 Suffix Runner Type Convert Type
*ForCausalLM generate none
*ForConditionalGeneration generate none
*ChatModel generate none
*LMHeadModel generate none
*ForTextEncoding pooling embed
*EmbeddingModel pooling embed
*ForSequenceClassification pooling classify
*For*Classification pooling classify
*ForRewardModeling pooling embed
*RewardModel pooling embed
*Model pooling embed

3.2 이미 Pooling 모델인 경우

원본 모델이 이미 is_pooling_model = True인 경우, adapter는 변환을 건너뜁니다:

def as_embedding_model(cls: _T) -> _T:
    if is_pooling_model(cls):  # 이미 pooling 모델이면
        return cls             # 그대로 반환
    # ...

따라서 네이티브 Pooling 모델은 convert 옵션을 지정해도 변환되지 않고 원본 구조를 유지합니다.


4. 구현 상세

4.1 Adapter 패턴: _create_pooling_model_cls()

_create_pooling_model_cls() (adapters.py#L159-L236)

def _create_pooling_model_cls(orig_cls: _T) -> _T:
    class ModelForPooling(orig_cls, VllmModelForPooling):
        is_pooling_model = True

        def __init__(self, *, vllm_config, prefix="", **kwargs):
            super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

            # 불필요한 속성 제거 (생성 모델용)
            for attr in ("lm_head", "logits_processor"):
                if hasattr(self, attr):
                    delattr(self, attr)

            # Pooler 초기화
            if not getattr(self, "pooler", None):
                self._init_pooler(vllm_config, prefix=prefix)

        def load_weights(self, weights, load_lm_head=False):
            # lm_head 가중치 로딩 스킵
            if not load_lm_head:
                weights = ((n, d) for n, d in weights if not n.startswith("lm_head."))
            ...

4.2 Pooler 구조

Pooler (ABC)
├── SimplePooler          # 기본 pooling + head
├── ClassifierPooler      # classification용
├── AllPooler             # 토큰별 출력
├── StepPooler            # 특정 토큰 추출
└── DispatchPooler        # 태스크별 라우팅
         │
         └── {"embed": Pooler, "classify": Pooler, "score": Pooler}

PoolingMethod:

  • CLSPool: 첫 번째 토큰 ([CLS])
  • LastPool: 마지막 토큰
  • MeanPool: 평균
  • AllPool: 모든 토큰

Task → Pooler 라우팅 (pooler.py#L780-L826):

class DispatchPooler(Pooler):
    def forward(self, hidden_states, pooling_metadata):
        for task, group in groupby(pooling_metadata.tasks):
            pooler = self.poolers_by_task[task]  # task에 따라 pooler 선택
            outputs.extend(pooler(hidden_states, ...))

4.3 온라인 변환 (LLM → Reranker)

seq_cls_model_loader() (adapters.py#L496-L510)는 ForCausalLM 모델을 Reranker로 변환:

  1. from_2_way_softmax: Qwen3-Reranker, mxbai-rerank-v2 등

    • lm_head 가중치에서 true/false 토큰 가중치 추출
    • score.weight = lm_head[true_id] - lm_head[false_id]
  2. no_post_processing: bge-reranker-v2-gemma 등

    • 지정된 토큰들의 lm_head 가중치를 score 레이어로 복사

4.4 RunnerType이 별도로 존재하는 이유

1. 실행 경로 분기 (gpu_model_runner.py#L3102-L3114)

# GPUModelRunner.execute_model() 내부
if self.is_pooling_model:  # runner_type == "pooling"
    output = self._pool(hidden_states, ...)  # Pooler 사용
    return output

# runner_type == "generate"
logits = self.model.compute_logits(sample_hidden_states)
# → Sampler로 토큰 샘플링

2. 지원 태스크 결정 (gpu_model_runner.py#L2324-L2332)

def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
    if self.model_config.runner_type == "generate":
        tasks.extend(self.get_supported_generation_tasks())
    if self.model_config.runner_type == "pooling":
        tasks.extend(self.get_supported_pooling_tasks())
    return tuple(tasks)

3. Tokenizer 설정 (tokenizers/registry.py#L220)

if runner_type == "generate" or runner_type == "draft":
    truncation_side = "left"   # 생성: 왼쪽 잘림 (최근 컨텍스트 유지)
elif runner_type == "pooling":
    truncation_side = "right"  # 임베딩: 오른쪽 잘림 (시작 컨텍스트 유지)

5. 핵심 코드 경로 (참조)

진입점

파일 위치 설명
vllm/engine/arg_utils.py line 628 --convert CLI 인자 정의
vllm/entrypoints/llm.py line 189 LLM(convert=...) 파라미터

타입 정의 및 설정 결정

파일 위치 설명
vllm/config/model.py#L74-L75 타입 정의 ConvertType = Literal["none", "embed", "classify", "reward"]
vllm/config/model.py#L1894-L1910 아키텍처 매핑 _SUFFIX_TO_DEFAULTS: 모델 suffix → (runner_type, convert_type)
vllm/config/model.py#L869-L919 자동 결정 _get_default_convert_type(), _get_convert_type()

모델 아키텍처 변환 (핵심)

파일 위치 설명
vllm/model_executor/model_loader/utils.py#L167-L212 변환 적용 _get_model_architecture(): convert_type에 따라 adapter 적용
vllm/model_executor/models/adapters.py 전체 Adapter 구현 핵심 파일

Pooler 구현

파일 설명
vllm/model_executor/layers/pooler.py Pooler 클래스들 (CLSPool, LastPool, MeanPool, etc.)
vllm/config/pooler.py PoolerConfig 설정

요약

vLLM의 Model Conversion은 동적 클래스 상속을 통한 Adapter 패턴을 사용합니다:

  1. 원본 모델 클래스를 변경하지 않음
  2. 런타임에 VllmModelForPooling 믹스인을 추가한 새 클래스 생성
  3. 불필요한 생성 모델 컴포넌트(lm_head) 제거
  4. 적절한 Pooler 초기화
  5. 가중치 로딩 시 lm_head 스킵 (또는 온라인 변환에 활용)

이를 통해 하나의 ForCausalLM 체크포인트로 Embedding, Classification, Reranking 등 다양한 Pooling 태스크를 수행할 수 있습니다.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment