기준 커밋:
83319b44c
vLLM은 --convert <type> 옵션을 통해 기존 Text Generation 모델(예: *ForCausalLM)을 Pooling 모델(Embedding, Classification, Reward)로 변환하는 기능을 제공합니다. 이는 Adapter 패턴을 사용하여 원본 모델 코드 수정 없이 동적으로 클래스를 확장합니다.
vLLM은 모델 실행을 위해 3가지 독립적인 설정을 사용합니다:
┌─────────────────────────────────────────────────────────────────┐
│ RunnerType (실행 엔진 선택) │
│ - "generate": 토큰 생성 (GPUModelRunner + Sampler) │
│ - "pooling": 임베딩/분류 (GPUModelRunner + Pooler) │
│ - "draft": Speculative Decoding용 드래프트 모델 │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ ConvertType (모델 아키텍처 변환) │
│ - "none": 변환 없음 │
│ - "embed": Embedding 모델로 변환 │
│ - "classify": Classification 모델로 변환 │
│ - "reward": Reward 모델 (네이티브만 지원) │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Task (요청별 처리 방식) │
│ - "embed", "token_embed": 임베딩 추출 │
│ - "classify", "token_classify": 분류 │
│ - "score": Cross-encoder 점수 │
│ - "generate", "transcription": 텍스트 생성 │
└─────────────────────────────────────────────────────────────────┘
| 계층 | 시점 | 역할 | 영향 범위 |
|---|---|---|---|
| RunnerType | 엔진 초기화 | 실행 파이프라인 선택 | Worker, Scheduler, Output 형식 |
| ConvertType | 모델 로딩 | 모델 클래스 구조 변환 | 모델 아키텍처 (head, pooler) |
| Task | 요청 처리 | Pooler/처리 로직 선택 | 개별 추론 요청 |
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
"generate": [], # 변환 불가
"pooling": ["embed", "classify", "reward"], # 변환 가능
"draft": [], # 변환 불가
}
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {
"generate": ["generate", "transcription"],
"pooling": ["embedding", "embed", "classify", "score", "reward"],
"draft": ["draft"],
}하나의 convert_type이 여러 task를 지원합니다:
convert_type="embed"
└── tasks: ["embed", "token_embed"]
convert_type="classify"
└── tasks: ["classify", "score", "token_classify"]
실용적 예시:
# 모델 로딩 (한 번) - convert_type 결정
llm = LLM(model="llama", convert="embed")
# 추론 요청 1: 문장 임베딩
llm.embed(["Hello world"]) # task="embed" → LastPool + Normalize
# 추론 요청 2: 토큰별 임베딩
llm.encode(["Hello world"], task="token_embed") # task="token_embed" → AllPool-
명확한 책임 분리:
RunnerType: 어떤 출력을 생성할지 (토큰 vs 벡터)ConvertType: 모델 구조를 어떻게 수정할지Task: 요청을 어떻게 처리할지
-
효율적인 리소스 관리:
- Pooling 모델은 KV 캐시 재사용 불필요 → 메모리 최적화 가능
- Generate 모델은 Sampler, Pooling 모델은 Pooler만 초기화
-
확장성:
- 새로운 RunnerType 추가 시 기존 로직 영향 최소화
- Speculative Decoding(
draft)처럼 특수 실행 모드 지원
Adapter를 통한 변환은 특정 원본 모델 아키텍처를 전제로 합니다.
_GENERATE_SUFFIXES (adapters.py#L33-L39):
_GENERATE_SUFFIXES = ("ForCausalLM", "ForConditionalGeneration", "ChatModel", "LMHeadModel")| Convert Type | 기대 원본 모델 | 변환 결과 | 비고 |
|---|---|---|---|
embed |
*ForCausalLM, *ForConditionalGeneration, *ChatModel, *LMHeadModel |
*Embedding |
Text Generation → Embedding |
classify |
*ForCausalLM, *ForConditionalGeneration, *ChatModel, *LMHeadModel |
*Classification |
Text Generation → Classification |
reward |
❌ 없음 (변환 불가) | - | 네이티브 모델만 지원 |
none |
모든 모델 | 변환 없음 | 원본 그대로 사용 |
변환 예시:
# convert="embed"
LlamaForCausalLM → LlamaForCausalLMEmbedding
Qwen2ForCausalLM → Qwen2ForCausalLMEmbedding
# convert="classify"
LlamaForCausalLM → LlamaForCausalLMClassification
# convert="reward" (변환 불가 - 네이티브만)
Qwen2ForRewardModel → (변환 없이 그대로 사용)
as_embedding_model() (adapters.py#L239-L271)
def as_embedding_model(cls: _T) -> _T:
if is_pooling_model(cls): # 이미 pooling 모델이면 그대로 반환
return cls
class ModelForEmbedding(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config, prefix=""):
self.pooler = DispatchPooler({
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
})
return ModelForEmbedding지원 Task: embed, token_embed
as_seq_cls_model() (adapters.py#L274-L351)
def as_seq_cls_model(cls: _T) -> _T:
class ModelForSequenceClassification(_create_pooling_model_cls(cls), SupportsCrossEncoding):
def _init_pooler(self, vllm_config, prefix=""):
# Classification head 추가
self.score = ReplicatedLinear(hidden_size, num_labels, ...)
self.pooler = DispatchPooler({
"token_classify": Pooler.for_token_classify(pooler_config, classifier=self.score),
"classify": Pooler.for_classify(pooler_config, classifier=self.score, act_fn="classify"),
"score": Pooler.for_classify(pooler_config, classifier=self.score, act_fn="score"),
})지원 Task: classify, score, token_classify
⚠️ 주의:convert="reward"는 타입으로 정의되어 있지만, 별도의 adapter 함수(as_reward_model)가 존재하지 않습니다. CausalLM → Reward 변환은 지원되지 않으며, 네이티브 Reward 모델만 지원됩니다.
| Convert Type | Adapter 함수 | CausalLM 변환 | 네이티브 모델 |
|---|---|---|---|
embed |
as_embedding_model() |
✅ 지원 | ✅ 지원 |
classify |
as_seq_cls_model() |
✅ 지원 | ✅ 지원 |
reward |
❌ 없음 | ❌ 미지원 | ✅ 지원 |
네이티브 Reward 모델 예시:
Qwen2ForRewardModel (qwen2_rm.py#L99-L109)
@default_pooling_type("ALL")
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 1
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)InternLM2ForRewardModel (internlm2.py#L407-L440)
@default_pooling_type("ALL")
class InternLM2ForRewardModel(InternLM2ForCausalLM):
is_pooling_model = True
def __init__(self, *, vllm_config, prefix="", ...):
super().__init__(...)
# lm_head 대신 v_head (reward head) 사용
for attr in ("output", "logits_processor"):
delattr(self, attr)
self.v_head = RowParallelLinear(config.hidden_size, 1, ...)
self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)Reward 모델의 특징:
| 특성 | 설명 |
|---|---|
| Pooling Type | @default_pooling_type("ALL") - 모든 토큰에 대한 reward 계산 |
| Task | "token_classify" - 토큰별 분류/점수 |
| Head | v_head 또는 score - 단일 값(reward) 출력 |
| num_labels | 보통 1 (스칼라 reward) 또는 2 (Process Reward Model) |
사용자 요청: LLM(model="llama", convert="embed")
│
▼
┌─────────────────────────────────────────────────────┐
│ 1. ModelConfig 초기화 │
│ - _get_convert_type(): convert="auto"면 자동 결정 │
│ - _SUFFIX_TO_DEFAULTS로 아키텍처 suffix 매칭 │
└─────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ 2. _get_model_architecture() 호출 │
│ (vllm/model_executor/model_loader/utils.py:167) │
│ - registry에서 원본 모델 클래스 조회 │
│ - convert_type에 따라 adapter 함수 호출 │
└─────────────────────────────────────────────────────┘
│
├── convert="embed" ──▶ as_embedding_model(cls)
│
└── convert="classify" ──▶ as_seq_cls_model(cls)
│
▼
┌─────────────────────────────────────────────────────┐
│ 3. Adapter 함수 실행 │
│ (vllm/model_executor/models/adapters.py) │
│ │
│ _create_pooling_model_cls(orig_cls): │
│ - 새 클래스 ModelForPooling 동적 생성 │
│ - orig_cls + VllmModelForPooling 상속 │
│ - lm_head, logits_processor 제거 │
│ - pooler 초기화 │
└─────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ 4. 변환된 모델 클래스 반환 │
│ 예: LlamaForCausalLM → LlamaForEmbedding │
└─────────────────────────────────────────────────────┘
convert="auto" (기본값)일 때, vLLM은 HuggingFace 모델의 아키텍처 클래스 이름 suffix를 기반으로 RunnerType과 ConvertType을 자동 결정합니다.
예를 들어:
meta-llama/Llama-3.1-8B→ 아키텍처:LlamaForCausalLM→ suffixForCausalLM매칭 →runner_type="generate",convert_type="none"BAAI/bge-base-en-v1.5→ 아키텍처:BertModel→ suffixModel매칭 →runner_type="pooling",convert_type="embed"
_SUFFIX_TO_DEFAULTS (config/model.py#L1894-L1910):
| 모델 아키텍처 Suffix | Runner Type | Convert Type |
|---|---|---|
*ForCausalLM |
generate | none |
*ForConditionalGeneration |
generate | none |
*ChatModel |
generate | none |
*LMHeadModel |
generate | none |
*ForTextEncoding |
pooling | embed |
*EmbeddingModel |
pooling | embed |
*ForSequenceClassification |
pooling | classify |
*For*Classification |
pooling | classify |
*ForRewardModeling |
pooling | embed |
*RewardModel |
pooling | embed |
*Model |
pooling | embed |
원본 모델이 이미 is_pooling_model = True인 경우, adapter는 변환을 건너뜁니다:
def as_embedding_model(cls: _T) -> _T:
if is_pooling_model(cls): # 이미 pooling 모델이면
return cls # 그대로 반환
# ...따라서 네이티브 Pooling 모델은 convert 옵션을 지정해도 변환되지 않고 원본 구조를 유지합니다.
_create_pooling_model_cls() (adapters.py#L159-L236)
def _create_pooling_model_cls(orig_cls: _T) -> _T:
class ModelForPooling(orig_cls, VllmModelForPooling):
is_pooling_model = True
def __init__(self, *, vllm_config, prefix="", **kwargs):
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# 불필요한 속성 제거 (생성 모델용)
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
# Pooler 초기화
if not getattr(self, "pooler", None):
self._init_pooler(vllm_config, prefix=prefix)
def load_weights(self, weights, load_lm_head=False):
# lm_head 가중치 로딩 스킵
if not load_lm_head:
weights = ((n, d) for n, d in weights if not n.startswith("lm_head."))
...Pooler (ABC)
├── SimplePooler # 기본 pooling + head
├── ClassifierPooler # classification용
├── AllPooler # 토큰별 출력
├── StepPooler # 특정 토큰 추출
└── DispatchPooler # 태스크별 라우팅
│
└── {"embed": Pooler, "classify": Pooler, "score": Pooler}
PoolingMethod:
CLSPool: 첫 번째 토큰 ([CLS])LastPool: 마지막 토큰MeanPool: 평균AllPool: 모든 토큰
Task → Pooler 라우팅 (pooler.py#L780-L826):
class DispatchPooler(Pooler):
def forward(self, hidden_states, pooling_metadata):
for task, group in groupby(pooling_metadata.tasks):
pooler = self.poolers_by_task[task] # task에 따라 pooler 선택
outputs.extend(pooler(hidden_states, ...))seq_cls_model_loader() (adapters.py#L496-L510)는 ForCausalLM 모델을 Reranker로 변환:
-
from_2_way_softmax: Qwen3-Reranker, mxbai-rerank-v2 등
- lm_head 가중치에서 true/false 토큰 가중치 추출
score.weight = lm_head[true_id] - lm_head[false_id]
-
no_post_processing: bge-reranker-v2-gemma 등
- 지정된 토큰들의 lm_head 가중치를 score 레이어로 복사
1. 실행 경로 분기 (gpu_model_runner.py#L3102-L3114)
# GPUModelRunner.execute_model() 내부
if self.is_pooling_model: # runner_type == "pooling"
output = self._pool(hidden_states, ...) # Pooler 사용
return output
# runner_type == "generate"
logits = self.model.compute_logits(sample_hidden_states)
# → Sampler로 토큰 샘플링2. 지원 태스크 결정 (gpu_model_runner.py#L2324-L2332)
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
if self.model_config.runner_type == "generate":
tasks.extend(self.get_supported_generation_tasks())
if self.model_config.runner_type == "pooling":
tasks.extend(self.get_supported_pooling_tasks())
return tuple(tasks)3. Tokenizer 설정 (tokenizers/registry.py#L220)
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left" # 생성: 왼쪽 잘림 (최근 컨텍스트 유지)
elif runner_type == "pooling":
truncation_side = "right" # 임베딩: 오른쪽 잘림 (시작 컨텍스트 유지)| 파일 | 위치 | 설명 |
|---|---|---|
vllm/engine/arg_utils.py |
line 628 | --convert CLI 인자 정의 |
vllm/entrypoints/llm.py |
line 189 | LLM(convert=...) 파라미터 |
| 파일 | 위치 | 설명 |
|---|---|---|
vllm/config/model.py#L74-L75 |
타입 정의 | ConvertType = Literal["none", "embed", "classify", "reward"] |
vllm/config/model.py#L1894-L1910 |
아키텍처 매핑 | _SUFFIX_TO_DEFAULTS: 모델 suffix → (runner_type, convert_type) |
vllm/config/model.py#L869-L919 |
자동 결정 | _get_default_convert_type(), _get_convert_type() |
| 파일 | 위치 | 설명 |
|---|---|---|
vllm/model_executor/model_loader/utils.py#L167-L212 |
변환 적용 | _get_model_architecture(): convert_type에 따라 adapter 적용 |
vllm/model_executor/models/adapters.py |
전체 | Adapter 구현 핵심 파일 |
| 파일 | 설명 |
|---|---|
vllm/model_executor/layers/pooler.py |
Pooler 클래스들 (CLSPool, LastPool, MeanPool, etc.) |
vllm/config/pooler.py |
PoolerConfig 설정 |
vLLM의 Model Conversion은 동적 클래스 상속을 통한 Adapter 패턴을 사용합니다:
- 원본 모델 클래스를 변경하지 않음
- 런타임에
VllmModelForPooling믹스인을 추가한 새 클래스 생성 - 불필요한 생성 모델 컴포넌트(lm_head) 제거
- 적절한 Pooler 초기화
- 가중치 로딩 시 lm_head 스킵 (또는 온라인 변환에 활용)
이를 통해 하나의 ForCausalLM 체크포인트로 Embedding, Classification, Reranking 등 다양한 Pooling 태스크를 수행할 수 있습니다.