diff --git a/cmd/maincmd/main.go b/cmd/maincmd/main.go index 7393bfe516afd9840c47723444cde0d53535eae8..5334aacd1bd6f203836d57a63da36ce6fd2e9c00 100644 --- a/cmd/maincmd/main.go +++ b/cmd/maincmd/main.go @@ -1,6 +1,7 @@ package maincmd import ( + "context" "crypto/tls" "flag" "fmt" @@ -8,7 +9,9 @@ import ( "net" "net/http" "os" + "os/signal" "strconv" + "syscall" "time" "github.com/mccutchen/go-httpbin/httpbin" @@ -43,7 +46,7 @@ func Main() { if maxBodySize == httpbin.DefaultMaxBodySize && os.Getenv("MAX_BODY_SIZE") != "" { maxBodySize, err = strconv.ParseInt(os.Getenv("MAX_BODY_SIZE"), 10, 64) if err != nil { - fmt.Printf("invalid value %#v for env var MAX_BODY_SIZE: %s\n", os.Getenv("MAX_BODY_SIZE"), err) + fmt.Fprintf(os.Stderr, "Error: invalid value %#v for env var MAX_BODY_SIZE: %s\n\n", os.Getenv("MAX_BODY_SIZE"), err) flag.Usage() os.Exit(1) } @@ -51,7 +54,7 @@ func Main() { if maxDuration == httpbin.DefaultMaxDuration && os.Getenv("MAX_DURATION") != "" { maxDuration, err = time.ParseDuration(os.Getenv("MAX_DURATION")) if err != nil { - fmt.Printf("invalid value %#v for env var MAX_DURATION: %s\n", os.Getenv("MAX_DURATION"), err) + fmt.Fprintf(os.Stderr, "Error: invalid value %#v for env var MAX_DURATION: %s\n\n", os.Getenv("MAX_DURATION"), err) flag.Usage() os.Exit(1) } @@ -62,7 +65,7 @@ func Main() { if port == defaultPort && os.Getenv("PORT") != "" { port, err = strconv.Atoi(os.Getenv("PORT")) if err != nil { - fmt.Printf("invalid value %#v for env var PORT: %s\n", os.Getenv("PORT"), err) + fmt.Fprintf(os.Stderr, "Error: invalid value %#v for env var PORT: %s\n\n", os.Getenv("PORT"), err) flag.Usage() os.Exit(1) } @@ -77,6 +80,17 @@ func Main() { logger := log.New(os.Stderr, "", 0) + // A hacky log helper function to ensure that shutdown messages are + // formatted the same as other messages. See StdLogObserver in + // httpbin/middleware.go for the format we're matching here. + serverLog := func(msg string, args ...interface{}) { + const ( + logFmt = "time=%q msg=%q" + dateFmt = "2006-01-02T15:04:05.9999" + ) + logger.Printf(logFmt, time.Now().Format(dateFmt), fmt.Sprintf(msg, args...)) + } + h := httpbin.New( httpbin.WithMaxBodySize(maxBodySize), httpbin.WithMaxDuration(maxDuration), @@ -90,22 +104,48 @@ func Main() { Handler: h.Handler(), } + // shutdownCh triggers graceful shutdown on SIGINT or SIGTERM + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, syscall.SIGINT, syscall.SIGTERM) + + // exitCh will be closed when it is safe to exit, after graceful shutdown + exitCh := make(chan struct{}) + + go func() { + sig := <-shutdownCh + serverLog("shutdown started by signal: %s", sig) + + shutdownTimeout := maxDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + + server.SetKeepAlivesEnabled(false) + if err := server.Shutdown(ctx); err != nil { + serverLog("shutdown error: %s", err) + } + + close(exitCh) + }() + var listenErr error if httpsCertFile != "" && httpsKeyFile != "" { cert, err := tls.LoadX509KeyPair(httpsCertFile, httpsKeyFile) if err != nil { - logger.Fatal("Failed to generate https key pair: ", err) + logger.Fatalf("failed to generate https key pair: %s", err) } server.TLSConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, } - logger.Printf("go-httpbin listening on https://%s", listenAddr) + serverLog("go-httpbin listening on https://%s", listenAddr) listenErr = server.ListenAndServeTLS("", "") } else { - logger.Printf("go-httpbin listening on http://%s", listenAddr) + serverLog("go-httpbin listening on http://%s", listenAddr) listenErr = server.ListenAndServe() } - if listenErr != nil { - logger.Fatalf("Failed to listen: %s", listenErr) + if listenErr != nil && listenErr != http.ErrServerClosed { + logger.Fatalf("failed to listen: %s", listenErr) } + + <-exitCh + serverLog("shutdown finished") }