Exploring MLX Swift: Working with Generate Parameters for Language Models
In all the previous posts, I used the default parameters for generating the text outputs.
try MLXLMCommon.generate(input: input, parameters: .init(), context: context) { tokens in _ }
In this blog post, we will explore the ins and outs of the parameters like temperature, top-p sampling, and repetition penalties in explaining how they affect text generation, and their practical usage.
What are Tokens?
For a refresher, language models (as of now) do not process words directly but smaller units called 'tokens'. Tokens can be parts of words, whole words, or even punctuation marks. The model predicts the next token in a sequence.
Why Customize Generation Parameters?
By default, language models aim to predict the next most likely token based on a given input (prompt). Sometimes, the default behavior may not always align with what you want. Creative tasks like storytelling may need more randomness. Or you may require more deterministic outputs.
Or balance it like Thanos would have wanted - perfectly.
With GenerateParameters
, you can play around the parameters for your specific use case.
The Anatomy of Generate Parameters
Here is a look at the primary fields of GenerateParameters
and what they do:
public struct GenerateParameters: Sendable {
public var prefillStepSize: Int = 512 // Tokens processed per step during prompt processing
public var temperature: Float = 0.6 // Adds randomness; lower values are more deterministic
public var topP: Float = 1.0 // Limits token selection to top-probability tokens
public var repetitionPenalty: Float? // Penalizes repeated tokens
public var repetitionContextSize: Int = 20 // Context size for repetition penalty
}
prefillStepSize
: Tokens processed per step during prompt processing.temperature
: This parameter controls how 'surprised' the model is allowed to be. A temperature of 0 will always pick the most probable token, while a higher temperature makes less probable tokens more likely, resulting in more creative and diverse (but potentially less coherent) text. Think of it like choosing an outfit like how I have to do in winters. A low temperature would be like picking the most obvious one from the wardrobe. A high temperature is like grabbing things you have not worn in a while, which might be interesting but may not match perfectly. Commonly used temperatures range from 0.0 to 1.0, with values above 1 often resulting in nonsensical output.topP
: Top-p sampling aims to strike a balance between diversity and coherence. Imagine the model has a list of possible next tokens, each with a probability. Top-p sampling sums these probabilities, from highest to lowest, until the total reaches a certain point (the topP value). The model then selects the next token from this group. A value of 1.0 would consider all tokens, while a lower value restricts the choices to the most likely. For example, a topP of 0.9 will select the smallest set of tokens that together sum up to 90% of the total probability. top-p values range from 0.0 to 1.0, where lower values can make the model more deterministic.repetitionPenalty
: Penalizes tokens that were already generated, reducing repetitive outputs. Language models, by nature, might repeat themselves because the most probable next word is often one they just generated, leading to boring and redundant text. The repetition penalty is used to counteract this. It works by reducing the probability of tokens that have already been seen in the current context. A value above 1.0 is used to discourage repetition. Higher values can make the text less repetitive but might also lead to unusual or less natural-sounding text. Each time a word is used it gets a bit 'boring' for the model. The repetition penalty makes the model favor using fresh words over 'boring' ones. This penalty might need to be tuned for good results. All about experimentation!repetitionContextSize
: This parameter determines how far back the model looks to check for repeated tokens. For example, if you set the context size to 20, the model will penalize tokens that were generated within the previous 20 tokens. Increasing this value may reduce repetition over long text, but too high value can affect generation.
Choosing the Right Sampler
The GenerateParameters
structure in MLXLMCommon
package already provides methods to select the appropriate LogitSampler
based on your configuration.
func sampler() -> LogitSampler {
if temperature == 0 {
return ArgMaxSampler()
} else if topP > 0 && topP < 1 {
return TopPSampler(temperature: temperature, topP: topP)
} else {
return CategoricalSampler(temperature: temperature)
}
}
ArgMaxSampler
: Picks the most likely token for deterministic outputs.TopPSampler
: Adds diversity by sampling from a restricted set of high-probability tokens.CategoricalSampler
: Uses temperature to add randomness, favoring diverse outputs.
Building a Custom Generate Workflow
Let us walk through creating a custom GenerateParameters
:
let parameters = GenerateParameters(
temperature: 0.8,
topP: 0.9,
repetitionPenalty: 1.2,
repetitionContextSize: 50
)
Here are some examples for common scenarios:
- Creative writing: For things like stories or poems, where you want a good amount of drama like a Bollywood movie
let parameters = GenerateParameters(temperature: 1.0, topP: 0.95)
- Technical documentation: When accuracy and consistency are more important, like in research papers, and you want the model to be more predictable and reliable
let parameters = GenerateParameters(temperature: 0.2, repetitionPenalty: 1.5)
- Balanced chat responses: For a conversational experience where you want a good balance between being creative but not drifting away
let parameters = GenerateParameters(temperature: 0.7, topP: 0.8)
Parameters Showdown
Here is a macOS app code example to compare text generation side-by-side using the Llama 3.2 3B model. The app uses two different sets of GenerateParameters
: default values and parameters optimized for creative writing. The user can input a prompt and see the generated outputs:
import SwiftUI
import MLXLLM
import MLXLMCommon
struct MLXComparisonView: View {
@State private var prompt: String = "Describe the future of AI in education."
@State private var defaultOutput: String = "Loading..."
@State private var creativeOutput: String = "Loading..."
@State private var isLoading: Bool = false
var body: some View {
VStack {
Text("Llama 3.2 3B Model Parameter Comparison")
.font(.title)
.padding()
TextField("Enter your prompt here", text: $prompt)
.textFieldStyle(RoundedBorderTextFieldStyle())
.padding()
HStack {
VStack(alignment: .leading) {
Text("Default Parameters")
.font(.headline)
ScrollView {
Text(defaultOutput)
.frame(maxWidth: .infinity, alignment: .leading)
.padding()
}
.border(Color.gray)
}
.padding()
VStack(alignment: .leading) {
Text("Creative Writing Parameters")
.font(.headline)
ScrollView {
Text(creativeOutput)
.frame(maxWidth: .infinity, alignment: .leading)
.padding()
}
.border(Color.gray)
}
.padding()
}
Button(action: generateText) {
if isLoading {
ProgressView()
.progressViewStyle(CircularProgressViewStyle())
} else {
Text("Generate")
.font(.headline)
}
}
.padding()
}
.padding()
.frame(width: 1200, height: 800)
}
private func generateText() {
isLoading = true
Task {
do {
// Load model container
let modelContainer = try await LLMModelFactory.shared.loadContainer(
configuration: ModelRegistry.llama3_2_3B_4bit
)
// Run generation for both parameter sets
let defaultResult = try await generate(prompt: prompt, parameters: .default, modelContainer: modelContainer)
let creativeResult = try await generate(prompt: prompt, parameters: .creativeWriting, modelContainer: modelContainer)
// Update the outputs
await MainActor.run {
self.defaultOutput = defaultResult
self.creativeOutput = creativeResult
self.isLoading = false
}
} catch {
debugPrint("Error generating text: \(error)")
await MainActor.run { self.isLoading = false }
}
}
}
private func generate(prompt: String, parameters: GenerateParameters, modelContainer: ModelContainer) async throws -> String {
let result = try await modelContainer.perform { context in
let input = try await context.processor.prepare(input: .init(prompt: prompt))
let maxTokens = 1000
return try MLXLMCommon.generate(input: input, parameters: parameters, context: context) { token in
if tokens.count >= maxTokens {
return .stop
} else {
return .more
}
}
}
return result.output
}
}
extension GenerateParameters {
static var `default`: GenerateParameters {
GenerateParameters(temperature: 0.6, topP: 1.0) // Default values
}
static var creativeWriting: GenerateParameters {
GenerateParameters(temperature: 1.0, topP: 0.9) // For creative writing
}
}
And here is the output!
Discovery requires experimentation, and it is valid here too. Start with small adjustments, and experiment to understand how each parameter affects the model.
Moving Forward
With GenerateParameters
, MLX Swift allows you precise control over text generation. In the future posts, I will explore working with system prompts.
Here is a blog post for further reading:
If you have any questions or want to share your experiments, reach out on Twitter @rudrankriyam.
Happy MLXing!