Skip to content

Commit

Permalink
use rs/cors library to handle multiple CORS origins [#191]
Browse files Browse the repository at this point in the history
  • Loading branch information
bdon committed Jan 22, 2025
1 parent cb7d438 commit 07660e3
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 24 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ require (
github.com/dustin/go-humanize v1.0.1
github.com/paulmach/orb v0.10.0
github.com/prometheus/client_golang v1.19.1
github.com/rs/cors v1.11.1
github.com/schollz/progressbar/v3 v3.13.1
github.com/stretchr/testify v1.9.0
go.uber.org/zap v1.27.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
Expand Down
21 changes: 17 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import (
"os"
"path/filepath"
"strconv"
"strings"
"time"

"github.com/alecthomas/kong"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/cors"

"github.com/protomaps/go-pmtiles/pmtiles"
_ "gocloud.dev/blob/azureblob"
Expand Down Expand Up @@ -93,7 +95,7 @@ var cli struct {
Interface string `default:"0.0.0.0"`
Port int `default:"8080"`
AdminPort int `default:"-1"`
Cors string `help:"Value of HTTP CORS header."`
Cors string `help:"Comma-separated list of of allowed HTTP CORS origins."`
CacheSize int `default:"64" help:"Size of cache in Megabytes."`
Bucket string `help:"Remote bucket"`
PublicURL string `help:"Public base URL of tile endpoint for TileJSON e.g. https://example.com/tiles/"`
Expand Down Expand Up @@ -139,7 +141,7 @@ func main() {
logger.Fatalf("Failed to show tile, %v", err)
}
case "serve <path>":
server, err := pmtiles.NewServer(cli.Serve.Bucket, cli.Serve.Path, logger, cli.Serve.CacheSize, cli.Serve.Cors, cli.Serve.PublicURL)
server, err := pmtiles.NewServer(cli.Serve.Bucket, cli.Serve.Path, logger, cli.Serve.CacheSize, cli.Serve.PublicURL)

if err != nil {
logger.Fatalf("Failed to create new server, %v", err)
Expand All @@ -148,7 +150,9 @@ func main() {
pmtiles.SetBuildInfo(version, commit, date)
server.Start()

http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
mux := http.NewServeMux()

mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
statusCode := server.ServeHTTP(w, r)
logger.Printf("served %d %s in %s", statusCode, url.PathEscape(r.URL.Path), time.Since(start))
Expand All @@ -164,7 +168,16 @@ func main() {
logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+adminPort, adminMux))
}()
}
logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+strconv.Itoa(cli.Serve.Port), nil))

if cli.Serve.Cors != "" {
c := cors.New(cors.Options{
AllowedOrigins: strings.Split(cli.Serve.Cors, ","),
})
muxWithCors := c.Handler(mux)
logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+strconv.Itoa(cli.Serve.Port), muxWithCors))
} else {
logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+strconv.Itoa(cli.Serve.Port), mux))
}
case "extract <input> <output>":
err := pmtiles.Extract(logger, cli.Extract.Bucket, cli.Extract.Input, cli.Extract.Minzoom, cli.Extract.Maxzoom, cli.Extract.Region, cli.Extract.Bbox, cli.Extract.Output, cli.Extract.DownloadThreads, cli.Extract.Overfetch, cli.Extract.DryRun)
if err != nil {
Expand Down
24 changes: 4 additions & 20 deletions pmtiles/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,12 @@ type Server struct {
bucket Bucket
logger *log.Logger
cacheSize int
cors string
publicURL string
metrics *metrics
}

// NewServer creates a new pmtiles HTTP server.
func NewServer(bucketURL string, prefix string, logger *log.Logger, cacheSize int, cors string, publicURL string) (*Server, error) {
func NewServer(bucketURL string, prefix string, logger *log.Logger, cacheSize int, publicURL string) (*Server, error) {

ctx := context.Background()

Expand All @@ -71,11 +70,11 @@ func NewServer(bucketURL string, prefix string, logger *log.Logger, cacheSize in
return nil, err
}

return NewServerWithBucket(bucket, prefix, logger, cacheSize, cors, publicURL)
return NewServerWithBucket(bucket, prefix, logger, cacheSize, publicURL)
}

// NewServerWithBucket creates a new HTTP server for a gocloud Bucket.
func NewServerWithBucket(bucket Bucket, _ string, logger *log.Logger, cacheSize int, cors string, publicURL string) (*Server, error) {
func NewServerWithBucket(bucket Bucket, _ string, logger *log.Logger, cacheSize int, publicURL string) (*Server, error) {

reqs := make(chan request, 8)

Expand All @@ -84,7 +83,6 @@ func NewServerWithBucket(bucket Bucket, _ string, logger *log.Logger, cacheSize
bucket: bucket,
logger: logger,
cacheSize: cacheSize,
cors: cors,
publicURL: publicURL,
metrics: createMetrics("", logger), // change scope string if there are multiple servers running in one process
}
Expand Down Expand Up @@ -474,9 +472,6 @@ func (server *Server) get(ctx context.Context, unsanitizedPath string) (archive,
handler = ""
archive = ""
headers = make(map[string]string)
if len(server.cors) > 0 {
headers["Access-Control-Allow-Origin"] = server.cors
}

if ok, key, z, x, y, ext := parseTilePath(unsanitizedPath); ok {
archive, handler = key, "tile"
Expand Down Expand Up @@ -518,18 +513,7 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) {
// Serve an HTTP response from the archive
func (server *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) int {
tracker := server.metrics.startRequest()
if r.Method == http.MethodOptions {
if len(server.cors) > 0 {
w.Header().Set("Access-Control-Allow-Origin", server.cors)
}
w.WriteHeader(204)
tracker.finish(r.Context(), "", r.Method, 204, 0, false)
return 204
} else if r.Method != http.MethodGet && r.Method != http.MethodHead {
w.WriteHeader(405)
tracker.finish(r.Context(), "", r.Method, 405, 0, false)
return 405
}

archive, handler, statusCode, headers, body := server.get(r.Context(), r.URL.Path)
for k, v := range headers {
w.Header().Set(k, v)
Expand Down

0 comments on commit 07660e3

Please sign in to comment.