0
Fork 0
mirror of https://github.com/project-zot/zot.git synced 2024-12-16 21:56:37 -05:00
zot/pkg/cli/discover.go

169 lines
3.9 KiB
Go
Raw Normal View History

//go:build search
// +build search
package cli
import (
"context"
"fmt"
distext "github.com/opencontainers/distribution-spec/specs-go/v1/extensions"
zerr "zotregistry.io/zot/errors"
"zotregistry.io/zot/pkg/api/constants"
zcommon "zotregistry.io/zot/pkg/common"
)
type field struct {
Name string `json:"name"`
Args []struct {
Name string `json:"name"`
} `json:"args"`
}
type schemaList struct {
Data struct {
Schema struct {
QueryType struct {
Fields []field `json:"fields"`
} `json:"queryType"` //nolint:tagliatelle // graphQL schema
Types []typeInfo `json:"types"`
} `json:"__schema"` //nolint:tagliatelle // graphQL schema
} `json:"data"`
Errors []zcommon.ErrorGQL `json:"errors"`
}
type typeInfo struct {
Name string `json:"name"`
Fields []typeField `json:"fields"`
}
type typeField struct {
Name string `json:"name"`
}
func containsGQLQueryWithParams(queryList []field, serverGQLTypesList []typeInfo, requiredQueries ...GQLQuery) error {
serverGQLTypes := map[string][]typeField{}
for _, typeInfo := range serverGQLTypesList {
serverGQLTypes[typeInfo.Name] = typeInfo.Fields
}
for _, reqQuery := range requiredQueries {
foundQuery := false
for _, query := range queryList {
if query.Name == reqQuery.Name && haveSameArgs(query, reqQuery) {
foundQuery = true
}
}
if !foundQuery {
return fmt.Errorf("%w: %s", zerr.ErrGQLQueryNotSupported, reqQuery.Name)
}
// let's check just the name of the returned type
returnType := reqQuery.ReturnType.Name
// we can next define fields of the returned types and check them recursively
// for now we will just check the name of the returned type to be known by the server
_, ok := serverGQLTypes[returnType]
if !ok {
return fmt.Errorf("%w: server doesn't support needed type '%s'", zerr.ErrGQLQueryNotSupported, returnType)
}
}
return nil
}
func haveSameArgs(query field, reqQuery GQLQuery) bool {
if len(query.Args) != len(reqQuery.Args) {
return false
}
for i := range query.Args {
if query.Args[i].Name != reqQuery.Args[i] {
return false
}
}
return true
}
func CheckExtEndPointQuery(config searchConfig, requiredQueries ...GQLQuery) error {
username, password := getUsernameAndPassword(config.user)
ctx := context.Background()
discoverEndPoint, err := combineServerAndEndpointURL(config.servURL, fmt.Sprintf("%s%s",
constants.RoutePrefix, constants.ExtOciDiscoverPrefix))
if err != nil {
return err
}
discoverResponse := &distext.ExtensionList{}
_, err = makeGETRequest(ctx, discoverEndPoint, username, password, config.verifyTLS,
config.debug, &discoverResponse, config.resultWriter)
if err != nil {
return err
}
searchEnabled := false
for _, extension := range discoverResponse.Extensions {
if extension.Name == constants.BaseExtension {
for _, endpoint := range extension.Endpoints {
if endpoint == constants.FullSearchPrefix {
searchEnabled = true
}
}
}
}
if !searchEnabled {
return fmt.Errorf("%w: search extension gql endpoints not found", zerr.ErrExtensionNotEnabled)
}
searchEndPoint, _ := combineServerAndEndpointURL(config.servURL, constants.FullSearchPrefix)
schemaQuery := `
{
__schema() {
queryType {
fields {
name
args {
name
}
type {
name
kind
}
}
__typename
}
types {
name
fields {
name
}
}
}
}`
queryResponse := &schemaList{}
err = makeGraphQLRequest(ctx, searchEndPoint, schemaQuery, username, password, config.verifyTLS,
config.debug, queryResponse, config.resultWriter)
if err != nil {
return fmt.Errorf("gql query failed: %w", err)
}
if err = checkResultGraphQLQuery(ctx, err, queryResponse.Errors); err != nil {
return fmt.Errorf("gql query failed: %w", err)
}
return containsGQLQueryWithParams(queryResponse.Data.Schema.QueryType.Fields,
queryResponse.Data.Schema.Types, requiredQueries...)
}