We will now dive deeper in the technique we used in the post called SLERP, which stands for Spherical LinEar inteRPolation and look at some code on how it is implemented.
In our last post, we learnt some basics about model merging, and how it works in practice. We will now dive deeper in the technique we used in the post called SLERP, which stands for Spherical LinEar inteRPolation and look at some code on how it is implemented.
SLERP is one of the first approaches used for model merging, and has been used for some of the top performers across all benchmarks. It’s understanding stems from merging “Task Vectors” (which we covered in last post) of models finetuned on a specific task like Coding, Translation, Math etc. and creating a composite model capable of performing both the task. First let’s review LERP, which for the basis for it.
At its core, interpolation is a technique used to estimate or calculate intermediate values between two known points. In linear interpolation, you’re essentially drawing a straight line between two points in space and finding a point along that line, depending on an interpolation factor. Given 2 tensors and the factor t, this is how you will perform Linear Interpolation (LERP) –
def lerp(t: float, v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor:
"""
Performs linear interpolation between two vectors or tensors.
Args:
t (float): The interpolation factor, where 0.0 returns `v0`, and 1.0 returns `v1`. Values between 0.0 and 1.0 will return a weighted average of `v0` and `v1`.
v0 (torch.Tensor): The starting vector/tensor, from which the interpolation begins.
v1 (torch.Tensor): The ending vector/tensor, to which the interpolation goes.
Returns:
torch.Tensor: The interpolated vector/tensor between `v0` and `v1` based on the interpolation factor `t`.
Explanation:
- The formula (1 - t) * v0 + t * v1 calculates a point along the straight line between `v0` and `v1`.
- When `t` is 0, this formula gives `v0`, and when `t` is 1, it gives `v1`.
- For values of `t` between 0 and 1, the formula returns a weighted average of `v0` and `v1`,
where the weighting changes linearly from `v0` to `v1` as `t` moves from 0 to 1.
"""
return (1 - t) * v0 + t * v1
However, linear interpolation can be limiting as redundant parameters, which change only slightly during fine-tuning and have minimal impact on performance, can obscure more influential parameter values when merged if they have larger magnitude. This will lower overall model performance. SLERP solves this by merging the vectors on spherical plane, which sounds like a handful but we’ll break it down next!
Things get a bit more interesting with SLERP. Instead of dealing with straight lines, we’re working on the surface of a sphere. Imagine you have two points on a globe, and you want to find the shortest path between them. This path isn’t a straight line through the earth but a curve along the surface.
SLERP allows for a smooth and consistent transition between any two points on a sphere, which is essential in graphics, animation, and even merging different AI model parameters while maintaining their unique characteristics. Here’s a gif which shows the difference between Spherical and Linear Interpolation.
SLERP helps us find points along this curved path in a very smooth way. Here’s a simplified breakdown:
The implementation of SLERP for two vectors can be found here.
Below is the same code but with added comments for additional context.
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
'''
Spherical Linear Interpolation (SLERP) between two vectors.
SLERP provides a smooth transition between two vectors over the surface of a sphere,
which is particularly useful in animations, rotations, and merging model parameters.
Args:
t (float/np.ndarray): Interpolation factor between 0.0 and 1.0, where 0 returns v0, and 1 returns v1.
v0 (np.ndarray): The starting vector, from which the interpolation starts.
v1 (np.ndarray): The destination vector, to which the interpolation goes.
DOT_THRESHOLD (float): A threshold for dot product to handle nearly parallel vectors where SLERP simplifies to LERP.
Returns:
np.ndarray: The interpolated vector between v0 and v1 at the position specified by t.
'''
# Make copies of the vectors to avoid altering the originals during normalization.
v0_copy = np.copy(v0)
v1_copy = np.copy(v1)
# Normalize the vectors to ensure they lie on the unit sphere. This is crucial for the geometric calculations in SLERP.
v0 = v0 / np.linalg.norm(v0)
v1 = v1 / np.linalg.norm(v1)
# Calculate the dot product between normalized vectors to find the cosine of the angle between them. This helps determine how 'aligned' the vectors are.
dot = np.sum(v0 * v1)
# If vectors are nearly parallel (dot product close to 1), interpolation simplifies to linear interpolation (LERP).
if np.abs(dot) > DOT_THRESHOLD:
# Directly interpolate between the original vectors without spherical adjustment.
return lerp(t, v0_copy, v1_copy)
# Calculate the angle between the vectors using the arc cosine of the dot product.
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0) # The sine of the angle is used in the SLERP formula.
# Compute the actual angle for the interpolation factor 't'.
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
# Calculate the scale factors for each vector, based on the interpolation factor and the sine values.
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
# Compute the final interpolated vector as a weighted sum of the original vectors.
v2 = s0 * v0_copy + s1 * v1_copy
return v2
SLERP is one of the most commonly used techniques for model merging owing to its effectiveness and simplicity to explain. By seamlessly blending the strengths of specialized models, it opens up new ways for achieving improved performance across diverse tasks. For all AI practitioners, understanding and applying SLERP can be a valuable addition to your toolkit.