Skip to content

Commit

Permalink
refactor creation methods (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrparkers authored Dec 10, 2020
1 parent 6441226 commit 1374b0b
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 92 deletions.
140 changes: 52 additions & 88 deletions go/v1beta1/storage/elasticsearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,61 +92,36 @@ func (es *ElasticsearchStorage) CreateProject(ctx context.Context, projectId str
projectName := fmt.Sprintf("projects/%s", projectId)
log := es.logger.Named("CreateProject").With(zap.String("project", projectName))

searchBody := encodeRequest(&esSearch{
// check if project already exists
search := &esSearch{
Query: &filtering.Query{
Term: &filtering.Term{
"name": projectName,
},
},
})

res, err := es.client.Search(
es.client.Search.WithContext(ctx),
es.client.Search.WithIndex(projectsIndex()),
es.client.Search.WithBody(searchBody),
)
if err != nil {
return nil, createError(log, "error sending request to elasticsearch", err)
}
if res.IsError() {
return nil, createError(log, "error searching elasticsearch for projects", nil)
}

var searchResults esSearchResponse
if err := json.NewDecoder(res.Body).Decode(&searchResults); err != nil {
return nil, err
}
if searchResults.Hits.Total.Value > 0 {
err := es.genericGet(ctx, log, search, projectsIndex(), &prpb.Project{})
if err == nil { // project exists
log.Debug("project already exists")
return nil, status.Error(codes.AlreadyExists, fmt.Sprintf("project with name %s already exists", projectName))
} else if status.Code(err) != codes.NotFound { // unexpected error (we expect a not found error here)
return nil, err
}

p.Name = projectName
str, err := protojson.Marshal(proto.MessageV2(p))
if err != nil {
return nil, createError(log, "error marshalling occurrence to json", err)
}

// Create new project document
res, err = es.client.Index(
projectsIndex(),
bytes.NewReader(str),
es.client.Index.WithContext(ctx),
es.client.Index.WithRefresh("true"),
)
// create project document
err = es.genericCreate(ctx, log, projectsIndex(), p)
if err != nil {
return nil, createError(log, "error sending request to elasticsearch", err)
}
if res.IsError() {
return nil, createError(log, "error creating project document within elasticsearch", nil)
return nil, err
}

// Create indices for occurrences and notes
// create indices for occurrences and notes
for _, index := range []string{
occurrencesIndex(projectId),
notesIndex(projectId),
} {
res, err = es.client.Indices.Create(
res, err := es.client.Indices.Create(
index,
es.client.Indices.Create.WithContext(ctx),
withIndexMetadataAndStringMapping(),
Expand Down Expand Up @@ -383,33 +358,11 @@ func (es *ElasticsearchStorage) CreateOccurrence(ctx context.Context, projectId,
}
o.Name = fmt.Sprintf("projects/%s/occurrences/%s", projectId, uuid.New().String())

str, err := protojson.Marshal(proto.MessageV2(o))
err := es.genericCreate(ctx, log, occurrencesIndex(projectId), o)
if err != nil {
return nil, createError(log, "error marshalling occurrence to json", err)
}

res, err := es.client.Index(
occurrencesIndex(projectId),
bytes.NewReader(str),
es.client.Index.WithContext(ctx),
es.client.Index.WithRefresh("true"),
)
if err != nil {
return nil, createError(log, "error creating occurrence in elasticsearch", err)
}

if res.IsError() {
return nil, createError(log, "got unexpected status code from elasticsearch", nil, zap.Int("status", res.StatusCode))
}

esResponse := &esIndexDocResponse{}
err = decodeResponse(res.Body, esResponse)
if err != nil {
return nil, createError(log, "error decoding elasticsearch response", err)
return nil, err
}

log.Debug("elasticsearch response", zap.Any("response", esResponse))

return o, nil
}

Expand Down Expand Up @@ -648,13 +601,14 @@ func (es *ElasticsearchStorage) CreateNote(ctx context.Context, projectId, noteI
log := es.logger.Named("CreateNote").With(zap.String("note", noteName))

// since note IDs are provided up front by the client, we need to search ES to see if this note already exists before creating it
err := es.genericGet(ctx, log, &esSearch{
search := &esSearch{
Query: &filtering.Query{
Term: &filtering.Term{
"name": noteName,
},
},
}, notesIndex(projectId), &pb.Note{})
}
err := es.genericGet(ctx, log, search, notesIndex(projectId), &pb.Note{})
if err == nil { // note exists
log.Debug("note already exists")
return nil, status.Error(codes.AlreadyExists, fmt.Sprintf("note with name %s already exists", noteName))
Expand All @@ -667,32 +621,11 @@ func (es *ElasticsearchStorage) CreateNote(ctx context.Context, projectId, noteI
}
n.Name = noteName

str, err := protojson.Marshal(proto.MessageV2(n))
if err != nil {
return nil, createError(log, "error marshalling note to json", err)
}

res, err := es.client.Index(
notesIndex(projectId),
bytes.NewReader(str),
es.client.Index.WithContext(ctx),
)
if err != nil {
return nil, createError(log, "error sending request to elasticsearch", err)
}

if res.IsError() {
return nil, createError(log, "error creating note in elasticsearch", nil, zap.String("response", res.String()))
}

esResponse := &esIndexDocResponse{}
err = decodeResponse(res.Body, esResponse)
err = es.genericCreate(ctx, log, notesIndex(projectId), n)
if err != nil {
return nil, createError(log, "error decoding elasticsearch response", err)
return nil, err
}

log.Debug("elasticsearch response", zap.Any("response", esResponse))

return n, nil
}

Expand Down Expand Up @@ -726,7 +659,7 @@ func (es *ElasticsearchStorage) GetVulnerabilityOccurrencesSummary(ctx context.C
return &pb.VulnerabilityOccurrencesSummary{}, nil
}

func (es *ElasticsearchStorage) genericGet(ctx context.Context, log *zap.Logger, search *esSearch, index string, i interface{}) error {
func (es *ElasticsearchStorage) genericGet(ctx context.Context, log *zap.Logger, search *esSearch, index string, protoMessage interface{}) error {
res, err := es.client.Search(
es.client.Search.WithContext(ctx),
es.client.Search.WithIndex(index),
Expand All @@ -746,10 +679,41 @@ func (es *ElasticsearchStorage) genericGet(ctx context.Context, log *zap.Logger,

if searchResults.Hits.Total.Value == 0 {
log.Debug("document not found", zap.Any("search", search))
return status.Error(codes.NotFound, fmt.Sprintf("%T not found", i))
return status.Error(codes.NotFound, fmt.Sprintf("%T not found", protoMessage))
}

return protojson.Unmarshal(searchResults.Hits.Hits[0].Source, proto.MessageV2(i))
return protojson.Unmarshal(searchResults.Hits.Hits[0].Source, proto.MessageV2(protoMessage))
}

func (es *ElasticsearchStorage) genericCreate(ctx context.Context, log *zap.Logger, index string, protoMessage interface{}) error {
str, err := protojson.Marshal(proto.MessageV2(protoMessage))
if err != nil {
return createError(log, fmt.Sprintf("error marshalling %T to json", protoMessage), err)
}

res, err := es.client.Index(
index,
bytes.NewReader(str),
es.client.Index.WithContext(ctx),
es.client.Index.WithRefresh("true"),
)
if err != nil {
return createError(log, "error sending request to elasticsearch", err)
}

if res.IsError() {
return createError(log, "error indexing document in elasticsearch", nil, zap.String("response", res.String()), zap.Int("status", res.StatusCode))
}

esResponse := &esIndexDocResponse{}
err = decodeResponse(res.Body, esResponse)
if err != nil {
return createError(log, "error decoding elasticsearch response", err)
}

log.Debug("elasticsearch response", zap.Any("response", esResponse))

return nil
}

// createError is a helper function that allows you to easily log an error and return a gRPC formatted error.
Expand Down
7 changes: 3 additions & 4 deletions go/v1beta1/storage/elasticsearch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ var _ = Describe("elasticsearch storage", func() {
},
{
StatusCode: http.StatusOK,
Body: structToJsonBody(&esIndexDocResponse{
Id: gofakeit.LetterN(10),
}),
},
{
StatusCode: http.StatusOK,
Expand Down Expand Up @@ -216,10 +219,6 @@ var _ = Describe("elasticsearch storage", func() {
})

When("the project does not exist", func() {
BeforeEach(func() {
transport.preparedHttpResponses[1] = &http.Response{StatusCode: http.StatusCreated}
})

It("should create a new document for the project", func() {
Expect(transport.receivedHttpRequests[1].URL.Path).To(Equal(fmt.Sprintf("/%s/_doc", expectedProjectIndex)))
Expect(transport.receivedHttpRequests[1].Method).To(Equal(http.MethodPost))
Expand Down

0 comments on commit 1374b0b

Please sign in to comment.