Implement OlMoE logits logging.
Update: logging items in OlMoE. Update: some changes for main script. Add: logits post processing notebook.
This commit is contained in:
59
logits_post_process.py
Normal file
59
logits_post_process.py
Normal file
@ -0,0 +1,59 @@
|
||||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# formats: ipynb,py:percent
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.17.3
|
||||
# kernelspec:
|
||||
# display_name: venv
|
||||
# language: python
|
||||
# name: python3
|
||||
# ---
|
||||
|
||||
# %%
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
# import torch
|
||||
|
||||
# %%
|
||||
log_file = "olmoe_20250929-165403.parquet"
|
||||
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
|
||||
|
||||
# %%
|
||||
df = pd.read_parquet(log_file)
|
||||
df.head()
|
||||
|
||||
# %%
|
||||
# logit = df.loc[1, "router_logits"]
|
||||
# logit
|
||||
|
||||
# %%
|
||||
main_mask = (df["src"] == "main")
|
||||
|
||||
token_ids = df.loc[main_mask, "token_ids"].item()
|
||||
output_text = df.loc[main_mask, "output_text"].item()
|
||||
|
||||
# %%
|
||||
lm_mask = (df["src"] == "lm_logit")
|
||||
|
||||
df.loc[lm_mask, "logits"] = df.loc[lm_mask, "logits"].apply(lambda arr: np.stack([a for a in arr]).flatten())
|
||||
df.loc[lm_mask, "token_id"] = df.loc[lm_mask, "logits"].apply(lambda l: np.argmax(l, axis=-1))
|
||||
df.head()
|
||||
|
||||
# %%
|
||||
df[lm_mask, "token_id"].to_numpy()
|
||||
|
||||
# %%
|
||||
token_ids
|
||||
|
||||
# %%
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GPTNeoXTokenizerFast
|
||||
|
||||
tokenizer: GPTNeoXTokenizerFast = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# %%
|
||||
tokenizer.decode(token_ids)
|
Reference in New Issue
Block a user