package aws import ( "fmt" "path/filepath" "sort" "strconv" "strings" "sync" "github.com/BishopFox/cloudfox/aws/sdk" "github.com/BishopFox/cloudfox/internal" "github.com/aws/aws-sdk-go-v2/aws" apigatewayTypes "github.com/aws/aws-sdk-go-v2/service/apigateway/types" apigatewayV2Types "github.com/aws/aws-sdk-go-v2/service/apigatewayv2/types" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/bishopfox/awsservicemap" "github.com/sirupsen/logrus" ) var CURL_COMMAND string = "curl -X %s %s" type ApiGwModule struct { // General configuration data APIGatewayClient sdk.APIGatewayClientInterface APIGatewayv2Client sdk.APIGatewayv2ClientInterface Caller sts.GetCallerIdentityOutput AWSRegions []string Goroutines int AWSProfile string WrapTable bool // Main module data Gateways []ApiGateway CommandCounter internal.CommandCounter Errors []string // Used to store output data for pretty printing output internal.OutputData2 modLog *logrus.Entry } type ApiGateway struct { AWSService string Region string Name string Endpoint string ApiKey string Public string Method string } func (m *ApiGwModule) PrintApiGws(outputDirectory string, verbosity int) { m.output.Verbosity = verbosity m.output.Directory = outputDirectory m.output.CallingModule = "api-gw" m.modLog = internal.TxtLog.WithFields(logrus.Fields{ "module": m.output.CallingModule, }) if m.AWSProfile == "" { m.AWSProfile = internal.BuildAWSPath(m.Caller) } fmt.Printf("[%s][%s] Enumerating api-gateways for account %s.\n", cyan(m.output.CallingModule), cyan(m.AWSProfile), aws.ToString(m.Caller.Account)) wg := new(sync.WaitGroup) semaphore := make(chan struct{}, m.Goroutines) // Create a channel to signal the spinner aka task status goroutine to finish spinnerDone := make(chan bool) //fire up the the task status spinner/updated go internal.SpinUntil(m.output.CallingModule, &m.CommandCounter, spinnerDone, "tasks") //create a channel to receive the objects dataReceiver := make(chan ApiGateway) // Create a channel to signal to stop receiverDone := make(chan bool) go m.Receiver(dataReceiver, receiverDone) //execute regional checks for _, region := range m.AWSRegions { wg.Add(1) go m.executeChecks(region, wg, semaphore, dataReceiver) } wg.Wait() // Send a message to the spinner goroutine to close the channel and stop spinnerDone <- true <-spinnerDone receiverDone <- true <-receiverDone sort.Slice(m.Gateways, func(i, j int) bool { return m.Gateways[i].AWSService < m.Gateways[j].AWSService }) m.output.Headers = []string{ "Service", "Region", "Name", "Method", "Endpoint", "ApiKey", "Public", } // Table rows for i := range m.Gateways { m.output.Body = append( m.output.Body, []string{ m.Gateways[i].AWSService, m.Gateways[i].Region, m.Gateways[i].Name, m.Gateways[i].Method, m.Gateways[i].Endpoint, m.Gateways[i].ApiKey, m.Gateways[i].Public, }, ) } if len(m.output.Body) > 0 { filepath := filepath.Join(outputDirectory, "cloudfox-output", "aws", fmt.Sprintf("%s-%s", m.AWSProfile, aws.ToString(m.Caller.Account))) o := internal.OutputClient{ Verbosity: verbosity, CallingModule: m.output.CallingModule, Table: internal.TableClient{ Wrap: m.WrapTable, DirectoryName: filepath, }, Loot: internal.LootClient{ DirectoryName: filepath, }, } o.Table.TableFiles = append(o.Table.TableFiles, internal.TableFile{ Header: m.output.Headers, Body: m.output.Body, Name: m.output.CallingModule, }) o.PrefixIdentifier = m.AWSProfile loot := m.writeLoot(filepath, verbosity) o.Loot.LootFiles = append(o.Loot.LootFiles, internal.LootFile{ Name: m.output.CallingModule, Contents: loot, }) o.WriteFullOutput(o.Table.TableFiles, o.Loot.LootFiles) fmt.Printf("[%s][%s] %s API gateways found.\n", cyan(m.output.CallingModule), cyan(m.AWSProfile), strconv.Itoa(len(m.output.Body))) } else { fmt.Printf("[%s][%s] No API gateways found, skipping the creation of an output file.\n", cyan(m.output.CallingModule), cyan(m.AWSProfile)) } fmt.Printf("[%s][%s] For context and next steps: https://github.com/BishopFox/cloudfox/wiki/AWS-Commands#%s\n", cyan(m.output.CallingModule), cyan(m.AWSProfile), m.output.CallingModule) } func (m *ApiGwModule) Receiver(receiver chan ApiGateway, receiverDone chan bool) { defer close(receiverDone) for { select { case data := <-receiver: m.Gateways = append(m.Gateways, data) case <-receiverDone: receiverDone <- true return } } } func (m *ApiGwModule) executeChecks(r string, wg *sync.WaitGroup, semaphore chan struct{}, dataReceiver chan ApiGateway) { defer wg.Done() // check the concurrency semaphore // semaphore <- struct{}{} // defer func() { // <-semaphore // }() servicemap := &awsservicemap.AwsServiceMap{ JsonFileSource: "DOWNLOAD_FROM_AWS", } res, err := servicemap.IsServiceInRegion("apigateway", r) if err != nil { m.modLog.Error(err) } if res { m.CommandCounter.Total++ wg.Add(1) go m.getAPIGatewayAPIsPerRegion(r, wg, semaphore, dataReceiver) m.CommandCounter.Total++ wg.Add(1) go m.getAPIGatewayVIPsPerRegion(r, wg, semaphore, dataReceiver) m.CommandCounter.Total++ wg.Add(1) go m.getAPIGatewayv2APIsPerRegion(r, wg, semaphore, dataReceiver) m.CommandCounter.Total++ wg.Add(1) go m.getAPIGatewayv2VIPsPerRegion(r, wg, semaphore, dataReceiver) } } func (m *ApiGwModule) writeLoot(outputDirectory string, verbosity int) string { path := filepath.Join(outputDirectory, "loot") f := filepath.Join(path, "api-gws.txt") var out string for _, endpoint := range m.Gateways { method := endpoint.Method // Write a GET and POST for ANY if endpoint.Method == "ANY" { line := fmt.Sprintf(CURL_COMMAND, "GET", endpoint.Endpoint) if endpoint.ApiKey != "" { line += fmt.Sprintf(" -H 'X-Api-Key: %s'", endpoint.ApiKey) } out += line + "\n" method = "POST" } line := fmt.Sprintf(CURL_COMMAND, method, endpoint.Endpoint) if endpoint.ApiKey != "" { line += fmt.Sprintf(" -H 'X-Api-Key: %s'", endpoint.ApiKey) } if method == "DELETE" || method == "PATCH" || method == "POST" || method == "PUT" { line += " -H 'Content-Type: application/json' -d '{}'" } out += line + "\n" } // err = os.WriteFile(f, []byte(out), 0644) // if err != nil { // m.modLog.Error(err.Error()) // m.CommandCounter.Error++ // panic(err.Error()) // } if verbosity > 2 { fmt.Println() fmt.Printf("[%s][%s] %s \n", cyan(m.output.CallingModule), cyan(m.AWSProfile), green("Send these requests through your favorite interception proxy")) fmt.Print(out) fmt.Printf("[%s][%s] %s \n\n", cyan(m.output.CallingModule), cyan(m.AWSProfile), green("End of loot file.")) } fmt.Printf("[%s][%s] Loot written to [%s]\n", cyan(m.output.CallingModule), cyan(m.AWSProfile), f) return out } func (m *ApiGwModule) getAPIGatewayAPIsPerRegion(r string, wg *sync.WaitGroup, semaphore chan struct{}, dataReceiver chan ApiGateway) { defer func() { m.CommandCounter.Executing-- m.CommandCounter.Complete++ wg.Done() }() semaphore <- struct{}{} defer func() { <-semaphore }() // m.CommandCounter.Total++ m.CommandCounter.Pending-- m.CommandCounter.Executing++ // "PaginationMarker" is a control variable used for output continuity, as AWS return the output in pages. Items, err := sdk.CachedApiGatewayGetRestAPIs(m.APIGatewayClient, aws.ToString(m.Caller.Account), r) if err != nil { m.modLog.Error(err.Error()) m.CommandCounter.Error++ return } for _, api := range Items { m.CommandCounter.Total++ for _, endpoint := range m.getEndpointsPerAPIGateway(r, api) { dataReceiver <- endpoint } } } func (m *ApiGwModule) getAPIGatewayVIPsPerRegion(r string, wg *sync.WaitGroup, semaphore chan struct{}, dataReceiver chan ApiGateway) { defer func() { m.CommandCounter.Executing-- m.CommandCounter.Complete++ wg.Done() }() semaphore <- struct{}{} defer func() { <-semaphore }() // m.CommandCounter.Total++ m.CommandCounter.Pending-- m.CommandCounter.Executing++ // "PaginationMarker" is a control variable used for output continuity, as AWS return the output in pages. Items, err := sdk.CachedApiGatewayGetRestAPIs(m.APIGatewayClient, aws.ToString(m.Caller.Account), r) if err != nil { m.modLog.Error(err.Error()) m.CommandCounter.Error++ return } GetDomainNames, err := sdk.CachedApiGatewayGetDomainNames(m.APIGatewayClient, aws.ToString(m.Caller.Account), r) if err != nil { m.Errors = append(m.Errors, fmt.Sprintf(" Error: Region: %s", r)) m.modLog.Error(err.Error()) m.CommandCounter.Error++ return } for _, item := range GetDomainNames { domain := aws.ToString(item.DomainName) GetBasePathMappings, err := sdk.CachedApiGatewayGetBasePathMappings(m.APIGatewayClient, aws.ToString(m.Caller.Account), r, item.DomainName) if err != nil { m.Errors = append(m.Errors, fmt.Sprintf(" Error: Region: %s", r)) m.modLog.Error(err.Error()) m.CommandCounter.Error++ break } for _, mapping := range GetBasePathMappings { stage := aws.ToString(mapping.Stage) basePath := aws.ToString(mapping.BasePath) if basePath == "(none)" { basePath = "" // Empty string since '/' is already prepended } for _, api := range Items { if api.Id != nil && aws.ToString(api.Id) == aws.ToString(mapping.RestApiId) { m.CommandCounter.Total++ endpoints := m.getEndpointsPerAPIGateway(r, api) for _, endpoint := range endpoints { old := fmt.Sprintf("https://%s.execute-api.%s.amazonaws.com/%s/", aws.ToString(mapping.RestApiId), r, stage) if strings.HasPrefix(endpoint.Endpoint, old) { var new string if basePath == "" { new = fmt.Sprintf("https://%s/", domain) } else { new = fmt.Sprintf("https://%s/%s/", domain, basePath) } endpoint.Endpoint = strings.Replace(endpoint.Endpoint, old, new, 1) endpoint.Name = domain dataReceiver <- endpoint } } break } } } } } func (m *ApiGwModule) getEndpointsPerAPIGateway(r string, api apigatewayTypes.RestApi) []ApiGateway { defer func() { m.CommandCounter.Executing-- m.CommandCounter.Complete++ }() var gateways []ApiGateway //var PaginationControl2 *string awsService := "APIGateway" var public string name := aws.ToString(api.Name) id := aws.ToString(api.Id) raw_endpoint := fmt.Sprintf("https://%s.execute-api.%s.amazonaws.com", id, r) endpointType := *api.EndpointConfiguration //fmt.Println(endpointType) if endpointType.Types[0] == "PRIVATE" { public = "False" } else { public = "True" } GetStages, err := sdk.CachedApiGatewayGetStages(m.APIGatewayClient, aws.ToString(m.Caller.Account), r, id) if err != nil { m.Errors = append(m.Errors, fmt.Sprintf(" Error: Region: %s", r)) m.modLog.Error(err.Error()) m.CommandCounter.Error++ return gateways } resources, err := sdk.CachedApiGatewayGetResources(m.APIGatewayClient, aws.ToString(m.Caller.Account), r, id) if err != nil { m.Errors = append(m.Errors, fmt.Sprintf(" Error: Region: %s", r)) m.modLog.Error(err.Error()) m.CommandCounter.Error++ } for _, stage := range GetStages.Item { stageName := aws.ToString(stage.StageName) for _, resource := range resources { if len(resource.ResourceMethods) != 0 { for method := range resource.ResourceMethods { // Check if API Key is required for endpoint apiKey := "" if m.ApiGatewayApiKeyRequired(r, api.Id, resource.Id, method) { apiKey, err = m.GetApiGatewayApiKey(r, id, stageName) if err != nil { m.Errors = append(m.Errors, fmt.Sprintf(" Error: Region: %s", r)) m.modLog.Error(err.Error()) m.CommandCounter.Error++ } } path := aws.ToString(resource.Path) endpoint := fmt.Sprintf("%s/%s%s", raw_endpoint, stageName, path) gateways = append(gateways, ApiGateway{ AWSService: awsService, Region: r, Name: name, Endpoint: endpoint, Method: method, Public: public, ApiKey: apiKey, }) } } } } return gateways } func (m *ApiGwModule) getAPIGatewayv2APIsPerRegion(r string, wg *sync.WaitGroup, semaphore chan struct{}, dataReceiver chan ApiGateway) { defer func() { m.CommandCounter.Executing-- m.CommandCounter.Complete++ wg.Done() }() semaphore <- struct{}{} defer func() { <-semaphore }() // m.CommandCounter.Total++ m.CommandCounter.Pending-- m.CommandCounter.Executing++ // "PaginationMarker" is a control variable used for output continuity, as AWS return the output in pages. Items, err := sdk.CachedAPIGatewayv2GetAPIs(m.APIGatewayv2Client, aws.ToString(m.Caller.Account), r) if err != nil { m.modLog.Error(err.Error()) m.CommandCounter.Error++ return } for _, api := range Items { m.CommandCounter.Total++ for _, endpoint := range m.getEndpointsPerAPIGatewayv2(r, api) { dataReceiver <- endpoint } } } func (m *ApiGwModule) getAPIGatewayv2VIPsPerRegion(r string, wg *sync.WaitGroup, semaphore chan struct{}, dataReceiver chan ApiGateway) { defer func() { m.CommandCounter.Executing-- m.CommandCounter.Complete++ wg.Done() }() semaphore <- struct{}{} defer func() { <-semaphore }() // m.CommandCounter.Total++ m.CommandCounter.Pending-- m.CommandCounter.Executing++ // "PaginationMarker" is a control variable used for output continuity, as AWS return the output in pages. Items, err := sdk.CachedAPIGatewayv2GetAPIs(m.APIGatewayv2Client, aws.ToString(m.Caller.Account), r) if err != nil { m.modLog.Error(err.Error()) m.CommandCounter.Error++ return } GetDomainNames, err := sdk.CachedAPIGatewayv2GetDomainNames(m.APIGatewayv2Client, aws.ToString(m.Caller.Account), r) if err != nil { m.Errors = append(m.Errors, fmt.Sprintf(" Error: Region: %s", r)) m.modLog.Error(err.Error()) m.CommandCounter.Error++ } for _, item := range GetDomainNames { domain := aws.ToString(item.DomainName) GetApiMappings, err := sdk.CachedAPIGatewayv2GetApiMappings(m.APIGatewayv2Client, aws.ToString(m.Caller.Account), r, domain) if err != nil { m.Errors = append(m.Errors, fmt.Sprintf(" Error: Region: %s", r)) m.modLog.Error(err.Error()) m.CommandCounter.Error++ break } for _, mapping := range GetApiMappings { stage := aws.ToString(mapping.Stage) if stage == "$default" { stage = "" } path := aws.ToString(mapping.ApiMappingKey) for _, api := range Items { if api.ApiId != nil && aws.ToString(api.ApiId) == aws.ToString(mapping.ApiId) { m.CommandCounter.Total++ endpoints := m.getEndpointsPerAPIGatewayv2(r, api) for _, endpoint := range endpoints { var old string if stage == "" { old = fmt.Sprintf("https://%s.execute-api.%s.amazonaws.com/", aws.ToString(mapping.ApiId), r) } else { old = fmt.Sprintf("https://%s.execute-api.%s.amazonaws.com/%s/", aws.ToString(mapping.ApiId), r, stage) } if strings.HasPrefix(endpoint.Endpoint, old) { var new string if path == "" { new = fmt.Sprintf("https://%s/", domain) } else { new = fmt.Sprintf("https://%s/%s/", domain, path) } endpoint.Endpoint = strings.Replace(endpoint.Endpoint, old, new, 1) endpoint.Name = domain dataReceiver <- endpoint } } break } } } } } func (m *ApiGwModule) getEndpointsPerAPIGatewayv2(r string, api apigatewayV2Types.Api) []ApiGateway { defer func() { m.CommandCounter.Executing-- m.CommandCounter.Complete++ }() var gateways []ApiGateway awsService := "APIGatewayv2" var public string name := aws.ToString(api.Name) raw_endpoint := aws.ToString(api.ApiEndpoint) id := aws.ToString(api.ApiId) var stages []string GetStages, err := sdk.CachedAPIGatewayv2GetStages(m.APIGatewayv2Client, aws.ToString(m.Caller.Account), r, id) if err != nil { m.Errors = append(m.Errors, fmt.Sprintf(" Error: Region: %s", r)) m.modLog.Error(err.Error()) m.CommandCounter.Error++ } for _, stage := range GetStages { s := aws.ToString(stage.StageName) if s == "$default" { s = "" } stages = append(stages, s) } GetRoutes, err := sdk.CachedAPIGatewayv2GetRoutes(m.APIGatewayv2Client, aws.ToString(m.Caller.Account), r, id) if err != nil { m.Errors = append(m.Errors, fmt.Sprintf(" Error: Region: %s", r)) m.modLog.Error(err.Error()) m.CommandCounter.Error++ } for _, stage := range stages { for _, route := range GetRoutes { routeKey := route.RouteKey var method string var path string if len(strings.Fields(*routeKey)) == 2 { method = strings.Fields(*routeKey)[0] path = strings.Fields(*routeKey)[1] } var endpoint string if stage == "" { endpoint = fmt.Sprintf("%s%s", raw_endpoint, path) } else { endpoint = fmt.Sprintf("%s/%s%s", raw_endpoint, stage, path) } public = "True" gateways = append(gateways, ApiGateway{ AWSService: awsService, Region: r, Name: name, Method: method, Endpoint: endpoint, Public: public, }) } } return gateways } func (m *ApiGwModule) ApiGatewayApiKeyRequired(r string, ApiId *string, ResourceId *string, method string) bool { GetMethod, err := sdk.CachedApiGatewayGetMethod(m.APIGatewayClient, aws.ToString(m.Caller.Account), r, aws.ToString(ApiId), aws.ToString(ResourceId), method) if err != nil { m.Errors = append(m.Errors, fmt.Sprintf(" Error: Region: %s", r)) m.modLog.Error(err.Error()) m.CommandCounter.Error++ } else { return aws.ToBool(GetMethod.ApiKeyRequired) } return false } func (m *ApiGwModule) GetApiGatewayApiKey(r string, ApiId string, Stage string) (string, error) { var items []apigatewayTypes.UsagePlan GetUsagePlans, err := sdk.CachedApiGatewayGetUsagePlans(m.APIGatewayClient, aws.ToString(m.Caller.Account), r) if err != nil { return "", err } for _, item := range GetUsagePlans { items = append(items, item) } for _, item := range items { for _, apiStage := range item.ApiStages { if aws.ToString(apiStage.ApiId) == ApiId && aws.ToString(apiStage.Stage) == Stage { GetUsagePlanKeys, err := sdk.CachedApiGatewayGetUsagePlanKeys(m.APIGatewayClient, aws.ToString(m.Caller.Account), r, aws.ToString(item.Id)) if err != nil { return "", err } for _, i := range GetUsagePlanKeys { if aws.ToString(i.Type) == "API_KEY" { return aws.ToString(i.Value), nil } } } } } return "", nil }