Init commit.
This commit is contained in:
70
olmoe_inference_vllm.py
Normal file
70
olmoe_inference_vllm.py
Normal file
@ -0,0 +1,70 @@
|
||||
# %%
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.distributed.parallel_state import destroy_model_parallel
|
||||
|
||||
# %%
|
||||
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
|
||||
|
||||
try:
|
||||
llm = LLM(
|
||||
model=model_id,
|
||||
cpu_offload_gb=4,
|
||||
# tensor_parallel_size=2,
|
||||
gpu_memory_utilization=0.90,
|
||||
max_model_len=4096,
|
||||
# compilation_config=CompilationConfig(
|
||||
# level=CompilationLevel.PIECEWISE,
|
||||
# # By default, it goes up to max_num_seqs
|
||||
# cudagraph_capture_sizes=[1, 2, 4, 8, 16],
|
||||
# ),
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.6,
|
||||
top_p=0.95,
|
||||
top_k=20,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
# Prepare the input to the model
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
# Generate outputs
|
||||
outputs = llm.chat(
|
||||
messages,
|
||||
sampling_params=sampling_params,
|
||||
# chat_template_kwargs={"enable_thinking": True}, # Set to False to strictly disable thinking
|
||||
)
|
||||
|
||||
# Print the outputs.
|
||||
for out in outputs:
|
||||
# out.prompt is the input prompt; out.outputs is a list of completion choices
|
||||
# print("=== PROMPT ===")
|
||||
# print(out.prompt)
|
||||
# print("=== COMPLETION ===")
|
||||
print(out.outputs[0].text)
|
||||
print("\n---\n")
|
||||
|
||||
print("Finish completion")
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
finally:
|
||||
if llm := globals().get("llm", None):
|
||||
if engine := getattr(llm, "llm_engine", None):
|
||||
# llm.llm_engine
|
||||
del engine
|
||||
del llm
|
||||
destroy_model_parallel()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
Reference in New Issue
Block a user