Introduction: Long contexts contain valuable information, but they also contain noise, redundancy, and irrelevant details that consume tokens and dilute model attention. Context distillation extracts the essential information from lengthy documents, conversations, or retrieved passages, producing compact representations that preserve what matters while discarding what doesn’t. This technique is crucial for RAG systems processing multiple documents, conversation summarization, and any application where context length limits constrain what you can include. This guide covers practical distillation strategies: extractive methods that select key sentences, abstractive approaches that generate summaries, hierarchical distillation for very long documents, and techniques for validating that distilled context preserves the information needed for downstream tasks. Whether you’re building document QA, multi-turn assistants, or knowledge synthesis systems, context distillation lets you fit more relevant information into limited context windows.

Extractive Distillation
from dataclasses import dataclass, field
from typing import Any, Optional
from abc import ABC, abstractmethod
@dataclass
class Sentence:
"""A sentence with metadata."""
text: str
index: int
score: float = 0.0
source_doc: str = ""
@dataclass
class DistilledContext:
"""Distilled context result."""
content: str
original_length: int
distilled_length: int
compression_ratio: float
sentences_kept: int
sentences_total: int
class Distiller(ABC):
"""Abstract distiller interface."""
@abstractmethod
async def distill(
self,
text: str,
target_length: int = None,
query: str = None
) -> DistilledContext:
"""Distill text to shorter form."""
pass
class TextRankDistiller(Distiller):
"""TextRank-based extractive distillation."""
def __init__(self, similarity_threshold: float = 0.1):
self.similarity_threshold = similarity_threshold
async def distill(
self,
text: str,
target_length: int = None,
query: str = None
) -> DistilledContext:
"""Extract key sentences using TextRank."""
# Split into sentences
sentences = self._split_sentences(text)
if len(sentences) <= 3:
return DistilledContext(
content=text,
original_length=len(text),
distilled_length=len(text),
compression_ratio=1.0,
sentences_kept=len(sentences),
sentences_total=len(sentences)
)
# Build similarity matrix
similarity_matrix = self._build_similarity_matrix(sentences)
# Run PageRank
scores = self._pagerank(similarity_matrix)
# Score sentences
for i, sent in enumerate(sentences):
sent.score = scores[i]
# If query provided, boost query-relevant sentences
if query:
self._boost_query_relevant(sentences, query)
# Select top sentences
target = target_length or len(text) // 3
selected = self._select_sentences(sentences, target)
# Reconstruct in original order
selected.sort(key=lambda s: s.index)
distilled = " ".join(s.text for s in selected)
return DistilledContext(
content=distilled,
original_length=len(text),
distilled_length=len(distilled),
compression_ratio=len(distilled) / len(text),
sentences_kept=len(selected),
sentences_total=len(sentences)
)
def _split_sentences(self, text: str) -> list[Sentence]:
"""Split text into sentences."""
import re
# Simple sentence splitting
pattern = r'(?<=[.!?])\s+'
parts = re.split(pattern, text)
sentences = []
for i, part in enumerate(parts):
part = part.strip()
if part:
sentences.append(Sentence(text=part, index=i))
return sentences
def _build_similarity_matrix(
self,
sentences: list[Sentence]
) -> list[list[float]]:
"""Build sentence similarity matrix."""
n = len(sentences)
matrix = [[0.0] * n for _ in range(n)]
for i in range(n):
for j in range(i + 1, n):
sim = self._sentence_similarity(
sentences[i].text,
sentences[j].text
)
matrix[i][j] = sim
matrix[j][i] = sim
return matrix
def _sentence_similarity(self, s1: str, s2: str) -> float:
"""Compute sentence similarity using word overlap."""
words1 = set(s1.lower().split())
words2 = set(s2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1 & words2
union = words1 | words2
return len(intersection) / len(union)
def _pagerank(
self,
matrix: list[list[float]],
damping: float = 0.85,
iterations: int = 100
) -> list[float]:
"""Run PageRank algorithm."""
n = len(matrix)
scores = [1.0 / n] * n
for _ in range(iterations):
new_scores = []
for i in range(n):
score = (1 - damping) / n
for j in range(n):
if matrix[j][i] > self.similarity_threshold:
# Sum of outgoing edges from j
out_sum = sum(matrix[j])
if out_sum > 0:
score += damping * scores[j] * matrix[j][i] / out_sum
new_scores.append(score)
scores = new_scores
return scores
def _boost_query_relevant(
self,
sentences: list[Sentence],
query: str
):
"""Boost scores of query-relevant sentences."""
query_words = set(query.lower().split())
for sent in sentences:
sent_words = set(sent.text.lower().split())
overlap = len(query_words & sent_words)
if overlap > 0:
sent.score *= (1 + overlap * 0.2)
def _select_sentences(
self,
sentences: list[Sentence],
target_length: int
) -> list[Sentence]:
"""Select sentences up to target length."""
# Sort by score
sorted_sents = sorted(sentences, key=lambda s: s.score, reverse=True)
selected = []
current_length = 0
for sent in sorted_sents:
if current_length + len(sent.text) <= target_length:
selected.append(sent)
current_length += len(sent.text) + 1 # +1 for space
return selected
class EmbeddingDistiller(Distiller):
"""Embedding-based extractive distillation."""
def __init__(self, embedding_model: Any):
self.embedding_model = embedding_model
async def distill(
self,
text: str,
target_length: int = None,
query: str = None
) -> DistilledContext:
"""Extract sentences using embedding similarity."""
sentences = self._split_sentences(text)
if len(sentences) <= 3:
return DistilledContext(
content=text,
original_length=len(text),
distilled_length=len(text),
compression_ratio=1.0,
sentences_kept=len(sentences),
sentences_total=len(sentences)
)
# Embed all sentences
sent_texts = [s.text for s in sentences]
embeddings = await self.embedding_model.embed(sent_texts)
if query:
# Score by query similarity
query_emb = (await self.embedding_model.embed([query]))[0]
for i, sent in enumerate(sentences):
sent.score = self._cosine_similarity(query_emb, embeddings[i])
else:
# Score by centrality (similarity to centroid)
centroid = self._compute_centroid(embeddings)
for i, sent in enumerate(sentences):
sent.score = self._cosine_similarity(centroid, embeddings[i])
# Select diverse sentences using MMR
target = target_length or len(text) // 3
selected = self._mmr_select(sentences, embeddings, target)
# Reconstruct
selected.sort(key=lambda s: s.index)
distilled = " ".join(s.text for s in selected)
return DistilledContext(
content=distilled,
original_length=len(text),
distilled_length=len(distilled),
compression_ratio=len(distilled) / len(text),
sentences_kept=len(selected),
sentences_total=len(sentences)
)
def _split_sentences(self, text: str) -> list[Sentence]:
"""Split into sentences."""
import re
parts = re.split(r'(?<=[.!?])\s+', text)
return [Sentence(text=p.strip(), index=i) for i, p in enumerate(parts) if p.strip()]
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
"""Compute cosine similarity."""
import math
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
return dot / (norm_a * norm_b) if norm_a and norm_b else 0.0
def _compute_centroid(self, embeddings: list[list[float]]) -> list[float]:
"""Compute centroid of embeddings."""
n = len(embeddings)
dim = len(embeddings[0])
centroid = [0.0] * dim
for emb in embeddings:
for i, v in enumerate(emb):
centroid[i] += v / n
return centroid
def _mmr_select(
self,
sentences: list[Sentence],
embeddings: list[list[float]],
target_length: int,
lambda_param: float = 0.7
) -> list[Sentence]:
"""Select using Maximal Marginal Relevance."""
selected = []
selected_indices = set()
current_length = 0
while current_length < target_length and len(selected_indices) < len(sentences):
best_score = -float('inf')
best_idx = -1
for i, sent in enumerate(sentences):
if i in selected_indices:
continue
# Relevance score
relevance = sent.score
# Diversity penalty
max_sim = 0.0
for j in selected_indices:
sim = self._cosine_similarity(embeddings[i], embeddings[j])
max_sim = max(max_sim, sim)
# MMR score
mmr = lambda_param * relevance - (1 - lambda_param) * max_sim
if mmr > best_score:
best_score = mmr
best_idx = i
if best_idx >= 0:
selected.append(sentences[best_idx])
selected_indices.add(best_idx)
current_length += len(sentences[best_idx].text) + 1
else:
break
return selected
Abstractive Distillation
from dataclasses import dataclass
from typing import Any, Optional
class LLMDistiller(Distiller):
"""LLM-based abstractive distillation."""
def __init__(
self,
client: Any,
model: str = "gpt-4o-mini"
):
self.client = client
self.model = model
async def distill(
self,
text: str,
target_length: int = None,
query: str = None
) -> DistilledContext:
"""Distill using LLM summarization."""
target = target_length or len(text) // 3
target_words = target // 5 # Rough word estimate
if query:
prompt = f"""Summarize this text, focusing on information relevant to the query.
Keep the summary under {target_words} words.
Preserve key facts, numbers, and specific details.
Query: {query}
Text:
{text}
Summary:"""
else:
prompt = f"""Summarize this text in under {target_words} words.
Preserve key facts, numbers, names, and specific details.
Focus on the most important information.
Text:
{text}
Summary:"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.3
)
distilled = response.choices[0].message.content.strip()
return DistilledContext(
content=distilled,
original_length=len(text),
distilled_length=len(distilled),
compression_ratio=len(distilled) / len(text),
sentences_kept=0, # N/A for abstractive
sentences_total=0
)
class QueryFocusedDistiller(Distiller):
"""Query-focused distillation."""
def __init__(
self,
client: Any,
model: str = "gpt-4o-mini"
):
self.client = client
self.model = model
async def distill(
self,
text: str,
target_length: int = None,
query: str = None
) -> DistilledContext:
"""Extract information relevant to query."""
if not query:
# Fall back to general summarization
return await self._general_distill(text, target_length)
target = target_length or len(text) // 3
prompt = f"""Extract all information from this text that is relevant to answering the question.
Include specific facts, numbers, dates, and details.
Omit information that is not relevant to the question.
Keep the extraction under {target // 5} words.
Question: {query}
Text:
{text}
Relevant information:"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0
)
distilled = response.choices[0].message.content.strip()
return DistilledContext(
content=distilled,
original_length=len(text),
distilled_length=len(distilled),
compression_ratio=len(distilled) / len(text),
sentences_kept=0,
sentences_total=0
)
async def _general_distill(
self,
text: str,
target_length: int
) -> DistilledContext:
"""General distillation without query."""
target = target_length or len(text) // 3
prompt = f"""Summarize this text, preserving key information.
Keep under {target // 5} words.
Text:
{text}
Summary:"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.3
)
distilled = response.choices[0].message.content.strip()
return DistilledContext(
content=distilled,
original_length=len(text),
distilled_length=len(distilled),
compression_ratio=len(distilled) / len(text),
sentences_kept=0,
sentences_total=0
)
class IncrementalDistiller(Distiller):
"""Incrementally distill as new content arrives."""
def __init__(
self,
client: Any,
model: str = "gpt-4o-mini",
max_context: int = 2000
):
self.client = client
self.model = model
self.max_context = max_context
self._current_summary = ""
async def add_content(self, new_content: str) -> str:
"""Add new content and update summary."""
if not self._current_summary:
# First content
result = await self.distill(new_content, self.max_context)
self._current_summary = result.content
return self._current_summary
# Combine and re-distill
combined = f"{self._current_summary}\n\nNew information:\n{new_content}"
prompt = f"""Update this summary with the new information.
Keep the updated summary under {self.max_context // 5} words.
Merge overlapping information and remove redundancy.
Current summary:
{self._current_summary}
New information:
{new_content}
Updated summary:"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.3
)
self._current_summary = response.choices[0].message.content.strip()
return self._current_summary
async def distill(
self,
text: str,
target_length: int = None,
query: str = None
) -> DistilledContext:
"""Distill text."""
target = target_length or self.max_context
prompt = f"""Summarize this text in under {target // 5} words.
Text:
{text}
Summary:"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.3
)
distilled = response.choices[0].message.content.strip()
return DistilledContext(
content=distilled,
original_length=len(text),
distilled_length=len(distilled),
compression_ratio=len(distilled) / len(text),
sentences_kept=0,
sentences_total=0
)
def reset(self):
"""Reset accumulated summary."""
self._current_summary = ""
def get_summary(self) -> str:
"""Get current summary."""
return self._current_summary
Hierarchical Distillation
from dataclasses import dataclass
from typing import Any, Optional
@dataclass
class DocumentChunk:
"""A chunk of a document."""
content: str
index: int
summary: str = ""
importance: float = 0.0
class HierarchicalDistiller(Distiller):
"""Hierarchical distillation for very long documents."""
def __init__(
self,
client: Any,
model: str = "gpt-4o-mini",
chunk_size: int = 2000
):
self.client = client
self.model = model
self.chunk_size = chunk_size
async def distill(
self,
text: str,
target_length: int = None,
query: str = None
) -> DistilledContext:
"""Hierarchically distill long document."""
# Split into chunks
chunks = self._split_into_chunks(text)
if len(chunks) <= 1:
# Short enough for direct distillation
return await self._direct_distill(text, target_length, query)
# Level 1: Summarize each chunk
chunk_summaries = await self._summarize_chunks(chunks, query)
# Combine summaries
combined = "\n\n".join([
f"Section {i+1}: {s}"
for i, s in enumerate(chunk_summaries)
])
# Level 2: Final distillation
target = target_length or len(text) // 5
if query:
prompt = f"""Create a final summary focusing on information relevant to the query.
Keep under {target // 5} words.
Query: {query}
Section summaries:
{combined}
Final summary:"""
else:
prompt = f"""Create a coherent final summary from these section summaries.
Keep under {target // 5} words.
Section summaries:
{combined}
Final summary:"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.3
)
distilled = response.choices[0].message.content.strip()
return DistilledContext(
content=distilled,
original_length=len(text),
distilled_length=len(distilled),
compression_ratio=len(distilled) / len(text),
sentences_kept=0,
sentences_total=len(chunks)
)
def _split_into_chunks(self, text: str) -> list[DocumentChunk]:
"""Split text into chunks."""
chunks = []
words = text.split()
chunk_words = self.chunk_size // 5 # Rough word count
for i in range(0, len(words), chunk_words):
chunk_text = " ".join(words[i:i + chunk_words])
chunks.append(DocumentChunk(
content=chunk_text,
index=len(chunks)
))
return chunks
async def _summarize_chunks(
self,
chunks: list[DocumentChunk],
query: str = None
) -> list[str]:
"""Summarize each chunk."""
import asyncio
tasks = [
self._summarize_chunk(chunk, query)
for chunk in chunks
]
return await asyncio.gather(*tasks)
async def _summarize_chunk(
self,
chunk: DocumentChunk,
query: str = None
) -> str:
"""Summarize single chunk."""
if query:
prompt = f"""Summarize this section, focusing on information relevant to: {query}
Keep under 100 words.
Section:
{chunk.content}
Summary:"""
else:
prompt = f"""Summarize this section in under 100 words.
Section:
{chunk.content}
Summary:"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.3
)
return response.choices[0].message.content.strip()
async def _direct_distill(
self,
text: str,
target_length: int,
query: str
) -> DistilledContext:
"""Direct distillation for short text."""
target = target_length or len(text) // 3
prompt = f"""Summarize in under {target // 5} words.
Text:
{text}
Summary:"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.3
)
distilled = response.choices[0].message.content.strip()
return DistilledContext(
content=distilled,
original_length=len(text),
distilled_length=len(distilled),
compression_ratio=len(distilled) / len(text),
sentences_kept=0,
sentences_total=0
)
class MapReduceDistiller(Distiller):
"""Map-reduce style distillation."""
def __init__(
self,
client: Any,
model: str = "gpt-4o-mini",
chunk_size: int = 3000
):
self.client = client
self.model = model
self.chunk_size = chunk_size
async def distill(
self,
text: str,
target_length: int = None,
query: str = None
) -> DistilledContext:
"""Map-reduce distillation."""
# Split into chunks
chunks = self._split_text(text)
# Map: Extract key information from each chunk
mapped = await self._map_phase(chunks, query)
# Reduce: Combine and synthesize
reduced = await self._reduce_phase(mapped, target_length, query)
return DistilledContext(
content=reduced,
original_length=len(text),
distilled_length=len(reduced),
compression_ratio=len(reduced) / len(text),
sentences_kept=0,
sentences_total=len(chunks)
)
def _split_text(self, text: str) -> list[str]:
"""Split text into chunks."""
words = text.split()
chunk_words = self.chunk_size // 5
chunks = []
for i in range(0, len(words), chunk_words):
chunks.append(" ".join(words[i:i + chunk_words]))
return chunks
async def _map_phase(
self,
chunks: list[str],
query: str = None
) -> list[str]:
"""Map phase: extract from each chunk."""
import asyncio
tasks = [
self._extract_from_chunk(chunk, query)
for chunk in chunks
]
return await asyncio.gather(*tasks)
async def _extract_from_chunk(
self,
chunk: str,
query: str = None
) -> str:
"""Extract key information from chunk."""
if query:
prompt = f"""Extract key facts relevant to: {query}
Be concise, use bullet points.
Text:
{chunk}
Key facts:"""
else:
prompt = f"""Extract the key facts from this text.
Be concise, use bullet points.
Text:
{chunk}
Key facts:"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0
)
return response.choices[0].message.content.strip()
async def _reduce_phase(
self,
mapped: list[str],
target_length: int,
query: str = None
) -> str:
"""Reduce phase: combine extractions."""
combined = "\n\n".join(mapped)
target = target_length or len(combined) // 2
if query:
prompt = f"""Synthesize these extractions into a coherent summary.
Focus on answering: {query}
Keep under {target // 5} words.
Remove redundancy and merge related points.
Extractions:
{combined}
Synthesis:"""
else:
prompt = f"""Synthesize these extractions into a coherent summary.
Keep under {target // 5} words.
Remove redundancy and merge related points.
Extractions:
{combined}
Synthesis:"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.3
)
return response.choices[0].message.content.strip()
Information Preservation Validation
from dataclasses import dataclass
from typing import Any, Optional
@dataclass
class ValidationResult:
"""Result of distillation validation."""
is_valid: bool
information_preserved: float # 0-1
factual_accuracy: float # 0-1
missing_facts: list[str]
hallucinations: list[str]
class DistillationValidator:
"""Validate distillation quality."""
def __init__(
self,
client: Any,
model: str = "gpt-4o"
):
self.client = client
self.model = model
async def validate(
self,
original: str,
distilled: str,
query: str = None
) -> ValidationResult:
"""Validate distillation preserves information."""
# Check for missing facts
missing = await self._find_missing_facts(original, distilled, query)
# Check for hallucinations
hallucinations = await self._find_hallucinations(original, distilled)
# Compute scores
info_preserved = 1.0 - (len(missing) * 0.1) # Rough estimate
info_preserved = max(0.0, min(1.0, info_preserved))
factual_accuracy = 1.0 - (len(hallucinations) * 0.2)
factual_accuracy = max(0.0, min(1.0, factual_accuracy))
return ValidationResult(
is_valid=info_preserved > 0.7 and factual_accuracy > 0.9,
information_preserved=info_preserved,
factual_accuracy=factual_accuracy,
missing_facts=missing,
hallucinations=hallucinations
)
async def _find_missing_facts(
self,
original: str,
distilled: str,
query: str = None
) -> list[str]:
"""Find important facts missing from distillation."""
if query:
prompt = f"""Compare the original text and summary.
List any important facts from the original that are missing from the summary,
especially facts relevant to: {query}
Original:
{original[:3000]}
Summary:
{distilled}
Missing facts (one per line, or "None" if nothing important is missing):"""
else:
prompt = f"""Compare the original text and summary.
List any important facts from the original that are missing from the summary.
Original:
{original[:3000]}
Summary:
{distilled}
Missing facts (one per line, or "None" if nothing important is missing):"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0
)
content = response.choices[0].message.content.strip()
if content.lower() == "none":
return []
return [line.strip() for line in content.split("\n") if line.strip()]
async def _find_hallucinations(
self,
original: str,
distilled: str
) -> list[str]:
"""Find facts in distillation not in original."""
prompt = f"""Compare the summary to the original text.
List any claims in the summary that are NOT supported by the original text.
Original:
{original[:3000]}
Summary:
{distilled}
Unsupported claims (one per line, or "None" if all claims are supported):"""
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0
)
content = response.choices[0].message.content.strip()
if content.lower() == "none":
return []
return [line.strip() for line in content.split("\n") if line.strip()]
class AdaptiveDistiller(Distiller):
"""Adaptively distill with validation."""
def __init__(
self,
distiller: Distiller,
validator: DistillationValidator,
max_attempts: int = 3
):
self.distiller = distiller
self.validator = validator
self.max_attempts = max_attempts
async def distill(
self,
text: str,
target_length: int = None,
query: str = None
) -> DistilledContext:
"""Distill with validation and retry."""
best_result = None
best_score = -1
for attempt in range(self.max_attempts):
# Distill
result = await self.distiller.distill(text, target_length, query)
# Validate
validation = await self.validator.validate(
text, result.content, query
)
# Score
score = (
validation.information_preserved * 0.6 +
validation.factual_accuracy * 0.4
)
if score > best_score:
best_score = score
best_result = result
# Good enough?
if validation.is_valid:
return result
# Adjust target length for next attempt
if validation.information_preserved < 0.7:
target_length = int((target_length or len(text) // 3) * 1.2)
return best_result
Production Distillation Service
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
from enum import Enum
app = FastAPI()
# Initialize distillers
textrank_distiller = TextRankDistiller()
# llm_distiller = LLMDistiller(client)
# hierarchical_distiller = HierarchicalDistiller(client)
class DistillerType(str, Enum):
TEXTRANK = "textrank"
EMBEDDING = "embedding"
LLM = "llm"
HIERARCHICAL = "hierarchical"
MAP_REDUCE = "map_reduce"
class DistillRequest(BaseModel):
text: str
target_length: Optional[int] = None
query: Optional[str] = None
distiller: DistillerType = DistillerType.TEXTRANK
validate: bool = False
class DistillResponse(BaseModel):
distilled: str
original_length: int
distilled_length: int
compression_ratio: float
validation: Optional[dict] = None
class BatchDistillRequest(BaseModel):
texts: list[str]
target_length: Optional[int] = None
query: Optional[str] = None
distiller: DistillerType = DistillerType.TEXTRANK
class BatchDistillResponse(BaseModel):
results: list[DistillResponse]
@app.post("/v1/distill")
async def distill_text(request: DistillRequest) -> DistillResponse:
"""Distill text to shorter form."""
# Select distiller
if request.distiller == DistillerType.TEXTRANK:
distiller = textrank_distiller
else:
# Would select appropriate distiller
distiller = textrank_distiller
# Distill
result = await distiller.distill(
request.text,
request.target_length,
request.query
)
response = DistillResponse(
distilled=result.content,
original_length=result.original_length,
distilled_length=result.distilled_length,
compression_ratio=result.compression_ratio
)
# Validate if requested
if request.validate:
# Would run validation
response.validation = {
"is_valid": True,
"information_preserved": 0.9,
"factual_accuracy": 1.0
}
return response
@app.post("/v1/distill/batch")
async def batch_distill(request: BatchDistillRequest) -> BatchDistillResponse:
"""Distill multiple texts."""
import asyncio
distiller = textrank_distiller
tasks = [
distiller.distill(text, request.target_length, request.query)
for text in request.texts
]
results = await asyncio.gather(*tasks)
return BatchDistillResponse(
results=[
DistillResponse(
distilled=r.content,
original_length=r.original_length,
distilled_length=r.distilled_length,
compression_ratio=r.compression_ratio
)
for r in results
]
)
class IncrementalSession:
"""Session for incremental distillation."""
def __init__(self, max_context: int = 2000):
self.summary = ""
self.max_context = max_context
self.chunks_processed = 0
sessions: dict[str, IncrementalSession] = {}
class IncrementalAddRequest(BaseModel):
session_id: str
content: str
class IncrementalResponse(BaseModel):
session_id: str
current_summary: str
chunks_processed: int
@app.post("/v1/distill/incremental/create")
async def create_incremental_session(max_context: int = 2000):
"""Create incremental distillation session."""
import uuid
session_id = str(uuid.uuid4())
sessions[session_id] = IncrementalSession(max_context)
return {"session_id": session_id}
@app.post("/v1/distill/incremental/add")
async def add_to_session(request: IncrementalAddRequest) -> IncrementalResponse:
"""Add content to incremental session."""
if request.session_id not in sessions:
raise HTTPException(404, "Session not found")
session = sessions[request.session_id]
# Would use incremental distiller
# For now, simple append and truncate
session.summary = (session.summary + " " + request.content)[:session.max_context]
session.chunks_processed += 1
return IncrementalResponse(
session_id=request.session_id,
current_summary=session.summary,
chunks_processed=session.chunks_processed
)
@app.get("/v1/distill/incremental/{session_id}")
async def get_session(session_id: str) -> IncrementalResponse:
"""Get incremental session state."""
if session_id not in sessions:
raise HTTPException(404, "Session not found")
session = sessions[session_id]
return IncrementalResponse(
session_id=session_id,
current_summary=session.summary,
chunks_processed=session.chunks_processed
)
@app.delete("/v1/distill/incremental/{session_id}")
async def delete_session(session_id: str):
"""Delete incremental session."""
if session_id in sessions:
del sessions[session_id]
return {"deleted": True}
@app.get("/health")
async def health():
return {"status": "healthy"}
References
- TextRank: https://web.eecs.umich.edu/~mihalcea/papers/mihalcea.emnlp04.pdf
- LongLLMLingua: https://arxiv.org/abs/2310.06839
- RECOMP: https://arxiv.org/abs/2310.04408
- LlamaIndex Summarization: https://docs.llamaindex.ai/en/stable/examples/response_synthesizers/tree_summarize/
Conclusion
Context distillation extracts signal from noise in long documents. Start with extractive methods like TextRank for fast, interpretable distillation that preserves original phrasing. Use embedding-based selection with MMR for diversity when you need coverage across topics. Move to abstractive LLM distillation when you need coherent summaries that synthesize information. For very long documents, use hierarchical approaches: summarize chunks first, then synthesize chunk summaries into a final distillation. Query-focused distillation is crucial for RAG—extract only what's relevant to the question rather than generic summaries. Always validate distillation quality: check that key facts are preserved and no hallucinations were introduced. Build incremental distillers for streaming contexts like conversations that grow over time. The key insight is that distillation is lossy compression—you're trading information for token efficiency. The goal is to lose the right information: redundancy, tangents, and details irrelevant to your task, while preserving the facts and relationships that matter for downstream generation.
