Skip to content

Commit

Permalink
Implemented Allowed IP List Feature and Fixed a bug in Token Denylist
Browse files Browse the repository at this point in the history
  • Loading branch information
CSPF-Founder authored and afr1ka committed Feb 28, 2024
1 parent 855df49 commit a5bfe1e
Show file tree
Hide file tree
Showing 14 changed files with 312 additions and 55 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The **key features** of API Firewall are:
* Discover Shadow API endpoints
* Validate JWT access tokens for OAuth 2.0 protocol-based authentication
* Denylist compromised API tokens, keys, and Cookies
* AllowIPList - Restrict access to endpoints by defining a list of allowed IP addresses

The product is **open source**, available at DockerHub and already got 1 billion (!!!) pulls. To support this project, you can star the [repository](https://hub.docker.com/r/wallarm/api-firewall).

Expand Down
19 changes: 15 additions & 4 deletions cmd/api-firewall/internal/handlers/graphql/routes.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
package graphql

import (
"github.com/savsgio/gotils/strconv"
"github.com/savsgio/gotils/strings"
"net/url"
"os"
"sync"

"github.com/savsgio/gotils/strconv"
"github.com/savsgio/gotils/strings"

"github.com/fasthttp/websocket"
"github.com/sirupsen/logrus"
"github.com/valyala/fasthttp"
"github.com/valyala/fastjson"
"github.com/wallarm/api-firewall/internal/config"
"github.com/wallarm/api-firewall/internal/mid"
"github.com/wallarm/api-firewall/internal/platform/allowiplist"
"github.com/wallarm/api-firewall/internal/platform/denylist"
"github.com/wallarm/api-firewall/internal/platform/proxy"
"github.com/wallarm/api-firewall/internal/platform/web"
"github.com/wundergraph/graphql-go-tools/pkg/graphql"
"github.com/wundergraph/graphql-go-tools/pkg/playground"
)

func Handlers(cfg *config.GraphQLMode, schema *graphql.Schema, serverURL *url.URL, shutdown chan os.Signal, logger *logrus.Logger, proxy proxy.Pool, wsClient proxy.WebSocketClient, deniedTokens *denylist.DeniedTokens) fasthttp.RequestHandler {
func Handlers(cfg *config.GraphQLMode, schema *graphql.Schema, serverURL *url.URL, shutdown chan os.Signal, logger *logrus.Logger, proxy proxy.Pool, wsClient proxy.WebSocketClient, deniedTokens *denylist.DeniedTokens, AllowedIPCache *allowiplist.AllowedIPsType) fasthttp.RequestHandler {

// Construct the web.App which holds all routes as well as common Middleware.
appOptions := web.AppAdditionalOptions{
Expand All @@ -42,7 +44,16 @@ func Handlers(cfg *config.GraphQLMode, schema *graphql.Schema, serverURL *url.UR
DeniedTokens: deniedTokens,
Logger: logger,
}
app := web.NewApp(&appOptions, shutdown, logger, mid.Logger(logger), mid.Errors(logger), mid.Panics(logger), mid.Proxy(&proxyOptions), mid.Denylist(&denylistOptions))

ipAllowlistOptions := mid.IPAllowListOptions{
Mode: web.GraphQLMode,
Config: &cfg.AllowIP,
CustomBlockStatusCode: fasthttp.StatusUnauthorized,
AllowedIPs: AllowedIPCache,
Logger: logger,
}

app := web.NewApp(&appOptions, shutdown, logger, mid.Logger(logger), mid.Errors(logger), mid.Panics(logger), mid.Proxy(&proxyOptions), mid.IPAllowlist(&ipAllowlistOptions), mid.Denylist(&denylistOptions))

// define FastJSON parsers pool
var parserPool fastjson.ParserPool
Expand Down
13 changes: 11 additions & 2 deletions cmd/api-firewall/internal/handlers/proxy/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ import (
"github.com/valyala/fastjson"
"github.com/wallarm/api-firewall/internal/config"
"github.com/wallarm/api-firewall/internal/mid"
"github.com/wallarm/api-firewall/internal/platform/allowiplist"
"github.com/wallarm/api-firewall/internal/platform/denylist"
woauth2 "github.com/wallarm/api-firewall/internal/platform/oauth2"
"github.com/wallarm/api-firewall/internal/platform/proxy"
"github.com/wallarm/api-firewall/internal/platform/router"
"github.com/wallarm/api-firewall/internal/platform/web"
)

func Handlers(cfg *config.ProxyMode, serverURL *url.URL, shutdown chan os.Signal, logger *logrus.Logger, httpClientsPool proxy.Pool, swagRouter *router.Router, deniedTokens *denylist.DeniedTokens) fasthttp.RequestHandler {
func Handlers(cfg *config.ProxyMode, serverURL *url.URL, shutdown chan os.Signal, logger *logrus.Logger, httpClientsPool proxy.Pool, swagRouter *router.Router, deniedTokens *denylist.DeniedTokens, AllowedIPCache *allowiplist.AllowedIPsType) fasthttp.RequestHandler {

// define FastJSON parsers pool
var parserPool fastjson.ParserPool
Expand Down Expand Up @@ -85,7 +86,15 @@ func Handlers(cfg *config.ProxyMode, serverURL *url.URL, shutdown chan os.Signal
DeniedTokens: deniedTokens,
Logger: logger,
}
app := web.NewApp(&options, shutdown, logger, mid.Logger(logger), mid.Errors(logger), mid.Panics(logger), mid.Proxy(&proxyOptions), mid.Denylist(&denylistOptions), mid.ShadowAPIMonitor(logger, &cfg.ShadowAPI))
ipAllowlistOptions := mid.IPAllowListOptions{
Mode: web.GraphQLMode,
Config: &cfg.AllowIP,
CustomBlockStatusCode: cfg.CustomBlockStatusCode,
AllowedIPs: AllowedIPCache,
Logger: logger,
}

app := web.NewApp(&options, shutdown, logger, mid.Logger(logger), mid.Errors(logger), mid.Panics(logger), mid.Proxy(&proxyOptions), mid.IPAllowlist(&ipAllowlistOptions), mid.Denylist(&denylistOptions), mid.ShadowAPIMonitor(logger, &cfg.ShadowAPI))

serverPath := "/"
if serverURL.Path != "" {
Expand Down
38 changes: 35 additions & 3 deletions cmd/api-firewall/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
handlersProxy "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers/proxy"
"github.com/wallarm/api-firewall/cmd/api-firewall/internal/updater"
"github.com/wallarm/api-firewall/internal/config"
"github.com/wallarm/api-firewall/internal/platform/allowiplist"
"github.com/wallarm/api-firewall/internal/platform/database"
"github.com/wallarm/api-firewall/internal/platform/denylist"
"github.com/wallarm/api-firewall/internal/platform/proxy"
Expand Down Expand Up @@ -477,11 +478,26 @@ func runGraphQLMode(logger *logrus.Logger) error {
default:
logger.Infof("%s: Loaded %d tokens to the cache", logPrefix, deniedTokens.ElementsNum)
}
// =========================================================================

logger.Infof("%s: Initializing IP Whitelist Cache", logPrefix)

allowedIPCache, err := allowiplist.New(&cfg.AllowIP, logger)
if err != nil {
return errors.Wrap(err, "allowiplist init error")
}

switch allowedIPCache {
case nil:
logger.Infof("%s: allowiplist not configured", logPrefix)
default:
logger.Infof("%s: Loaded %d Whitelisted IP's to the cache", logPrefix, allowedIPCache.ElementsNum)
}

// =========================================================================
// Init Handlers

requestHandlers := handlersGQL.Handlers(&cfg, schema, serverURL, shutdown, logger, pool, wsPool, deniedTokens)
requestHandlers := handlersGQL.Handlers(&cfg, schema, serverURL, shutdown, logger, pool, wsPool, deniedTokens, allowedIPCache)

// =========================================================================
// Start Health API Service
Expand Down Expand Up @@ -759,7 +775,7 @@ func runProxyMode(logger *logrus.Logger) error {
// =========================================================================
// Init Cache

logger.Infof("%s: Initializing Cache", logPrefix)
logger.Infof("%s: Initializing Token Cache", logPrefix)

deniedTokens, err := denylist.New(&cfg.Denylist, logger)
if err != nil {
Expand All @@ -773,10 +789,26 @@ func runProxyMode(logger *logrus.Logger) error {
logger.Infof("%s: Loaded %d tokens to the cache", logPrefix, deniedTokens.ElementsNum)
}

// =========================================================================

logger.Infof("%s: Initializing IP Whitelist Cache", logPrefix)

AllowedIPCache, err := allowiplist.New(&cfg.AllowIP, logger)
if err != nil {
return errors.Wrap(err, "allowiplist init error")
}

switch AllowedIPCache {
case nil:
logger.Infof("%s: allowiplist not configured", logPrefix)
default:
logger.Infof("%s: Loaded %d Whitelisted IP's to the cache", logPrefix, AllowedIPCache.ElementsNum)
}

// =========================================================================
// Init Handlers

requestHandlers = handlersProxy.Handlers(&cfg, serverURL, shutdown, logger, pool, swagRouter, deniedTokens)
requestHandlers = handlersProxy.Handlers(&cfg, serverURL, shutdown, logger, pool, swagRouter, deniedTokens, AllowedIPCache)

// =========================================================================
// Start Health API Service
Expand Down
2 changes: 1 addition & 1 deletion cmd/api-firewall/tests/main_graphql_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func BenchmarkGraphQL(b *testing.B) {
APIHost: benchHandlerURL,
}

handler := graphqlHandler.Handlers(&cfg, schema, serverURL, shutdown, logger, pool, wsPool, nil)
handler := graphqlHandler.Handlers(&cfg, schema, serverURL, shutdown, logger, pool, wsPool, nil, nil)

srv := fasthttp.Server{
Handler: handler,
Expand Down
26 changes: 13 additions & 13 deletions cmd/api-firewall/tests/main_graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (s *ServiceGraphQLTests) testGQLSuccess(t *testing.T) {
t.Fatalf("Loading GraphQL Schema error: %v", err)
}

handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// Construct GraphQL request payload
query := `
Expand Down Expand Up @@ -240,7 +240,7 @@ func (s *ServiceGraphQLTests) testGQLGETSuccess(t *testing.T) {
t.Fatalf("Loading GraphQL Schema error: %v", err)
}

handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// Construct GraphQL request payload
query := `
Expand Down Expand Up @@ -328,7 +328,7 @@ func (s *ServiceGraphQLTests) testGQLGETMutationFailed(t *testing.T) {
t.Fatalf("Loading GraphQL Schema error: %v", err)
}

handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// Construct GraphQL request payload
query := `
Expand Down Expand Up @@ -398,7 +398,7 @@ func (s *ServiceGraphQLTests) testGQLValidationFailed(t *testing.T) {
t.Fatalf("Loading GraphQL Schema error: %v", err)
}

handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// Construct GraphQL request payload
query := `
Expand Down Expand Up @@ -478,7 +478,7 @@ func (s *ServiceGraphQLTests) testGQLInvalidQuerySyntax(t *testing.T) {
t.Fatalf("Loading GraphQL Schema error: %v", err)
}

handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// Construct GraphQL request payload
query := `
Expand Down Expand Up @@ -558,7 +558,7 @@ func (s *ServiceGraphQLTests) testGQLInvalidMaxComplexity(t *testing.T) {
t.Fatalf("Loading GraphQL Schema error: %v", err)
}

handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// Construct GraphQL request payload
query := `
Expand Down Expand Up @@ -638,7 +638,7 @@ func (s *ServiceGraphQLTests) testGQLInvalidMaxDepth(t *testing.T) {
t.Fatalf("Loading GraphQL Schema error: %v", err)
}

handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// Construct GraphQL request payload
query := `
Expand Down Expand Up @@ -718,7 +718,7 @@ func (s *ServiceGraphQLTests) testGQLInvalidNodeLimit(t *testing.T) {
t.Fatalf("Loading GraphQL Schema error: %v", err)
}

handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// Construct GraphQL request payload
query := `
Expand Down Expand Up @@ -812,7 +812,7 @@ func (s *ServiceGraphQLTests) testGQLDenylistBlock(t *testing.T) {
t.Fatal(err)
}

handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, deniedTokens)
handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, deniedTokens, nil)

// Construct GraphQL request payload
query := `
Expand Down Expand Up @@ -1041,7 +1041,7 @@ func (s *ServiceGraphQLTests) testGQLSubscription(t *testing.T) {
serverUrl, err := url.ParseRequestURI(cfg.Server.URL)
assert.Nil(t, err)

handler := graphqlHandler.Handlers(&cfg, schema, serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// connection to the backend
headers := http.Header{}
Expand Down Expand Up @@ -1215,7 +1215,7 @@ func (s *ServiceGraphQLTests) testGQLSubscriptionLogOnly(t *testing.T) {
serverUrl, err := url.ParseRequestURI(cfg.Server.URL)
assert.Nil(t, err)

handler := graphqlHandler.Handlers(&cfg, schema, serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// connection to the backend
headers := http.Header{}
Expand Down Expand Up @@ -1342,7 +1342,7 @@ func (s *ServiceGraphQLTests) testGQLMaxAliasesNum(t *testing.T) {
t.Fatalf("Loading GraphQL Schema error: %v", err)
}

handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// Construct GraphQL request payload
query := `
Expand Down Expand Up @@ -1433,7 +1433,7 @@ func (s *ServiceGraphQLTests) testGQLDuplicateFields(t *testing.T) {
t.Fatalf("Loading GraphQL Schema error: %v", err)
}

handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil)
handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil)

// Construct GraphQL request payload
query := `
Expand Down
6 changes: 3 additions & 3 deletions cmd/api-firewall/tests/main_json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestJSONBasic(t *testing.T) {

func (s *ServiceTests) testBasicObjJSONFieldValidation(t *testing.T) {

handler := proxyHandler.Handlers(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil)
handler := proxyHandler.Handlers(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, nil)

// basic object check
p, err := json.Marshal(map[string]interface{}{
Expand Down Expand Up @@ -178,7 +178,7 @@ func (s *ServiceTests) testBasicObjJSONFieldValidation(t *testing.T) {

func (s *ServiceTests) testBasicArrJSONFieldValidation(t *testing.T) {

handler := proxyHandler.Handlers(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil)
handler := proxyHandler.Handlers(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, nil)

p, err := json.Marshal([]map[string]interface{}{{
"valueNum": 10.1,
Expand Down Expand Up @@ -226,7 +226,7 @@ func (s *ServiceTests) testBasicArrJSONFieldValidation(t *testing.T) {

func (s *ServiceTests) testNegativeJSONFieldValidation(t *testing.T) {

handler := proxyHandler.Handlers(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil)
handler := proxyHandler.Handlers(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, nil)

req := fasthttp.AcquireRequest()
req.SetRequestURI("/test")
Expand Down
Loading

0 comments on commit a5bfe1e

Please sign in to comment.