From 5942bbfd05c65dee1988c8a65a288761bd4a3d65 Mon Sep 17 00:00:00 2001 From: Huxley Date: Sat, 27 Sep 2025 16:39:55 +0800 Subject: [PATCH] Init commit. --- .clineignore | 225 ++++++++++++++++ .gitignore | 225 ++++++++++++++++ models/.gitignore | 1 + models/__init__.py | 0 models/log_expert/__init__.py | 0 models/log_expert/olmoe.py | 489 +++++++++++++++++++++++++++++++++ olmoe_inference.ipynb | 91 +++++++ olmoe_inference_vllm.py | 70 +++++ requirements.txt | 9 + utils/__init__.py | 1 + utils/logger.py | 493 ++++++++++++++++++++++++++++++++++ 11 files changed, 1604 insertions(+) create mode 100644 .clineignore create mode 100644 .gitignore create mode 100644 models/.gitignore create mode 100644 models/__init__.py create mode 100644 models/log_expert/__init__.py create mode 100644 models/log_expert/olmoe.py create mode 100644 olmoe_inference.ipynb create mode 100644 olmoe_inference_vllm.py create mode 100644 requirements.txt create mode 100644 utils/__init__.py create mode 100644 utils/logger.py diff --git a/.clineignore b/.clineignore new file mode 100644 index 0000000..5d600fd --- /dev/null +++ b/.clineignore @@ -0,0 +1,225 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +# poetry.lock +# poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +# pdm.lock +# pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +# pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# Redis +*.rdb +*.aof +*.pid + +# RabbitMQ +mnesia/ +rabbitmq/ +rabbitmq-data/ + +# ActiveMQ +activemq-data/ + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +# .idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml + +# Python virtual environments +.venv +venv +.venv/ +venv/ + +# Checkpoints for LLMs +llms/ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5d600fd --- /dev/null +++ b/.gitignore @@ -0,0 +1,225 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +# poetry.lock +# poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +# pdm.lock +# pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +# pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# Redis +*.rdb +*.aof +*.pid + +# RabbitMQ +mnesia/ +rabbitmq/ +rabbitmq-data/ + +# ActiveMQ +activemq-data/ + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +# .idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml + +# Python virtual environments +.venv +venv +.venv/ +venv/ + +# Checkpoints for LLMs +llms/ diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..b17529c --- /dev/null +++ b/models/.gitignore @@ -0,0 +1 @@ +vanilla_vllm/ \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/log_expert/__init__.py b/models/log_expert/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/log_expert/olmoe.py b/models/log_expert/olmoe.py new file mode 100644 index 0000000..734b6d7 --- /dev/null +++ b/models/log_expert/olmoe.py @@ -0,0 +1,489 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMoE model compatible with HuggingFace weights. + +This file is origin from vllm/model_executor/models/olmoe.py +""" +from collections.abc import Iterable +from functools import partial +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import OlmoeConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) +from vllm.distributed.utils import split_tensor_along_last_dim +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +from utils import DataLogger + +logger = init_logger(__name__) + + +class OlmoeMoE(nn.Module): + """A tensor-parallel MoE implementation for Olmoe that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + quant_config=None) + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + return final_hidden_states.view(orig_shape) + + +class OlmoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 4096, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.tp_size = tp_size + self.tp_rank = get_tensor_model_parallel_rank() + self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5) + self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, + eps=1e-5) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm(q) + k = self.k_norm(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoeDecoderLayer(nn.Module): + + def __init__( + self, + config: OlmoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 4096) + + self.self_attn = OlmoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + self.mlp = OlmoeMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class OlmoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.vocab_size = config.vocab_size + self.config = config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OlmoeDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 + name, + remapped_kv_scale_name, + ) + continue + else: + name = remapped_kv_scale_name + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class OlmoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = OlmoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/olmoe_inference.ipynb b/olmoe_inference.ipynb new file mode 100644 index 0000000..275efe7 --- /dev/null +++ b/olmoe_inference.ipynb @@ -0,0 +1,91 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "928261ae", + "metadata": {}, + "outputs": [], + "source": [ + "from vllm import LLM, SamplingParams" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a35fc2b4", + "metadata": {}, + "outputs": [], + "source": [ + "model_id = \"./llms/OLMoE-1B-7B-0924-Instruct\"\n", + "\n", + "llm = LLM(\n", + " model=model_id,\n", + " # cpu_offload_gb=4,\n", + " tensor_parallel_size=2,\n", + " # gpu_memory_utilization=0.90,\n", + " max_model_len=4096,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a708f11", + "metadata": {}, + "outputs": [], + "source": [ + "sampling_params = SamplingParams(\n", + " temperature=0.6,\n", + " top_p=0.95,\n", + " top_k=20,\n", + " max_tokens=1024,\n", + ")\n", + "\n", + "# Prepare the input to the model\n", + "prompt = \"Give me a short introduction to large language models.\"\n", + "messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": prompt}\n", + "]\n", + "\n", + "# Generate outputs\n", + "outputs = llm.chat(\n", + " messages, \n", + " sampling_params=sampling_params,\n", + " # chat_template_kwargs={\"enable_thinking\": True}, # Set to False to strictly disable thinking\n", + ")\n", + "\n", + "# Print the outputs.\n", + "for out in outputs:\n", + " # out.prompt is the input prompt; out.outputs is a list of completion choices\n", + " print(\"=== PROMPT ===\")\n", + " print(out.prompt)\n", + " print(\"=== COMPLETION ===\")\n", + " print(out.outputs[0].text)\n", + " print(\"\\n---\\n\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "moe-explore", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/olmoe_inference_vllm.py b/olmoe_inference_vllm.py new file mode 100644 index 0000000..77b3e00 --- /dev/null +++ b/olmoe_inference_vllm.py @@ -0,0 +1,70 @@ +# %% +import gc + +import torch +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig, CompilationLevel +from vllm.distributed.parallel_state import destroy_model_parallel + +# %% +model_id = "./llms/OLMoE-1B-7B-0924-Instruct" + +try: + llm = LLM( + model=model_id, + cpu_offload_gb=4, + # tensor_parallel_size=2, + gpu_memory_utilization=0.90, + max_model_len=4096, + # compilation_config=CompilationConfig( + # level=CompilationLevel.PIECEWISE, + # # By default, it goes up to max_num_seqs + # cudagraph_capture_sizes=[1, 2, 4, 8, 16], + # ), + enforce_eager=True, + ) + + sampling_params = SamplingParams( + temperature=0.6, + top_p=0.95, + top_k=20, + max_tokens=1024, + ) + + # Prepare the input to the model + prompt = "Give me a short introduction to large language models." + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + + # Generate outputs + outputs = llm.chat( + messages, + sampling_params=sampling_params, + # chat_template_kwargs={"enable_thinking": True}, # Set to False to strictly disable thinking + ) + + # Print the outputs. + for out in outputs: + # out.prompt is the input prompt; out.outputs is a list of completion choices + # print("=== PROMPT ===") + # print(out.prompt) + # print("=== COMPLETION ===") + print(out.outputs[0].text) + print("\n---\n") + + print("Finish completion") + +except Exception as e: + print(e) + +finally: + if llm := globals().get("llm", None): + if engine := getattr(llm, "llm_engine", None): + # llm.llm_engine + del engine + del llm + destroy_model_parallel() + torch.cuda.empty_cache() + gc.collect() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5702385 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +vllm==0.10.1.1 +notebook +ipywidgets +python-dotenv +pandas +datasets +accelerate +kernels +modelscope diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..648b87a --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +from .logger import DataLogger, LoggerConfig \ No newline at end of file diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..53c2c79 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,493 @@ +""" +Asynchronous, batched, and schema-evolving Parquet logger. + +This module provides the `DataLogger`, a high-performance logger for structured +data, designed for applications like machine learning experiments, simulations, +or any scenario requiring efficient serialization of row-based data. + +Key Features: +- **Unified Interface**: Log data via a simple `DataLogger.log({"key": "value"})` call. +- **Asynchronous & Batched**: A dedicated background thread handles I/O, + batching rows to minimize disk writes and reduce application latency. +- **Schema Evolution**: Automatically adapts the Parquet schema if new data fields + are introduced, rewriting the file to maintain a consistent structure. +- **Singleton Pattern**: A global singleton instance is managed automatically, + providing a convenient, fire-and-forget logging experience. +- **Type Handling**: Natively handles Python primitives, NumPy arrays, and PyTorch + tensors, converting them to Parquet-compatible formats. +- **Robust & Thread-Safe**: Designed for use in multi-threaded environments. + +Basic Usage: +------------- +.. code-block:: python + + from logger.data_logger import DataLogger + + # The first call creates and configures the singleton logger. + # A timestamped filename is generated by default. + DataLogger.log({"step": 0, "loss": 10.5, "accuracy": 0.5}) + DataLogger.log({"step": 1, "loss": 9.8, "accuracy": 0.55}) + + # For the singleton, data is automatically flushed and saved on program exit. + # No explicit `close()` call is required for this simple case. + +Advanced Usage (Instance-Based): +--------------------------------- +.. code-block:: python + + from logger.data_logger import DataLogger, LoggerConfig + + config = LoggerConfig(batch_size=512, flush_interval=5.0) + with DataLogger("my_experiment.parquet", config=config) as logger: + for i in range(1000): + logger.submit({"value": i}) + # The `with` statement ensures flush and close on exit. +""" + +from __future__ import annotations + +import datetime +import atexit +import os +import queue +import threading +import time +import traceback +import typing as t +from dataclasses import dataclass +from pathlib import Path + +# Third-party libraries are imported with runtime checks to provide clear +# error messages if they are not installed. +try: + import numpy as np +except ImportError: + np = None # type: ignore + +try: + import pandas as pd +except ImportError: + raise ImportError( + "pandas is required for DataLogger. Install with `pip install pandas`." + ) + +try: + import pyarrow as pa + import pyarrow.parquet as pq +except ImportError: + raise ImportError( + "pyarrow is required for DataLogger. Install with `pip install pyarrow`." + ) + +try: + import torch +except ImportError: + torch = None # type: ignore + +# Type alias for a single row of data. +Row = t.Dict[str, t.Any] + + +@dataclass +class LoggerConfig: + """Configuration for the DataLogger's writer behavior.""" + + batch_size: int = 1024 + """Number of rows to accumulate before writing a batch to the Parquet file.""" + + flush_interval: float = 1.0 + """Maximum time in seconds to wait before flushing the buffer, even if + `batch_size` is not reached.""" + + parquet_compression: str = "snappy" + """Compression codec to use for the Parquet file. + Common options: 'snappy', 'gzip', 'brotli', 'none'.""" + + allow_schema_rewrite: bool = True + """If True, the logger will automatically rewrite the entire Parquet file to + accommodate new columns. If False, it will raise an error.""" + + +class DataLogger: + """ + An asynchronous, batched logger that writes data to a Parquet file. + + This class manages a background thread to handle file I/O, allowing the + calling application to log data with minimal blocking. It supports schema + evolution, making it robust to changes in data structure over time. + """ + + _singleton: t.Optional["DataLogger"] = None + _singleton_lock = threading.Lock() + + # --- Public API --- + + @classmethod + def get_instance( + cls, + path: t.Optional[t.Union[str, Path]] = None, + config: t.Optional[LoggerConfig] = None, + ) -> "DataLogger": + """ + Get or create the global singleton instance of the DataLogger. + + The first time this method is called, it creates a new `DataLogger` + instance and registers a cleanup function via `atexit` to ensure + `close()` is called automatically upon program termination. + + Subsequent calls will ignore the arguments and return the existing + instance. + + Args: + path: The file path for the log file. If None, a timestamped + filename like 'log_YYYYMMDD-HHMMSS.parquet' is created in the + current working directory. + config: A `LoggerConfig` object to configure the writer's behavior. + If None, default settings are used. + + Returns: + The singleton `DataLogger` instance. + """ + if cls._singleton is None: + with cls._singleton_lock: + if cls._singleton is None: + # Create the singleton instance. + instance = cls(path, config) + # Register its close method to be called at program exit. + # This ensures data is saved even if the user forgets to call close(). + atexit.register(instance.close) + cls._singleton = instance + return cls._singleton + + @classmethod + def log(cls, row: Row) -> None: + """ + Log a data row using the singleton instance. + + This is a convenience method that lazily initializes the singleton on + its first call. The operation is non-blocking; the data is placed in + an internal queue to be processed by the background writer thread. + + Args: + row: A dictionary representing a single row of data, where keys + are column names and values are the data points. + """ + instance = cls.get_instance() + instance.submit(row) + + def __init__( + self, + path: t.Optional[t.Union[str, Path]] = None, + config: t.Optional[LoggerConfig] = None, + ): + """ + Initialize a DataLogger instance. + + Args: + path: The file path for the log file. If None, a timestamped + filename is automatically generated. + config: A `LoggerConfig` object. If None, default settings are used. + """ + self.path = self._resolve_path(path) + self._config = config or LoggerConfig() + + # Internal state for the writer thread + self._queue: queue.Queue[t.Optional[Row]] = queue.Queue() + self._stop_event = threading.Event() + self._flush_event = threading.Event() + self._writer_thread: t.Optional[threading.Thread] = None + self._writer_lock = threading.RLock() # Protects writer and schema + + # Parquet-specific state, managed exclusively by the writer thread + self._parquet_writer: t.Optional[pq.ParquetWriter] = None + self._schema: t.Optional[pa.Schema] = None + self._buffer: t.List[Row] = [] + + self._start_writer_thread() + + def submit(self, row: Row) -> None: + """ + Submit a data row to be written asynchronously by the logger instance. + + Args: + row: A dictionary representing a single row of data. + + Raises: + TypeError: If the provided row is not a dictionary. + RuntimeError: If the logger has already been closed. + """ + if self._stop_event.is_set(): + raise RuntimeError("Logger has been closed and cannot accept new data.") + if not isinstance(row, dict): + raise TypeError(f"Expected a dict for a row, but got {type(row)}.") + + normalized_row = self._normalize_row(row) + self._queue.put(normalized_row) + + def flush(self, timeout: float = 10.0) -> None: + """ + Block until all currently queued and buffered data is written to disk. + + Args: + timeout: Maximum time in seconds to wait for the flush to complete. + """ + if self._writer_thread is None or not self._writer_thread.is_alive(): + return + + self._flush_event.clear() + self._queue.put(None) # Sentinel to trigger a flush + self._flush_event.wait(timeout) + + def close(self, timeout: float = 10.0) -> None: + """ + Flush all remaining data and shut down the background writer thread. + + This method is idempotent and thread-safe. It is designed to be + called explicitly, via a `with` statement, or automatically at program + exit. + + Args: + timeout: Maximum time in seconds to wait for the writer thread + to finish. + """ + if self._stop_event.is_set(): + return + + self._stop_event.set() + self._queue.put(None) # Wake up the writer thread if it's blocking. + + # Do not join the writer thread from itself, which would cause a deadlock. + if self._writer_thread and threading.current_thread() != self._writer_thread: + self._writer_thread.join(timeout) + + def __enter__(self) -> "DataLogger": + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Ensures the logger is closed upon exiting a `with` block.""" + self.close() + + def __del__(self): + """Ensures data is flushed when the logger object is destroyed.""" + self.close() + + # --- Internal Methods --- + + def _resolve_path(self, path: t.Optional[t.Union[str, Path]]) -> Path: + """Determine the final output path for the log file.""" + if path is None: + timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + filename = f"log_{timestamp}.parquet" + return Path.cwd() / filename + + resolved_path = Path(path) + if resolved_path.suffix == "": + resolved_path = resolved_path.with_suffix(".parquet") + return resolved_path + + def _start_writer_thread(self) -> None: + """Initialize and start the background writer thread.""" + if self._writer_thread is not None: + return + thread_name = f"DataLoggerWriter-{self.path.name}" + self._writer_thread = threading.Thread( + target=self._writer_loop, name=thread_name, daemon=True + ) + self._writer_thread.start() + + def _writer_loop(self) -> None: + """ + The main loop for the background writer thread. + + This loop continuously pulls data from the queue, batches it, and + writes it to the Parquet file. It handles flush signals, stop events, + and schema evolution. + """ + try: + while not self._stop_event.is_set(): + try: + # Block until an item is available or the flush interval times out. + item = self._queue.get(timeout=self._config.flush_interval) + except queue.Empty: + # Timeout occurred, treat as a periodic flush signal. + item = None + + if item is not None: + self._buffer.append(item) + + buffer_size = len(self._buffer) + is_flush_signal = item is None + is_batch_full = buffer_size >= self._config.batch_size + is_shutting_down = self._stop_event.is_set() + + if self._buffer and ( + is_flush_signal or is_batch_full or is_shutting_down + ): + self._write_batch(self._buffer) + self._buffer.clear() + + if is_flush_signal: + self._flush_event.set() # Signal that a flush completed + + # Final drain of the queue and buffer after the stop event is set. + self._drain_remaining() + + except Exception as e: + print(f"FATAL: DataLogger writer thread crashed: {e}", flush=True) + traceback.print_exc() + finally: + # This block ensures that the Parquet writer is always closed + # when the writer thread exits, for any reason. + with self._writer_lock: + if self._parquet_writer: + try: + self._parquet_writer.close() + except Exception as e: + print( + f"ERROR: Exception while closing Parquet writer: {e}", + flush=True, + ) + self._parquet_writer = None + + def _drain_remaining(self) -> None: + """Process all remaining items in the queue and buffer during shutdown.""" + while True: + try: + item = self._queue.get_nowait() + if item: + self._buffer.append(item) + except queue.Empty: + break + if self._buffer: + self._write_batch(self._buffer) + self._buffer.clear() + + def _write_batch(self, rows: t.List[Row]) -> None: + """ + Convert a list of rows into a Parquet table and write it to the file. + + This method handles schema creation, validation, and evolution. + It is always executed within the writer thread. + """ + if not rows: + return + + try: + with self._writer_lock: + df = pd.DataFrame(rows) + # Ensure a consistent column order for schema stability. + df = df.reindex(sorted(df.columns), axis=1) + new_table = pa.Table.from_pandas(df, preserve_index=False) + + if self.path.exists(): + # File exists, need to append or evolve schema + existing_table = pq.read_table(self.path) + existing_schema = existing_table.schema + + if existing_schema.equals(new_table.schema): + # Schema matches, append the data + combined_table = pa.concat_tables([existing_table, new_table]) + else: + # Schema evolution needed + if not self._config.allow_schema_rewrite: + raise RuntimeError( + "Schema mismatch detected, and rewriting is disabled. " + f"Existing schema: {existing_schema}, New schema: {new_table.schema}" + ) + print( + f"INFO: Schema evolution detected. Rewriting {self.path}...", + flush=True, + ) + # Combine with schema evolution + combined_df = pd.concat( + [existing_table.to_pandas(), new_table.to_pandas()], + ignore_index=True, + sort=False, + ) + combined_df = combined_df.reindex( + sorted(combined_df.columns), axis=1 + ) + combined_table = pa.Table.from_pandas( + combined_df, preserve_index=False + ) + else: + # New file + self.path.parent.mkdir(parents=True, exist_ok=True) + combined_table = new_table + + # Write the combined table atomically + temp_path = self.path.with_suffix(f"{self.path.suffix}.tmp") + pq.write_table( + combined_table, + temp_path, + compression=self._config.parquet_compression, + ) + os.replace(temp_path, self.path) + + # Update our schema tracking + self._schema = combined_table.schema + + except Exception as e: + print(f"ERROR: Failed to write batch to {self.path}: {e}", flush=True) + traceback.print_exc() + + def _rewrite_with_new_schema(self, new_table: pa.Table) -> None: + """ + Rewrite the entire Parquet file to accommodate an evolved schema. + + This is a potentially expensive operation as it reads the entire + existing file into memory. + + Args: + new_table: The new batch of data with a different schema. + """ + print(f"INFO: Schema evolution detected. Rewriting {self.path}...", flush=True) + + # Close the current writer before reading the file. + if self._parquet_writer: + self._parquet_writer.close() + + # Read existing data, combine with new data, and create a unified table. + existing_table = pq.read_table(self.path) + combined_df = pd.concat( + [existing_table.to_pandas(), new_table.to_pandas()], + ignore_index=True, + sort=False, + ) + # Re-sort columns for the new unified schema. + combined_df = combined_df.reindex(sorted(combined_df.columns), axis=1) + final_table = pa.Table.from_pandas(combined_df, preserve_index=False) + self._schema = final_table.schema + + # Atomically replace the old file with the new one. + temp_path = self.path.with_suffix(f"{self.path.suffix}.tmp") + pq.write_table( + final_table, temp_path, compression=self._config.parquet_compression + ) + os.replace(temp_path, self.path) + + # Re-initialize the writer with the new schema for subsequent writes. + self._parquet_writer = pq.ParquetWriter( + self.path, self._schema, compression=self._config.parquet_compression + ) + + def _normalize_row(self, row: Row) -> Row: + """ + Sanitize all values in a row for Parquet compatibility. + """ + return {key: self._normalize_value(value) for key, value in row.items()} + + def _normalize_value(self, value: t.Any) -> t.Any: + """ + Convert a single value to a Parquet-friendly format. + - NumPy arrays and Torch tensors are converted to nested lists. + - Other types are passed through for pandas to handle. + """ + if value is None: + return None + if np and isinstance(value, np.ndarray): + return value.tolist() + if torch and isinstance(value, torch.Tensor): + return value.detach().cpu().numpy().tolist() + # return value.detach().cpu().numpy() + return value