Irrelevant Responses with Unsloth Fine-tuned Llama 3.1 8B using vLLM

I’m encountering an issue while trying to run inference on a Llama 3.1 8B model fine-tuned using Unsloth. When I serve the model using vLLM, it generates irrelevant responses instead of following the instructions or the chat history provided.

vLLM code

# Model path (where the Unsloth fine-tuned model is saved)
model_path = "YOUR_MODEL_PATH"
system_prompt = "..." 
user_message = "Hello..." 


# Initialize vLLM
llm = LLM(
    model=model_path,
    trust_remote_code=True,
    tensor_parallel_size=1,
    dtype="float16",
    max_model_len=15000,
    load_format="safetensors",
    tokenizer=model_path,
    enable_prefix_caching=False
)

# Tokenizer & Chat Template Application
tokenizer = llm.get_tokenizer()
conversation = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": user_message},
]
token_ids = tokenizer.apply_chat_template(
    conversation,
    tokenize=True,
    add_generation_prompt=True 
)
prompt_input = TokensPrompt(prompt_token_ids=token_ids) # Using vllm.prompt_adapter.request

# Sampling Parameters
sampling_params = SamplingParams(
    temperature=0.6,
    top_p=0.9,
    max_tokens=1024,
   
)

# Generate Response
outputs = llm.generate([prompt_input], sampling_params)

# Process and print outputs
for output in outputs:
    generated_text = output.outputs[0].text
    print(f"Generated text: {generated_text!r}")

Unsloth code:

import time
import torch
from transformers import TextStreamer
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template

# --- Configuration ---

# Model loading parameters (should match training/saving)
max_seq_length = 4096
dtype = None          # Auto-detect, or set to float16/bfloat16
load_in_4bit = False # Set to True if the saved model is 4-bit

# Device for inference
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Inference parameters
inference_output_tokens = 128 # Max new tokens to generate

# Load the saved PEFT model for inference
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=output_dir, # Load the saved directory
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

# Set the correct chat template (important for formatting)
tokenizer = get_chat_template(
    tokenizer,
    chat_template="llama-3.1", # Specify the template name
)

# Enable Unsloth's fast inference mode
FastLanguageModel.for_inference(model)
model.to(device) # Move model to the desired device



# --- Prepare Input for Inference ---

# Example conversation - structure should match training data format
messages = [....]



# Apply the chat template to format the input correctly
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,  
    return_tensors="pt",
).to(device)


# --- Run Inference ---

print("\nRunning batch inference...")
start_time = time.time()
with torch.inference_mode(): # Use inference mode for efficiency
    outputs = model.generate(
        input_ids=inputs,
        max_new_tokens=inference_output_tokens,
        use_cache=True,
        # Optional: Add sampling parameters if desired (temperature, top_p, etc.)
        # temperature=0.6,
        # top_p=0.9,
        eos_token_id=tokenizer.eos_token_id 
    )
# Decode only the newly generated tokens
generated_text = tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0]

Some comments on your LLM instantiation:

llm = LLM(
    model=model_path,
    trust_remote_code=True,  # we natively support Llama so this shouldn't be necessary
    tensor_parallel_size=1, # the default is 1 so this shouldn't be necessary
    dtype="float16",
    max_model_len=15000,
    load_format="safetensors", # I believe this should be automatically detected
    tokenizer=model_path,
    enable_prefix_caching=False # why disable this?
)

Instead of formatting the prompt yourself, could you use llm.chat(conversation, sampling_params) instead?

@hmellor, thanks for the reply , I tried llm.chat too, but unfortunately, it gave the same irrelevant responses. I also disabled prefix caching for this test – I’ve noticed it sometimes repeats responses, so I wanted to rule that out. I’ve addressed the other comments as well.

It could also be because you are loading the model in float16, the original Llama 3.1 8B checkpoint (and presumably your local version) is bfloat16.

Could you provide some minimal code I can use to repro it? Perhaps uploading your model to hf.co for full reproducability.