diff --git a/pkg/extensions/extension_search.go b/pkg/extensions/extension_search.go index 0d24a6cb..533bee90 100644 --- a/pkg/extensions/extension_search.go +++ b/pkg/extensions/extension_search.go @@ -179,11 +179,27 @@ func SetupSearchRoutes(config *config.Config, router *mux.Router, storeControlle resConfig := search.GetResolverConfig(log, storeController, repoDB, cveInfo) extRouter := router.PathPrefix(constants.ExtSearchPrefix).Subrouter() + extRouter.Use(SearchACHeadersHandler()) extRouter.Methods("GET", "POST", "OPTIONS"). Handler(addSearchSecurityHeaders(gqlHandler.NewDefaultServer(gql_generated.NewExecutableSchema(resConfig)))) } } +func SearchACHeadersHandler() mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS") + resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + + if req.Method == http.MethodOptions { + return + } + + next.ServeHTTP(resp, req) + }) + } +} + func getExtension(name, url, description string, endpoints []string) distext.Extension { return distext.Extension{ Name: name, diff --git a/pkg/extensions/extension_userprefs.go b/pkg/extensions/extension_userprefs.go index 9ab0c46d..8f6d876d 100644 --- a/pkg/extensions/extension_userprefs.go +++ b/pkg/extensions/extension_userprefs.go @@ -31,20 +31,29 @@ func SetupUserPreferencesRoutes(config *config.Config, router *mux.Router, store log.Info().Msg("setting up user preferences routes") userprefsRouter := router.PathPrefix(constants.ExtUserPreferencesPrefix).Subrouter() + userprefsRouter.Use(UserPrefsACHeadersHandler()) userprefsRouter.HandleFunc("", HandleUserPrefs(repoDB, log)).Methods(zcommon.AllowedMethods(http.MethodPut)...) } } +func UserPrefsACHeadersHandler() mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,PUT,OPTIONS") + resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + + if req.Method == http.MethodOptions { + return + } + + next.ServeHTTP(resp, req) + }) + } +} + func HandleUserPrefs(repoDB repodb.RepoDB, log log.Logger) func(w http.ResponseWriter, r *http.Request) { return func(rsp http.ResponseWriter, req *http.Request) { - rsp.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,PUT,OPTIONS") - rsp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") - - if req.Method == http.MethodOptions { - return - } - if !queryHasParams(req.URL.Query(), []string{"action"}) { rsp.WriteHeader(http.StatusBadRequest)