Utilize large model inference containers powered by DJL Serving & Nvidia TensorRT


The Generative AI space continues to expand at an unprecedented rate, with the introduction of more Large Language Model (LLM) families by the day. Within each family there are also varying sizes of each model, for instances there’s Llama7b, Llama13B, and Llama70B. Regardless of the model that you select, the same challenges arise for hosting these LLMs for inference.
The size of these LLMs continue to be the most pressing challenge, as it’s very difficult/impossible to fit many of these LLMs onto a single GPU. There are a few different approaches to tackling this problem, such as model partitioning. With model partitioning you can use techniques such as Pipeline or Tensor Parallelism to essentially shard the model across multiple GPUs. Outside of model partitioning, other popular approaches include Quantization of model weights to a lower precision to reduce the model size itself at a cost of accuracy.
While the model size is a large challenge in itself, there is also the challenge of retaining the previous inference/attention in Text Generation for Decoder based models. Text Generation with these models is not as simple as traditional ML model inference where there is just an input and output. To calculate the next word in text generation, the state/attention of the previously generated tokens must be retained to provide a logical output. The storing of these values is known as the KV Cache. The KV Cache enables you to cache the previously generated tensors in GPU memory to generate the next tokens. The KV Cache also takes up a large amount of memory that needs to be accounted for during model inference.
To address these challenges many different model serving technologies have been introduced such as vLLM, DeepSpeed, FasterTransformers, and more. In this article we specifically look at Nvidia TensorRT-LLM and how we can integrate the serving stack with DJL Serving on Amazon SageMaker Real-Time Inference to efficiently host the popular Mistral 7B Model.
NOTE: This article assumes an intermediate understanding of Python, LLMs, and Amazon SageMaker Inference. I would…