Categories

Archives

A sample text widget

Etiam pulvinar consectetur dolor sed malesuada. Ut convallis euismod dolor nec pretium. Nunc ut tristique massa.

Nam sodales mi vitae dolor ullamcorper et vulputate enim accumsan. Morbi orci magna, tincidunt vitae molestie nec, molestie at mi. Nulla nulla lorem, suscipit in posuere in, interdum non magna.

Retrieval Augmented Fine-Tuning (RAFT): Training LLMs to Excel at RAG Tasks

Introduction: Retrieval Augmented Fine-Tuning (RAFT) represents a powerful approach to improving LLM performance on domain-specific tasks by combining the benefits of fine-tuning with retrieval-augmented generation. Traditional RAG systems retrieve relevant documents at inference time and include them in the prompt, but the base model wasn’t trained to effectively use retrieved context. RAFT addresses this by fine-tuning the model on examples that include both relevant and irrelevant retrieved documents, teaching it to identify and extract information from noisy retrieval results. The result is a model that’s significantly better at answering questions from your specific document corpus—it learns not just the domain knowledge, but how to reason over retrieved context. This guide covers the RAFT methodology, from generating training data to fine-tuning strategies and evaluation approaches.

Retrieval Augmented Fine-Tuning
RAFT: Generate QA Pairs, Augment with Retrieval, Fine-Tune

RAFT Data Generation

from dataclasses import dataclass, field
from typing import Any, Optional, List, Dict
from enum import Enum
import random
import json

class DocumentType(Enum):
    """Types of documents in training data."""
    
    ORACLE = "oracle"  # Contains the answer
    DISTRACTOR = "distractor"  # Doesn't contain the answer

@dataclass
class Document:
    """A document for RAFT training."""
    
    content: str
    doc_id: str
    doc_type: DocumentType = DocumentType.DISTRACTOR
    metadata: dict = field(default_factory=dict)

@dataclass
class RAFTExample:
    """A single RAFT training example."""
    
    question: str
    answer: str
    oracle_document: Document
    distractor_documents: list[Document]
    chain_of_thought: str = ""

@dataclass
class RAFTDataset:
    """Dataset for RAFT training."""
    
    examples: list[RAFTExample]
    
    def to_training_format(
        self,
        include_cot: bool = True,
        num_distractors: int = 4
    ) -> list[dict]:
        """Convert to training format."""
        
        training_data = []
        
        for example in self.examples:
            # Select distractors
            distractors = random.sample(
                example.distractor_documents,
                min(num_distractors, len(example.distractor_documents))
            )
            
            # Build context with oracle and distractors
            all_docs = [example.oracle_document] + distractors
            random.shuffle(all_docs)
            
            context = self._format_context(all_docs)
            
            # Build answer with optional CoT
            if include_cot and example.chain_of_thought:
                answer = f"{example.chain_of_thought}\n\nAnswer: {example.answer}"
            else:
                answer = example.answer
            
            training_data.append({
                "instruction": f"Answer the question based on the provided documents.\n\nDocuments:\n{context}\n\nQuestion: {example.question}",
                "response": answer
            })
        
        return training_data
    
    def _format_context(self, documents: list[Document]) -> str:
        """Format documents as context."""
        
        parts = []
        for i, doc in enumerate(documents):
            parts.append(f"[Document {i+1}]\n{doc.content}")
        
        return "\n\n".join(parts)

class QAGenerator:
    """Generate question-answer pairs from documents."""
    
    def __init__(self, llm_client: Any):
        self.llm = llm_client
    
    async def generate_qa_pairs(
        self,
        document: str,
        num_pairs: int = 5
    ) -> list[dict]:
        """Generate QA pairs from document."""
        
        prompt = f"""Generate {num_pairs} question-answer pairs based on the following document.
Each question should be answerable using only the information in the document.
Include a mix of factual, reasoning, and inference questions.

Document:
{document}

Generate the QA pairs in this JSON format:
[
  {{"question": "...", "answer": "...", "reasoning": "..."}}
]

QA Pairs:"""
        
        response = await self.llm.generate(prompt)
        
        # Parse JSON response
        try:
            import re
            json_match = re.search(r'\[[\s\S]*\]', response)
            if json_match:
                return json.loads(json_match.group(0))
        except json.JSONDecodeError:
            pass
        
        return []
    
    async def generate_cot(
        self,
        question: str,
        answer: str,
        document: str
    ) -> str:
        """Generate chain-of-thought reasoning."""
        
        prompt = f"""Given a question, answer, and source document, generate a step-by-step reasoning chain that explains how to arrive at the answer.

Document:
{document}

Question: {question}
Answer: {answer}

Generate a clear chain-of-thought reasoning that:
1. Identifies relevant information in the document
2. Explains the logical steps to reach the answer
3. Cites specific parts of the document

Chain of Thought:"""
        
        return await self.llm.generate(prompt)

class RAFTDataGenerator:
    """Generate complete RAFT training dataset."""
    
    def __init__(
        self,
        llm_client: Any,
        embedding_model: Any
    ):
        self.llm = llm_client
        self.embedder = embedding_model
        self.qa_generator = QAGenerator(llm_client)
    
    async def generate_dataset(
        self,
        documents: list[str],
        qa_per_doc: int = 5,
        distractors_per_example: int = 4,
        include_cot: bool = True
    ) -> RAFTDataset:
        """Generate RAFT dataset from documents."""
        
        examples = []
        
        # Generate QA pairs for each document
        for i, doc in enumerate(documents):
            qa_pairs = await self.qa_generator.generate_qa_pairs(doc, qa_per_doc)
            
            for qa in qa_pairs:
                # Generate CoT if requested
                cot = ""
                if include_cot:
                    cot = await self.qa_generator.generate_cot(
                        qa["question"],
                        qa["answer"],
                        doc
                    )
                
                # Select distractor documents
                distractor_docs = [
                    Document(
                        content=d,
                        doc_id=f"doc_{j}",
                        doc_type=DocumentType.DISTRACTOR
                    )
                    for j, d in enumerate(documents)
                    if j != i
                ]
                
                # Sample distractors
                selected_distractors = random.sample(
                    distractor_docs,
                    min(distractors_per_example, len(distractor_docs))
                )
                
                example = RAFTExample(
                    question=qa["question"],
                    answer=qa["answer"],
                    oracle_document=Document(
                        content=doc,
                        doc_id=f"doc_{i}",
                        doc_type=DocumentType.ORACLE
                    ),
                    distractor_documents=selected_distractors,
                    chain_of_thought=cot
                )
                
                examples.append(example)
        
        return RAFTDataset(examples=examples)
    
    async def generate_hard_negatives(
        self,
        question: str,
        oracle_doc: str,
        candidate_docs: list[str],
        top_k: int = 4
    ) -> list[str]:
        """Generate hard negative distractors using embedding similarity."""
        
        # Embed question
        question_emb = await self.embedder.embed(question)
        
        # Embed candidates
        candidate_embs = await self.embedder.embed_batch(candidate_docs)
        
        # Calculate similarities
        import numpy as np
        similarities = []
        
        for i, emb in enumerate(candidate_embs):
            if candidate_docs[i] != oracle_doc:
                sim = np.dot(question_emb, emb) / (
                    np.linalg.norm(question_emb) * np.linalg.norm(emb)
                )
                similarities.append((i, sim))
        
        # Sort by similarity (highest = hardest negatives)
        similarities.sort(key=lambda x: x[1], reverse=True)
        
        # Return top-k hard negatives
        return [candidate_docs[i] for i, _ in similarities[:top_k]]

Training Data Formatting

from dataclasses import dataclass
from typing import Any, Optional, List
import json
import random

@dataclass
class TrainingConfig:
    """Configuration for training data generation."""
    
    include_oracle_only_examples: float = 0.2  # % of examples with only oracle
    include_no_oracle_examples: float = 0.1  # % of examples without oracle
    num_distractors: int = 4
    include_cot: bool = True
    shuffle_documents: bool = True

class TrainingDataFormatter:
    """Format RAFT data for different training frameworks."""
    
    def __init__(self, config: TrainingConfig):
        self.config = config
    
    def format_for_sft(
        self,
        dataset: RAFTDataset
    ) -> list[dict]:
        """Format for supervised fine-tuning."""
        
        formatted = []
        
        for example in dataset.examples:
            # Determine example type
            rand = random.random()
            
            if rand < self.config.include_oracle_only_examples:
                # Oracle only
                docs = [example.oracle_document]
            elif rand < self.config.include_oracle_only_examples + self.config.include_no_oracle_examples:
                # No oracle (teach model to say "I don't know")
                docs = example.distractor_documents[:self.config.num_distractors]
                example.answer = "I cannot find the answer in the provided documents."
            else:
                # Normal: oracle + distractors
                docs = [example.oracle_document] + example.distractor_documents[:self.config.num_distractors]
            
            if self.config.shuffle_documents:
                random.shuffle(docs)
            
            # Format context
            context = self._format_documents(docs)
            
            # Format answer
            if self.config.include_cot and example.chain_of_thought:
                answer = f"\n{example.chain_of_thought}\n\n\n{example.answer}"
            else:
                answer = example.answer
            
            formatted.append({
                "messages": [
                    {
                        "role": "system",
                        "content": "You are a helpful assistant that answers questions based on provided documents. If the answer is not in the documents, say so."
                    },
                    {
                        "role": "user",
                        "content": f"Documents:\n{context}\n\nQuestion: {example.question}"
                    },
                    {
                        "role": "assistant",
                        "content": answer
                    }
                ]
            })
        
        return formatted
    
    def format_for_dpo(
        self,
        dataset: RAFTDataset
    ) -> list[dict]:
        """Format for Direct Preference Optimization."""
        
        formatted = []
        
        for example in dataset.examples:
            docs = [example.oracle_document] + example.distractor_documents[:self.config.num_distractors]
            
            if self.config.shuffle_documents:
                random.shuffle(docs)
            
            context = self._format_documents(docs)
            
            # Chosen: correct answer with reasoning
            chosen = example.answer
            if self.config.include_cot and example.chain_of_thought:
                chosen = f"\n{example.chain_of_thought}\n\n\n{example.answer}"
            
            # Rejected: wrong answer (from distractor or hallucinated)
            rejected = self._generate_wrong_answer(example)
            
            formatted.append({
                "prompt": f"Documents:\n{context}\n\nQuestion: {example.question}",
                "chosen": chosen,
                "rejected": rejected
            })
        
        return formatted
    
    def _format_documents(self, docs: list[Document]) -> str:
        """Format documents as context string."""
        
        parts = []
        for i, doc in enumerate(docs):
            parts.append(f"[Document {i+1}]\n{doc.content}")
        
        return "\n\n".join(parts)
    
    def _generate_wrong_answer(self, example: RAFTExample) -> str:
        """Generate a plausible but wrong answer."""
        
        # Use information from distractor documents
        if example.distractor_documents:
            distractor = random.choice(example.distractor_documents)
            # Extract a sentence from distractor as wrong answer
            sentences = distractor.content.split('.')
            if sentences:
                return sentences[0].strip() + "."
        
        return "I'm not sure about the answer."

class AlpacaFormatter:
    """Format for Alpaca-style training."""
    
    def format(self, dataset: RAFTDataset) -> list[dict]:
        """Format as Alpaca dataset."""
        
        formatted = []
        
        for example in dataset.examples:
            docs = [example.oracle_document] + example.distractor_documents[:4]
            random.shuffle(docs)
            
            context = "\n\n".join([
                f"Document {i+1}: {doc.content}"
                for i, doc in enumerate(docs)
            ])
            
            formatted.append({
                "instruction": "Answer the question based on the provided documents.",
                "input": f"Documents:\n{context}\n\nQuestion: {example.question}",
                "output": example.answer
            })
        
        return formatted

class ShareGPTFormatter:
    """Format for ShareGPT-style training."""
    
    def format(self, dataset: RAFTDataset) -> list[dict]:
        """Format as ShareGPT dataset."""
        
        formatted = []
        
        for example in dataset.examples:
            docs = [example.oracle_document] + example.distractor_documents[:4]
            random.shuffle(docs)
            
            context = "\n\n".join([
                f"[Doc {i+1}] {doc.content}"
                for i, doc in enumerate(docs)
            ])
            
            formatted.append({
                "conversations": [
                    {
                        "from": "human",
                        "value": f"Based on these documents:\n\n{context}\n\nAnswer: {example.question}"
                    },
                    {
                        "from": "gpt",
                        "value": example.answer
                    }
                ]
            })
        
        return formatted

class DatasetSplitter:
    """Split dataset into train/val/test."""
    
    def split(
        self,
        dataset: RAFTDataset,
        train_ratio: float = 0.8,
        val_ratio: float = 0.1,
        test_ratio: float = 0.1
    ) -> tuple[RAFTDataset, RAFTDataset, RAFTDataset]:
        """Split dataset."""
        
        examples = list(dataset.examples)
        random.shuffle(examples)
        
        n = len(examples)
        train_end = int(n * train_ratio)
        val_end = train_end + int(n * val_ratio)
        
        train = RAFTDataset(examples=examples[:train_end])
        val = RAFTDataset(examples=examples[train_end:val_end])
        test = RAFTDataset(examples=examples[val_end:])
        
        return train, val, test

Fine-Tuning Implementation

from dataclasses import dataclass
from typing import Any, Optional, List
import json

@dataclass
class FineTuningConfig:
    """Configuration for RAFT fine-tuning."""
    
    base_model: str = "meta-llama/Llama-2-7b-hf"
    output_dir: str = "./raft-model"
    num_epochs: int = 3
    batch_size: int = 4
    gradient_accumulation_steps: int = 4
    learning_rate: float = 2e-5
    warmup_ratio: float = 0.1
    max_seq_length: int = 4096
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05

class RAFTTrainer:
    """Train model using RAFT methodology."""
    
    def __init__(self, config: FineTuningConfig):
        self.config = config
    
    def prepare_model(self):
        """Prepare model for training."""
        
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
        import torch
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.base_model)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load model with quantization
        self.model = AutoModelForCausalLM.from_pretrained(
            self.config.base_model,
            torch_dtype=torch.float16,
            device_map="auto",
            load_in_4bit=True
        )
        
        # Prepare for k-bit training
        self.model = prepare_model_for_kbit_training(self.model)
        
        # Add LoRA adapters
        lora_config = LoraConfig(
            r=self.config.lora_r,
            lora_alpha=self.config.lora_alpha,
            lora_dropout=self.config.lora_dropout,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
            bias="none",
            task_type="CAUSAL_LM"
        )
        
        self.model = get_peft_model(self.model, lora_config)
        
        return self.model
    
    def prepare_dataset(self, training_data: list[dict]):
        """Prepare dataset for training."""
        
        from datasets import Dataset
        
        def tokenize(example):
            # Format as chat
            messages = example["messages"]
            text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            
            # Tokenize
            tokenized = self.tokenizer(
                text,
                truncation=True,
                max_length=self.config.max_seq_length,
                padding="max_length"
            )
            
            tokenized["labels"] = tokenized["input_ids"].copy()
            
            return tokenized
        
        dataset = Dataset.from_list(training_data)
        tokenized_dataset = dataset.map(tokenize, remove_columns=dataset.column_names)
        
        return tokenized_dataset
    
    def train(self, train_dataset, eval_dataset=None):
        """Run training."""
        
        from transformers import TrainingArguments, Trainer
        
        training_args = TrainingArguments(
            output_dir=self.config.output_dir,
            num_train_epochs=self.config.num_epochs,
            per_device_train_batch_size=self.config.batch_size,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            learning_rate=self.config.learning_rate,
            warmup_ratio=self.config.warmup_ratio,
            logging_steps=10,
            save_steps=100,
            evaluation_strategy="steps" if eval_dataset else "no",
            eval_steps=100 if eval_dataset else None,
            fp16=True,
            optim="paged_adamw_8bit",
            report_to="wandb"
        )
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=self.tokenizer
        )
        
        trainer.train()
        
        # Save model
        trainer.save_model(self.config.output_dir)
        self.tokenizer.save_pretrained(self.config.output_dir)
        
        return trainer

class SFTTrainerWrapper:
    """Wrapper for TRL's SFTTrainer."""
    
    def __init__(self, config: FineTuningConfig):
        self.config = config
    
    def train(self, training_data: list[dict]):
        """Train using SFTTrainer."""
        
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from trl import SFTTrainer, SFTConfig
        from peft import LoraConfig
        from datasets import Dataset
        import torch
        
        # Load model and tokenizer
        tokenizer = AutoTokenizer.from_pretrained(self.config.base_model)
        tokenizer.pad_token = tokenizer.eos_token
        
        model = AutoModelForCausalLM.from_pretrained(
            self.config.base_model,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        
        # Prepare dataset
        def format_example(example):
            messages = example["messages"]
            return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}
        
        dataset = Dataset.from_list(training_data)
        dataset = dataset.map(format_example)
        
        # LoRA config
        peft_config = LoraConfig(
            r=self.config.lora_r,
            lora_alpha=self.config.lora_alpha,
            lora_dropout=self.config.lora_dropout,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
            bias="none",
            task_type="CAUSAL_LM"
        )
        
        # Training config
        sft_config = SFTConfig(
            output_dir=self.config.output_dir,
            num_train_epochs=self.config.num_epochs,
            per_device_train_batch_size=self.config.batch_size,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            learning_rate=self.config.learning_rate,
            max_seq_length=self.config.max_seq_length,
            packing=True
        )
        
        # Train
        trainer = SFTTrainer(
            model=model,
            args=sft_config,
            train_dataset=dataset,
            tokenizer=tokenizer,
            peft_config=peft_config,
            dataset_text_field="text"
        )
        
        trainer.train()
        trainer.save_model(self.config.output_dir)
        
        return trainer

class UnslothTrainer:
    """Fast training using Unsloth."""
    
    def __init__(self, config: FineTuningConfig):
        self.config = config
    
    def train(self, training_data: list[dict]):
        """Train using Unsloth for 2x faster training."""
        
        from unsloth import FastLanguageModel
        from trl import SFTTrainer, SFTConfig
        from datasets import Dataset
        
        # Load model with Unsloth
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=self.config.base_model,
            max_seq_length=self.config.max_seq_length,
            dtype=None,  # Auto-detect
            load_in_4bit=True
        )
        
        # Add LoRA
        model = FastLanguageModel.get_peft_model(
            model,
            r=self.config.lora_r,
            lora_alpha=self.config.lora_alpha,
            lora_dropout=self.config.lora_dropout,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                          "gate_proj", "up_proj", "down_proj"]
        )
        
        # Prepare dataset
        def format_example(example):
            messages = example["messages"]
            return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}
        
        dataset = Dataset.from_list(training_data)
        dataset = dataset.map(format_example)
        
        # Train
        trainer = SFTTrainer(
            model=model,
            tokenizer=tokenizer,
            train_dataset=dataset,
            dataset_text_field="text",
            max_seq_length=self.config.max_seq_length,
            args=SFTConfig(
                output_dir=self.config.output_dir,
                per_device_train_batch_size=self.config.batch_size,
                gradient_accumulation_steps=self.config.gradient_accumulation_steps,
                num_train_epochs=self.config.num_epochs,
                learning_rate=self.config.learning_rate,
                fp16=True,
                logging_steps=10,
                optim="adamw_8bit"
            )
        )
        
        trainer.train()
        
        # Save
        model.save_pretrained(self.config.output_dir)
        tokenizer.save_pretrained(self.config.output_dir)
        
        return trainer

Evaluation and Inference

from dataclasses import dataclass
from typing import Any, Optional, List
import numpy as np

@dataclass
class EvaluationResult:
    """Result of RAFT evaluation."""
    
    accuracy: float
    f1_score: float
    exact_match: float
    retrieval_accuracy: float
    avg_response_length: float

class RAFTEvaluator:
    """Evaluate RAFT model performance."""
    
    def __init__(self, model: Any, tokenizer: Any):
        self.model = model
        self.tokenizer = tokenizer
    
    def evaluate(
        self,
        test_dataset: RAFTDataset,
        num_distractors: int = 4
    ) -> EvaluationResult:
        """Evaluate model on test set."""
        
        predictions = []
        references = []
        retrieval_correct = 0
        response_lengths = []
        
        for example in test_dataset.examples:
            # Prepare input
            docs = [example.oracle_document] + example.distractor_documents[:num_distractors]
            import random
            random.shuffle(docs)
            
            context = "\n\n".join([
                f"[Document {i+1}]\n{doc.content}"
                for i, doc in enumerate(docs)
            ])
            
            prompt = f"Documents:\n{context}\n\nQuestion: {example.question}\n\nAnswer:"
            
            # Generate response
            response = self._generate(prompt)
            
            predictions.append(response)
            references.append(example.answer)
            response_lengths.append(len(response.split()))
            
            # Check if model used oracle document
            oracle_idx = docs.index(example.oracle_document)
            if f"Document {oracle_idx + 1}" in response or self._answer_matches(response, example.answer):
                retrieval_correct += 1
        
        # Calculate metrics
        accuracy = self._calculate_accuracy(predictions, references)
        f1 = self._calculate_f1(predictions, references)
        exact_match = self._calculate_exact_match(predictions, references)
        
        return EvaluationResult(
            accuracy=accuracy,
            f1_score=f1,
            exact_match=exact_match,
            retrieval_accuracy=retrieval_correct / len(test_dataset.examples),
            avg_response_length=np.mean(response_lengths)
        )
    
    def _generate(self, prompt: str, max_tokens: int = 512) -> str:
        """Generate response from model."""
        
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.1,
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id
        )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract answer part
        if "Answer:" in response:
            response = response.split("Answer:")[-1].strip()
        
        return response
    
    def _answer_matches(self, prediction: str, reference: str) -> bool:
        """Check if prediction matches reference."""
        
        pred_lower = prediction.lower().strip()
        ref_lower = reference.lower().strip()
        
        return ref_lower in pred_lower or pred_lower in ref_lower
    
    def _calculate_accuracy(
        self,
        predictions: list[str],
        references: list[str]
    ) -> float:
        """Calculate accuracy."""
        
        correct = sum(
            1 for p, r in zip(predictions, references)
            if self._answer_matches(p, r)
        )
        
        return correct / len(predictions)
    
    def _calculate_f1(
        self,
        predictions: list[str],
        references: list[str]
    ) -> float:
        """Calculate token-level F1 score."""
        
        f1_scores = []
        
        for pred, ref in zip(predictions, references):
            pred_tokens = set(pred.lower().split())
            ref_tokens = set(ref.lower().split())
            
            if not pred_tokens or not ref_tokens:
                f1_scores.append(0.0)
                continue
            
            common = pred_tokens & ref_tokens
            precision = len(common) / len(pred_tokens)
            recall = len(common) / len(ref_tokens)
            
            if precision + recall == 0:
                f1_scores.append(0.0)
            else:
                f1_scores.append(2 * precision * recall / (precision + recall))
        
        return np.mean(f1_scores)
    
    def _calculate_exact_match(
        self,
        predictions: list[str],
        references: list[str]
    ) -> float:
        """Calculate exact match score."""
        
        matches = sum(
            1 for p, r in zip(predictions, references)
            if p.lower().strip() == r.lower().strip()
        )
        
        return matches / len(predictions)

class RAFTInference:
    """Run inference with RAFT model."""
    
    def __init__(self, model_path: str):
        self.model_path = model_path
        self._load_model()
    
    def _load_model(self):
        """Load fine-tuned model."""
        
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from peft import PeftModel
        import torch
        
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            torch_dtype=torch.float16,
            device_map="auto"
        )
    
    def answer(
        self,
        question: str,
        documents: list[str],
        max_tokens: int = 512
    ) -> str:
        """Answer question using retrieved documents."""
        
        # Format context
        context = "\n\n".join([
            f"[Document {i+1}]\n{doc}"
            for i, doc in enumerate(documents)
        ])
        
        # Build prompt
        messages = [
            {
                "role": "system",
                "content": "You are a helpful assistant that answers questions based on provided documents."
            },
            {
                "role": "user",
                "content": f"Documents:\n{context}\n\nQuestion: {question}"
            }
        ]
        
        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Generate
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.1,
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id
        )
        
        response = self.tokenizer.decode(
            outputs[0][inputs.input_ids.shape[1]:],
            skip_special_tokens=True
        )
        
        return response.strip()

Production RAFT Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List

app = FastAPI()

class GenerateDataRequest(BaseModel):
    documents: List[str]
    qa_per_doc: int = 5
    include_cot: bool = True

class AnswerRequest(BaseModel):
    question: str
    documents: List[str]
    max_tokens: int = 512

class TrainRequest(BaseModel):
    training_data_path: str
    output_dir: str
    num_epochs: int = 3
    batch_size: int = 4

# Initialize components (would be loaded from config in production)
raft_inference = None

@app.post("/v1/generate-data")
async def generate_training_data(request: GenerateDataRequest) -> dict:
    """Generate RAFT training data from documents."""
    
    # This would use RAFTDataGenerator in production
    return {
        "status": "Data generation started",
        "num_documents": len(request.documents),
        "estimated_examples": len(request.documents) * request.qa_per_doc
    }

@app.post("/v1/answer")
async def answer_question(request: AnswerRequest) -> dict:
    """Answer question using RAFT model."""
    
    global raft_inference
    
    if raft_inference is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    answer = raft_inference.answer(
        request.question,
        request.documents,
        request.max_tokens
    )
    
    return {
        "question": request.question,
        "answer": answer,
        "num_documents": len(request.documents)
    }

@app.post("/v1/train")
async def start_training(request: TrainRequest) -> dict:
    """Start RAFT fine-tuning job."""
    
    # This would start async training job in production
    return {
        "status": "Training job started",
        "output_dir": request.output_dir,
        "num_epochs": request.num_epochs
    }

@app.post("/v1/load-model")
async def load_model(model_path: str) -> dict:
    """Load a fine-tuned RAFT model."""
    
    global raft_inference
    
    try:
        raft_inference = RAFTInference(model_path)
        return {"status": "Model loaded", "path": model_path}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health():
    return {
        "status": "healthy",
        "model_loaded": raft_inference is not None
    }

References

Conclusion

RAFT bridges the gap between generic LLMs and domain-specific RAG systems by teaching models how to effectively use retrieved context. The key insight is that including both relevant (oracle) and irrelevant (distractor) documents during training teaches the model to identify and extract the right information from noisy retrieval results. Start with high-quality QA pair generation—the quality of your training data directly determines model performance. Include chain-of-thought reasoning to help the model learn not just what to answer, but how to find the answer in the documents. Use hard negative mining to select distractors that are semantically similar to the query but don’t contain the answer—this teaches the model to distinguish between relevant and merely similar content. For training, LoRA provides an efficient way to fine-tune large models without full parameter updates. Evaluate on held-out data with the same oracle/distractor setup used in training. The result is a model that significantly outperforms both vanilla RAG and standard fine-tuning on domain-specific question answering tasks.