Init commit.

This commit is contained in:
2025-09-27 16:39:55 +08:00
commit 5942bbfd05
11 changed files with 1604 additions and 0 deletions

225
.clineignore Normal file
View File

@ -0,0 +1,225 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
# Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
# poetry.lock
# poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
# pdm.lock
# pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
# pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# Redis
*.rdb
*.aof
*.pid
# RabbitMQ
mnesia/
rabbitmq/
rabbitmq-data/
# ActiveMQ
activemq-data/
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# .idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Streamlit
.streamlit/secrets.toml
# Python virtual environments
.venv
venv
.venv/
venv/
# Checkpoints for LLMs
llms/

225
.gitignore vendored Normal file
View File

@ -0,0 +1,225 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
# Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
# poetry.lock
# poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
# pdm.lock
# pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
# pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# Redis
*.rdb
*.aof
*.pid
# RabbitMQ
mnesia/
rabbitmq/
rabbitmq-data/
# ActiveMQ
activemq-data/
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# .idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Streamlit
.streamlit/secrets.toml
# Python virtual environments
.venv
venv
.venv/
venv/
# Checkpoints for LLMs
llms/

1
models/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
vanilla_vllm/

0
models/__init__.py Normal file
View File

View File

489
models/log_expert/olmoe.py Normal file
View File

@ -0,0 +1,489 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights.
This file is origin from vllm/model_executor/models/olmoe.py
"""
from collections.abc import Iterable
from functools import partial
from typing import Any, Optional, Union
import torch
from torch import nn
from transformers import OlmoeConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed.utils import split_tensor_along_last_dim
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from utils import DataLogger
logger = init_logger(__name__)
class OlmoeMoE(nn.Module):
"""A tensor-parallel MoE implementation for Olmoe that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = hidden_size
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size,
num_experts,
bias=False,
quant_config=None)
self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=True,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
return final_hidden_states.view(orig_shape)
class OlmoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 4096,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.tp_size = tp_size
self.tp_rank = get_tensor_model_parallel_rank()
self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5)
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
eps=1e-5)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class OlmoeDecoderLayer(nn.Module):
def __init__(
self,
config: OlmoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
4096)
self.self_attn = OlmoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = OlmoeMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class OlmoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoeDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class OlmoeForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = OlmoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()

91
olmoe_inference.ipynb Normal file
View File

@ -0,0 +1,91 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "928261ae",
"metadata": {},
"outputs": [],
"source": [
"from vllm import LLM, SamplingParams"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a35fc2b4",
"metadata": {},
"outputs": [],
"source": [
"model_id = \"./llms/OLMoE-1B-7B-0924-Instruct\"\n",
"\n",
"llm = LLM(\n",
" model=model_id,\n",
" # cpu_offload_gb=4,\n",
" tensor_parallel_size=2,\n",
" # gpu_memory_utilization=0.90,\n",
" max_model_len=4096,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a708f11",
"metadata": {},
"outputs": [],
"source": [
"sampling_params = SamplingParams(\n",
" temperature=0.6,\n",
" top_p=0.95,\n",
" top_k=20,\n",
" max_tokens=1024,\n",
")\n",
"\n",
"# Prepare the input to the model\n",
"prompt = \"Give me a short introduction to large language models.\"\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": prompt}\n",
"]\n",
"\n",
"# Generate outputs\n",
"outputs = llm.chat(\n",
" messages, \n",
" sampling_params=sampling_params,\n",
" # chat_template_kwargs={\"enable_thinking\": True}, # Set to False to strictly disable thinking\n",
")\n",
"\n",
"# Print the outputs.\n",
"for out in outputs:\n",
" # out.prompt is the input prompt; out.outputs is a list of completion choices\n",
" print(\"=== PROMPT ===\")\n",
" print(out.prompt)\n",
" print(\"=== COMPLETION ===\")\n",
" print(out.outputs[0].text)\n",
" print(\"\\n---\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "moe-explore",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

70
olmoe_inference_vllm.py Normal file
View File

@ -0,0 +1,70 @@
# %%
import gc
import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.distributed.parallel_state import destroy_model_parallel
# %%
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
try:
llm = LLM(
model=model_id,
cpu_offload_gb=4,
# tensor_parallel_size=2,
gpu_memory_utilization=0.90,
max_model_len=4096,
# compilation_config=CompilationConfig(
# level=CompilationLevel.PIECEWISE,
# # By default, it goes up to max_num_seqs
# cudagraph_capture_sizes=[1, 2, 4, 8, 16],
# ),
enforce_eager=True,
)
sampling_params = SamplingParams(
temperature=0.6,
top_p=0.95,
top_k=20,
max_tokens=1024,
)
# Prepare the input to the model
prompt = "Give me a short introduction to large language models."
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
# Generate outputs
outputs = llm.chat(
messages,
sampling_params=sampling_params,
# chat_template_kwargs={"enable_thinking": True}, # Set to False to strictly disable thinking
)
# Print the outputs.
for out in outputs:
# out.prompt is the input prompt; out.outputs is a list of completion choices
# print("=== PROMPT ===")
# print(out.prompt)
# print("=== COMPLETION ===")
print(out.outputs[0].text)
print("\n---\n")
print("Finish completion")
except Exception as e:
print(e)
finally:
if llm := globals().get("llm", None):
if engine := getattr(llm, "llm_engine", None):
# llm.llm_engine
del engine
del llm
destroy_model_parallel()
torch.cuda.empty_cache()
gc.collect()

9
requirements.txt Normal file
View File

@ -0,0 +1,9 @@
vllm==0.10.1.1
notebook
ipywidgets
python-dotenv
pandas
datasets
accelerate
kernels
modelscope

1
utils/__init__.py Normal file
View File

@ -0,0 +1 @@
from .logger import DataLogger, LoggerConfig

493
utils/logger.py Normal file
View File

@ -0,0 +1,493 @@
"""
Asynchronous, batched, and schema-evolving Parquet logger.
This module provides the `DataLogger`, a high-performance logger for structured
data, designed for applications like machine learning experiments, simulations,
or any scenario requiring efficient serialization of row-based data.
Key Features:
- **Unified Interface**: Log data via a simple `DataLogger.log({"key": "value"})` call.
- **Asynchronous & Batched**: A dedicated background thread handles I/O,
batching rows to minimize disk writes and reduce application latency.
- **Schema Evolution**: Automatically adapts the Parquet schema if new data fields
are introduced, rewriting the file to maintain a consistent structure.
- **Singleton Pattern**: A global singleton instance is managed automatically,
providing a convenient, fire-and-forget logging experience.
- **Type Handling**: Natively handles Python primitives, NumPy arrays, and PyTorch
tensors, converting them to Parquet-compatible formats.
- **Robust & Thread-Safe**: Designed for use in multi-threaded environments.
Basic Usage:
-------------
.. code-block:: python
from logger.data_logger import DataLogger
# The first call creates and configures the singleton logger.
# A timestamped filename is generated by default.
DataLogger.log({"step": 0, "loss": 10.5, "accuracy": 0.5})
DataLogger.log({"step": 1, "loss": 9.8, "accuracy": 0.55})
# For the singleton, data is automatically flushed and saved on program exit.
# No explicit `close()` call is required for this simple case.
Advanced Usage (Instance-Based):
---------------------------------
.. code-block:: python
from logger.data_logger import DataLogger, LoggerConfig
config = LoggerConfig(batch_size=512, flush_interval=5.0)
with DataLogger("my_experiment.parquet", config=config) as logger:
for i in range(1000):
logger.submit({"value": i})
# The `with` statement ensures flush and close on exit.
"""
from __future__ import annotations
import datetime
import atexit
import os
import queue
import threading
import time
import traceback
import typing as t
from dataclasses import dataclass
from pathlib import Path
# Third-party libraries are imported with runtime checks to provide clear
# error messages if they are not installed.
try:
import numpy as np
except ImportError:
np = None # type: ignore
try:
import pandas as pd
except ImportError:
raise ImportError(
"pandas is required for DataLogger. Install with `pip install pandas`."
)
try:
import pyarrow as pa
import pyarrow.parquet as pq
except ImportError:
raise ImportError(
"pyarrow is required for DataLogger. Install with `pip install pyarrow`."
)
try:
import torch
except ImportError:
torch = None # type: ignore
# Type alias for a single row of data.
Row = t.Dict[str, t.Any]
@dataclass
class LoggerConfig:
"""Configuration for the DataLogger's writer behavior."""
batch_size: int = 1024
"""Number of rows to accumulate before writing a batch to the Parquet file."""
flush_interval: float = 1.0
"""Maximum time in seconds to wait before flushing the buffer, even if
`batch_size` is not reached."""
parquet_compression: str = "snappy"
"""Compression codec to use for the Parquet file.
Common options: 'snappy', 'gzip', 'brotli', 'none'."""
allow_schema_rewrite: bool = True
"""If True, the logger will automatically rewrite the entire Parquet file to
accommodate new columns. If False, it will raise an error."""
class DataLogger:
"""
An asynchronous, batched logger that writes data to a Parquet file.
This class manages a background thread to handle file I/O, allowing the
calling application to log data with minimal blocking. It supports schema
evolution, making it robust to changes in data structure over time.
"""
_singleton: t.Optional["DataLogger"] = None
_singleton_lock = threading.Lock()
# --- Public API ---
@classmethod
def get_instance(
cls,
path: t.Optional[t.Union[str, Path]] = None,
config: t.Optional[LoggerConfig] = None,
) -> "DataLogger":
"""
Get or create the global singleton instance of the DataLogger.
The first time this method is called, it creates a new `DataLogger`
instance and registers a cleanup function via `atexit` to ensure
`close()` is called automatically upon program termination.
Subsequent calls will ignore the arguments and return the existing
instance.
Args:
path: The file path for the log file. If None, a timestamped
filename like 'log_YYYYMMDD-HHMMSS.parquet' is created in the
current working directory.
config: A `LoggerConfig` object to configure the writer's behavior.
If None, default settings are used.
Returns:
The singleton `DataLogger` instance.
"""
if cls._singleton is None:
with cls._singleton_lock:
if cls._singleton is None:
# Create the singleton instance.
instance = cls(path, config)
# Register its close method to be called at program exit.
# This ensures data is saved even if the user forgets to call close().
atexit.register(instance.close)
cls._singleton = instance
return cls._singleton
@classmethod
def log(cls, row: Row) -> None:
"""
Log a data row using the singleton instance.
This is a convenience method that lazily initializes the singleton on
its first call. The operation is non-blocking; the data is placed in
an internal queue to be processed by the background writer thread.
Args:
row: A dictionary representing a single row of data, where keys
are column names and values are the data points.
"""
instance = cls.get_instance()
instance.submit(row)
def __init__(
self,
path: t.Optional[t.Union[str, Path]] = None,
config: t.Optional[LoggerConfig] = None,
):
"""
Initialize a DataLogger instance.
Args:
path: The file path for the log file. If None, a timestamped
filename is automatically generated.
config: A `LoggerConfig` object. If None, default settings are used.
"""
self.path = self._resolve_path(path)
self._config = config or LoggerConfig()
# Internal state for the writer thread
self._queue: queue.Queue[t.Optional[Row]] = queue.Queue()
self._stop_event = threading.Event()
self._flush_event = threading.Event()
self._writer_thread: t.Optional[threading.Thread] = None
self._writer_lock = threading.RLock() # Protects writer and schema
# Parquet-specific state, managed exclusively by the writer thread
self._parquet_writer: t.Optional[pq.ParquetWriter] = None
self._schema: t.Optional[pa.Schema] = None
self._buffer: t.List[Row] = []
self._start_writer_thread()
def submit(self, row: Row) -> None:
"""
Submit a data row to be written asynchronously by the logger instance.
Args:
row: A dictionary representing a single row of data.
Raises:
TypeError: If the provided row is not a dictionary.
RuntimeError: If the logger has already been closed.
"""
if self._stop_event.is_set():
raise RuntimeError("Logger has been closed and cannot accept new data.")
if not isinstance(row, dict):
raise TypeError(f"Expected a dict for a row, but got {type(row)}.")
normalized_row = self._normalize_row(row)
self._queue.put(normalized_row)
def flush(self, timeout: float = 10.0) -> None:
"""
Block until all currently queued and buffered data is written to disk.
Args:
timeout: Maximum time in seconds to wait for the flush to complete.
"""
if self._writer_thread is None or not self._writer_thread.is_alive():
return
self._flush_event.clear()
self._queue.put(None) # Sentinel to trigger a flush
self._flush_event.wait(timeout)
def close(self, timeout: float = 10.0) -> None:
"""
Flush all remaining data and shut down the background writer thread.
This method is idempotent and thread-safe. It is designed to be
called explicitly, via a `with` statement, or automatically at program
exit.
Args:
timeout: Maximum time in seconds to wait for the writer thread
to finish.
"""
if self._stop_event.is_set():
return
self._stop_event.set()
self._queue.put(None) # Wake up the writer thread if it's blocking.
# Do not join the writer thread from itself, which would cause a deadlock.
if self._writer_thread and threading.current_thread() != self._writer_thread:
self._writer_thread.join(timeout)
def __enter__(self) -> "DataLogger":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Ensures the logger is closed upon exiting a `with` block."""
self.close()
def __del__(self):
"""Ensures data is flushed when the logger object is destroyed."""
self.close()
# --- Internal Methods ---
def _resolve_path(self, path: t.Optional[t.Union[str, Path]]) -> Path:
"""Determine the final output path for the log file."""
if path is None:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"log_{timestamp}.parquet"
return Path.cwd() / filename
resolved_path = Path(path)
if resolved_path.suffix == "":
resolved_path = resolved_path.with_suffix(".parquet")
return resolved_path
def _start_writer_thread(self) -> None:
"""Initialize and start the background writer thread."""
if self._writer_thread is not None:
return
thread_name = f"DataLoggerWriter-{self.path.name}"
self._writer_thread = threading.Thread(
target=self._writer_loop, name=thread_name, daemon=True
)
self._writer_thread.start()
def _writer_loop(self) -> None:
"""
The main loop for the background writer thread.
This loop continuously pulls data from the queue, batches it, and
writes it to the Parquet file. It handles flush signals, stop events,
and schema evolution.
"""
try:
while not self._stop_event.is_set():
try:
# Block until an item is available or the flush interval times out.
item = self._queue.get(timeout=self._config.flush_interval)
except queue.Empty:
# Timeout occurred, treat as a periodic flush signal.
item = None
if item is not None:
self._buffer.append(item)
buffer_size = len(self._buffer)
is_flush_signal = item is None
is_batch_full = buffer_size >= self._config.batch_size
is_shutting_down = self._stop_event.is_set()
if self._buffer and (
is_flush_signal or is_batch_full or is_shutting_down
):
self._write_batch(self._buffer)
self._buffer.clear()
if is_flush_signal:
self._flush_event.set() # Signal that a flush completed
# Final drain of the queue and buffer after the stop event is set.
self._drain_remaining()
except Exception as e:
print(f"FATAL: DataLogger writer thread crashed: {e}", flush=True)
traceback.print_exc()
finally:
# This block ensures that the Parquet writer is always closed
# when the writer thread exits, for any reason.
with self._writer_lock:
if self._parquet_writer:
try:
self._parquet_writer.close()
except Exception as e:
print(
f"ERROR: Exception while closing Parquet writer: {e}",
flush=True,
)
self._parquet_writer = None
def _drain_remaining(self) -> None:
"""Process all remaining items in the queue and buffer during shutdown."""
while True:
try:
item = self._queue.get_nowait()
if item:
self._buffer.append(item)
except queue.Empty:
break
if self._buffer:
self._write_batch(self._buffer)
self._buffer.clear()
def _write_batch(self, rows: t.List[Row]) -> None:
"""
Convert a list of rows into a Parquet table and write it to the file.
This method handles schema creation, validation, and evolution.
It is always executed within the writer thread.
"""
if not rows:
return
try:
with self._writer_lock:
df = pd.DataFrame(rows)
# Ensure a consistent column order for schema stability.
df = df.reindex(sorted(df.columns), axis=1)
new_table = pa.Table.from_pandas(df, preserve_index=False)
if self.path.exists():
# File exists, need to append or evolve schema
existing_table = pq.read_table(self.path)
existing_schema = existing_table.schema
if existing_schema.equals(new_table.schema):
# Schema matches, append the data
combined_table = pa.concat_tables([existing_table, new_table])
else:
# Schema evolution needed
if not self._config.allow_schema_rewrite:
raise RuntimeError(
"Schema mismatch detected, and rewriting is disabled. "
f"Existing schema: {existing_schema}, New schema: {new_table.schema}"
)
print(
f"INFO: Schema evolution detected. Rewriting {self.path}...",
flush=True,
)
# Combine with schema evolution
combined_df = pd.concat(
[existing_table.to_pandas(), new_table.to_pandas()],
ignore_index=True,
sort=False,
)
combined_df = combined_df.reindex(
sorted(combined_df.columns), axis=1
)
combined_table = pa.Table.from_pandas(
combined_df, preserve_index=False
)
else:
# New file
self.path.parent.mkdir(parents=True, exist_ok=True)
combined_table = new_table
# Write the combined table atomically
temp_path = self.path.with_suffix(f"{self.path.suffix}.tmp")
pq.write_table(
combined_table,
temp_path,
compression=self._config.parquet_compression,
)
os.replace(temp_path, self.path)
# Update our schema tracking
self._schema = combined_table.schema
except Exception as e:
print(f"ERROR: Failed to write batch to {self.path}: {e}", flush=True)
traceback.print_exc()
def _rewrite_with_new_schema(self, new_table: pa.Table) -> None:
"""
Rewrite the entire Parquet file to accommodate an evolved schema.
This is a potentially expensive operation as it reads the entire
existing file into memory.
Args:
new_table: The new batch of data with a different schema.
"""
print(f"INFO: Schema evolution detected. Rewriting {self.path}...", flush=True)
# Close the current writer before reading the file.
if self._parquet_writer:
self._parquet_writer.close()
# Read existing data, combine with new data, and create a unified table.
existing_table = pq.read_table(self.path)
combined_df = pd.concat(
[existing_table.to_pandas(), new_table.to_pandas()],
ignore_index=True,
sort=False,
)
# Re-sort columns for the new unified schema.
combined_df = combined_df.reindex(sorted(combined_df.columns), axis=1)
final_table = pa.Table.from_pandas(combined_df, preserve_index=False)
self._schema = final_table.schema
# Atomically replace the old file with the new one.
temp_path = self.path.with_suffix(f"{self.path.suffix}.tmp")
pq.write_table(
final_table, temp_path, compression=self._config.parquet_compression
)
os.replace(temp_path, self.path)
# Re-initialize the writer with the new schema for subsequent writes.
self._parquet_writer = pq.ParquetWriter(
self.path, self._schema, compression=self._config.parquet_compression
)
def _normalize_row(self, row: Row) -> Row:
"""
Sanitize all values in a row for Parquet compatibility.
"""
return {key: self._normalize_value(value) for key, value in row.items()}
def _normalize_value(self, value: t.Any) -> t.Any:
"""
Convert a single value to a Parquet-friendly format.
- NumPy arrays and Torch tensors are converted to nested lists.
- Other types are passed through for pandas to handle.
"""
if value is None:
return None
if np and isinstance(value, np.ndarray):
return value.tolist()
if torch and isinstance(value, torch.Tensor):
return value.detach().cpu().numpy().tolist()
# return value.detach().cpu().numpy()
return value