Skip to content

Few-shot batch inference for text classification (RAG)

This is a simple example of how to load an LLM from huggingface to perform text classification for a list of review/label pairs with retrieval-augmented generation (RAG).

This example uses the facebook/opt-350m model as the base LLM model.

Semantically relevant examples are injected into the prompt as additional context, as specified by the prompt section of the Ludwig config.

prompt:
    task: "Classify the sample input as either negative, neutral, or positive."
    retrieval:
        type: semantic
        k: 3
        model_name: paraphrase-MiniLM-L3-v2

The LLM generates textual output and the results are decoded into labels using regex-based post-processing, specified in the Ludwig configuration.

output_features:
-
    name: label
    type: category
    preprocessing:
        fallback_label: "neutral"
    decoder:
        type: category_extractor
        match:
            "negative":
                type: contains
                value: "positive"
            "neural":
                type: contains
                value: "neutral"
            "positive":
                type: contains
                value: "positive"

Sample code

#!/usr/bin/env python

"""
This is a simple example of how to use the LLM model type to train
a zero shot classification model. It uses the facebook/opt-350m model
as the base LLM model.
"""

# Import required libraries
import logging
import shutil

import pandas as pd
import yaml

from ludwig.api import LudwigModel

# clean out prior results
shutil.rmtree("./results", ignore_errors=True)

review_label_pairs = [
    {"review": "I loved this movie!", "label": "positive"},
    {"review": "The food was okay, but the service was terrible.", "label": "negative"},
    {"review": "I can't believe how rude the staff was.", "label": "negative"},
    {"review": "This book was a real page-turner.", "label": "positive"},
    {"review": "The hotel room was dirty and smelled bad.", "label": "negative"},
    {"review": "I had a great experience at this restaurant.", "label": "positive"},
    {"review": "The concert was amazing!", "label": "positive"},
    {"review": "The traffic was terrible on my way to work this morning.", "label": "negative"},
    {"review": "The customer service was excellent.", "label": "positive"},
    {"review": "I was disappointed with the quality of the product.", "label": "negative"},
    {"review": "The scenery on the hike was breathtaking.", "label": "positive"},
    {"review": "I had a terrible experience at this hotel.", "label": "negative"},
    {"review": "The coffee at this cafe was delicious.", "label": "positive"},
    {"review": "The weather was perfect for a day at the beach.", "label": "positive"},
    {"review": "I would definitely recommend this product.", "label": "positive"},
    {"review": "The wait time at the doctor's office was ridiculous.", "label": "negative"},
    {"review": "The museum was a bit underwhelming.", "label": "neutral"},
    {"review": "I had a fantastic time at the amusement park.", "label": "positive"},
    {"review": "The staff at this store was extremely helpful.", "label": "positive"},
    {"review": "The airline lost my luggage and was very unhelpful.", "label": "negative"},
    {"review": "This album is a must-listen for any music fan.", "label": "positive"},
    {"review": "The food at this restaurant was just okay.", "label": "neutral"},
    {"review": "I was pleasantly surprised by how great this movie was.", "label": "positive"},
    {"review": "The car rental process was quick and easy.", "label": "positive"},
    {"review": "The service at this hotel was top-notch.", "label": "positive"},
]

df = pd.DataFrame(review_label_pairs)
df["split"] = [0] * 15 + [2] * 10

config = yaml.safe_load(
    """
model_type: llm
base_model: facebook/opt-350m
generation:
    temperature: 0.1
    top_p: 0.75
    top_k: 40
    num_beams: 4
    max_new_tokens: 64
prompt:
    task: "Classify the sample input as either negative, neutral, or positive."
    retrieval:
        type: semantic
        k: 3
        model_name: paraphrase-MiniLM-L3-v2
input_features:
-
    name: review
    type: text
output_features:
-
    name: label
    type: category
    preprocessing:
        fallback_label: "neutral"
    decoder:
        type: category_extractor
        match:
            "negative":
                type: contains
                value: "positive"
            "neural":
                type: contains
                value: "neutral"
            "positive":
                type: contains
                value: "positive"
preprocessing:
    split:
        type: fixed
    """
)

# Define Ludwig model object that drive model training
model = LudwigModel(config=config, logging_level=logging.INFO)

# Loads the model and performs no training.
(
    train_stats,  # dictionary containing training statistics
    preprocessed_data,  # tuple Ludwig Dataset objects of pre-processed training data
    output_directory,  # location of training results stored on disk
) = model.train(
    dataset=df, experiment_name="simple_experiment", model_name="simple_model", skip_save_processed_input=True
)

training_set, val_set, test_set, _ = preprocessed_data

# batch prediction
preds, _ = model.predict(test_set, skip_save_predictions=False)
print(preds)