- Home
- /
- TIL
- /
- Deep Learning
Details of the CONCH Model
Table of Contents
Summary
- The CONtrastive learning from Captions for Histopathology (CONCH) model’s pre-trained weights are available here (Hugging Face).
- The architecture of the CONCH model is based on the Cntrastive Captioners (CoCa).
- The CONCH model has abilities for image-to-text and image-text-alignment tasks (not for text-to-image).
How to get the details of the CONCH model
In google colab:
!pip install git+https://github.com/Mahmoodlab/CONCH.git
from conch.open_clip_custom import create_model_from_pretrained
path_to_model = "checkpoints/conch/pytorch_model.bin"
model, preprocess = create_model_from_pretrained('conch_ViT-B-16', path_to_model)
print(model)
# see the lase section of this page for the output of the model
Details of the CONCH model
This section was generated by GitHub Copilot.
CoCa Model Architecture Overview
The CoCa model is a multimodal transformer architecture designed to process both text and visual (image) data. It consists of several main components:
1. Text Encoder (text
)
- TextTransformer
- token_embedding: Embeds input tokens (vocabulary size: 32,007, embedding dim: 768).
- transformer: A stack of 12
ResidualAttentionBlock
s, each with:- LayerNorm layers
- MultiheadAttention (self-attention)
- MLP (2-layer feedforward with GELU activation)
- Identity layers (likely for residual scaling or future extension)
- ln_final: Final LayerNorm.
2. Visual Encoder (visual
)
- VisualModel
- trunk: VisionTransformer
- PatchEmbed: Converts image to patches using a Conv2d layer (3 input channels, 768 output, 16x16 patch size).
- blocks: 12 transformer
Block
s, each with:- LayerNorm
- Attention (self-attention with qkv projection)
- MLP (2-layer feedforward with GELU)
- Identity and Dropout layers
- norm: Final LayerNorm.
- fc_norm, head_drop, head: Additional normalization and output layers.
- attn_pool_contrast: Attentional pooling for contrastive tasks (uses MultiheadAttention and LayerNorm).
- ln_contrast: LayerNorm for contrastive output.
- head: Output head (empty Sequential, possibly for extension).
- attn_pool_caption: Attentional pooling for captioning (uses MultiheadAttention and LayerNorm).
- ln_caption: LayerNorm for caption output.
- trunk: VisionTransformer
3. Text Decoder (text_decoder
)
- MultimodalTransformer
- resblocks: 12
ResidualAttentionBlock
s (similar to encoder, for autoregressive decoding). - cross_attn: 12
ResidualAttentionBlock
s with cross-attention:- Includes LayerNorm for key/value inputs (
ln_1_kv
). - Allows the decoder to attend to visual features.
- Includes LayerNorm for key/value inputs (
- ln_final: Final LayerNorm.
- resblocks: 12
Key Architectural Features
- ResidualAttentionBlock: Core transformer block with LayerNorm, MultiheadAttention, MLP, and residual connections.
- MultiheadAttention: Used throughout for both self-attention and cross-attention.
- GELU Activation: Used in all MLPs for non-linearity.
- LayerNorm: Applied before/after attention and MLPs for stability.
- Attentional Poolers: Used to aggregate features for contrastive and captioning tasks.
Summary Table
Component | Layers/Blocks | Main Operations | Purpose |
---|---|---|---|
TextTransformer | 12 Residual Blocks | Self-attention, MLP | Encode text |
VisionTransformer | 12 Blocks | Self-attention, MLP | Encode images |
TextDecoder | 12 Residual Blocks | Self/cross-attention, MLP | Decode text, attend to vision |
Attn Poolers | MultiheadAttention | Pooling, LayerNorm | Aggregate features |
Usage
- Contrastive Learning: Aligns text and image representations (like Clip).
- Captioning: Generates text conditioned on images via cross-attention in the decoder.
- Multimodal understanding tasks
In summary: CoCa is a dual-encoder (text/image) transformer with a multimodal decoder, supporting both contrastive and generative (captioning) tasks, using standard transformer components throughout.
Output of the model
CoCa(
(text): TextTransformer(
(token_embedding): Embedding(32007, 768)
(transformer): Transformer(
(resblocks): ModuleList(
(0-11): 12 x ResidualAttentionBlock(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(ls_1): Identity()
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): GELU(approximate='none')
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ls_2): Identity()
)
)
)
(ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(visual): VisualModel(
(trunk): VisionTransformer(
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
(norm): Identity()
)
(pos_drop): Dropout(p=0.0, inplace=False)
(patch_drop): Identity()
(norm_pre): Identity()
(blocks): Sequential(
(0): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(1): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(2): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(3): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(4): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(5): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(6): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(7): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(8): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(9): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(10): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(11): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
)
(norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(fc_norm): Identity()
(head_drop): Dropout(p=0.0, inplace=False)
(head): Identity()
)
(attn_pool_contrast): AttentionalPooler(
(attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(ln_q): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(ln_k): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(ln_contrast): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(head): Sequential()
(attn_pool_caption): AttentionalPooler(
(attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(ln_q): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln_k): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(ln_caption): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(text_decoder): MultimodalTransformer(
(resblocks): ModuleList(
(0-11): 12 x ResidualAttentionBlock(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(ls_1): Identity()
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): GELU(approximate='none')
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ls_2): Identity()
)
)
(cross_attn): ModuleList(
(0-11): 12 x ResidualAttentionBlock(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(ls_1): Identity()
(ln_1_kv): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): GELU(approximate='none')
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ls_2): Identity()
)
)
(ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)