Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 154 additions & 75 deletions pkg/nest/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ import (
)

type API struct {
Token string
ExpiresAt time.Time
Token string
ExpiresAt time.Time
ClientID string
ClientSecret string
RefreshToken string

StreamProjectID string
StreamDeviceID string
Expand All @@ -27,6 +30,8 @@ type API struct {
StreamExtensionToken string

extendTimer *time.Timer
extendMu sync.Mutex
extendStop chan struct{}
}

type Auth struct {
Expand All @@ -53,6 +58,25 @@ func NewAPI(clientID, clientSecret, refreshToken string) (*API, error) {
return api, nil
}

token, expiresAt, err := requestAccessToken(clientID, clientSecret, refreshToken)
if err != nil {
return nil, err
}

api := &API{
Token: token,
ExpiresAt: expiresAt,
ClientID: clientID,
ClientSecret: clientSecret,
RefreshToken: refreshToken,
}

cache[key] = api

return api, nil
}

func requestAccessToken(clientID, clientSecret, refreshToken string) (string, time.Time, error) {
data := url.Values{
"grant_type": []string{"refresh_token"},
"client_id": []string{clientID},
Expand All @@ -63,33 +87,25 @@ func NewAPI(clientID, clientSecret, refreshToken string) (*API, error) {
client := &http.Client{Timeout: time.Second * 5000}
res, err := client.PostForm("https://www.googleapis.com/oauth2/v4/token", data)
if err != nil {
return nil, err
return "", time.Time{}, err
}
defer res.Body.Close()

if res.StatusCode != 200 {
return nil, errors.New("nest: wrong status: " + res.Status)
return "", time.Time{}, errors.New("nest: wrong status: " + res.Status)
}

var resv struct {
AccessToken string `json:"access_token"`
ExpiresIn time.Duration `json:"expires_in"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}

if err = json.NewDecoder(res.Body).Decode(&resv); err != nil {
return nil, err
return "", time.Time{}, err
}

api := &API{
Token: resv.AccessToken,
ExpiresAt: now.Add(resv.ExpiresIn * time.Second),
}

cache[key] = api

return api, nil
now := time.Now()
return resv.AccessToken, now.Add(resv.ExpiresIn * time.Second), nil
}

func (a *API) GetDevices(projectID string) ([]DeviceInfo, error) {
Expand Down Expand Up @@ -228,37 +244,47 @@ func (a *API) ExchangeSDP(projectID, deviceID, offer string) (string, error) {
}

func (a *API) refreshToken() error {
// Get the cached API with matching token to get credentials
var refreshKey string
cacheMu.Lock()
for key, api := range cache {
if api.Token == a.Token {
refreshKey = key
break
clientID := a.ClientID
clientSecret := a.ClientSecret
refreshToken := a.RefreshToken

if clientID == "" || clientSecret == "" || refreshToken == "" {
// Backward-compatible fallback: derive credentials from cache key.
var refreshKey string
cacheMu.Lock()
for key, api := range cache {
if api == a || api.Token == a.Token {
refreshKey = key
break
}
}
}
cacheMu.Unlock()
cacheMu.Unlock()

if refreshKey == "" {
return errors.New("nest: unable to find cached credentials")
}
if refreshKey == "" {
return errors.New("nest: unable to find cached credentials")
}

// Parse credentials from cache key
parts := strings.Split(refreshKey, ":")
if len(parts) != 3 {
return errors.New("nest: invalid cache key format")
parts := strings.Split(refreshKey, ":")
if len(parts) != 3 {
return errors.New("nest: invalid cache key format")
}
clientID, clientSecret, refreshToken = parts[0], parts[1], parts[2]
}
clientID, clientSecret, refreshToken := parts[0], parts[1], parts[2]

// Get new API instance which will refresh the token
newAPI, err := NewAPI(clientID, clientSecret, refreshToken)
token, expiresAt, err := requestAccessToken(clientID, clientSecret, refreshToken)
if err != nil {
return err
}

// Update current API with new token
a.Token = newAPI.Token
a.ExpiresAt = newAPI.ExpiresAt
a.Token = token
a.ExpiresAt = expiresAt
a.ClientID = clientID
a.ClientSecret = clientSecret
a.RefreshToken = refreshToken

cacheMu.Lock()
cache[clientID+":"+clientSecret+":"+refreshToken] = a
cacheMu.Unlock()
return nil
}

Expand Down Expand Up @@ -288,43 +314,74 @@ func (a *API) ExtendStream() error {

uri := "https://smartdevicemanagement.googleapis.com/v1/enterprises/" +
a.StreamProjectID + "/devices/" + a.StreamDeviceID + ":executeCommand"
req, err := http.NewRequest("POST", uri, bytes.NewReader(b))
if err != nil {
return err
}

req.Header.Set("Authorization", "Bearer "+a.Token)
maxRetries := 3
retryDelay := 30 * time.Second

client := &http.Client{Timeout: time.Second * 5000}
res, err := client.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
for attempt := 0; attempt < maxRetries; attempt++ {
req, err := http.NewRequest("POST", uri, bytes.NewReader(b))
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+a.Token)

if res.StatusCode != 200 {
return errors.New("nest: wrong status: " + res.Status)
}
client := &http.Client{Timeout: time.Second * 5000}
res, err := client.Do(req)
if err != nil {
return err
}

var resv struct {
Results struct {
ExpiresAt time.Time `json:"expiresAt"`
MediaSessionID string `json:"mediaSessionId"`
StreamExtensionToken string `json:"streamExtensionToken"`
StreamToken string `json:"streamToken"`
} `json:"results"`
}
// 401 => force refresh token and retry fast
if res.StatusCode == 401 {
res.Body.Close()
if attempt < maxRetries-1 {
if err := a.refreshToken(); err != nil {
return err
}
time.Sleep(time.Second)
continue
}
}

if err = json.NewDecoder(res.Body).Decode(&resv); err != nil {
return err
}
// 409/429 => backoff en retry (zonder meteen token refresh)
if res.StatusCode == 409 || res.StatusCode == 429 {
res.Body.Close()
if attempt < maxRetries-1 {
time.Sleep(retryDelay)
retryDelay *= 2
continue
}
}

a.StreamSessionID = resv.Results.MediaSessionID
a.StreamExpiresAt = resv.Results.ExpiresAt
a.StreamExtensionToken = resv.Results.StreamExtensionToken
a.StreamToken = resv.Results.StreamToken
if res.StatusCode != 200 {
res.Body.Close()
return errors.New("nest: wrong status: " + res.Status)
}

return nil
var resv struct {
Results struct {
ExpiresAt time.Time `json:"expiresAt"`
MediaSessionID string `json:"mediaSessionId"`
StreamExtensionToken string `json:"streamExtensionToken"`
StreamToken string `json:"streamToken"`
} `json:"results"`
}

if err = json.NewDecoder(res.Body).Decode(&resv); err != nil {
res.Body.Close()
return err
}
res.Body.Close()

a.StreamSessionID = resv.Results.MediaSessionID
a.StreamExpiresAt = resv.Results.ExpiresAt
a.StreamExtensionToken = resv.Results.StreamExtensionToken
a.StreamToken = resv.Results.StreamToken

return nil
}

return errors.New("nest: max retries exceeded")
}

func (a *API) GenerateRtspStream(projectID, deviceID string) (string, error) {
Expand Down Expand Up @@ -465,22 +522,44 @@ type Device struct {
}

func (a *API) StartExtendStreamTimer() {
if a.extendTimer != nil {
a.extendMu.Lock()
if a.extendStop != nil {
a.extendMu.Unlock()
return
}
a.extendStop = make(chan struct{})
stop := a.extendStop
a.extendMu.Unlock()

a.extendTimer = time.NewTimer(time.Until(a.StreamExpiresAt) - time.Minute)
go func() {
<-a.extendTimer.C
if err := a.ExtendStream(); err != nil {
return
for {
// plan 1 minuut vóór expiry; clamp zodat het nooit negatief/te klein wordt
d := time.Until(a.StreamExpiresAt) - time.Minute
if d < 10*time.Second {
d = 10 * time.Second
}

t := time.NewTimer(d)
select {
case <-t.C:
// Keep retrying extension on transient failures to avoid avoidable stream drops.
if err := a.ExtendStream(); err != nil {
continue
}
// loop gaat door en plant opnieuw met nieuw StreamExpiresAt
case <-stop:
t.Stop()
return
}
}
}()
}

func (a *API) StopExtendStreamTimer() {
if a.extendTimer != nil {
a.extendTimer.Stop()
a.extendTimer = nil
a.extendMu.Lock()
if a.extendStop != nil {
close(a.extendStop)
a.extendStop = nil
}
a.extendMu.Unlock()
}