diff --git a/autotests/testPMKSA-SAE/connection_test.py b/autotests/testPMKSA-SAE/connection_test.py index 5bab3ff8..749ebd44 100644 --- a/autotests/testPMKSA-SAE/connection_test.py +++ b/autotests/testPMKSA-SAE/connection_test.py @@ -4,7 +4,7 @@ import sys sys.path.append('../util') -from iwd import IWD +from iwd import IWD, FailedEx from iwd import PSKAgent from iwd import NetworkType from hostapd import HostapdCLI @@ -94,6 +94,36 @@ def test_pmksa_sae(self): self.hostapd.wait_for_event("AP-ENABLED") self.validate_connection(self.wd, "ssidSAE", self.hostapd, 19) + def test_pmksa_forget_network(self): + psk_agent = PSKAgent(["secret123", "wrong_password"]) + self.wd.register_psk_agent(psk_agent) + + devices = self.wd.list_devices(1) + self.assertIsNotNone(devices) + device = devices[0] + + device.disconnect() + + network = device.get_ordered_network("ssidSAE", full_scan=True) + + self.assertEqual(network.type, NetworkType.psk) + + network.network_object.connect() + + condition = 'obj.state == DeviceState.connected' + self.wd.wait_for_object_condition(device, condition) + + self.wd.wait(2) + + testutil.test_iface_operstate(intf=device.name) + testutil.test_ifaces_connected(if0=device.name, if1=self.hostapd.ifname) + + known_network = self.wd.list_known_networks()[0] + known_network.forget() + + with self.assertRaises(FailedEx): + network.network_object.connect() + def setUp(self): self.hostapd.default() self.wd = IWD(True) diff --git a/src/network.c b/src/network.c index a5a2375a..570745cb 100644 --- a/src/network.c +++ b/src/network.c @@ -58,6 +58,7 @@ #include "src/handshake.h" #include "src/band.h" #include "src/util.h" +#include "src/pmksa.h" #define SAE_PT_SETTING "SAE-PT-Group%u" @@ -2051,6 +2052,8 @@ static void emit_known_network_removed(struct station *station, void *user_data) l_queue_destroy(network->secrets, eap_secret_info_free); network->secrets = NULL; + + pmksa_cache_flush_ssid(info->ssid); } connected_network = station_get_connected_network(station); diff --git a/src/pmksa.c b/src/pmksa.c index a50c8208..9b02d7e4 100644 --- a/src/pmksa.c +++ b/src/pmksa.c @@ -213,6 +213,33 @@ int pmksa_cache_flush(void) return 0; } +/* + * Flushes all PMKSA entries that match an SSID + */ +int pmksa_cache_flush_ssid(const char ssid[static 32]) +{ + int i; + int used = cache.used; + int remaining = 0; + + for (i = 0; i < used; i++) { + if (!memcmp(ssid, cache.data[i]->ssid, cache.data[i]->ssid_len)) { + pmksa_cache_free(cache.data[i]); + continue; + } + + cache.data[remaining] = cache.data[i]; + remaining += 1; + } + + cache.used = remaining; + + for (i = cache.used >> 1; i >= 0; i--) + __minheap_sift_down(cache.data, cache.used, i, &ops); + + return used - remaining; +} + int pmksa_cache_free(struct pmksa *pmksa) { if (driver_remove) diff --git a/src/pmksa.h b/src/pmksa.h index 6a624504..946ef0b2 100644 --- a/src/pmksa.h +++ b/src/pmksa.h @@ -45,6 +45,7 @@ struct pmksa *pmksa_cache_get(const uint8_t spa[static 6], int pmksa_cache_put(struct pmksa *pmksa); int pmksa_cache_expire(uint64_t cutoff); int pmksa_cache_flush(void); +int pmksa_cache_flush_ssid(const char ssid[static 32]); int pmksa_cache_free(struct pmksa *pmksa); uint64_t pmksa_lifetime(void);