From 3685c0330b8eae46debd4d52fda7e37349fffcb2 Mon Sep 17 00:00:00 2001 From: Syed Hashim Date: Mon, 4 Nov 2024 17:08:46 +0500 Subject: [PATCH] feat(inputs/firehose): update firehose plugin --- plugins/inputs/firehose/firehose.go | 65 +++++------------ plugins/inputs/firehose/firehose_request.go | 78 +++++++++++++++------ 2 files changed, 72 insertions(+), 71 deletions(-) diff --git a/plugins/inputs/firehose/firehose.go b/plugins/inputs/firehose/firehose.go index 6493765854e19..4316d3bfc7807 100644 --- a/plugins/inputs/firehose/firehose.go +++ b/plugins/inputs/firehose/firehose.go @@ -118,64 +118,34 @@ func (f *Firehose) ServeHTTP(res http.ResponseWriter, req *http.Request) { return } - r := &request{req: req} - // Set the default response status code - r.res.statusCode = http.StatusInternalServerError - requestID := r.req.Header.Get("x-amz-firehose-request-id") - if requestID == "" { - r.res.statusCode = http.StatusBadRequest - f.Log.Errorf("Request header x-amz-firehose-request-id not set") - if err := r.sendResponse(res); err != nil { - f.Log.Errorf("Sending response failed: %v", err) - } - return - } + r := f.handleRequest(req) - if err := r.validate(); err != nil { - f.Log.Errorf("Validation failed for request %q: %v", requestID, err) - if err = r.sendResponse(res); err != nil { - f.Log.Errorf("Sending response to request %q failed: %v", requestID, err) - } - return - } - - if err := r.authenticate(f.AccessKey); err != nil { - f.Log.Errorf("Authentication failed for request %q: %v", requestID, err) - if err = r.sendResponse(res); err != nil { - f.Log.Errorf("Sending response to request %q failed: %v", requestID, err) - } - return + if err := r.sendResponse(res); err != nil { + f.Log.Errorf("Sending response failed: %v", err) } +} - records, err := r.decodeData() - if err != nil { - f.Log.Errorf("Decode base64 data from request %q failed: %v", requestID, err) - if err = r.sendResponse(res); err != nil { - f.Log.Errorf("Sending response to request %q failed: %v", requestID, err) - } - return +func (f *Firehose) handleRequest(req *http.Request) (r *request) { + var err error + if r, err = newRequest(req); err != nil { + f.Log.Errorf("Creating request object failed: %v", err) + return r } - paramTags, err := r.extractParameterTags(f.ParameterTags) + records, paramTags, err := r.processRequest(f.AccessKey, f.ParameterTags) if err != nil { - f.Log.Errorf("Extracting parameter tags for request %q failed: %v", requestID, err) - if err = r.sendResponse(res); err != nil { - f.Log.Errorf("Sending response to request %q failed: %v", requestID, err) - } - return + f.Log.Errorf("Processing request failed: %v", err) + return r } var metrics []telegraf.Metric for _, record := range records { m, err := f.parser.Parse(record) if err != nil { - f.Log.Errorf("Parse data from request %q failed: %v", requestID, err) // respond with bad request status code to inform firehose about the failure r.res.statusCode = http.StatusBadRequest - if err = r.sendResponse(res); err != nil { - f.Log.Errorf("Sending response to request %q failed: %v", requestID, err) - } - return + f.Log.Errorf("Parse data from request %q failed: %v", r.body.RequestID, err) + return r } metrics = append(metrics, m...) } @@ -184,21 +154,18 @@ func (f *Firehose) ServeHTTP(res http.ResponseWriter, req *http.Request) { f.once.Do(func() { f.Log.Info(internal.NoMetricsCreatedMsg) }) - return } for _, m := range metrics { for k, v := range paramTags { m.AddTag(k, v) } - m.AddTag("firehose_http_path", req.URL.Path) + m.AddTag("firehose_http_path", r.req.URL.Path) f.acc.AddMetric(m) } r.res.statusCode = http.StatusOK - if err := r.sendResponse(res); err != nil { - f.Log.Errorf("Sending response to request %q failed: %v", requestID, err) - } + return r } func init() { diff --git a/plugins/inputs/firehose/firehose_request.go b/plugins/inputs/firehose/firehose_request.go index 241dff91dc7e9..4cd17e296390a 100644 --- a/plugins/inputs/firehose/firehose_request.go +++ b/plugins/inputs/firehose/firehose_request.go @@ -41,6 +41,37 @@ type responseBody struct { ErrorMessage string `json:"errorMessage,omitempty"` } +func newRequest(req *http.Request) (*request, error) { + r := &request{req: req} + requestID := r.req.Header.Get("x-amz-firehose-request-id") + if requestID == "" { + r.res.statusCode = http.StatusBadRequest + return r, errors.New("x-amz-firehose-request-id header is not set") + } + // Set a default response status code + r.res.statusCode = http.StatusInternalServerError + + encoding := r.req.Header.Get("content-encoding") + body, err := internal.NewStreamContentDecoder(encoding, r.req.Body) + if err != nil { + r.res.statusCode = http.StatusBadRequest + return r, fmt.Errorf("creating %q decoder for request %q failed: %w", encoding, requestID, err) + } + defer r.req.Body.Close() + + if err := json.NewDecoder(body).Decode(&r.body); err != nil { + r.res.statusCode = http.StatusBadRequest + return r, fmt.Errorf("decode body for request %q failed: %w", requestID, err) + } + + if requestID != r.body.RequestID { + r.res.statusCode = http.StatusBadRequest + return r, fmt.Errorf("mismatch between requestID in the request header (%q) and the request body (%s)", requestID, r.body.RequestID) + } + + return r, nil +} + func (r *request) authenticate(expected config.Secret) error { // We completely switch off authentication if no 'access_key' was provided in the config, it's intended! if expected.Empty() { @@ -75,28 +106,9 @@ func (r *request) validate() error { return fmt.Errorf("method %q is not allowed", r.req.Method) } - contentType := r.req.Header.Get("content-type") - if contentType != "application/json" { + if r.req.Header.Get("content-type") != "application/json" { r.res.statusCode = http.StatusBadRequest - return fmt.Errorf("content type %s is not allowed", contentType) - } - - encoding := r.req.Header.Get("content-encoding") - body, err := internal.NewStreamContentDecoder(encoding, r.req.Body) - if err != nil { - r.res.statusCode = http.StatusBadRequest - return fmt.Errorf("creating %q decoder failed: %w", encoding, err) - } - defer r.req.Body.Close() - - if err := json.NewDecoder(body).Decode(&r.body); err != nil { - r.res.statusCode = http.StatusBadRequest - return fmt.Errorf("decode body failed: %w", err) - } - - if r.body.RequestID != r.req.Header.Get("x-amz-firehose-request-id") { - r.res.statusCode = http.StatusBadRequest - return errors.New("requestId in the body does not match x-amz-firehose-request-id request header") + return fmt.Errorf("content type, %s, is not allowed", r.req.Header.Get("content-type")) } return nil @@ -145,6 +157,28 @@ func (r *request) extractParameterTags(parameterTags []string) (map[string]strin return paramTags, nil } +func (r *request) processRequest(key config.Secret, tags []string) ([][]byte, map[string]string, error) { + if err := r.authenticate(key); err != nil { + return nil, nil, fmt.Errorf("authentication for request %q failed: %w", r.body.RequestID, err) + } + + if err := r.validate(); err != nil { + return nil, nil, fmt.Errorf("validation for request %q failed: %w", r.body.RequestID, err) + } + + records, err := r.decodeData() + if err != nil { + return nil, nil, fmt.Errorf("decode base64 data from request %q failed: %w", r.body.RequestID, err) + } + + paramTags, err := r.extractParameterTags(tags) + if err != nil { + return nil, nil, fmt.Errorf("extracting parameter tags for request %q failed: %w", r.body.RequestID, err) + } + + return records, paramTags, nil +} + func (r *request) sendResponse(res http.ResponseWriter) error { var errorMessage string if r.res.statusCode != http.StatusOK { @@ -162,5 +196,5 @@ func (r *request) sendResponse(res http.ResponseWriter) error { res.Header().Set("content-type", "application/json") res.WriteHeader(r.res.statusCode) _, err = res.Write(response) - return err + return fmt.Errorf("writing response to request %s failed: %w", r.res.body.RequestID, err) }