Files
moe_explore/logits_post_process.py
Huxley 5d1d818138 Implement OlMoE logits logging.
Update: logging items in OlMoE.

Update: some changes for main script.

Add: logits post processing notebook.
2025-09-29 19:52:44 +08:00

60 lines
1.2 KiB
Python

# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.17.3
# kernelspec:
# display_name: venv
# language: python
# name: python3
# ---
# %%
import numpy as np
import pandas as pd
# import torch
# %%
log_file = "olmoe_20250929-165403.parquet"
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
# %%
df = pd.read_parquet(log_file)
df.head()
# %%
# logit = df.loc[1, "router_logits"]
# logit
# %%
main_mask = (df["src"] == "main")
token_ids = df.loc[main_mask, "token_ids"].item()
output_text = df.loc[main_mask, "output_text"].item()
# %%
lm_mask = (df["src"] == "lm_logit")
df.loc[lm_mask, "logits"] = df.loc[lm_mask, "logits"].apply(lambda arr: np.stack([a for a in arr]).flatten())
df.loc[lm_mask, "token_id"] = df.loc[lm_mask, "logits"].apply(lambda l: np.argmax(l, axis=-1))
df.head()
# %%
df[lm_mask, "token_id"].to_numpy()
# %%
token_ids
# %%
import transformers
from transformers import AutoTokenizer, GPTNeoXTokenizerFast
tokenizer: GPTNeoXTokenizerFast = AutoTokenizer.from_pretrained(model_id)
# %%
tokenizer.decode(token_ids)