diff --git a/autotests/testPMKSA-SAE/connection_test.py b/autotests/testPMKSA-SAE/connection_test.py index 5bab3ff8..749ebd44 100644 --- a/autotests/testPMKSA-SAE/connection_test.py +++ b/autotests/testPMKSA-SAE/connection_test.py @@ -4,7 +4,7 @@ import sys sys.path.append('../util') -from iwd import IWD +from iwd import IWD, FailedEx from iwd import PSKAgent from iwd import NetworkType from hostapd import HostapdCLI @@ -94,6 +94,36 @@ def test_pmksa_sae(self): self.hostapd.wait_for_event("AP-ENABLED") self.validate_connection(self.wd, "ssidSAE", self.hostapd, 19) + def test_pmksa_forget_network(self): + psk_agent = PSKAgent(["secret123", "wrong_password"]) + self.wd.register_psk_agent(psk_agent) + + devices = self.wd.list_devices(1) + self.assertIsNotNone(devices) + device = devices[0] + + device.disconnect() + + network = device.get_ordered_network("ssidSAE", full_scan=True) + + self.assertEqual(network.type, NetworkType.psk) + + network.network_object.connect() + + condition = 'obj.state == DeviceState.connected' + self.wd.wait_for_object_condition(device, condition) + + self.wd.wait(2) + + testutil.test_iface_operstate(intf=device.name) + testutil.test_ifaces_connected(if0=device.name, if1=self.hostapd.ifname) + + known_network = self.wd.list_known_networks()[0] + known_network.forget() + + with self.assertRaises(FailedEx): + network.network_object.connect() + def setUp(self): self.hostapd.default() self.wd = IWD(True) diff --git a/src/common.c b/src/common.c index 91979423..1e5b3b7b 100644 --- a/src/common.c +++ b/src/common.c @@ -28,6 +28,8 @@ #include #include +#include + #include "src/iwd.h" #include "src/common.h" #include "src/ie.h" @@ -64,18 +66,15 @@ bool security_from_str(const char *str, enum security *security) return true; } -#define AKM_IS_PSK(akm) \ -( \ - akm & (IE_RSN_AKM_SUITE_PSK | \ +#define AKMS_PSK \ + (IE_RSN_AKM_SUITE_PSK | \ IE_RSN_AKM_SUITE_PSK_SHA256 | \ IE_RSN_AKM_SUITE_FT_USING_PSK | \ IE_RSN_AKM_SUITE_SAE_SHA256 | \ - IE_RSN_AKM_SUITE_FT_OVER_SAE_SHA256) \ -) + IE_RSN_AKM_SUITE_FT_OVER_SAE_SHA256) -#define AKM_IS_8021X(akm) \ -( \ - akm & (IE_RSN_AKM_SUITE_8021X | \ +#define AKMS_8021X \ + (IE_RSN_AKM_SUITE_8021X | \ IE_RSN_AKM_SUITE_8021X_SHA256 | \ IE_RSN_AKM_SUITE_FT_OVER_8021X | \ IE_RSN_AKM_SUITE_FT_OVER_8021X_SHA384 | \ @@ -83,8 +82,11 @@ bool security_from_str(const char *str, enum security *security) IE_RSN_AKM_SUITE_FILS_SHA384 | \ IE_RSN_AKM_SUITE_FT_OVER_FILS_SHA256 | \ IE_RSN_AKM_SUITE_FT_OVER_FILS_SHA384 | \ - IE_RSN_AKM_SUITE_OSEN) \ -) + IE_RSN_AKM_SUITE_OSEN) + +#define AKM_IS_PSK(akm) (akm & AKMS_PSK) + +#define AKM_IS_8021X(akm) (akm & AKMS_8021X) enum security security_determine(uint16_t bss_capability, const struct ie_rsn_info *info) @@ -103,3 +105,22 @@ enum security security_determine(uint16_t bss_capability, return SECURITY_NONE; } + +/* Returns all possible AKMs (as bitmask) for a given security type */ +uint32_t security_to_akms(enum security security) +{ + switch (security) { + case SECURITY_WEP: + l_warn("WEP security type not supported"); + /* Fall through */ + case SECURITY_NONE: + return 0; + case SECURITY_PSK: + return AKMS_PSK; + case SECURITY_8021X: + return AKMS_8021X; + default: + l_warn("Unhandled security type: %u", security); + return 0; + } +} diff --git a/src/common.h b/src/common.h index ca12d813..7e877299 100644 --- a/src/common.h +++ b/src/common.h @@ -36,3 +36,4 @@ const char *security_to_str(enum security security); bool security_from_str(const char *str, enum security *security); enum security security_determine(uint16_t bss_capability, const struct ie_rsn_info *info); +uint32_t security_to_akms(enum security security); \ No newline at end of file diff --git a/src/network.c b/src/network.c index a5a2375a..2d67383f 100644 --- a/src/network.c +++ b/src/network.c @@ -58,6 +58,7 @@ #include "src/handshake.h" #include "src/band.h" #include "src/util.h" +#include "src/pmksa.h" #define SAE_PT_SETTING "SAE-PT-Group%u" @@ -2051,6 +2052,10 @@ static void emit_known_network_removed(struct station *station, void *user_data) l_queue_destroy(network->secrets, eap_secret_info_free); network->secrets = NULL; + + pmksa_cache_flush_ssid((uint8_t *)info->ssid, + sizeof(info->ssid), + security_to_akms(network->security)); } connected_network = station_get_connected_network(station); diff --git a/src/pmksa.c b/src/pmksa.c index a50c8208..7f37f0c7 100644 --- a/src/pmksa.c +++ b/src/pmksa.c @@ -213,6 +213,36 @@ int pmksa_cache_flush(void) return 0; } +/* + * Flushes all PMKSA entries that match an SSID + */ +int pmksa_cache_flush_ssid(const uint8_t *ssid, size_t ssid_len, uint32_t akms) +{ + int i; + int used = cache.used; + int remaining = 0; + + for (i = 0; i < used; i++) { + /* Check that the both the AKM matches as well as the SSID */ + if ((cache.data[i]->akm & akms) && + !memcmp(ssid, cache.data[i]->ssid, + cache.data[i]->ssid_len)) { + pmksa_cache_free(cache.data[i]); + continue; + } + + cache.data[remaining] = cache.data[i]; + remaining += 1; + } + + cache.used = remaining; + + for (i = cache.used >> 1; i >= 0; i--) + __minheap_sift_down(cache.data, cache.used, i, &ops); + + return used - remaining; +} + int pmksa_cache_free(struct pmksa *pmksa) { if (driver_remove) diff --git a/src/pmksa.h b/src/pmksa.h index 6a624504..9f059a11 100644 --- a/src/pmksa.h +++ b/src/pmksa.h @@ -45,6 +45,7 @@ struct pmksa *pmksa_cache_get(const uint8_t spa[static 6], int pmksa_cache_put(struct pmksa *pmksa); int pmksa_cache_expire(uint64_t cutoff); int pmksa_cache_flush(void); +int pmksa_cache_flush_ssid(const uint8_t *ssid, size_t ssid_len, uint32_t akms); int pmksa_cache_free(struct pmksa *pmksa); uint64_t pmksa_lifetime(void);