From c8ae7f9c2dba5206d25fd3b171365660e5ba6805 Mon Sep 17 00:00:00 2001
From: SUR Frederic <frederic.sur@univ-lorraine.fr>
Date: Tue, 19 Nov 2024 19:37:02 +0000
Subject: [PATCH] Replace TP1_ex2_sujet.ipynb

---
 TP1/TP1_ex2_sujet.ipynb | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/TP1/TP1_ex2_sujet.ipynb b/TP1/TP1_ex2_sujet.ipynb
index 5b3e19e..9b15cc6 100755
--- a/TP1/TP1_ex2_sujet.ipynb
+++ b/TP1/TP1_ex2_sujet.ipynb
@@ -122,11 +122,13 @@
     "alphas = np.logspace(-4, 4, n_alphas)\n",
     "coefs=np.zeros((0,10))\n",
     "MSE_ridge=[]\n",
+    "MSE_linear=[]\n",
     "for a in alphas:\n",
     "    ridge = linear_model.Ridge(alpha=a)\n",
     "    ridge.fit(X_train, y_train)\n",
     "    coefs=np.vstack((coefs,ridge.coef_))\n",
-    "    MSE_ridge.append([MSE_lr, np.mean((ridge.predict(X_test) - y_test) ** 2)])\n",
+    "    MSE_ridge.append(np.mean((ridge.predict(X_test) - y_test) ** 2))\n",
+    "    MSE_linear.append(MSE_lr)\n",
     "\n",
     "plt.figure(figsize=(8,6))\n",
     "plt.semilogx(alphas, coefs)\n",
@@ -140,11 +142,12 @@
     "\n",
     "plt.figure(figsize=(8,6))\n",
     "plt.semilogx(alphas, MSE_ridge)\n",
+    "plt.semilogx(alphas, MSE_linear)\n",
     "plt.xlabel('alpha')\n",
     "plt.ylabel('MSE')\n",
     "plt.title('MSE régression linéaire et ridge vs. alpha ')\n",
     "plt.axis([1e-4,1e4,2750,5500])\n",
-    "plt.legend(['MSE lr','MSE ridge'])\n",
+    "plt.legend(['MSE ridge','MSE linear'])\n",
     "plt.grid()\n",
     "plt.show()"
    ]
@@ -175,7 +178,7 @@
     "    lasso = linear_model.Lasso(alpha=a)\n",
     "    lasso.fit(X_train, y_train)\n",
     "    coefs=np.vstack((coefs,lasso.coef_))\n",
-    "    MSE_lasso.append([MSE_lr, np.mean((lasso.predict(X_test) - y_test) ** 2)])\n",
+    "    MSE_lasso.append(np.mean((lasso.predict(X_test) - y_test) ** 2))\n",
     "\n",
     "#print(coefs)\n",
     "    \n",
@@ -192,11 +195,12 @@
     "plt.figure(figsize=(8,6))\n",
     "plt.semilogx(alphas, MSE_lasso)\n",
     "plt.semilogx(alphas, MSE_ridge, '--')\n",
+    "plt.semilogx(alphas, MSE_linear)\n",
     "plt.xlabel('alpha')\n",
     "plt.ylabel('MSE')\n",
     "plt.axis([1e-4,1e4,2750,5500])\n",
     "plt.title('MSE régression linéaire, lasso, et ridge vs. alpha ')\n",
-    "plt.legend(['MSE lr','MSE lasso','MSE ridge (rappel)'])\n",
+    "plt.legend(['MSE lasso','MSE ridge (rappel)','MSE linear'])\n",
     "plt.grid()\n",
     "plt.show()"
    ]
-- 
GitLab