diff --git a/errors/errors.go b/errors/errors.go index a3fb55c..9675fd5 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -1,6 +1,9 @@ package errors -import "github.com/Xavier-Lam/go-wechat/internal/client" +import ( + "github.com/Xavier-Lam/go-wechat/internal/auth" + "github.com/Xavier-Lam/go-wechat/internal/client" +) // Exported errors type ( @@ -8,7 +11,6 @@ type ( ) var ( - // client - ErrCacheNotSet = client.ErrCacheNotSet + ErrCacheNotSet = auth.ErrCacheNotSet ErrInvalidResponse = client.ErrInvalidResponse ) diff --git a/internal/auth/accesstoken.go b/internal/auth/accesstoken.go index e30de65..3898f88 100644 --- a/internal/auth/accesstoken.go +++ b/internal/auth/accesstoken.go @@ -2,10 +2,22 @@ package auth import ( "encoding/json" + "errors" + "fmt" + "net/http" "time" + + "github.com/Xavier-Lam/go-wechat/caches" ) -const DefaultTokenExpiresIn = 7200 +const DefaultAccessTokenExpiresIn = 7200 + +// AccessTokenClient is an client to request the newest access token +type AccessTokenClient interface { + PrepareRequest(auth Auth) (*http.Request, error) + SendRequest(auth Auth, req *http.Request) (*http.Response, error) + HandleResponse(auth Auth, resp *http.Response, req *http.Request) (*AccessToken, error) +} type AccessToken struct { accessToken string @@ -16,7 +28,7 @@ type AccessToken struct { // NewAccessToken creates a new `AccessToken` instance. func NewAccessToken(accessToken string, expiresIn int) *AccessToken { if expiresIn <= 0 { - expiresIn = DefaultTokenExpiresIn + expiresIn = DefaultAccessTokenExpiresIn } return &AccessToken{ accessToken: accessToken, @@ -46,6 +58,116 @@ func (t *AccessToken) GetExpiresAt() time.Time { return t.createdAt.Add(timeDiff) } +// accessTokenManager is an implement of the `auth.CredentialManager` +// which is used to manage access token credentials. +type accessTokenManager struct { + atc AccessTokenClient + auth Auth + cache caches.Cache +} + +// NewAccessTokenManager creates a new instance of `auth.CredentialManager` +// to manage access token credentials. +func NewAccessTokenManager(atc AccessTokenClient, auth Auth, cache caches.Cache) CredentialManager { + return &accessTokenManager{ + atc: atc, + auth: auth, + cache: cache, + } +} + +func (cm *accessTokenManager) Get() (interface{}, error) { + cachedValue, err := cm.get() + if err == nil { + return cachedValue, nil + } + + return cm.Renew() +} + +func (cm *accessTokenManager) Set(credential interface{}) error { + return errors.New("not settable") +} + +func (cm *accessTokenManager) Renew() (interface{}, error) { + cm.Delete() + + // TODO: prevent concurrent fetching + token, err := cm.getAccessToken() + if err != nil { + return nil, err + } + + if cm.cache == nil { + err = fmt.Errorf("cache is not set") + } else { + serializedToken, err := SerializeAccessToken(token) + if err != nil { + return nil, err + } + err = cm.cache.Set( + cm.auth.GetAppId(), + caches.BizAccessToken, + serializedToken, + token.GetExpiresIn(), + ) + } + + return token, err +} + +func (cm *accessTokenManager) Delete() error { + token, err := cm.get() + if err != nil { + return err + } + serializedToken, err := SerializeAccessToken(token) + if err != nil { + return err + } + return cm.cache.Delete( + cm.auth.GetAppId(), + caches.BizAccessToken, + serializedToken, + ) +} + +func (cm *accessTokenManager) get() (*AccessToken, error) { + if cm.cache == nil { + return nil, ErrCacheNotSet + } + + cachedValue, err := cm.cache.Get(cm.auth.GetAppId(), caches.BizAccessToken) + if err != nil { + return nil, err + } + + token, err := DeserializeAccessToken(cachedValue) + if err != nil { + return nil, err + } + + return token, nil +} + +func (cm *accessTokenManager) getAccessToken() (*AccessToken, error) { + req, err := cm.atc.PrepareRequest(cm.auth) + if err != nil { + return nil, err + } + + resp, err := cm.atc.SendRequest(cm.auth, req) + if resp != nil { + defer resp.Body.Close() + } + if err != nil { + return nil, err + } + + return cm.atc.HandleResponse(cm.auth, resp, req) +} + +// TODO: make private type accessToken struct { AccessToken string `json:"access_token"` ExpiresIn int `json:"expires_in"` diff --git a/internal/auth/accesstoken_test.go b/internal/auth/accesstoken_test.go index a3923dc..8e40d51 100644 --- a/internal/auth/accesstoken_test.go +++ b/internal/auth/accesstoken_test.go @@ -4,11 +4,13 @@ import ( "testing" "time" + "github.com/Xavier-Lam/go-wechat/caches" "github.com/Xavier-Lam/go-wechat/internal/auth" + "github.com/Xavier-Lam/go-wechat/internal/test" "github.com/stretchr/testify/assert" ) -func TestTokenGetExpires(t *testing.T) { +func TestAccessTokenGetExpires(t *testing.T) { token := auth.NewAccessToken("access_token", 2) assert.Equal(t, 2, token.GetExpiresIn()) @@ -27,7 +29,69 @@ func TestTokenGetExpires(t *testing.T) { assert.WithinDuration(t, time.Now().Add(time.Second*-1), token.GetExpiresAt(), time.Millisecond*50) } -func TestTokenSerialize(t *testing.T) { +func TestAccessTokenManager(t *testing.T) { + mockAuth := auth.New("", "") + oldToken := "old" + newToken := "token" + + cache := caches.NewDummyCache() + atc := test.NewMockAccessTokenClient(oldToken) + cm := auth.NewAccessTokenManager(atc, mockAuth, cache) + + token, err := cm.Get() + assert.NoError(t, err) + assert.IsType(t, &auth.AccessToken{}, token) + assert.Equal(t, oldToken, token.(*auth.AccessToken).GetAccessToken()) + + atc = test.NewMockAccessTokenClient(newToken) + cm = auth.NewAccessTokenManager(atc, mockAuth, cache) + + token, err = cm.Get() + assert.NoError(t, err) + assert.IsType(t, &auth.AccessToken{}, token) + assert.Equal(t, oldToken, token.(*auth.AccessToken).GetAccessToken()) + + token, err = cm.Renew() + assert.NoError(t, err) + assert.IsType(t, &auth.AccessToken{}, token) + assert.Equal(t, newToken, token.(*auth.AccessToken).GetAccessToken()) + + token, err = cm.Get() + assert.NoError(t, err) + assert.IsType(t, &auth.AccessToken{}, token) + assert.Equal(t, newToken, token.(*auth.AccessToken).GetAccessToken()) +} + +func TestAccessTokenManagerDelete(t *testing.T) { + oldToken := "old" + newToken := "token" + mockAuth := auth.New("", "") + + cache := caches.NewDummyCache() + atc := test.NewMockAccessTokenClient(oldToken) + cm := auth.NewAccessTokenManager(atc, mockAuth, cache) + + err := cm.Delete() + assert.Error(t, err) + + token, err := cm.Get() + assert.NoError(t, err) + assert.IsType(t, &auth.AccessToken{}, token) + assert.Equal(t, oldToken, token.(*auth.AccessToken).GetAccessToken()) + + atc = test.NewMockAccessTokenClient(newToken) + cm = auth.NewAccessTokenManager(atc, mockAuth, cache) + + err = cm.Delete() + assert.NoError(t, err) + + token, err = cm.Get() + assert.NoError(t, err) + assert.IsType(t, &auth.AccessToken{}, token) + assert.Equal(t, newToken, token.(*auth.AccessToken).GetAccessToken()) +} + +func TestAccessTokenSerialize(t *testing.T) { token := auth.NewAccessToken("access_token", 2) // Serialize the token bytes, err := auth.SerializeAccessToken(token) diff --git a/internal/auth/credential.go b/internal/auth/credential.go index 2487764..ebcfc63 100644 --- a/internal/auth/credential.go +++ b/internal/auth/credential.go @@ -1,7 +1,5 @@ package auth -import "errors" - // It would be much better if Go supports covariance... type CredentialManager interface { // Get the latest credential @@ -16,31 +14,3 @@ type CredentialManager interface { // Delete a credential Delete() error } - -type AuthCredentialManager struct { - auth Auth -} - -// Provide `Auth` -func NewAuthCredentialManager(auth Auth) CredentialManager { - return &AuthCredentialManager{auth: auth} -} - -func (cm *AuthCredentialManager) Get() (interface{}, error) { - if cm.auth == nil { - return errors.New("auth not set"), nil - } - return cm.auth, nil -} - -func (cm *AuthCredentialManager) Set(credential interface{}) error { - return errors.New("not settable") -} - -func (cm *AuthCredentialManager) Renew() (interface{}, error) { - return nil, errors.New("not renewable") -} - -func (cm *AuthCredentialManager) Delete() error { - return errors.New("not deletable") -} diff --git a/internal/auth/errors.go b/internal/auth/errors.go new file mode 100644 index 0000000..b853137 --- /dev/null +++ b/internal/auth/errors.go @@ -0,0 +1,7 @@ +package auth + +import "errors" + +var ( + ErrCacheNotSet = errors.New("cache not set") +) diff --git a/internal/client/client.go b/internal/client/client.go index 5b0767e..ba3dd41 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -48,10 +48,10 @@ type WeChatClient interface { // Config is a configuration struct used to set up a `client.WeChatClient`. type Config struct { - // AccessTokenManagerFactory is a factory function that creates a `CredentialManager` managing the access token. + // AccessTokenClient is used for request a latest access token when it is needed // This option should be left as the default value (nil), unless you want to customize the client // For example, if you want to request your access token from a different service than Tencent's. - AccessTokenManagerFactory AccessTokenManagerProvider + AccessTokenClient auth.AccessTokenClient // BaseApiUrl is the base URL used for making API requests. // If not provided, the default value is 'https://api.weixin.qq.com'. @@ -87,11 +87,11 @@ type DefaultWeChatClient struct { } // Create a new `WeChatClient` -func New(auth auth.Auth, conf Config) WeChatClient { +func New(a auth.Auth, conf Config) WeChatClient { var ( + atc auth.AccessTokenClient baseApiUrl *url.URL - client http.Client - factory AccessTokenManagerProvider + baseClient http.Client ) if conf.BaseApiUrl == nil { @@ -101,28 +101,30 @@ func New(auth auth.Auth, conf Config) WeChatClient { } if conf.HttpClient == nil { - client = http.Client{Transport: http.DefaultTransport} + baseClient = http.Client{Transport: http.DefaultTransport} } else { - client = *conf.HttpClient + baseClient = *conf.HttpClient } - if conf.AccessTokenManagerFactory == nil { - factory = AccessTokenManagerFactory + if conf.AccessTokenClient == nil { + client := baseClient + atc = NewAccessTokenClient(&client, "") } else { - factory = conf.AccessTokenManagerFactory + atc = conf.AccessTokenClient } - cm := factory(auth, client, conf.Cache, nil) - client.Transport = + cm := auth.NewAccessTokenManager(atc, a, conf.Cache) + + baseClient.Transport = NewCredentialRoundTripper(cm, NewAccessTokenRoundTripper( NewCommonRoundTripper( - baseApiUrl, client.Transport))) + baseApiUrl, baseClient.Transport))) return &DefaultWeChatClient{ cm: cm, - auth: auth, - client: &client, + auth: a, + client: &baseClient, } } diff --git a/internal/client/client_test.go b/internal/client/client_test.go index d886cc7..054daf0 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -188,11 +188,9 @@ func TestWeChatClientDoWithoutToken(t *testing.T) { atc := test.NewMockAccessTokenClient(accessToken) cache := caches.NewDummyCache() config := client.Config{ - AccessTokenManagerFactory: func(auth auth.Auth, c http.Client, cache caches.Cache, accessTokenUrl *url.URL) auth.CredentialManager { - return client.NewAccessTokenManager(atc, auth, cache) - }, - HttpClient: mc, - Cache: cache, + AccessTokenClient: atc, + HttpClient: mc, + Cache: cache, } c := client.New(mockAuth, config) @@ -237,11 +235,9 @@ func TestWeChatClientDoWithInvalidToken(t *testing.T) { serializedToken, _ := auth.SerializeAccessToken(auth.NewAccessToken(invalidToken, 3600)) cache.Set(appID, caches.BizAccessToken, serializedToken, 3600) config := client.Config{ - AccessTokenManagerFactory: func(auth auth.Auth, c http.Client, cache caches.Cache, accessTokenUrl *url.URL) auth.CredentialManager { - return client.NewAccessTokenManager(atc, auth, cache) - }, - HttpClient: mc, - Cache: cache, + AccessTokenClient: atc, + HttpClient: mc, + Cache: cache, } c := client.New(mockAuth, config) @@ -270,7 +266,7 @@ func TestWeChatClientDoWithInvalidTokenAndInvalidCredential(t *testing.T) { return test.Responses.Json(`{"errcode": 40014, "errmsg": "Invalid access token"}`) } else if calls == 2 { assert.Equal(t, "GET", req.Method) - test.AssertEndpointEqual(t, client.DefaultAccessTokenUri, req.URL) + test.AssertEndpointEqual(t, client.DefaultAccessTokenUrl, req.URL) assert.Equal(t, "client_credential", req.URL.Query().Get("grant_type")) assert.Equal(t, appID, req.URL.Query().Get("appid")) assert.Equal(t, appSecret, req.URL.Query().Get("secret")) @@ -306,10 +302,8 @@ func TestWeChatClientGetAccessToken(t *testing.T) { cache := caches.NewDummyCache() atc := test.NewMockAccessTokenClient(oldToken) config := client.Config{ - AccessTokenManagerFactory: func(auth auth.Auth, c http.Client, cache caches.Cache, accessTokenUrl *url.URL) auth.CredentialManager { - return client.NewAccessTokenManager(atc, auth, cache) - }, - Cache: cache, + AccessTokenClient: atc, + Cache: cache, } c := client.New(mockAuth, config) @@ -319,10 +313,8 @@ func TestWeChatClientGetAccessToken(t *testing.T) { atc = test.NewMockAccessTokenClient(newToken) config = client.Config{ - AccessTokenManagerFactory: func(auth auth.Auth, c http.Client, cache caches.Cache, accessTokenUrl *url.URL) auth.CredentialManager { - return client.NewAccessTokenManager(atc, auth, cache) - }, - Cache: cache, + AccessTokenClient: atc, + Cache: cache, } c = client.New(mockAuth, config) diff --git a/internal/client/errors.go b/internal/client/errors.go index 4acb821..b2c4963 100644 --- a/internal/client/errors.go +++ b/internal/client/errors.go @@ -1,8 +1,9 @@ package client -import "errors" +import ( + "errors" +) var ( - ErrCacheNotSet = errors.New("cache not set") ErrInvalidResponse = errors.New("invalid response") ) diff --git a/internal/client/token.go b/internal/client/token.go index 3e4305e..6b20234 100644 --- a/internal/client/token.go +++ b/internal/client/token.go @@ -1,214 +1,80 @@ package client import ( - "context" - "errors" "fmt" "net/http" "net/url" - "github.com/Xavier-Lam/go-wechat/caches" "github.com/Xavier-Lam/go-wechat/internal/auth" ) const ( - DefaultAccessTokenUri = "https://api.weixin.qq.com/cgi-bin/token" + DefaultAccessTokenUrl = "https://api.weixin.qq.com/cgi-bin/token" ) -// AccessTokenClient is an client to request the newest access token -type AccessTokenClient interface { - GetAccessToken() (*auth.AccessToken, error) -} - -// AccessTokenResponse represents the response data received from the server -// for an access token request. -type AccessTokenResponse interface { - GetAccessToken() string - GetExpiresIn() int -} - -// AccessTokenManagerProvider is a factory function to create a `auth.CredentialManager` -// for manage access token of a WeChat application -type AccessTokenManagerProvider = func( - auth auth.Auth, - client http.Client, - cache caches.Cache, - accessTokenUrl *url.URL, -) auth.CredentialManager - -// AccessTokenManager is an implement of the `auth.CredentialManager` -// which is used to manage access token credentials. -type AccessTokenManager struct { - atc AccessTokenClient - auth auth.Auth - cache caches.Cache -} - -// NewAccessTokenManager creates a new instance of `auth.CredentialManager` -// to manage access token credentials. -func NewAccessTokenManager(atc AccessTokenClient, auth auth.Auth, cache caches.Cache) auth.CredentialManager { - return &AccessTokenManager{ - atc: atc, - auth: auth, - cache: cache, - } -} - -func (cm *AccessTokenManager) Get() (interface{}, error) { - cachedValue, err := cm.get() - if err == nil { - return cachedValue, nil - } - - return cm.Renew() -} - -func (cm *AccessTokenManager) Set(credential interface{}) error { - return errors.New("not settable") +type tokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` } -func (cm *AccessTokenManager) Renew() (interface{}, error) { - cm.Delete() - - // TODO: prevent concurrent fetching - token, err := cm.atc.GetAccessToken() - if err != nil { - return nil, err - } - - if cm.cache == nil { - err = fmt.Errorf("cache is not set") - } else { - serializedToken, err := auth.SerializeAccessToken(token) - if err != nil { - return nil, err - } - err = cm.cache.Set( - cm.auth.GetAppId(), - caches.BizAccessToken, - serializedToken, - token.GetExpiresIn(), - ) - } - - return token, err +type accessTokenClient struct { + client *http.Client + requestUrl *url.URL } -func (cm *AccessTokenManager) Delete() error { - token, err := cm.get() - if err != nil { - return err - } - serializedToken, err := auth.SerializeAccessToken(token) - if err != nil { - return err +// NewAccessTokenClient creates an `auth.AccessTokenClient` to get the latest access token +func NewAccessTokenClient(baseClient *http.Client, rawRequestUrl string) auth.AccessTokenClient { + if baseClient == nil { + baseClient = &http.Client{Transport: http.DefaultTransport} } - return cm.cache.Delete( - cm.auth.GetAppId(), - caches.BizAccessToken, - serializedToken, - ) -} -func (m *AccessTokenManager) get() (*auth.AccessToken, error) { - if m.cache == nil { - return nil, ErrCacheNotSet - } + client := *baseClient + client.Transport = NewCommonRoundTripper(nil, client.Transport) - cachedValue, err := m.cache.Get(m.auth.GetAppId(), caches.BizAccessToken) - if err != nil { - return nil, err + if rawRequestUrl == "" { + rawRequestUrl = DefaultAccessTokenUrl } - - token, err := auth.DeserializeAccessToken(cachedValue) + requestUrl, err := url.Parse(rawRequestUrl) if err != nil { - return nil, err + requestUrl, _ = url.Parse(DefaultAccessTokenUrl) } - return token, nil -} - -func AccessTokenManagerFactory(auth auth.Auth, client http.Client, cache caches.Cache, accessTokenUrl *url.URL) auth.CredentialManager { - atc := AccessTokenClientFactory(accessTokenUrl, auth, &client) - return NewAccessTokenManager(atc, auth, cache) -} - -type tokenResponse struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` -} - -func (t *tokenResponse) GetAccessToken() string { - return t.AccessToken -} - -func (t *tokenResponse) GetExpiresIn() int { - return t.ExpiresIn -} - -type accessTokenClient struct { - client *http.Client - requestUrl *url.URL // The url to request a new token, default value is 'https://api.weixin.qq.com/cgi-bin/token' - dto AccessTokenResponse -} - -// NewAccessTokenClient creates the default access token client which is used to -// request the latest access token from server -func NewAccessTokenClient(client *http.Client, dto AccessTokenResponse, requestUrl *url.URL) AccessTokenClient { return &accessTokenClient{ - client: client, + client: &client, requestUrl: requestUrl, - dto: &tokenResponse{}, } } -func (c *accessTokenClient) GetAccessToken() (*auth.AccessToken, error) { - // Prepare request +func (c *accessTokenClient) PrepareRequest(a auth.Auth) (*http.Request, error) { req, err := http.NewRequest(http.MethodGet, c.requestUrl.String(), nil) if err != nil { return nil, err } - ctx := context.Background() - ctx = context.WithValue(ctx, RequestContextWithCredential, true) - req = req.WithContext(ctx) - // Send request - // Use `fetchAccessTokenRoundTripper` to set up request parameters - // TODO: to use a request maker instead of RoundTrippers to send request (too complicated) - resp, err := c.client.Do(req) - if err != nil { - return nil, err + query := url.Values{ + "grant_type": {"client_credential"}, + "appid": {a.GetAppId()}, + "secret": {a.GetAppSecret()}, } - defer resp.Body.Close() + req.URL.RawQuery = query.Encode() - // Parse token - token := c.dto - err = GetJson(resp, token) - if err != nil { - return nil, fmt.Errorf("malformed access token response: %w", err) - } - if token.GetAccessToken() == "" { - return nil, fmt.Errorf("invalid access token response") - } + return req, nil +} - rv := auth.NewAccessToken(token.GetAccessToken(), token.GetExpiresIn()) - return rv, nil +func (c *accessTokenClient) SendRequest(a auth.Auth, req *http.Request) (*http.Response, error) { + return c.client.Do(req) } -// AccessTokenClientFactory is a factory to creates the default access token client -// to request the latest access token from server -func AccessTokenClientFactory(requestUrl *url.URL, a auth.Auth, client *http.Client) AccessTokenClient { - if client == nil { - client = &http.Client{Transport: http.DefaultTransport} +func (c *accessTokenClient) HandleResponse(a auth.Auth, resp *http.Response, req *http.Request) (*auth.AccessToken, error) { + data := &tokenResponse{} + err := GetJson(resp, data) + if err != nil { + return nil, fmt.Errorf("malformed access token response: %w", err) } - client.Transport = - NewCredentialRoundTripper(auth.NewAuthCredentialManager(a), - NewFetchAccessTokenRoundTripper( - NewCommonRoundTripper(nil, client.Transport))) - - if requestUrl == nil { - requestUrl, _ = url.Parse(DefaultAccessTokenUri) + if data.AccessToken == "" { + return nil, fmt.Errorf("invalid access token response") } - return NewAccessTokenClient(client, &tokenResponse{}, requestUrl) + token := auth.NewAccessToken(data.AccessToken, data.ExpiresIn) + return token, nil } diff --git a/internal/client/token_test.go b/internal/client/token_test.go index bf695e5..f6da296 100644 --- a/internal/client/token_test.go +++ b/internal/client/token_test.go @@ -2,98 +2,64 @@ package client_test import ( "net/http" - "net/url" "testing" - "time" - "github.com/Xavier-Lam/go-wechat" - "github.com/Xavier-Lam/go-wechat/caches" - "github.com/Xavier-Lam/go-wechat/internal/auth" "github.com/Xavier-Lam/go-wechat/internal/client" "github.com/Xavier-Lam/go-wechat/internal/test" "github.com/stretchr/testify/assert" ) -func TestTokenGetAccessToken(t *testing.T) { - a := wechat.NewAuth("app-id", "app-secret") - - httpClient := test.NewMockHttpClient(func(req *http.Request, calls int) (*http.Response, error) { - assert.Equal(t, 1, calls) - assert.Equal(t, "GET", req.Method) - test.AssertEndpointEqual(t, client.DefaultAccessTokenUri, req.URL) - assert.Equal(t, "client_credential", req.URL.Query().Get("grant_type")) - assert.Equal(t, "app-id", req.URL.Query().Get("appid")) - assert.Equal(t, "app-secret", req.URL.Query().Get("secret")) - - return test.Responses.Json(`{"access_token": "access-token", "expires_in": 7200}`) - }) - url, _ := url.Parse(client.DefaultAccessTokenUri) - c := client.AccessTokenClientFactory(url, a, httpClient) - - token, err := c.GetAccessToken() +func TestPrepareRequest(t *testing.T) { + c := client.NewAccessTokenClient(nil, "") + req, err := c.PrepareRequest(mockAuth) assert.NoError(t, err) - assert.Equal(t, "access-token", token.GetAccessToken()) - assert.Equal(t, 7200, token.GetExpiresIn()) - assert.WithinDuration(t, time.Now().Add(time.Second*7200), token.GetExpiresAt(), time.Millisecond*50) -} + assert.Equal(t, "GET", req.Method) + test.AssertEndpointEqual(t, client.DefaultAccessTokenUrl, req.URL) + assert.Equal(t, "client_credential", req.URL.Query().Get("grant_type")) + assert.Equal(t, appID, req.URL.Query().Get("appid")) + assert.Equal(t, appSecret, req.URL.Query().Get("secret")) -func TestWeChatAccessTokenCredential(t *testing.T) { - oldToken := "old" - newToken := "token" + requestUrl := "https://example.com/test" + c = client.NewAccessTokenClient(nil, requestUrl) - cache := caches.NewDummyCache() - atc := test.NewMockAccessTokenClient(oldToken) - cm := client.NewAccessTokenManager(atc, mockAuth, cache) - - token, err := cm.Get() + req, err = c.PrepareRequest(mockAuth) assert.NoError(t, err) - assert.IsType(t, &auth.AccessToken{}, token) - assert.Equal(t, oldToken, token.(*auth.AccessToken).GetAccessToken()) + assert.Equal(t, "GET", req.Method) + test.AssertEndpointEqual(t, requestUrl, req.URL) + assert.Equal(t, "client_credential", req.URL.Query().Get("grant_type")) + assert.Equal(t, appID, req.URL.Query().Get("appid")) + assert.Equal(t, appSecret, req.URL.Query().Get("secret")) +} - atc = test.NewMockAccessTokenClient(newToken) - cm = client.NewAccessTokenManager(atc, mockAuth, cache) +func TestSendRequest(t *testing.T) { + httpClient := test.NewMockHttpClient(func(req *http.Request, calls int) (*http.Response, error) { + assert.Equal(t, 1, calls) + assert.Equal(t, "GET", req.Method) + test.AssertEndpointEqual(t, client.DefaultAccessTokenUrl, req.URL) + assert.Equal(t, "client_credential", req.URL.Query().Get("grant_type")) + assert.Equal(t, appID, req.URL.Query().Get("appid")) + assert.Equal(t, appSecret, req.URL.Query().Get("secret")) - token, err = cm.Get() - assert.NoError(t, err) - assert.IsType(t, &auth.AccessToken{}, token) - assert.Equal(t, oldToken, token.(*auth.AccessToken).GetAccessToken()) + return test.Responses.Empty() + }) + c := client.NewAccessTokenClient(httpClient, "") - token, err = cm.Renew() + req, err := c.PrepareRequest(mockAuth) assert.NoError(t, err) - assert.IsType(t, &auth.AccessToken{}, token) - assert.Equal(t, newToken, token.(*auth.AccessToken).GetAccessToken()) + resp, err := c.SendRequest(mockAuth, req) - token, err = cm.Get() assert.NoError(t, err) - assert.IsType(t, &auth.AccessToken{}, token) - assert.Equal(t, newToken, token.(*auth.AccessToken).GetAccessToken()) + assert.Equal(t, emptyResponse, resp) } -func TestWeChatAccessTokenCredentialDelete(t *testing.T) { - oldToken := "old" - newToken := "token" - - cache := caches.NewDummyCache() - atc := test.NewMockAccessTokenClient(oldToken) - cm := client.NewAccessTokenManager(atc, mockAuth, cache) +func TestHandleResponse(t *testing.T) { + resp, _ := test.Responses.Json(`{"access_token": "access-token", "expires_in": 7200}`) + client := client.NewAccessTokenClient(nil, "") - err := cm.Delete() - assert.Error(t, err) + token, err := client.HandleResponse(mockAuth, resp, &http.Request{}) - token, err := cm.Get() assert.NoError(t, err) - assert.IsType(t, &auth.AccessToken{}, token) - assert.Equal(t, oldToken, token.(*auth.AccessToken).GetAccessToken()) - - atc = test.NewMockAccessTokenClient(newToken) - cm = client.NewAccessTokenManager(atc, mockAuth, cache) - - err = cm.Delete() - assert.NoError(t, err) - - token, err = cm.Get() - assert.NoError(t, err) - assert.IsType(t, &auth.AccessToken{}, token) - assert.Equal(t, newToken, token.(*auth.AccessToken).GetAccessToken()) + assert.Equal(t, "access-token", token.GetAccessToken()) + assert.Equal(t, 7200, token.GetExpiresIn()) } diff --git a/internal/miniprogram/apis/apis_test.go b/internal/miniprogram/apis/apis_test.go index 2e16333..d41e88d 100644 --- a/internal/miniprogram/apis/apis_test.go +++ b/internal/miniprogram/apis/apis_test.go @@ -17,8 +17,8 @@ func newMockMiniProgram(handler test.RequestHandler) *miniprogram.App { return miniprogram.New( mockAuth, miniprogram.Config{ - AccessTokenManagerFactory: test.MockAccessTokenCredentialManagerFactoryProvider(accessToken), - HttpClient: test.NewMockHttpClient(handler), + AccessTokenClient: test.NewMockAccessTokenClient(accessToken), + HttpClient: test.NewMockHttpClient(handler), }, ) } diff --git a/internal/miniprogram/app.go b/internal/miniprogram/app.go index 05e4787..ba8ccca 100644 --- a/internal/miniprogram/app.go +++ b/internal/miniprogram/app.go @@ -17,10 +17,10 @@ type App struct { func New(auth auth.Auth, conf Config) *App { // Set up base dependencies if not given c := client.New(auth, client.Config{ - AccessTokenManagerFactory: conf.AccessTokenManagerFactory, - BaseApiUrl: conf.BaseApiUrl, - Cache: conf.Cache, - HttpClient: conf.HttpClient, + AccessTokenClient: conf.AccessTokenClient, + BaseApiUrl: conf.BaseApiUrl, + Cache: conf.Cache, + HttpClient: conf.HttpClient, }) a := apis.NewApis(c) return &App{ diff --git a/internal/officialaccount/apis/apis_test.go b/internal/officialaccount/apis/apis_test.go index 6a00bfc..54fd72b 100644 --- a/internal/officialaccount/apis/apis_test.go +++ b/internal/officialaccount/apis/apis_test.go @@ -17,8 +17,8 @@ func newMockOfficialAccount(handler test.RequestHandler) *officialaccount.App { return officialaccount.New( mockAuth, officialaccount.Config{ - AccessTokenManagerFactory: test.MockAccessTokenCredentialManagerFactoryProvider(accessToken), - HttpClient: test.NewMockHttpClient(handler), + AccessTokenClient: test.NewMockAccessTokenClient(accessToken), + HttpClient: test.NewMockHttpClient(handler), }, ) } diff --git a/internal/officialaccount/app.go b/internal/officialaccount/app.go index f72a7ef..c0d60b2 100644 --- a/internal/officialaccount/app.go +++ b/internal/officialaccount/app.go @@ -19,10 +19,10 @@ type App struct { func New(auth auth.Auth, conf Config) *App { // Set up base dependencies if not given c := client.New(auth, client.Config{ - AccessTokenManagerFactory: conf.AccessTokenManagerFactory, - BaseApiUrl: conf.BaseApiUrl, - Cache: conf.Cache, - HttpClient: conf.HttpClient, + AccessTokenClient: conf.AccessTokenClient, + BaseApiUrl: conf.BaseApiUrl, + Cache: conf.Cache, + HttpClient: conf.HttpClient, }) a := apis.NewApis(c) return &App{ diff --git a/internal/test/client.go b/internal/test/client.go index 01e132f..c7fcc7f 100644 --- a/internal/test/client.go +++ b/internal/test/client.go @@ -6,9 +6,7 @@ import ( "net/url" "testing" - "github.com/Xavier-Lam/go-wechat/caches" "github.com/Xavier-Lam/go-wechat/internal/auth" - "github.com/Xavier-Lam/go-wechat/internal/client" "github.com/stretchr/testify/assert" ) @@ -58,28 +56,18 @@ type mockAccessTokenClient struct { token string } -func NewMockAccessTokenClient(token string) client.AccessTokenClient { +func NewMockAccessTokenClient(token string) auth.AccessTokenClient { return &mockAccessTokenClient{token: token} } -func (c *mockAccessTokenClient) GetAccessToken() (*auth.AccessToken, error) { - return auth.NewAccessToken(c.token, auth.DefaultTokenExpiresIn), nil +func (c *mockAccessTokenClient) PrepareRequest(a auth.Auth) (*http.Request, error) { + return &http.Request{}, nil } -type mockAccessTokenCredentialManager struct { - client.AccessTokenManager - token string -} - -func MockAccessTokenCredentialManagerFactoryProvider(token string) client.AccessTokenManagerProvider { - return func(auth auth.Auth, c http.Client, cache caches.Cache, accessTokenUrl *url.URL) auth.CredentialManager { - return &mockAccessTokenCredentialManager{ - client.AccessTokenManager{}, - token, - } - } +func (c *mockAccessTokenClient) SendRequest(a auth.Auth, req *http.Request) (*http.Response, error) { + return Responses.Empty() } -func (m *mockAccessTokenCredentialManager) Get() (interface{}, error) { - return auth.NewAccessToken(m.token, 7200), nil +func (c *mockAccessTokenClient) HandleResponse(a auth.Auth, resp *http.Response, req *http.Request) (*auth.AccessToken, error) { + return auth.NewAccessToken(c.token, auth.DefaultAccessTokenExpiresIn), nil } diff --git a/wechat.go b/wechat.go index a5a9dac..9433fc8 100644 --- a/wechat.go +++ b/wechat.go @@ -10,16 +10,10 @@ import ( // Exported interfaces type ( Auth = auth.Auth - AccessTokenClient = client.AccessTokenClient + AccessTokenClient = auth.AccessTokenClient WeChatClient = client.WeChatClient ) -// Exported factories -var ( - AccessTokenClientFactory = client.AccessTokenClientFactory - AccessTokenManagerFactory = client.AccessTokenManagerFactory -) - // Exported constructors var ( NewAuth = auth.New @@ -28,9 +22,8 @@ var ( NewWeChatClient = client.New // less commonly used - NewAccessToken = auth.NewAccessToken - NewAccessTokenClient = client.NewAccessTokenClient - NewAccessTokenManager = client.NewAccessTokenManager + NewAccessToken = auth.NewAccessToken + NewAccessTokenClient = client.NewAccessTokenClient ) // Exported configurations