Build a competitive baseline for CAFA-6 protein function prediction using ESM-2 embeddings and multi-task learning. Implements a two-stage approach: pre-compute embeddings once, then rapidly iterate on classifiers.
Build a competitive baseline for the CAFA-6 Kaggle competition: predicting Gene Ontology (GO) terms from protein sequences using state-of-the-art protein language models.
This skill guides you through implementing a two-stage deep learning pipeline for protein function prediction:
1. **Stage 1**: Generate ESM-2 protein embeddings once (~1.5-2 hours)
2. **Stage 2**: Train multi-task classifiers rapidly (~30 minutes each)
The approach handles key challenges: extreme class imbalance (26,125 labels, 51.6% appearing in <10 proteins), long sequences (24.8% exceed 1024 amino acids), and multi-aspect complexity (Cellular Component, Molecular Function, Biological Process).
Expected performance: F1 0.45-0.55 for baseline, scalable to 0.60-0.65 with optimizations.
1. **Set up the project structure**
- Create directories: `src/data/`, `src/models/`, `src/utils/`, `scripts/`, `notebooks/`, `data/`
- Initialize a Python virtual environment
- Install core dependencies: `torch`, `transformers`, `fair-esm`, `pandas`, `numpy`, `scikit-learn`, `tqdm`
2. **Download and explore the CAFA-6 dataset**
- Use the Kaggle API to download: `kaggle competitions download -c cafa-6-protein-function-prediction`
- Extract to `data/raw/`
- Load `train_sequences.fasta` and `train_annotations.tsv`
- Analyze key statistics:
- 82,404 proteins, 26,125 unique GO terms
- Average 6.52 labels/protein (range: 1-233)
- Sequence lengths: 3-35,213 AA (median: 376)
- Label frequency distribution (29.4% ultra-rare: 1-5 proteins)
3. **Create exploratory data analysis scripts**
- Write `notebooks/eda.py` to compute:
- GO term distribution by aspect (C/F/P)
- Protein coverage per aspect
- Sequence length percentiles
- Label frequency histogram
- Write `notebooks/eda_advanced.py` for co-occurrence analysis and aspect overlap
4. **Implement preprocessing utilities** (`src/data/preprocessing.py`)
- Function to load FASTA sequences into dictionary
- Function to parse TSV annotations into protein→GO term mapping
- Function to create multi-hot label matrix (82,404 × 26,125)
- Function to split labels by aspect (C/F/P) and create separate label matrices
- Handle missing/invalid sequences
5. **Build PyTorch datasets** (`src/data/dataset.py`)
- `ProteinSequenceDataset`: wraps sequences and labels for training
- `PrecomputedEmbeddingDataset`: loads saved embeddings from disk
- Implement `__len__` and `__getitem__` methods
- Add collate function for batching
6. **Implement sequence handling strategies** (`src/utils/sequence_handling.py`)
- `truncate_sequence()`: Keep first 512 + last 512 AA for sequences >1024
- `sliding_window()`: Generate overlapping windows for long sequences
- `tokenize_sequence()`: Convert AA strings to ESM-2 token IDs
7. **Create embedding generation script** (`scripts/generate_embeddings.py`)
- Load ESM-2 150M model: `esm.pretrained.esm2_t33_650M_UR50D()` or smaller
- Set model to eval mode, move to GPU
- Iterate through all 82,404 proteins:
- Tokenize sequence (handle truncation for >1024 AA)
- Extract embeddings (use mean pooling over sequence length)
- Save to `data/embeddings/train_embeddings_esm2_150M.pt` (use torch.save)
- Include progress bar (tqdm) and time estimates
- Save metadata: protein IDs, sequence lengths, truncation flags
8. **Create Kaggle-specific embedding scripts**
- `scripts/generate_embeddings_kaggle.py`: Single P100 GPU version
- `scripts/generate_embeddings_kaggle_dual_gpu.py`: T4 x2 with DataParallel
- Add batch processing (batch_size=8-16 depending on GPU memory)
- Add checkpoint saving every 10k proteins (resume on failure)
9. **Implement baseline classifiers** (`src/models/baseline.py`)
- `SimpleClassifier`: Dense(640→512)→ReLU→Dropout→Dense(512→26125)→Sigmoid
- `ESM2Classifier`: End-to-end model wrapping ESM-2 + classification head
- `CNNClassifier`: Conv1D baseline for comparison
- Add forward methods returning logits
10. **Implement multi-task architecture** (`src/models/multitask.py`)
- `MultiTaskClassifier`:
- Shared encoder: Dense(640→1024)→ReLU→Dropout(0.3)
- C head: Dense(1024→512)→ReLU→Dense(512→2651)
- F head: Dense(1024→512)→ReLU→Dense(512→6616)
- P head: Dense(1024→512)→ReLU→Dense(512→16858)
- Forward returns dict: `{'C': logits_C, 'F': logits_F, 'P': logits_P}`
- Add learnable task weights (log space): `self.task_weights = nn.Parameter(torch.zeros(3))`
11. **Implement Focal Loss** (`src/utils/losses.py`)
- Binary focal loss: `FL = -α(1-p)^γ log(p)` for positives, `-α p^γ log(1-p)` for negatives
- Default hyperparameters: α=0.25, γ=2.0
- Add reduction options: 'mean', 'sum', 'none'
12. **Implement Multi-Task Loss** (`src/utils/losses.py`)
- `MultiTaskFocalLoss`:
- Compute focal loss for C, F, P predictions separately
- Weight by learned task weights: `w_i = exp(log_weight_i)`
- Final loss: `L = w_C * L_C + w_F * L_F + w_P * L_P + λ * (w_C + w_F + w_P)`
- Add regularization term to prevent weight collapse
13. **Create baseline training script** (`scripts/train_baseline.py`)
- Load pre-computed embeddings from disk
- Create `PrecomputedEmbeddingDataset` and DataLoader (batch_size=512)
- Initialize `SimpleClassifier` and move to GPU
- Set up optimizer: AdamW (lr=1e-3, weight_decay=1e-4)
- Set up scheduler: OneCycleLR or ReduceLROnPlateau
- Training loop:
- Forward pass → compute focal loss → backward → optimizer step
- Track training loss per epoch
- Validate every N epochs (compute F1 score on val set)
- Save best checkpoint based on val F1
- Log training progress to console and CSV
14. **Create multi-task training script** (`scripts/train_multitask.py`)
- Similar to baseline but:
- Split labels by aspect (C/F/P) during data loading
- Use `MultiTaskClassifier` model
- Use `MultiTaskFocalLoss`
- Track per-aspect losses and F1 scores
- Log task weights over training
15. **Implement evaluation utilities** (`scripts/evaluate.py`)
- Load trained model checkpoint
- Generate predictions on validation/test set
- Compute metrics:
- Per-label F1 score
- Macro-averaged F1 (average across all GO terms)
- Micro-averaged F1 (global TP/FP/FN)
- Per-aspect F1 (C/F/P)
- Implement threshold optimization:
- Grid search thresholds (0.1 to 0.9, step 0.05) per label
- Select threshold maximizing F1 on validation set
- Save optimal thresholds to JSON
16. **Create submission generation script** (`scripts/make_submission.py`)
- Load test sequences
- Generate embeddings (or load pre-computed)
- Load trained model + optimal thresholds
- Generate predictions
- Format as Kaggle submission CSV: `protein_id,GO_term,confidence`
- Save to `submissions/submission_v1.csv`
17. **Implement sliding window for long sequences**
- Modify embedding generation to use overlapping windows (stride=512)
- Average embeddings across windows (mean pooling)
- Expected improvement: +2-3% F1
18. **Upgrade to ESM-2 650M model**
- Change model loading in embedding scripts
- Adjust batch size for larger memory footprint
- Re-generate all embeddings (~3-4 hours)
- Expected improvement: +5-8% F1
19. **Ensemble multiple models**
- Train 3 models with different random seeds
- Average predictions (or weighted average based on val F1)
- Expected improvement: +3-5% F1
20. **Optimize thresholds per label**
- For each GO term, find threshold maximizing F1 on validation set
- Apply label-specific thresholds during inference
- Expected improvement: +3-5% F1
| Stage | Model | Time | F1 Score |
|-------|-------|------|----------|
| Baseline | ESM-2 150M + Simple NN | 2h | 0.45-0.50 |
| Multi-Task | ESM-2 150M + Multi-Task | 2.5h | 0.50-0.55 |
| Larger Model | ESM-2 650M + Multi-Task | 4h | 0.55-0.60 |
| Optimizations | + Sliding Window + Thresholds | 5h | 0.58-0.63 |
| Ensemble | Average 3 Models | 15h | 0.60-0.65 |
**Generate embeddings on Kaggle:**
```python
!python scripts/generate_embeddings_kaggle.py \
--model esm2_t33_650M_UR50D \
--input /kaggle/input/cafa-6/train_sequences.fasta \
--output /kaggle/working/embeddings_esm2_150M.pt \
--batch_size 16
```
**Train multi-task model:**
```bash
python scripts/train_multitask.py \
--embeddings data/embeddings/train_embeddings_esm2_150M.pt \
--annotations data/raw/train_annotations.tsv \
--output models/multitask_v1.pt \
--epochs 50 \
--batch_size 512 \
--lr 1e-3
```
**Generate submission:**
```bash
python scripts/make_submission.py \
--model models/multitask_v1.pt \
--thresholds outputs/optimal_thresholds.json \
--test_sequences data/raw/test_sequences.fasta \
--output submissions/submission_v1.csv
```
Leave a review
No reviews yet. Be the first to review this skill!
# Download SKILL.md from killerskills.ai/api/skills/cafa-6-protein-function-prediction/raw