High Level

Limited local storage (up to 50gb) and the inevitability of node interruptions introduce complexity when managing services that serve inference from hundreds or even thousands of models.

We can combine 3 techniques to manage that complexity in an arbitrarily scalable way:

  1. Preload the container image with the most popular models - This way your container can become productive immediately on start, while it downloads more models in the background.
  2. Local Least-Recently-Used Caching - Only keep the most recently used 50gb of models stored locally on any given node.
  3. Smart Job Scheduling - Assign jobs to nodes in a way that minimizes model downloading and swapping.

Your service likely has some models that are significantly more popular than others. SaladCloud allows a maximum container image size of 35gb compressed. SaladCloud also does not charge during the downloading phase, only charging once the instance enters the running state. This means it is often prudent to include some of your most popular models in the container image, as you effectively get that download time for free. Additionally, it means your container can start doing inference work immediately once it’s started, even as it downloads more models in the background. Finally, SaladCloud’s 50gb storage limit is in addition to any space taken up by your container, so you can get more total local storage by including models in the container image. The main downside is reduced elasticity in scaling, as the larger images will take longer to download to new nodes.

Local LRU Caching

When you have potentially terabytes of model checkpoints, loras, upscale models, and more, it’s never going to be possible to get it all cached locally on a SaladCloud node, due to storage size restrictions. Beyond that, you wouldn’t want to pay for node uptime while downloading all of those models at start, especially when the node may be interrupted at any time, and the process would need to start over from the beginning on another node.

You also don’t want to download 2-6gb checkpoints for potentially every single request, as it introduces unacceptable latency to user generation requests.

The solution here is to implement an LRU Cache, keeping only the most recently used 50gb of models, and automatically clearing out others as needed. A simple python implementation may look like this:

import os
import requests
from collections import OrderedDict

class LRUCacheFileDownloader:
    def __init__(self, cache_dir, max_size_bytes):
        self.cache_dir = cache_dir
        self.max_size_bytes = max_size_bytes
        self.current_size_bytes = 0
        self.cache = OrderedDict()
        self._init_cache()

    def _init_cache(self):
        for filename in os.listdir(self.cache_dir):
            file_path = os.path.join(self.cache_dir, filename)
            if os.path.isfile(file_path):
                file_size = os.path.getsize(file_path)
                self.cache[filename] = file_size
                self.current_size_bytes += file_size

    def _update_cache(self, filename):
        self.cache.move_to_end(filename)

    def _evict(self, required_space):
        while self.current_size_bytes + required_space > self.max_size_bytes:
            if not self.cache:
                raise ValueError("Cannot evict enough files to make space")
            oldest_file, oldest_file_size = self.cache.popitem(last=False)
            oldest_file_path = os.path.join(self.cache_dir, oldest_file)
            os.remove(oldest_file_path)
            self.current_size_bytes -= oldest_file_size

    def download_file(self, url, local_path):
        filename = os.path.basename(local_path)
        file_path = os.path.join(self.cache_dir, filename)

        if filename in self.cache:
            self._update_cache(filename)
            return file_path

        response = requests.get(url, stream=True)
        response.raise_for_status()

        file_size = int(response.headers.get('content-length', 0))
        self._evict(file_size)

        with open(file_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)

        self.cache[filename] = file_size
        self.current_size_bytes += file_size
        self._update_cache(filename)

        return file_path

    def get_cache_info(self):
        return {
            "current_size_bytes": self.current_size_bytes,
            "max_size_bytes": self.max_size_bytes,
            "num_files": len(self.cache),
            "files": list(self.cache.keys())
        }

# Example usage
if __name__ == "__main__":
    cache_dir = "/models"
    max_size_bytes = 1024 * 1024 * 1024 * 50  # 50 GB

    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)

    downloader = LRUCacheFileDownloader(cache_dir, max_size_bytes)

    # Download some files
    urls = [
        "https://example.com/file1.safetensors",
        "https://example.com/file2.safetensors",
        "https://example.com/file3.safetensors"
    ]

    for url in urls:
        local_path = os.path.join(cache_dir, os.path.basename(url))
        downloaded_path = downloader.download_file(url, local_path)
        print(f"Downloaded {url} to {downloaded_path}")

    # Print cache info
    print(downloader.get_cache_info())

Smart Job Scheduling

Minimizing model downloading and model swapping on any given node is key to maintaining good overall performance, as these processes may take significantly longer than the generations themselves.

Doing this requires using a “pull” method of job distribution, where inference workers request work from an API, rather than “pushing” work to nodes through a load balancer. However, a simple job queue will be insufficient, as nodes are likely to receive requests for models they may not have locally stored yet.

The basic pattern here is that workers should include information about themselves and their cache when they request work. For example, a worker may include every model they have loaded in VRAM, and also every model they have downloaded locally. Then, the API can use that information to preferentially return inference jobs for models that are already loaded, or at least downloaded. The API response can also include instructions to begin downloading models not required for the currently assigned generation, in order to expand system capacity for a model that may be increasing in popularity.

Example Request

POST /availability

{
  "node": {
    "id": "random-uuid",
    "gpu": "rtx4090",
    "vram_gb": 24
  },
  "vram_models": {
    "checkpoints": ["dreamshaper_8.safetensors"],
    "loras": ["add_detail.safetensors"]
  },
  "disk_models": {
    "checkpoints": ["Juggernaut_X_RunDiffusion_Hyper.safetensors"],
    "upscalers": ["nomos8khatLOtf_v20.safetensors"]
  }
}

Example Response

{
  "generation": {
    "checkpoint": "dreamshaper_8.safetensors",
    "params": {
      "prompt": "mechanical cat flying a spaceship",
      "negative_prompt": "text, watermark",
      "cfg": 5.6,
      "steps": 20
    }
  },
  "download": {
    "checkpoints": ["https://civitai.com/api/download/models/132760?type=Model&format=SafeTensor&size=pruned&fp=fp16"]
  }
}

In this way, the API can proactively keep models locally cached on n number of nodes, to ensure coverage in case of node interruptions, and to ensure adequate supply of warm inference servers for any given model. It also allows nodes to do other useful work while preparing the local cache for future work.

Detecting interrupted nodes typically involves a heartbeat mechanism, where if a node hasn’t requested work within a certain amount of time, assume it’s dead (this can be verified with the SaladCloud API), and reassign work as needed.