akasa
AKASA
September 03, 2025

The Gist

At AKASA, we're pushing the frontier of generative AI in healthcare by tackling one of the most technically daunting and impactful problems in the industry: medical coding at scale. Our team has built a code verification pipeline using attention-based multiple instance learning (MIL) to overcome large language model (LLM) context window limits and deliver interpretable, high-confidence predictions across massive clinical encounters — think 50,000+ word medical novels. This work combines deep domain-specific LLMs, clever optimization techniques, and distributed training infrastructure to solve a real-world bottleneck with elegance and speed. If you’re excited by scalable ML architectures, rich healthcare data, and building innovative solutions, keep reading.

The Challenge: Medical Coding of Long, Complex Inpatient Stays

Our mission at AKASA is to use the incredible power of generative AI and large language models (LLMs) to help ease one of the most critical bottlenecks in the U.S. healthcare system: the reimbursement channel between healthcare providers and insurers. A core part of this reimbursement process involves medical coding, the act of translating medical encounters to a short list of compliant, billable insurance codes.

Medical coding is currently performed by mid-revenue cycle staff at hospitals. Their day-to-day tasks involve reviewing upwards of 100+ complex clinical documents per medical encounter — the equivalent of a short novel at 50,000 words. Think The Great Gatsby. They must then map these documents to, on average, 19 ICD codes (out of a set of 140,000 codes!) to accurately describe that patient visit.

The good news for medical coders — who are expected to code each encounter in around 30 minutes — is that LLM-based assistants are very well-suited to supercharge their workflow by surfacing relevant parts of medical documentation and suggesting relevant codes.

When applying LLMs to complex, inpatient hospital stays, one is immediately confronted with several technical challenges. In this blog post, we outline a powerful technique called multiple instance learning (MIL), as described in the paper Attention-based Deep Multiple Instance Learning, to tackle two of these most pressing challenges:

1. Documentation for many encounters can get very long, far exceeding typical context window sizes for LLMs

2. Code predictions, which are used to compile the final codeset, should be accompanied by confidence scores, along with justifications and quotes from the underlying documentation in order to:

  • be useful to medical coders reviewing their work
  • minimize compliance risk
  • improve interpretability of code attribution
  • boost confidence in the underlying automation

Interestingly, while the most common approach to dealing with large document sets in industry is to use retrieval augmented generation (RAG), we find that MIL offers a very enticing, probabilistic alternative that is end-to-end trainable. Let’s dive in!

 

Multiple-Instance Learning

Multiple-instance learning (MIL) is a supervised learning paradigm in which each training example is a “bag” of instances sharing a single label, rather than each instance being individually labeled. A bag is labeled positive if at least one instance in it is positive, and negative if all instances are negative.

In our medical coding pipeline, we leverage MIL as a critical code verification step, which ensures that each code prediction is supported by the underlying documentation. Each clinical encounter is chunked and treated as a bag of text segments (instances) and given a candidate medical (ICD) code. Our goal is to verify whether that code applies.

The MIL model learns to aggregate evidence from these segments: if any instance (text segment) strongly indicates the code, the bag (encounter) is marked positive for that code.

Modern attention-based MIL architectures embed all instances and then use a permutation-invariant aggregator to compute the bag-level label. The aggregator uses learned attention weights to pool the instance embeddings. Intuitively, it “looks” across segments to find the most predictive ones. Importantly, this attention mechanism provides interpretability, since the weights highlight which segments contribute most to the decision.

In our module, we use our in-house clinical LLM as the segment encoder in order to compute a high-dimensional embedding for each text chunk and the medical code in question. These segment embeddings and code embedding are fed into an MIL network that outputs a binary (0/1) classification: whether the code is relevant to the given encounter.

Thanks to the MIL’s learned pooling strategy, the model not only predicts whether a code is relevant to an encounter, but also produces attention scores that pinpoint relevant segments in the encounter documentation. Because MIL operates over a bag of embeddings, it also handles our long-context problem. Importantly, our LLM encoder is only running over small text segments rather than the full encounter.

 

Model Architecture and Pipeline

Our verification module processes a medical encounter in several stages, which we illustrate below:

Figure 1: Illustration of how a large medical encounter is chunked and embedded jointly with a candidate medical code description (bottom). The core MIL pipeline, where the bag of instances (embeddings) is combined to form a probabilistic code verification, along with segment weights that can be used for evidence attribution.

First, the text of each document in the encounter is split into overlapping chunks (instances) of manageable length. Each chunk is embedded using the AKASA clinical foundation model, which has been trained using internal/proprietary datasets. We use the hidden state of the final token to get a fixed vector representation of that segment. These segment vectors make up the bag of instances for the encounter. Next, we feed the medical ICD-10 code plus code description through the same foundation model to get a vector representation of the medical code.

We then pass the bag of text segment embeddings, as well as the code embedding, to the MIL model. The MIL model starts by concatenating the code embedding with each text segment embedding, feeding each concatenated tensor through a multi-layer perceptron (MLP). The MLP projects the output of each text segment + code embedding down to a single value and computes a softmax across all text segments.

Essentially, the MLP is learning how to “weight” the importance of each segment. We have found that this type of MLP-based pooling is more effective than more standard pooling mechanisms, such as attention pooling.

The softmax output is used to compute a weighted average of hidden states across the text segments, which is then fed to a small feed-forward binary classifier. The final sigmoid output is thresholded to produce the binary code-verification decision, and the softmax output can be used to identify the relevant text segments for a given medical code.

This end-to-end pipeline is designed with coder workflow in mind. Medical notes can be thousands of words long, and coders must manually find evidence for each code they consider.

Our system significantly reduces the evidence-gathering process — it highlights text segments driving the decision. By combining rich embeddings from our LLM and on-the-fly aggregation through MIL, the code verification pipeline can leverage information from anywhere in the encounter without oversimplifying the problem to simple keyword matching or semantic search, as used in RAG.

Training

Though the model architecture is straightforward, we leverage multiple training techniques to achieve training stability. We highlight a few of the particularly notable areas below.

Negative sampling

While our dataset provides clear positive code labels, the selection of negative examples requires careful consideration. We face a fundamental trade-off: easy negatives (completely unrelated codes) offer little learning signal, while overly hard negatives (highly similar codes) can impede model convergence.

To strike this balance, we employ a multi-source negative sampling strategy:

  • Hierarchical negatives: We leverage the ICD code hierarchy to select clinically similar codes that are not present on the current record, providing meaningful contrast without being overly difficult.
  • Random negatives: We include randomly sampled codes to ensure the model learns broad discriminative patterns across the entire code space.
  • Production-informed negatives: We incorporate historical human labels from our production system, which allows us to continually improve and tailor it to client behavior.

This diverse negative sampling approach allows our model to learn both coarse-grained distinctions (avoiding obviously incorrect codes) and fine-grained discrimination (distinguishing between closely related conditions), while maintaining stable training dynamics.

Memory efficiency

While the MIL architecture itself is lightweight, we backpropagate through the entire LLM to refine the embeddings directly. This end-to-end training approach is both compute and memory-intensive, particularly when processing large clinical encounters with extensive documentation.

To manage these memory demands, we implement several optimization strategies:

  • Gradient checkpointing: We trade additional computation for reduced memory usage by recomputing intermediate activations during the backward pass rather than storing them.
  • Mixed precision training: We leverage bfloat16 quantization combined with mixed precision training to reduce memory footprint while maintaining numerical stability.
  • ZeRO optimizer: We employ a ZeRO redundancy optimizer to distribute optimizer states and gradients across devices, further reducing per-device memory requirements.
  • Mini-batch chunking: We process large bags of text segments in smaller chunks during each forward pass, accumulating gradients to keep peak memory consumption manageable.

This combination of techniques enables our system to perform end-to-end training on even the largest clinical encounters, refining LLM embeddings while staying within practical memory constraints.

Training throughput

Single-GPU training is prohibitively slow for our model size and dataset scale, necessitating a move to distributed training approaches.

We initially implemented multi-node distributed training using distributed data parallel (DDP). However, we encountered a critical bottleneck: clinical encounters vary dramatically in length, with some containing only tens of text segments while others include hundreds. When encounters of vastly different sizes are distributed across GPUs, significant load imbalance occurs — GPUs processing shorter encounters sit idle while waiting for GPUs handling longer encounters to complete their computations.

We address this by pre-chunking encounters and implementing batch balancing to ensure each GPU receives encounters of similar total length within each training batch, dramatically reducing GPU idle time.

To scale to larger models and batch sizes, we explore fully-sharded data parallel (FSDP) training, where the model, gradients, and optimizer states are sharded across multiple GPUs. During forward and backward passes, FSDP dynamically gathers necessary parameters from other GPUs when needed, then discards them to free memory. FSDP’s parameter sharing mechanism requires strict synchronization: all GPUs must execute each layer in the forward pass in synchrony. This constraint is incompatible with our variable-length bags, where one GPU might need to run additional calls through the text embedding function for longer encounters, while others have already finished.

While our encounter balancing approach partially mitigates this issue, we cannot guarantee perfectly uniform segment counts across all encounters in a batch. To fully resolve this, we modify the text embedding portion of our forward pass to ensure all GPUs execute the same number of embedding calls. The algorithm works as follows:

1. Determine the maximum number of calls needed across all encounters in a batch (based on encounter length and embedding mini-batch size).

2. Redistribute text chunks within shorter encounters to match this loop count, adjusting mini-batch sizes dynamically.

This ensures all GPUs execute the same number of embedding calls, maintaining FSDP’s synchronization requirements while handling variable-length encounters. We illustrate the algorithm in Figure 2 below.

Figure 2: An illustration of the speed gains obtained using our DDP and FSDP-compatible batch load balancing techniques.
2a: Single-GPU training, with multiple encounters processed sequentially
2b: DDP with unbalanced encounter size (GPU 0) and balanced encounter size (GPU 1 and 2) (Note the empty spaces representing GPU idle time while GPU 1 finishes)
2c: FSDP with dynamic mini-batching for segment embedding computation

 

Results

In Figure 3 below, we illustrate the effectiveness of using the MIL code verification module. We use MIL as a post-processing step to re-score and verify candidate codes produced by our baseline LLM. We compare its predictive performance, namely, the precision over the set of codes for an encounter.

The baseline LLM directly generates medical codes after ingesting the full encounter, ensembling predictions across “chunks” when the encounter is larger than the full context window limit.

As expected, the baseline model outperforms MIL when the entire encounter can fit into its context window (since its effective “chunk size” is larger), but MIL starts dominating as soon as the context window limit is reached.

Additionally, we see improvements across all encounter lengths in the high-confidence regime when we combine the baseline LLM scores with MIL using a weighted average. This is a very promising result because there are many real-world problems in which it is easy to hit LLM context window limits. MIL provides an effective way to work around this limit without sacrificing predictive performance.

Figure 3: Comparing the predictive performance of MIL versus our baseline model in scenarios where the entire encounter can fit in the context window (left) vs. cases where it does not (middle and right)

Conclusion

Medical coding is a complex, nuanced, and critical process in the relationship between hospitals and insurance companies. With thousands of possible ICD codes and lengthy clinical encounters, ensuring accurate code assignment requires significant expertise and time.

Our MIL-based code verification module represents a key component of this solution, addressing both human and technical challenges in medical coding review. For coders, it transforms the tedious process of manually searching through hundreds of pages of clinical notes by surfacing relevant codes with supporting evidence, dramatically reducing coding time and inaccuracies. Our system helps reduce healthcare staff burnout and ensures that medical codes more faithfully reflect the care provided to patients, supporting healthcare systems in their mission to deliver quality care.

If you’re excited about training/finetuning domain-specific LLMs on large healthcare datasets to solve immediate problems in our healthcare system today, we’d love to have you join AKASA. We’re always on the lookout for talented ML engineers and researchers who share our mission of making healthcare more efficient and accurate through generative AI. And we’re hiring!

 

You may also like

Blog Resource
Sep 30, 2025

Ring Attention: Shedding Light on the Dark Art of Attention Sharding

Healthcare documentation demands models that can process long-context medical documents and massive sequences. But standard attention mechanisms break down...

Blog auto
Feb 20, 2024

How the AKASA Engineering Team Created an Automation Solution for Database Migrations

AKASA builds products and tools to improve the various components of revenue cycle management (medical billing) for hospital systems....

Blog Resource
May 1, 2023

ChatGPT and Healthcare: Exciting Potential That Needs To Be Channeled

Recently I heard that as a fun exercise, the security officer at one of our healthcare clients tried asking...

Blog Resource
Jun 12, 2023

Overcoming the Top 3 Challenges Holding Back Healthcare Innovation

Healthcare is notoriously slow at adapting and incorporating new technologies into day-to-day operations. Healthcare lags behind as one of...

Blog Resource
Jan 16, 2025

7 IT Mistakes You’re Making With Your RCM Automation Partner

The right revenue cycle management (RCM) automation is capable of helping healthcare organizations overcome a litany of issues —...

Blog Resource
Jan 16, 2025

Questions Healthcare IT Teams Should Ask About Revenue Cycle Automation

RCM leaders at your organization are discussing automation. Period. The healthcare revenue cycle is fighting non-stop battles. Staffing challenges...

Blog Resource
Jan 16, 2025

9 Healthcare Technology Trends To Watch

Keeping track of the rapid changes in healthcare technology is no small task. The industry has seen numerous healthcare...

Find out how AKASA's GenAI-driven revenue cycle solutions can help you.