Shrinking the Impossible (Part 3): Embedding a Custom-Defined LLaVA-OneVision Model with MLC

Image credit: Generated using FLUX

Table of Contents

Introduction

I attended this year’s Google I/O Connect in Berlin, and seeing Google’s latest work in Edge AI was inspiring. Since then, I’ve been on a personal mission to deploy my own edge model. Given how rapidly the open source community is catching up with the tech giants in edge foundation models, I’ve decided test the limits of what I can realistically achieve as a solo developer. Can I embed a multi-modal foundation model onto my iPhone?

In my first blog post, I’ve introduced the Machine Learning Compiler Project as the open source solution that will make this possible. In my last blog post, I’ve given you the technical background knowledge needed to successfully deploy a multi-modal foundation model.

Now, we’re going to put this theory into practice in a hands-on activity. I’m going to show you how to embed a custom foundation model onto an edge device — and all the things that can go wrong in the process.

Selected Architecture

Remember that my end vision is to have a multi-modal foundation model deployed on my iPhone. Ideally, I want it to be multi-lingual so I can get some help the next time that I’m lost in a foreign country.

At a high-level, my implementation will consist of two different models: (a) the smallest Large Language and Vision Assistant (LLaVA) model that I can find, and (b) an instruction-tuned version of Gemma 2B.

I’ve chosen these models, since (a) the MLC Engine natively supports them, (b) they’re lightweight enough to be compiled on my laptop, and (c) their respective families have earned a reputation as high-performers.

Each model will work on a specific task. Whenever the end user shares an image with our multi-modal chatbot, the mini LLaVA model will generate a text description for Gemma. Gemma will then use this information to respond to the user’s original prompt. If no image is shared, our user will only interact with Gemma 2B. Of course, this model specialization will be abstracted away. Our user won’t be aware that they’re actually conversing with two different foundation models rather than just one.

My multi-modal chat pipeline consists of two models: (1) the eloquent and multilingual Gemma2B model, and (2) a mini LLaVA-OneVision model. LLaVa will act as a translator for Gemma, generating a text description for any user inputted images.

Why Not Only Use the LLaVA Model?

It may sound a bit strange that we’re deploying two models instead of one — especially since all LLaVA understand text and images. However, we need to make a distinction between the smallest LLaVa-OneVision model, which is less than 1B parameters in size, and its much larger 7B+ counterparts.

According to the 0.5B LLaVA-OneVision model card, the LLaVa model is just under 900M parameters in total. Given its 80,000+ model downloads, we can assume it does a decent job at its primary task: annotating images (source).

Larger LLaVA models are perfectly capable of holding a coherent conversation and to understanding whatever images we show them. You can find a few interesting examples of large LLaVA models answering questions about images here.

The LLaVA-v1.6 series contains models ranging from 7B to 34B parameters. At these sizes, a LLaVA model can easily identify Leonardo da Vinci's world famous masterpiece from a screenshot and give us a quick art history lesson. Unfortunately, we can't expect the same from the smallest LLaVA models. (image credit).

However, we need to adjust our expectations for a 900M parameter model. There’s only so much that a small model can do — and the LLaVA-OneVision Qwen2 O.5B model has been optimized for image annotation rather than instruction tasks. Meaning, we can probably converse directly with it, but may not be thrilled with the quality of its responses.

If we wanted to deploy a similar application in a cloud environment, then it would probably just make more sense to quantize a 7B LLaVA model and accept a slightly larger monthly bill. However, since we’re working on edge, we have tight resource constraints and need to make do with what we have.

The Overall Process

For each model, we need to:

  1. Quantize its weights; and
  2. Apply hardware-specific optimizations to it.

How simple this process really is comes down to the degree of built-in MLC Engine support. MLC has already pre-quantized an instruction-tuned version of Gemma 2B for us. Hence, deploying Gemma as a stand alone model in an iOS application is relatively straightforward task. Just follow MLC’s Quick Start Documentation for how to package Gemma 2B and their iOS Swift SDK instructions.

On the other hand, applying the same process to our LLaVA model is a bit trickier. If we go through the list of pre-quantized models offered by MLC,(as of November 2024) there is no pre-quantized LLaVA model available — much less our desired mini LLaVA model.

Since this process for deployed a pre-quantized model from HuggingFace so well-documented, I’m not going to focus on it. Rather, I’ll show you how I ported a new model into the MLC framework.

Manual Porting LLaVA-OneVision to MLC

At this point, you may be a bit confused. I’ve stated the MLC Engine supports LLaVA models, but I’m also talking about manually porting the 0.5B LLaVA-OneVision model into the MLC Framework.

What’s going on? At an initial glance, it looks like MLC fully supports the LLaVa model family, but that’s only partially true. At the time of writing (November 2024), MLC only natively supports specific LLaVA implementations, more specifically those that use (a) use Llama or Mistral model as its text decoder, and (b) use a CLIP-trained vision encoder. Unfortunately, all LLaVA variants that meet these requirements are at 7B+ parameters. Meaning, they’re too large for my laptop — muchless my smartphone — to handle.

As a result, I need to manually port the much smaller llava-onevision-qwen2-0.5b-ov-hf model definition into the MLC framework. Practically speaking, this means defining this model in the style used in the mlc-llm-cpu Python library, which is only available through MLC AI’s own Python code repository. Afterwards, we need to then recompile this library locally so that it contains our new model definition.

Once that’s done, I can quantize and hardware-optimize my selected LLaVA model just like any other MLC-supported model. As its full name suggests, the 0.5B LLaVA-OneVision model uses the Qwen2 0.5B LLM as its text decoder. Fortunately for us, MLC already supports Qwen2 0.5B implementation. Meaning, this change is quite easy. We just copy and paste MLC’s definition of LLaVA, rename the files, and change a few key-value pairs:

The original MLC LLaVA model definition looks like this:

from ..llama.llama_model import LlamaConfig, LlamaForCausalLM
from ..mistral.mistral_model import MistralConfig, MistralForCasualLM


CONFIG_MAP = {
    "LlamaForCausalLM": LlamaConfig,
    "MistralForCausalLM": MistralConfig,
}

ARCHITECTURE_MAP = {
    "LlamaForCausalLM": LlamaForCausalLM,
    "MistralForCausalLM": MistralForCasualLM,
}

Now, we just rewrite our LLaVA-OneVision model definition to support the LLaVA-OneVision Qwen2 0.5B model’s architecture definitions:

# Defined by MLC
from ..qwen2.qwen2_model import QWen2Config, QWen2LMHeadModel


CONFIG_MAP = {
    "QWen2LMHeadModel": QWen2Config,
    "Qwen2ForCausalLM": QWen2Config
}
ARCHITECTURE_MAP = {
    "QWen2LMHeadModel": QWen2LMHeadModel,
    "Qwen2ForCausalLM": QWen2LMHeadModel,
}

So far, so good. Let’s take a closer look at the the 0.5B LLaVA-OneVision config.json file. We see that this LLaVA model’s vision encoder was trained using SigLIP — rather than MLC’s natively supported CLIP training framework.

As seen in the LLaVA-OnVision Qwen2 0.5B model's configuration file, the vision encoder was trained using SigLIP. As of now (November 2024), MLC's only natively supports CLIP-trained vision encoder (source).

Meaning, there’s no MLC definition for us to just import. We need to write some custom Python model definition using MLC wrappers. Luckily, we already dived into the details of the SigLIP vision encoder in the previous blog post. So, we’re ready to get started.

I’ve included the final SigLIP vision encoder definition on in this blogpost series’s corresponding GitHub repository. For the sake of brevity, I’m just going to focus on the technical differences between these two vision encoders and how these changes translate into code.

MLC’s CLIP vs. Our Selected LLaVA SigLIP Implementation

Once again, we’re going to reference our LLaVA model’s trusty config.json file to get some clues about where to start. In particular, we see the key-value pair: "vision_feature_layer": -1, whereas the original LLaVA config uses "vision_feature_layer": -2. This is a hint about how a vision encoder aggregates its sequence of embeddings into a single vector.

Embedding Aggregation

Both LLaVA models use a Vision Transformer, which outputs a sequence of embeddings. For the training of CLIP and SigLIP, we need a single vector. In this specific Huggingface implementation, CLIP and SigLIP do this aggregation in different ways.

CLIP uses a class embedding, similar to common adaptations of the Vision Transformer for classification. Here, we add another token to each image sequence, which is fed through the Transformer along with the image tokens. In the end, we pick the aggregated image feature vector as the output of this classification token. By having the class token part of the self-attention layers, it gives the transformer the ability to aggregate information of the image in this token across its layers. In the implementation, we see this class embedding token being added to the image token sequence:

class CLIPVisionEmbeddings(Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        # Class Embedding Token, added to each image token sequence.
        self.class_embedding = nn.Parameter((self.embed_dim,))
        ...

    def forward(self, pixel_values: Tensor) -> Tensor:
        patch_embeds = self.patch_embedding(pixel_values)
        ...
        class_embeds = broadcast_to(
            self.class_embedding, shape=(batch_size, 1, self.embed_dim)
        )
        # Add class embedding token to image token sequence.
        embeddings = concat([class_embeds, patch_embeds], dim=1)
        ...
        return embeddings

In contrast, SigLIP pools its output features. Hence, the sequence of image tokens remains unchanged in the input and is fed through the layers. We add on top a pooling layer over the sequence dimension to aggregate all the feature information. This can either be done by a simple averaging, or, in case our specific case, with a multi-head attention pooling. This is similar to our self-attention layers, but just with a fixed query.

Note. We’ve removed the code parts that are common to both models. This simplifies the code and allows us to highlight key differences.

In our SigLIP implementation, shown below, we see as a difference to CLIP that there is no class embedding.

class SiglipVisionEmbeddings(Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        ...

    def forward(self, pixel_values: Tensor, interpolate_pos_encoding: bool = False) -> Tensor:
        patch_embeds = self.patch_embedding(pixel_values)
        embeddings = patch_embeds
        ...
        return embeddings

Output Features Explained

In the previous blog post, we discuss how we need a whole sequence of image embeddings for LLaVA rather than a single image feature vector. This provides more detailed information to the decoder.

While we do not make use of the output heads of CLIP and SigLIP respectively, it does affect which layer we select our features from. This is what the config argument vision_feature_layer ($-1$ for SigLIP and $-2$ for CLIP).

In other words, we choose the last layer in SigLIP, since the model was trained with image embeddings that are literally the weighted average of all image sequence tokens. Thus, the training process ensures that all these image embeddings have valuable information in them.

Attention pooling represents a weighted average pooling, where the weights are determined by the normalized dot product between a static query and the keys per token. While this example shows the averaging for text tokens, it has the same idea with image patch tokens in our vision encoder (image credit).
class SiglipVisionModel(Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.vision_model = SiglipVisionTransformer(config)

    def forward(self, pixel_values: Tensor) -> Tensor:
        # Siglip Qwen2 is using last layer, CLIP pre-last due to different
        # Transformer encoder.
        return self.vision_model(pixel_values)[-1]

For CLIP, choosing the last layer is suboptimal, because of the usage of a class token. In the CLIP loss, only the output features of the class token are used. Thus, the output features of all other tokens, namely our image embeddings, were not used. In other words, these features did not receive any gradients during training, and we cannot be sure that the model has stored useful information in them. Most likely, the model has specialized the last layer specifically for the class embedding token, making the outputs of the other tokens (possibly) meaningless.

Hence, we need to go back one more layer (i.e, the pre-last layer), because these tokens did receive gradients during the training by their dependency on the class token in the last self-attention layer. This ensures that these embeddings have strong features and makes them usable in our LLaVA model implementation.

class CLIPVisionModel(Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.vision_model = CLIPVisionTransformer(config)

    def forward(self, pixel_values: Tensor) -> Tensor:
        return self.vision_model(pixel_values)[-2]

GELU Approximation

The Gaussian Error Linear Unit (GELU) is a very popular activation function for Transformers, and is used in both of our CLIP and SigLIP implementations. However, there are some specific details in about how we can implement the GELU activation.

The “true”, precise implementation of GELU involves the cumulative distribution function (CDF) of the Gaussian distribution $\Phi(x)$:

$\text{gelu}(x)=x\cdot\Phi(x)$

This CDF is, however, expensive to implement and in particular for edge-devices, where every inference optimization counts, it’s sub-optimal. Instead, people commonly use GeLU approximations that are good enough. The standard approximation, often used during training and in frameworks like JAX and PyTorch, is the tanh-approximation:

$\text{gelu}(x)\approx 0.5x\left(1+\tanh\left[\sqrt{\frac{2}{\pi}}(x+0.044715\cdot x^3)\right]\right)$

This is also being used in the Huggingface implementation for SigLIP, and we port it over as shown below:

class QuickGELU(Module):  # SigLIP implementation
    def forward(self, input_tensor: Tensor) -> Tensor:
        c = (2 / math.pi)**0.5
        return 0.5 * input_tensor * (
          1 + tanh(c * (input_tensor + 0.044715 * input_tensor**3))
        )

In the MLC implementation of CLIP, another approximation is used. This one involved the sigmoid function, and is simply:

$\text{gelu}(x)\approx x\cdot \sigma(1.702x)$

CLIP

class QuickGELU(Module):
    def forward(self, input_tensor: Tensor) -> Tensor:
        return input_tensor * sigmoid(input_tensor * 1.702)

While the Sigmoid GeLU approximation is simpler and even cheaper to calculate, it is also less accurate. Thus, we have to make a tradeoff between efficiency and accuracy. Since we our selected SigLIP vision encoder was trained using the tanh-approximation, we’ll stick with it. Differences between the GeLU function implementation during training and inference time can cause a slight but noticeable drop in performance.

A visualization of the different implementations of the GELU activation function. The original GeLU function is in red, the tanh approximation is in purple, and the sigmoid approximation is in green. As you can see, all of them are quite similar. For the sigmoid activation, there is a noticeable difference for the negative range -1.5 and -4. However, for the tanh approximation, we need to zoom in closely to see the difference, showcasing why the tanh approximation is often used as a close match.

Embedding Normalization

Another minor design choice is whether we normalize the embedding features before feeding them into the main Transformer model. Both models use the pre-activation Transformer implementation, which applies a LayerNorm before each Self-Attention and MLP layer. However, we can also apply a LayerNorm on the embeddings themselves, or leave the model to learn the scaling of the residual part.

In the CLIP implementation, we find a LayerNorm applied to the embeddings before feeding it through the layers.

class CLIPVisionTransformer(Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        embed_dim = config.hidden_size
        self.embeddings = CLIPVisionEmbeddings(config)
        self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.encoder = CLIPEncoder(config)
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

    def forward(self, pixel_values: Tensor) -> Tensor:
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layernorm(hidden_states)
        encoder_outputs = self.encoder(inputs_embeds=hidden_states)
        return encoder_outputs

In contrast, in the SigLIP implementation, this normalization is missing. However, it is not expected to cause a major performance difference.


class SiglipVisionTransformer(Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        embed_dim = config.hidden_size
        self.embeddings = SiglipVisionEmbeddings(config)
        self.encoder = SiglipEncoder(config)
        # Defined but not actually used.
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

    def forward(self, pixel_values: Tensor) -> Tensor:
        hidden_states = self.embeddings(pixel_values)
        encoder_outputs = self.encoder(inputs_embeds=hidden_states)
        return encoder_outputs

Summary

Overall, our SigLIP implementation has some small, but crucial differences compared to MLC’S CLIP implementation. The table below summarizes the differences that we needed to account for.

MLC LLaVA Model - CLIP0.5B LLaVA-OneVision Qwen2 0.5B Model - SigLIP
Output Feature AggregationClass TokenAttention Pooling
Feature LayerPre-Last LayerLast Layer
Embedding NormalizationYesNo
GELU ImplementationSigmoid ApproximationTanh Approximation

Now that we’ve implemented SigLIP in the MLC framework, it is straight forward to integrate the LLaVA-OneVision models into the MLC framework. We can now proceed with quantizing and optimizing our model for deployment on an edge device.

Packaging Our Custom Model

Once we’ve extended the mlc-llm-cpu library to include our custom model definition, then we can proceed as normally.

First, we want to quantize it our newly ported LLaVA model. To do so, we run the following commands in our MLC LLM project’s repo directory:

# Create directory
mkdir -p dist/models && cd dist/models

# Clone Original LLaVA model's weights from HuggingFace
git lfs install
git clone https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf

# Apply the `q4f16_1` quantization method to our model
mlc_llm convert_weight ./dist/models/llava-onevision-qwen2-0.5b-ov-hf/ \
       --quantization q4f16_1 \
       -o dist/llava-onevision-qwen2-0.5b-ov-hf
       --model-type llava_onevision

Fortunately, the command ran successfully.

After defining my target LLaVA-OneVision model as a new MLC `model-type`, I was able to easily quantize this model.

Next, we need to apply some hardware-specific optimizations to our model.

# Generate a MLC config file for our quantized model
mkdir dist/libs
mlc_llm gen_config ./dist/models/llava-onevision-qwen2-0.5b-ov-hf/ \
  --quantization q4f16_1 \
  --conv-template redpajama_chat \
  --context-window-size 768 \
  -o dist/llava-onevision-qwen2-0.5b-ov-hf

# Optimize LLaVA OneVision for an iOS app implementation
mlc_llm compile ./dist/llava-onevision-qwen2-0.5b-ov-hf/mlc-chat-config.json \
  --device iphone \
  -o dist/libs/llava-onevision-qwen2-0.5b-ov-hf-iphone.tar

Finally, we need to package the quantized and optimized model for my iOS App. To make my life easier, I’ve uploaded the pre-quantized LLaVA-OneVision Qwen2 0.5B model to my personal HuggingFace account.

My mlc-package-config.json file located in the ios/MLCChat/ subdirectory (following MLC LLM’s project structure) now looks like this:

{
    "device": "iphone",
    "model_list": [
        {
            "model": "HF://bella-nich/llava-onevision-qwen2-0.5b-ov-q4f16_1-mlc",
            "model_id": "llava-onevision-qwen2-0.5b-ov-hf",
            "estimated_vram_bytes": 1000000000,
            "overrides": {
                "prefill_chunk_size": 128
            }
        },
    ]
}

So, I’m going to package my quantized and optimized LLaVA-OneVision model into a proper iOS app.

# Words
cd /path/to/MLCChat  # e.g., "ios/MLCChat"
export MLC_LLM_SOURCE_DIR=/path/to/mlc-llm  # e.g., "../.."
mlc_llm package
I was able to successfully package my newly ported LLaVA-OneVision Qwen2 0.5B model without any errors. Meaning, there's a chance that everything was quantized and compiled correctly.

End Result

While the lack of compilation errors is promising, the only way to validate this entire process is talk to the embedded LLaVA model. So, I built and deployed my packaged iOS application.

Once I do, I’m greeted with a few strange but mostly understandable sentences.

My embedded LLaVA-OneVision Qwen2 0.5B model is able to form mostly coherent sentences. However, its sense of humor doesn't seem fully developed.

In other words, LLaVA isn’t pure spouting gibberish. I take this as a sign that the quantization and model compilation processes have gone well. Of course, as the longer the conservation goes on, the less coherent LLaVA becomes. Pretty soon LLaVA is giving me random responses strung together.

After my dissatisfaction with the embedded model's sense of humor, I decided to see if LLaVA can tell me a fun fact. To my surprise, I'm greeted with a sudden and odd request.

The chat snippets of <im_start>, <im_end>, and assistant give us clues about how this LLaVA model was tuned for image annotation tasks. More specifically, this tells us about the structure of LLaVA-OneVision Qwen2 0.5 B’s chat template. A chat template restructures our current conversation, which is a list of string, into a single, tokenizable format that the model expects. Here, we can see that the chat template assistant role is prompting LLaVA to continue but in ways that are completely disconnected from my original text-only prompts.

The good news is that chat template assistant role should be more useful when we provide LLaVA image inputs. However, this LLaVA model’s current performance (in a task that it wasn’t fine-tuned for) highlights the differences between >1B and a 7B+ parameter models. Larger foundation models are simply more versatile than smaller ones.

Conclusion

In this blogpost, I’ve shown you everything that it takes to embed an unsupported model onto an edge device using the Machine Learning Engine Compiler framework. As you can see, the devil is in the details. You need to be well-versed in the different permutations of a given neural architecture’s implementation and able to spot those differences in the wild.

Of course, the only thing that’s more important than getting this process right is to choose the correct model to embed. The smaller we go in size, the more portable our foundation model becomes, but that portability comes at the cost of performance.

What’s next?

While embedding a custom model is an exciting milestone, we’re not done yet. As you can see, the embedded LLaVA model works but it doesn’t make for a scintillating conversation partner. Hence, we need to get Gemma 2B and the LLaVA-OneVision Qwen2 0.5B model to work together — which is exactly what I do in my next (and final) blogpost in this series. Stay tuned!

Bella Nicholson
Bella Nicholson
Machine Learning Engineer

Related