This project includes training and evaluation of a Concept Bottleneck Model (CBM) on the CelebA dataset. The dataset contains ~200,000 celebrity images with 40 binary facial attributes per image (e.g., Smiling, Bushy Eyebrows, Wearing Glasses, Wearing Lipstick...).
The CBM used in this project consists of two main components:
- Concept Predictor: a convolutional neural network backbone fine-tuned to predict all concept attributes of the CelebA dataset except the “Attractive” attribute,
- Decision Tree Classifier: an interpretable model that uses those concepts to produce the final “Attractive” prediction.
Backbone: ResNet-18 (pretrained on ImageNet)
Head: Single fully-connected layer (Linear(512 → 39))
Loss: Binary Cross-Entropy across all 39 attributes
Optimizer: AdamW
Learning rates: 1e-5 (backbone) and 3e-4 (classifier head)
Weight Decay: 1e-4 (only applied to the classifier head)
Scheduler: ReduceLROnPlateau (patience: 3, factor: 0.5)
Epochs: 4 (best checkpoint)
Batch Size: 16
Augmentations: To improve robustness and prevent overfitting, the following image transformations were applied during training:
-
Resize images to 224×224 (standard for ResNet architectures)
-
Random horizontal flip with probability 0.5
-
Color jitter: slight changes in brightness (±10%), contrast (±10%), saturation (±10%), and hue (±5%)
-
Normalization with mean [0.5063, 0.4258, 0.3832] and standard deviation [0.3105, 0.2903, 0.2897] (computed from the training data in
/preprocessing/normalization_stats):uv run python -m preprocessing.normalization_stats
For validation and test sets, only resizing and normalization were applied.
Max depth: 5 (to keep the model interpretable)
Criterion: Gini impurity
Prediction of the 39 concepts except the “Attractive” label on the test set:
| Metric | Score |
|---|---|
| Accuracy | 0.92 |
| Precision | 0.78 |
| Recall | 0.69 |
| F1-score | 0.72 |
Predictions of the “Attractive” attribute (i.e. Decision tree trained on the predictions of the Concept Model) on the test set:
| Metric | Score |
|---|---|
| Accuracy | 0.80 |
| ROC AUC | 0.88 |
Class-wise Performance:
| Class | Precision | Recall | F1-score |
|---|---|---|---|
| 0 (Not Attractive) | 0.84 | 0.74 | 0.79 |
| 1 (Attractive) | 0.76 | 0.85 | 0.81 |
Feature importances quantify how much each predicted concept contributed to the Decision Tree’s final prediction of the “Attractive” attribute. In this experiment, the most influential concepts were Bald, Blurry, and Chubby, which together account for the majority of the tree’s predictive power.
Below are the normalized importance scores computed by the decision tree:
| Concept | Importance |
|---|---|
| Bald | 0.636 |
| Blurry | 0.128 |
| Chubby | 0.127 |
| Young | 0.062 |
| Heavy_Makeup | 0.016 |
| Wearing_Lipstick | 0.012 |
| Pointy_Nose | 0.011 |
| Smiling | 0.004 |
| Oval_Face | 0.003 |
| Others | 0.000 |
Decision Tree Feature Importances

Predictions of the “Attractive” attribute by a decision tree trained only on the top k=3 concepts Bald, Blurry, Chubby on the test set:
| Metric | Score |
|---|---|
| Accuracy | 0.79 |
| ROC AUC | 0.87 |
Class-wise Performance:
| Class | Precision | Recall | F1-score |
|---|---|---|---|
| 0 (Not Attractive) | 0.82 | 0.75 | 0.78 |
| 1 (Attractive) | 0.76 | 0.83 | 0.80 |
ia176/concept_model.py— LightningModule for the Concept Model (ResNet18 + linear head)data/— DataModule, CelebA dataset and data transformsmodeling/— Backbone and head implementationscallbacks/— Lightning callbacks (prediction saver, evaluator)decision_tree.py— Train a Decision Tree on saved concepts__main__.py— App entry (Hydra + Lightning Trainer)
configs/default.yaml— global config, seeds, logger, and default experiment selectionexperiment/— experiment presets (train, predict, evaluate, saliency, trees)data/— dataset and transform configsmodel/— backbone and head configslogger/— MLflow logger confighydra/— Hydra runtime config
analysis/saliency_maps.py— saliency visualizations for selected conceptsfeature_importance.py— visualize and print Decision Tree importancesablations/train_on_top_k.py— train a Decision Tree on top-k concepts only
preprocessing/— scripts for data preprocessing (normalization statistics computation)images/— generated saliency maps and decision tree plots
The project uses a PyTorch Lightning and Hydra configuration setup. A standard pyproject.toml is used to resolve dependencies. You can run with uv in the following way:
# from repo root
uv syncThe main entry is configs/default.yaml, which picks the experiment to run via the experiment group. You can specify it in the default.yaml configuration file and then just run:
uv run python -m ia176or specify during runtime:
uv run python -m ia176 experiment=predict mode="test"Experiments are located in the configs/experiment/ folder (training, prediction, evaluation, saliency, tree training...)
Logging is configured via configs/logger/mlflow.yaml. Uses a local file storage under ./mlruns.
To launch the MLflow UI run:
uv run mlflow ui --backend-store-uri ./mlruns --host 127.0.0.1 --port 5000
Then open the UI in your browser at:
http://127.0.0.1:5000
To reproduce the results, you must first download the CelebA dataset: https://www.kaggle.com/datasets/kushsheth/face-vae.
Then update dataset locations in configs/data/datasets/CelebA.yaml to reflect your local file structure:
partition_csv_path: CSV with train/val/test split (columns: image_id, partition); corresponds tolist_eval_partition.csvimages_dir: directory with CelebA images; corresponds toimg_align_celeba/img_align_celebametadata_path: CSV with attributes (image_id, 39 attributes + Attractive); corresponds tolist_attr_celeba.csv
Experiment configs: configs/experiment/train.yaml, configs/experiment/predict.yaml
Uses ResNet-18 + linear head to predict 39 concepts.
uv run python -m ia176 experiment=trainThe training and validation metrics are logged to MLflow, the best model (with the lowest validation loss) checkpoint is saved locally.
To evaluate the model on the validation/test sets run:
uv run python -m ia176 experiment=predict mode="validate"or
uv run python -m ia176 experiment=predict mode="test"Experiment config: configs/experiment/predict_on_train.yaml
To save concept predictions for every train image run:
uv run python -m ia176 experiment=predict_on_trainThe output path can be configured in the save_path entry.
Experiment config: configs/experiment/train_tree.yaml
To train a decision tree on the saved train concepts and ground-truth Attractive labels run:
uv run python -m ia176.decision_treeThe output path can be configured in the save_path entry.
Experiment config: configs/experiment/evaluate_cbm.yaml
To evaluate the CBM run:
uv run python -m ia176 experiment=evaluate_cbmNotes:
-
Requires a CBM checkpoint
checkpoint: ...(edit the path) -
The callback loads
tree_path(edit accordingly) and prints Accuracy, AUC, Recall, Precision and F1 Score metrics
If you want to evaluate a tree trained on top-k concepts, use:
uv run python -m ia176 experiment=evaluate_cbm_top_kExperiment config: configs/experiment/evaluate_cbm_top_k.yaml
Notes:
-
Make sure
tree_pathmatches your saved model -
Set
top_k_indicesto match the training
Experiment config: configs/experiment/saliency.yaml
To compute simple pixel-wise gradient saliency map for chosen concept indices and test samples run:
uv run python -m analysis.saliency_mapsSaliency Map for the Bushy_Eyebrows concept:

Saliency Map for the Mouth_Slightly_Open concept:

Experiment config: configs/experiment/feature_importance.yaml
Prints feature importances and exports a full tree plot.
uv run python -m analysis.feature_importanceNotes:
-
Set
tree_pathandoutput_plot_path -
conceptslist defines feature names in the order that was used during training
Experiment config: configs/experiment/train_tree_on_top_k.yaml
To train a Decision Tree using only a subset of concepts from the saved concept predictions run:
uv run python -m analysis.ablations.train_on_top_kNotes:
-
top_k_indices: indices of the concepts to keep -
num_concepts: concept count (39) -
concepts_pathandsave_path: input and output files (edit accordingly)