diff --git a/main.go b/main.go index 4bc7fad..85bcb98 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,7 @@ var ( redirectNoPath = os.Getenv("GOREDIRECT_NOPATH") ) -func wrapper(coll *mongo.Collection) func(w http.ResponseWriter, r *http.Request) { +func lookupHandler(ctx context.Context, coll *mongo.Collection) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { counterQueries.Inc() slug := r.URL.Path[1:] @@ -65,9 +65,9 @@ func wrapper(coll *mongo.Collection) func(w http.ResponseWriter, r *http.Request // Query var doc inventoryItemType - err := coll.FindOne(context.TODO(), bson.M{"shortener.slug": slug}).Decode(&doc) - if err != nil { + if err := coll.FindOne(ctx, bson.M{"shortener.slug": slug}).Decode(&doc); err != nil { counterNotFound.Inc() + if redirectNotFound == "" { http.NotFound(w, r) return @@ -113,8 +113,10 @@ var ( ) func main() { + ctx := context.Background() + if val := os.Getenv("GOREDIRECT_REGEX"); val != "" { - reValid = regexp.MustCompile(val) + regexValid = regexp.MustCompile(val) } if val := os.Getenv("MONGO_URI"); val != "" { @@ -122,29 +124,47 @@ func main() { } if val := os.Getenv("GOREDIRECT_COLLECTION"); val != "" { - collectionName = val + mongoCollection = val } + // Mongo database // + cs, err := connstring.ParseAndValidate(mongoURI) + if err != nil { + log.Fatal(err) + + os.Exit(1) + } + client, err := mongo.NewClient(options.Client().ApplyURI(mongoURI)) if err != nil { log.Fatal(err) + + os.Exit(1) } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - err = client.Connect(ctx) - if err != nil { + + connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + + if err := client.Connect(connectCtx); err != nil { log.Fatal(err) + + os.Exit(1) } cancel() - coll := client.Database(cs.Database).Collection(collectionName) + + coll := client.Database(cs.Database).Collection(mongoCollection) defer client.Disconnect(ctx) + + // HTTP Server // + http.Handle("/metrics", promhttp.Handler()) - http.HandleFunc("/", wrapper(coll)) + http.HandleFunc("/", lookupHandler(ctx, coll)) - log.Printf("Starting HTTP server\n") - err2 := http.ListenAndServe(":8080", nil) - if err2 != nil { - log.Fatal("ListenAndServe: ", err2) + log.Printf("Starting HTTP server on :8080\n") + + if err := http.ListenAndServe(":8080", nil); err != nil { + log.Fatal("ListenAndServe: ", err) + + os.Exit(1) } - }