Skip to content

Instantly share code, notes, and snippets.

@paprikati
Last active December 3, 2021 15:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save paprikati/8369b7c00d1bd0a740ca45b3813d3167 to your computer and use it in GitHub Desktop.
Save paprikati/8369b7c00d1bd0a740ca45b3813d3167 to your computer and use it in GitHub Desktop.
CheckOrganisationScope: using middleware to check that API responses are scoped by organisation ID

CheckOrganisationScope

This gist is associated with the blog post "Building safe-by-default tools in our Go web application".

It contains the goa middleware that we use to ensure our queries are correctly scoped by organisation.

package mw
import (
"context"
"reflect"
"github.com/incident-io/core/server/api/gen/billing"
"github.com/incident-io/core/server/api/gen/insights"
"github.com/incident-io/core/server/api/gen/system"
"github.com/incident-io/core/server/api/gen/typeaheads"
"github.com/incident-io/core/server/db"
"github.com/incident-io/core/server/internal/errors"
"github.com/incident-io/core/server/log"
goa "goa.design/goa/v3/pkg"
)
// CheckOrganisationScope generates errors whenever we return resources that don't belong
// to the organisation associated with the API request scope.
//
// It assumes our responses are struct pointers with a field that is either:
//
// - A pointer to a struct which should have an OrganisationID field
// - A pointer to a slice of pointers to structs which should have an OrganisationID field
func CheckOrganisationScope(db *db.Postgres) func(goa.Endpoint) goa.Endpoint {
return func(e goa.Endpoint) goa.Endpoint {
return goa.Endpoint(func(ctx context.Context, req interface{}) (interface{}, error) {
res, err := e(ctx, req)
if err != nil {
return res, err
}
if res == nil {
return res, err // There is nothing to check!
}
// If we don't have an org, we're an unauthenticated request, so running this
// middleware doesn't make sense.
id, _, _ := GetIdentity(ctx)
if id.OrganisationID == "" {
return res, err
}
if err := CheckOrganisationScopeResponse(id.OrganisationID, res); err != nil {
log.Error(ctx, err, map[string]interface{}{
"event": "check_organisation_scope_violation",
"endpoint_service": ctx.Value(goa.ServiceKey).(string),
"endpoint_method": ctx.Value(goa.MethodKey).(string),
})
}
return res, err
})
}
}
func CheckOrganisationScopeResponse(orgID string, res interface{}) error {
val := reflect.ValueOf(res).Elem()
if !val.IsValid() || val.IsZero() {
return nil // there is nothing to check!
}
for idx := 0; idx < val.NumField(); idx++ {
fieldVal := val.Field(idx)
// If we're wrapped in an interface, unpack it to get the real pointer type.
if fieldVal.Kind() == reflect.Interface {
fieldVal = fieldVal.Elem()
}
switch fieldVal.Kind() {
case reflect.Slice:
for elemIdx := 0; elemIdx < fieldVal.Len(); elemIdx++ {
if err := checkOrganisationScope(orgID, fieldVal.Index(elemIdx)); err != nil {
return err
}
}
case reflect.Ptr:
return checkOrganisationScope(orgID, fieldVal)
}
}
return nil
}
var ErrCheckOrganisationScopeMissingID = errors.New("this response does not have any organisation ID")
type ErrCheckOrganisationScopeIncorrectID struct {
ExpectedOrganisationID string
ResourceOrganisationID string
}
func (e ErrCheckOrganisationScopeIncorrectID) Error() string {
return "response includes data for an organisation outside of this API scope"
}
// checkOrganisationScope ensures a val, which is expected to be a pointer to a struct,
// has a valid OrganisationID field.
func checkOrganisationScope(orgID string, val reflect.Value) error {
if val.Elem().Kind() != reflect.Struct {
return nil
}
organisationField := val.Elem().FieldByName("OrganisationID")
if !organisationField.IsValid() || organisationField.IsZero() {
return ErrCheckOrganisationScopeMissingID
}
if resourceOrgID := organisationField.Interface().(string); resourceOrgID != orgID {
return ErrCheckOrganisationScopeIncorrectID{
ExpectedOrganisationID: orgID,
ResourceOrganisationID: resourceOrgID,
}
}
return nil
}
package mw_test
import (
"github.com/incident-io/core/server/domain"
"github.com/incident-io/core/server/api/mw"
. "github.com/incident-io/core/server/spec"
. "github.com/nauyey/factory"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("CheckOrganisationScopeResponse", func() {
var (
org domain.Organisation
)
BeforeEach(func() {
MustCreate(ctx, tx, &org, Build(domain.OrganisationFactory))
})
type envelope struct {
Envelope interface{}
}
type resWithOrg struct {
ID string
OrganisationID string
}
type resWithoutOrg struct {
ID string
}
check := func(res interface{}) error {
return mw.CheckOrganisationScopeResponse(org.ID, res)
}
Context("when singular response matches organisation", func() {
It("returns no error", func() {
err := check(&envelope{
Envelope: &resWithOrg{
ID: "my-id",
OrganisationID: org.ID,
},
})
Expect(err).NotTo(HaveOccurred())
})
})
Context("when singular response has different organisation", func() {
It("returns no error", func() {
err := check(&envelope{
Envelope: &resWithOrg{
ID: "my-id",
OrganisationID: "different-org-id",
},
})
Expect(err).To(MatchError(mw.ErrCheckOrganisationScopeIncorrectID{
ExpectedOrganisationID: org.ID,
ResourceOrganisationID: "different-org-id",
}))
})
})
Context("when singular response has no organisation", func() {
It("returns no error", func() {
err := check(&envelope{
Envelope: &resWithoutOrg{
ID: "my-id",
},
})
Expect(err).To(MatchError(mw.ErrCheckOrganisationScopeMissingID))
})
})
Context("when slice response matches organisation", func() {
It("returns no error", func() {
err := check(&envelope{
Envelope: []*resWithOrg{
{
ID: "my-id",
OrganisationID: org.ID,
},
},
})
Expect(err).NotTo(HaveOccurred())
})
})
Context("when slice response has different organisation", func() {
It("returns no error", func() {
err := check(&envelope{
Envelope: []*resWithOrg{
{
ID: "my-id",
OrganisationID: "different-org-id",
},
},
})
Expect(err).To(MatchError(mw.ErrCheckOrganisationScopeIncorrectID{
ExpectedOrganisationID: org.ID,
ResourceOrganisationID: "different-org-id",
}))
})
})
Context("when singular response has no organisation", func() {
It("returns no error", func() {
err := check(&envelope{
Envelope: []*resWithoutOrg{
{
ID: "my-id",
},
},
})
Expect(err).To(MatchError(mw.ErrCheckOrganisationScopeMissingID))
})
})
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment