TP 9 : Arbres de décision ID3
Introduction
Dans ce TP, nous allons implémenter l'algorithme ID3 pour classifier les survivants du Titanic. On rappelle que c'est un algorithme de classification supervisée : à partir de données d'entraînement connues, il construit un arbre de décision permettant de prédire la classe d'une nouvelle donnée.
Partie 1 : Découverte des données
-
Créer un nouveau Codespace sur https://github.com/mpi-lamartin/titanic. Lire le README.md pour comprendre la structure du code. Compiler et exécuter le code fourni en exemple qui classifie les femmes comme survivantes.
-
Créer un fichier
src/id3.mldans lequel vous allez écrire le code du TP. Ajouteropen Csv_loaderet chargez les données d'entraînement et de test. Vous pouvez vous inspirer dewoman_survive.ml. -
Écrire une fonction
statsqui affiche le nombre de survivants et de décédés dans les données d'entraînement. On rappelle qu'on peut afficher une variablexavec :Printf.printf "valeur de x : %d\n" x.
val stats : passenger list -> unit
Survivants: 342, Décédés: 549
Pour compiler et exécuter votre code, vous pouvez modifier le fichier Makefile en remplaçant woman_survive.ml par id3.ml, puis utiliser dans le terminal :
make run
Partie 2 : Attributs catégoriels
Pour construire un arbre de décision, nous devons transformer les attributs numériques en attributs catégoriels. Par exemple, l'âge sera transformé en catégories : "child" ( ans), "adult" ( ans), "senior" ( ans).
On définit un type attribute des attributs que l'on utilisera pour construire l'arbre de décision :
type attribute =
| Sex (* "male" ou "female" *)
| Pclass (* "1", "2" ou "3" *)
| AgeGroup (* "child", "adult", "senior", "unknown" *)
| FamilySize (* "alone", "small", "large" *)
| Embarked (* "S", "C", "Q", "unknown" *)
| FareGroup (* "low", "medium", "high", "unknown" *)
- Écrire une fonction
attribute_valuesqui retourne la liste des valeurs possibles pour un attribut :
val attribute_values : attribute -> string list
attribute_values Sex (* = ["male"; "female"] *)
attribute_values Pclass (* = ["1"; "2"; "3"] *)
attribute_values AgeGroup (* = ["child"; "adult"; "senior"; "unknown"] *)
attribute_values FamilySize (* = ["alone"; "small"; "large"] *)
- Écrire une fonction
get_attribute_valuequi retourne la valeur d'un attribut pour un passager :
val get_attribute_value : attribute -> passenger -> string
Les règles de conversion sont :
AgeGroup: "child" si , "adult" si , "senior" sinonFamilySize: "alone" si , "small" si , "large" sinonFareGroup: "low" si , "medium" si , "high" sinon
Pour le passager n°1 (Braund, Mr. Owen Harris - homme, 22 ans, classe 3, sibsp=1, parch=0, fare=7.25, embarqué à S) :
get_attribute_value Sex p1 (* = "male" *)
get_attribute_value Pclass p1 (* = "3" *)
get_attribute_value AgeGroup p1 (* = "adult" car 18 ≤ 22 ≤ 60 *)
get_attribute_value FamilySize p1 (* = "small" car 1+0=1 ≤ 3 *)
get_attribute_value FareGroup p1 (* = "low" car 7.25 < 10 *)
get_attribute_value Embarked p1 (* = "S" *)
Partie 3 : Entropie et gain d'information
L'algorithme ID3 utilise l'entropie pour mesurer l'entropie d'un ensemble de données, définie par :
où est la proportion d'éléments de classe (survived) dans . Par convention, .
- Écrire une fonction
entropyqui calcule l'entropie d'une liste de passagers :
val entropy : passenger list -> float
entropy train (* = 0.9607 *)
Indication : On peut calculer avec log x /. log 2.0 et convertir un entier n en float avec float_of_int n.
Le gain d'information d'un attribut par rapport à un ensemble mesure la réduction d'entropie obtenue en divisant selon les valeurs de :
où est le sous-ensemble des éléments ayant la valeur pour l'attribut .
- Écrire une fonction
information_gainqui calcule le gain d'information d'un attribut :
val information_gain : passenger list -> attribute -> float
information_gain train Sex (* = 0.2177 *)
information_gain train Pclass (* = 0.0838 *)
information_gain train AgeGroup (* = 0.0162 *)
information_gain train FamilySize (* = 0.0608 *)
information_gain train Embarked (* = 0.0240 *)
information_gain train FareGroup (* = 0.0916 *)
- Écrire une fonction
best_attributequi trouve l'attribut avec le meilleur gain d'information parmi une liste d'attributs :
val best_attribute : passenger list -> attribute list -> attribute option
(* Retourne None si aucun attribut n'a un gain > 0 *)
let all_attrs = [Sex; Pclass; AgeGroup; FamilySize; Embarked; FareGroup] in
best_attribute train all_attrs (* = Some Sex, car Sex a le gain le plus élevé : 0.2177 *)
Partie 4 : Construction de l'arbre
Un arbre de décision est soit une feuille (prédiction), soit un nœud interne (test sur un attribut).
- Définir le type
decision_tree:
type decision_tree =
| Leaf of int (* Prédiction : 0 ou 1 *)
| Node of attribute * (string * decision_tree) list (* Attribut et branches *)
- Écrire une fonction
majority_classqui retourne la classe majoritaire (0 pour décédé, 1 pour survivant) d'un ensemble de passagers :
val majority_class : passenger list -> int
majority_class train (* = 0, car 549 décédés > 342 survivants *)
(* Pour un sous-ensemble de femmes de 1ère classe *)
let femmes_1ere = List.filter (fun p -> p.sex = "female" && p.pclass = 1) train in
majority_class femmes_1ere (* = 1, car la plupart ont survécu *)
- Écrire la fonction récursive
build_treequi construit un arbre de décision avec l'algorithme ID3 :
val build_tree : passenger list -> attribute list -> decision_tree
(* build_tree data attributes *)
L'algorithme ID3 est le suivant :
- Si la liste est vide, retourner une feuille avec la prédiction 0
- Si tous les passagers ont la même classe, retourner une feuille avec cette classe
- Si la liste d'attributs est vide, retourner une feuille avec la classe majoritaire
- Sinon :
- Trouver le meilleur attribut selon le gain d'information
- Pour chaque valeur de , construire récursivement un sous-arbre sur le sous-ensemble
- Retourner un nœud avec l'attribut et les sous-arbres
Utiliser la fonction suivante pour afficher l'arbre obtenu :
let rec print_tree tree indent =
let spaces = String.make (indent * 2) ' ' in
match tree with
| Leaf pred ->
Printf.printf "%s└── Prédiction: %s\n" spaces
(if pred = 1 then "Survit" else "Décède")
| Node (attr, branches) ->
Printf.printf "%s[%s]\n" spaces (attribute_name attr);
List.iter (fun (value, subtree) ->
Printf.printf "%s ├─ %s :\n" spaces value;
print_tree subtree (indent + 2)
) branches
Partie 5 : Prédiction et évaluation
- Écrire une fonction
predictqui prédit la survie d'un passager avec l'arbre de décision :
val predict : decision_tree -> passenger -> int
- Écrire une fonction
accuracyqui calcule la précision de l'arbre sur un ensemble de données :
val accuracy : decision_tree -> passenger list -> float
accuracy tree train (* = 0.8530, soit 85.30% de précision *)
Partie 6 : Génération de soumission
- Écrire une fonction
generate_submissionqui génère un fichier CSV de soumission au format Kaggle. Vous pouvez vous inspirer dewoman_survive.ml. Utilisez le lien Kaggle envoyé par mail pour soumettre votre fichier et obtenir votre score.
val generate_submission : decision_tree -> passenger list -> string -> unit
(* generate_submission tree test_data filename *)
Le fichier doit avoir le format suivant :
PassengerId,Survived
892,0
893,1
...
generate_submission tree test "submission_id3.csv"
(* Affiche : Fichier de soumission généré : submission_id3.csv *)
Pour aller plus loin
- Quand on mesure la précision sur les données d'entraînement (85%), on obtient une estimation optimiste car le modèle a été construit sur ces mêmes données. C'est le problème du surapprentissage (overfitting). La validation croisée k-fold permet d'estimer la performance réelle du modèle sur des données non vues.
Principe de la validation croisée k-fold :
Algorithme :
- Mélanger aléatoirement les données
- Diviser en parties (folds) de taille égale
- Pour chaque fold de 1 à :
- Utiliser le fold comme ensemble de validation
- Utiliser les autres folds comme ensemble d'entraînement
- Construire l'arbre sur l'entraînement
- Calculer la précision sur la validation → Score
- Retourner la moyenne et l'écart-type des scores
Signature :
val cross_validation : passenger list -> int -> (passenger list -> decision_tree) -> float * float * float list
(* cross_validation data k build_fn retourne (moyenne, écart_type, liste_scores) *)
- Forêt aléatoire
Un arbre de décision a tendance à surapprendre. La forêt aléatoire construit plusieurs arbres indépendants et les fait voter pour obtenir une prédiction plus robuste.
Principe :
-
Pour chaque arbre (sur
n_treesarbres) :- Bootstrap : tirer données avec remplacement (certaines répétées, d'autres absentes)
- Sous-espace : sélectionner aléatoirement quelques attributs
- Construire l'arbre sur cet échantillon
-
Prédiction : vote majoritaire de tous les arbres
Fonctions à implémenter :
(** Échantillonnage bootstrap *)
val bootstrap_sample : passenger list -> passenger list
(** Construit une forêt *)
val build_forest : passenger list -> attribute list -> int -> int option -> decision_tree list
(** Prédiction par vote majoritaire *)
val predict_forest : decision_tree list -> passenger -> int
Indication pour le bootstrap :
let bootstrap_sample data =
let n = List.length data in
let arr = Array.of_list data in
List.init n (fun _ -> arr.(Random.int n))
Paramètres recommandés : 100-200 arbres, 4-5 attributs par arbre, profondeur max 6-8.
Résultats attendus (200 arbres, 4 attrs, depth=6) :
Précision (entraînement) : ~83-85%
Validation croisée : ~80-82%
- Algorithme C4.5
Implémenter l'algorithme C4.5, une amélioration d'ID3 qui utilise le ratio de gain au lieu du gain d'information pour éviter de favoriser les attributs avec beaucoup de valeurs :
où est la valeur intrinsèque de l'attribut .
gain_ratio train Sex (* = 0.2325 *)
gain_ratio train FareGroup (* = 0.0612 *)
gain_ratio train Pclass (* = 0.0582 *)
gain_ratio train FamilySize (* = 0.0492 *)
gain_ratio train Embarked (* = 0.0215 *)
gain_ratio train AgeGroup (* = 0.0118 *)
Implémenter build_tree_c45 qui utilise best_attribute_c45 au lieu de best_attribute :
val build_tree_c45 : passenger list -> attribute list -> decision_tree
Résultats attendus :
Précision C4.5 (entraînement) : ~85%
Validation croisée C4.5 (5-fold) : ~80%
Remarque : Sur ce jeu de données avec peu d'attributs, C4.5 donne des résultats similaires à ID3 mais avec un écart-type plus faible en validation croisée.