Update module github.com/go-chi/chi/v5 to v5.2.5#125
Update module github.com/go-chi/chi/v5 to v5.2.5#125renovate[bot] wants to merge 1 commit intomasterfrom
Conversation
|
478c883 to
d57eded
Compare
d57eded to
f825e19
Compare
f825e19 to
1e7d00a
Compare
1e7d00a to
6d98ba2
Compare
|
[puLL-Merge] - go-chi/chi@v5.2.2..v5.2.5 Diffdiff --git .github/workflows/ci.yml .github/workflows/ci.yml
index 89ab2b42e..f42ef2d49 100644
--- .github/workflows/ci.yml
+++ .github/workflows/ci.yml
@@ -20,7 +20,7 @@ jobs:
strategy:
matrix:
- go-version: [1.20.x, 1.21.x, 1.22.x, 1.23.x, 1.24.x]
+ go-version: [1.22.x, 1.23.x, 1.24.x, 1.25.x]
os: [ubuntu-latest, windows-latest]
runs-on: ${{ matrix.os }}
diff --git _examples/README.md _examples/README.md
index 4bc4edc80..18b809b62 100644
--- _examples/README.md
+++ _examples/README.md
@@ -12,6 +12,7 @@ chi examples
* [router-walk](https://github.com/go-chi/chi/blob/master/_examples/router-walk/main.go) - Print to stdout a router's routes
* [todos-resource](https://github.com/go-chi/chi/blob/master/_examples/todos-resource/main.go) - Struct routers/handlers, an example of another code layout style
* [versions](https://github.com/go-chi/chi/blob/master/_examples/versions/main.go) - Demo of `chi/render` subpkg
+* [pathvalue](https://github.com/go-chi/chi/blob/master/_examples/pathvalue/main.go) - Demonstrates `PathValue` usage for retrieving URL parameters
## Usage
diff --git _examples/graceful/main.go _examples/graceful/main.go
index 6ae3750a1..73bce5a51 100644
--- _examples/graceful/main.go
+++ _examples/graceful/main.go
@@ -2,10 +2,10 @@ package main
import (
"context"
+ "errors"
"fmt"
"log"
"net/http"
- "os"
"os/signal"
"syscall"
"time"
@@ -18,41 +18,28 @@ func main() {
// The HTTP Server
server := &http.Server{Addr: "0.0.0.0:3333", Handler: service()}
- // Server run context
- serverCtx, serverStopCtx := context.WithCancel(context.Background())
+ // Create context that listens for the interrupt signal
+ ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+ defer stop()
- // Listen for syscall signals for process to interrupt/quit
- sig := make(chan os.Signal, 1)
- signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
+ // Run server in the background
go func() {
- <-sig
-
- // Shutdown signal with grace period of 30 seconds
- shutdownCtx, _ := context.WithTimeout(serverCtx, 30*time.Second)
-
- go func() {
- <-shutdownCtx.Done()
- if shutdownCtx.Err() == context.DeadlineExceeded {
- log.Fatal("graceful shutdown timed out.. forcing exit.")
- }
- }()
-
- // Trigger graceful shutdown
- err := server.Shutdown(shutdownCtx)
- if err != nil {
+ if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}
- serverStopCtx()
}()
- // Run the server
- err := server.ListenAndServe()
- if err != nil && err != http.ErrServerClosed {
+ // Listen for the interrupt signal
+ <-ctx.Done()
+
+ // Create shutdown context with 30-second timeout
+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ // Trigger graceful shutdown
+ if err := server.Shutdown(shutdownCtx); err != nil {
log.Fatal(err)
}
-
- // Wait for server context to be stopped
- <-serverCtx.Done()
}
func service() http.Handler {
@@ -69,7 +56,7 @@ func service() http.Handler {
// Simulates some hard work.
//
// We want this handler to complete successfully during a shutdown signal,
- // so consider the work here as some background routine to fetch a long running
+ // so consider the work here as some background routine to fetch a long-running
// search query to find as many results as possible, but, instead we cut it short
// and respond with what we have so far. How a shutdown is handled is entirely
// up to the developer, as some code blocks are preemptible, and others are not.
diff --git a/_examples/pathvalue/main.go b/_examples/pathvalue/main.go
new file mode 100644
index 000000000..c13f28007
--- /dev/null
+++ _examples/pathvalue/main.go
@@ -0,0 +1,25 @@
+package main
+
+import (
+ "fmt"
+ "net/http"
+
+ "github.com/go-chi/chi/v5"
+)
+
+func main() {
+ r := chi.NewRouter()
+
+ // Registering a handler that retrieves a path parameter using PathValue
+ r.Get("/users/{userID}", pathValueHandler)
+
+ http.ListenAndServe(":3333", r)
+}
+
+// pathValueHandler retrieves a URL parameter using PathValue and writes it to the response.
+func pathValueHandler(w http.ResponseWriter, r *http.Request) {
+ userID := r.PathValue("userID")
+
+ // Respond with the extracted userID
+ w.Write([]byte(fmt.Sprintf("User ID: %s", userID)))
+}
diff --git chi.go chi.go
index 2b6ebd337..f650116a8 100644
--- chi.go
+++ chi.go
@@ -1,6 +1,6 @@
// Package chi is a small, idiomatic and composable router for building HTTP services.
//
-// chi requires Go 1.14 or newer.
+// chi supports the four most recent major versions of Go.
//
// Example:
//
diff --git context.go context.go
index aacf6eff7..82220730e 100644
--- context.go
+++ context.go
@@ -133,11 +133,12 @@ func (x *Context) RoutePattern() string {
return routePattern
}
-// replaceWildcards takes a route pattern and recursively replaces all
-// occurrences of "/*/" to "/".
+// replaceWildcards takes a route pattern and replaces all occurrences of
+// "/*/" with "/". It iteratively runs until no wildcards remain to
+// correctly handle consecutive wildcards.
func replaceWildcards(p string) string {
- if strings.Contains(p, "/*/") {
- return replaceWildcards(strings.Replace(p, "/*/", "/", -1))
+ for strings.Contains(p, "/*/") {
+ p = strings.ReplaceAll(p, "/*/", "/")
}
return p
}
diff --git context_test.go context_test.go
index fa3c9f5b5..fa432852e 100644
--- context_test.go
+++ context_test.go
@@ -91,3 +91,14 @@ func TestRoutePattern(t *testing.T) {
t.Fatalf("unexpected non-empty route pattern for nil context: %q", p)
}
}
+
+// TestReplaceWildcardsConsecutive ensures multiple consecutive wildcards are
+// collapsed into a single slash.
+func TestReplaceWildcardsConsecutive(t *testing.T) {
+ if p := replaceWildcards("/foo/*/*/*/bar"); p != "/foo/bar" {
+ t.Fatalf("unexpected wildcard replacement: %s", p)
+ }
+ if p := replaceWildcards("/foo/*/*/*/bar/*"); p != "/foo/bar/*" {
+ t.Fatalf("unexpected trailing wildcard behavior: %s", p)
+ }
+}
diff --git go.mod go.mod
index cd49b2809..6f579ce2f 100644
--- go.mod
+++ go.mod
@@ -2,4 +2,4 @@ module github.com/go-chi/chi/v5
// Chi supports the four most recent major versions of Go.
// See https://github.com/go-chi/chi/issues/963.
-go 1.20
+go 1.22
diff --git middleware/compress_test.go middleware/compress_test.go
index eaafc13b9..5face239a 100644
--- middleware/compress_test.go
+++ middleware/compress_test.go
@@ -96,7 +96,6 @@ func TestCompressor(t *testing.T) {
}
for _, tc := range tests {
- tc := tc
t.Run(tc.name, func(t *testing.T) {
resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", tc.path, tc.acceptedEncodings...)
if respString != "textstring" {
diff --git middleware/content_charset.go middleware/content_charset.go
index 07bff9f2e..8e75fe8e4 100644
--- middleware/content_charset.go
+++ middleware/content_charset.go
@@ -2,6 +2,7 @@ package middleware
import (
"net/http"
+ "slices"
"strings"
)
@@ -29,13 +30,7 @@ func contentEncoding(ce string, charsets ...string) bool {
_, ce = split(strings.ToLower(ce), ";")
_, ce = split(ce, "charset=")
ce, _ = split(ce, ";")
- for _, c := range charsets {
- if ce == c {
- return true
- }
- }
-
- return false
+ return slices.Contains(charsets, ce)
}
// Split a string in two parts, cleaning any whitespace.
diff --git middleware/request_id.go middleware/request_id.go
index 4903ecc21..e1d4ccb7d 100644
--- middleware/request_id.go
+++ middleware/request_id.go
@@ -25,7 +25,7 @@ const RequestIDKey ctxKeyRequestID = 0
var RequestIDHeader = "X-Request-Id"
var prefix string
-var reqid uint64
+var reqid atomic.Uint64
// A quick note on the statistics here: we're trying to calculate the chance that
// two randomly generated base62 prefixes will collide. We use the formula from
@@ -69,7 +69,7 @@ func RequestID(next http.Handler) http.Handler {
ctx := r.Context()
requestID := r.Header.Get(RequestIDHeader)
if requestID == "" {
- myid := atomic.AddUint64(&reqid, 1)
+ myid := reqid.Add(1)
requestID = fmt.Sprintf("%s-%06d", prefix, myid)
}
ctx = context.WithValue(ctx, RequestIDKey, requestID)
@@ -92,5 +92,5 @@ func GetReqID(ctx context.Context) string {
// NextRequestID generates the next request ID in the sequence.
func NextRequestID() uint64 {
- return atomic.AddUint64(&reqid, 1)
+ return reqid.Add(1)
}
diff --git middleware/route_headers.go middleware/route_headers.go
index 88743769a..1c3334d35 100644
--- middleware/route_headers.go
+++ middleware/route_headers.go
@@ -79,6 +79,7 @@ func (hr HeaderRouter) Handler(next http.Handler) http.Handler {
if len(hr) == 0 {
// skip if no routes set
next.ServeHTTP(w, r)
+ return
}
// find first matching header route, and continue
diff --git a/middleware/route_headers_test.go b/middleware/route_headers_test.go
new file mode 100644
index 000000000..819101a93
--- /dev/null
+++ middleware/route_headers_test.go
@@ -0,0 +1,212 @@
+package middleware
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "sync/atomic"
+ "testing"
+)
+
+func TestRouteHeaders(t *testing.T) {
+ t.Run("empty router should call next handler exactly once", func(t *testing.T) {
+ var callCount atomic.Int32
+
+ hr := RouteHeaders()
+
+ handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ callCount.Add(1)
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ req := httptest.NewRequest("GET", "/", nil)
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ if callCount.Load() != 1 {
+ t.Errorf("expected next handler to be called exactly once, but was called %d times", callCount.Load())
+ }
+ })
+
+ t.Run("matching header should route to correct middleware", func(t *testing.T) {
+ var matchedRoute string
+
+ hr := RouteHeaders().
+ Route("Host", "example.com", func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ matchedRoute = "example.com"
+ next.ServeHTTP(w, r)
+ })
+ }).
+ Route("Host", "other.com", func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ matchedRoute = "other.com"
+ next.ServeHTTP(w, r)
+ })
+ })
+
+ handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.Host = "example.com"
+ req.Header.Set("Host", "example.com")
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ if matchedRoute != "example.com" {
+ t.Errorf("expected matched route to be 'example.com', got '%s'", matchedRoute)
+ }
+ })
+
+ t.Run("wildcard pattern should match", func(t *testing.T) {
+ var matched bool
+
+ hr := RouteHeaders().
+ Route("Host", "*.example.com", func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ matched = true
+ next.ServeHTTP(w, r)
+ })
+ })
+
+ handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.Header.Set("Host", "api.example.com")
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ if !matched {
+ t.Error("expected wildcard pattern to match")
+ }
+ })
+
+ t.Run("default route should be used when no match", func(t *testing.T) {
+ var usedDefault bool
+
+ hr := RouteHeaders().
+ Route("Host", "example.com", func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ next.ServeHTTP(w, r)
+ })
+ }).
+ RouteDefault(func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ usedDefault = true
+ next.ServeHTTP(w, r)
+ })
+ })
+
+ handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.Header.Set("Host", "other.com")
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ if !usedDefault {
+ t.Error("expected default route to be used when no match")
+ }
+ })
+
+ t.Run("RouteAny should match any of the provided patterns", func(t *testing.T) {
+ var matched bool
+
+ hr := RouteHeaders().
+ RouteAny("Content-Type", []string{"application/json", "application/xml"}, func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ matched = true
+ next.ServeHTTP(w, r)
+ })
+ })
+
+ handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ // Test with application/json
+ req := httptest.NewRequest("POST", "/", nil)
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ if !matched {
+ t.Error("expected RouteAny to match 'application/json'")
+ }
+
+ // Reset and test with application/xml
+ matched = false
+ req = httptest.NewRequest("POST", "/", nil)
+ req.Header.Set("Content-Type", "application/xml")
+ rec = httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ if !matched {
+ t.Error("expected RouteAny to match 'application/xml'")
+ }
+ })
+
+ t.Run("no match and no default should call next handler", func(t *testing.T) {
+ var nextCalled bool
+
+ hr := RouteHeaders().
+ Route("Host", "example.com", func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ next.ServeHTTP(w, r)
+ })
+ })
+
+ handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ nextCalled = true
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.Header.Set("Host", "other.com")
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ if !nextCalled {
+ t.Error("expected next handler to be called when no match and no default")
+ }
+ })
+}
+
+func TestPattern(t *testing.T) {
+ tests := []struct {
+ pattern string
+ value string
+ expected bool
+ }{
+ {"example.com", "example.com", true},
+ {"example.com", "other.com", false},
+ {"*.example.com", "api.example.com", true},
+ {"*.example.com", "example.com", false},
+ {"api.*", "api.example.com", true},
+ {"*", "anything", true},
+ {"prefix*suffix", "prefixmiddlesuffix", true},
+ {"prefix*suffix", "prefixsuffix", true},
+ {"prefix*suffix", "wrongmiddlesuffix", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.pattern+"_"+tt.value, func(t *testing.T) {
+ p := NewPattern(tt.pattern)
+ if got := p.Match(tt.value); got != tt.expected {
+ t.Errorf("Pattern(%q).Match(%q) = %v, want %v", tt.pattern, tt.value, got, tt.expected)
+ }
+ })
+ }
+}
diff --git middleware/strip.go middleware/strip.go
index 17aa9bf32..32d21e90b 100644
--- middleware/strip.go
+++ middleware/strip.go
@@ -47,15 +47,22 @@ func RedirectSlashes(next http.Handler) http.Handler {
} else {
path = r.URL.Path
}
+
if len(path) > 1 && path[len(path)-1] == '/' {
- // Trim all leading and trailing slashes (e.g., "//evil.com", "/some/path//")
- path = "/" + strings.Trim(path, "/")
+ // Normalize backslashes to forward slashes to prevent "/\evil.com" style redirects
+ // that some clients may interpret as protocol-relative.
+ path = strings.ReplaceAll(path, `\`, `/`)
+
+ // Collapse leading/trailing slashes and force a single leading slash.
+ path := "/" + strings.Trim(path, "/")
+
if r.URL.RawQuery != "" {
path = fmt.Sprintf("%s?%s", path, r.URL.RawQuery)
}
http.Redirect(w, r, path, 301)
return
}
+
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
diff --git middleware/strip_test.go middleware/strip_test.go
index 79a0b3480..2fa986803 100644
--- middleware/strip_test.go
+++ middleware/strip_test.go
@@ -4,6 +4,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
+ "strings"
"testing"
"github.com/go-chi/chi/v5"
@@ -271,3 +272,58 @@ func TestStripPrefix(t *testing.T) {
t.Fatalf("got: %q, want: %q", resp, "404 page not found\n")
}
}
+
+func TestRedirectSlashes_PreventBackslashRelativeOpenRedirect(t *testing.T) {
+ h := RedirectSlashes(http.NotFoundHandler())
+
+ tests := []struct {
+ name string
+ target string
+ }{
+ {
+ name: `raw backslash: /\evil.com/`,
+ target: `/\evil.com/`,
+ },
+ {
+ name: `encoded backslash: /%5Cevil.com/`,
+ target: "/%5Cevil.com/",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "http://example.test"+tc.target, nil)
+ rr := httptest.NewRecorder()
+
+ h.ServeHTTP(rr, req)
+ res := rr.Result()
+ defer res.Body.Close()
+
+ if res.StatusCode != http.StatusMovedPermanently {
+ t.Fatalf("expected %d, got %d", http.StatusMovedPermanently, res.StatusCode)
+ }
+
+ loc := res.Header.Get("Location")
+ if loc == "" {
+ t.Fatalf("expected Location header to be set")
+ }
+
+ // The core security assertions:
+ if strings.Contains(loc, `\`) {
+ t.Fatalf("Location must not contain backslashes: %q", loc)
+ }
+ if strings.HasPrefix(loc, "//") {
+ t.Fatalf("Location must not be protocol-relative: %q", loc)
+ }
+ if !strings.HasPrefix(loc, "/") {
+ t.Fatalf("Location must be an absolute-path reference starting with '/': %q", loc)
+ }
+
+ // Optional stronger assertion if your middleware normalizes to /evil.com exactly:
+ // (Keep or remove depending on your chosen behavior.)
+ if loc != "/evil.com" {
+ t.Fatalf("expected Location %q, got %q", "/evil.com", loc)
+ }
+ })
+ }
+}
diff --git middleware/throttle.go middleware/throttle.go
index b1f926de9..7ea482b91 100644
--- middleware/throttle.go
+++ middleware/throttle.go
@@ -83,12 +83,23 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
return
case btok := <-t.backlogTokens:
- timer := time.NewTimer(t.backlogTimeout)
-
defer func() {
t.backlogTokens <- btok
}()
+ // Try to get a processing token immediately first
+ select {
+ case tok := <-t.tokens:
+ defer func() {
+ t.tokens <- tok
+ }()
+ next.ServeHTTP(w, r)
+ return
+ default:
+ // No immediate token available, need to wait with timer
+ }
+
+ timer := time.NewTimer(t.backlogTimeout)
select {
case <-timer.C:
t.setRetryAfterHeaderIfNeeded(w, false)
diff --git middleware/throttle_test.go middleware/throttle_test.go
index d4855f45e..f26d7e7c9 100644
--- middleware/throttle_test.go
+++ middleware/throttle_test.go
@@ -37,7 +37,7 @@ func TestThrottleBacklog(t *testing.T) {
// The throttler processes 10 consecutive requests, each one of those
// requests lasts 1s. The maximum number of requests this can possible serve
// before the clients time out (5s) is 40.
- for i := 0; i < 40; i++ {
+ for i := range 40 {
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -75,7 +75,7 @@ func TestThrottleClientTimeout(t *testing.T) {
var wg sync.WaitGroup
- for i := 0; i < 10; i++ {
+ for i := range 10 {
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -108,7 +108,7 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
var wg sync.WaitGroup
// These requests will be processed normally until they finish.
- for i := 0; i < 50; i++ {
+ for i := range 50 {
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -123,7 +123,7 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
// These requests will wait for the first batch to complete but it will take
// too much time, so they will eventually receive a timeout error.
- for i := 0; i < 50; i++ {
+ for i := range 50 {
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -161,7 +161,7 @@ func TestThrottleMaximum(t *testing.T) {
var wg sync.WaitGroup
- for i := 0; i < 20; i++ {
+ for i := range 20 {
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -181,7 +181,7 @@ func TestThrottleMaximum(t *testing.T) {
// At this point the server is still processing, all the following request
// will be beyond the server capacity.
- for i := 0; i < 20; i++ {
+ for i := range 20 {
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -199,55 +199,67 @@ func TestThrottleMaximum(t *testing.T) {
wg.Wait()
}
-// NOTE: test is disabled as it requires some refactoring. It is prone to intermittent failure.
-/*func TestThrottleRetryAfter(t *testing.T) {
+func TestThrottleRetryAfter(t *testing.T) {
r := chi.NewRouter()
+ retryAfterFn := func(ctxDone bool) time.Duration { return time.Hour }
- retryAfterFn := func(ctxDone bool) time.Duration { return time.Hour * 1 }
- r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 10, RetryAfterFn: retryAfterFn}))
+ r.Use(ThrottleWithOpts(ThrottleOpts{
+ Limit: 5,
+ BacklogLimit: 0,
+ RetryAfterFn: retryAfterFn,
+ }))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
+ time.Sleep(time.Second * 1) // Expensive operation.
w.WriteHeader(http.StatusOK)
- time.Sleep(time.Second * 4) // Expensive operation.
- w.Write(testContent)
+ w.Write([]byte("ok"))
})
server := httptest.NewServer(r)
defer server.Close()
+ client := http.Client{}
- client := http.Client{
- Timeout: time.Second * 60, // Maximum waiting time.
+ type result struct {
+ status int
+ header http.Header
}
var wg sync.WaitGroup
+ totalRequests := 10
+ resultsCh := make(chan result, totalRequests)
- for i := 0; i < 10; i++ {
+ for i := 0; i < totalRequests; i++ {
wg.Add(1)
- go func(i int) {
+ go func() {
defer wg.Done()
-
- res, err := client.Get(server.URL)
- assertNoError(t, err)
- assertEqual(t, http.StatusOK, res.StatusCode)
- }(i)
+ res, _ := client.Get(server.URL)
+ resultsCh <- result{status: res.StatusCode, header: res.Header}
+ }()
}
- time.Sleep(time.Second * 1)
-
- for i := 0; i < 10; i++ {
- wg.Add(1)
- go func(i int) {
- defer wg.Done()
-
- res, err := client.Get(server.URL)
- assertNoError(t, err)
- assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
- assertEqual(t, res.Header.Get("Retry-After"), "3600")
- }(i)
+ wg.Wait()
+ close(resultsCh)
+
+ count200 := 0
+ count429 := 0
+ for res := range resultsCh {
+ switch res.status {
+ case http.StatusOK:
+ count200++
+ continue
+ case http.StatusTooManyRequests:
+ count429++
+ assertEqual(t, "3600", res.header.Get("Retry-After"))
+ continue
+ default:
+ t.Fatalf("Unexpected status code: %d", res.status)
+ continue
+ }
}
- wg.Wait()
-}*/
+ assertEqual(t, 5, count200)
+ assertEqual(t, 5, count429)
+}
func TestThrottleCustomStatusCode(t *testing.T) {
const timeout = time.Second * 3
@@ -271,7 +283,7 @@ func TestThrottleCustomStatusCode(t *testing.T) {
codes := make(chan int, totalRequestCount)
errs := make(chan error, totalRequestCount)
client := &http.Client{Timeout: timeout}
- for i := 0; i < totalRequestCount; i++ {
+ for range totalRequestCount {
go func() {
resp, err := client.Get(server.URL)
if err != nil {
@@ -293,9 +305,26 @@ func TestThrottleCustomStatusCode(t *testing.T) {
}
}
- for i := 0; i < totalRequestCount-1; i++ {
+ for range totalRequestCount - 1 {
waitResponse(http.StatusServiceUnavailable)
}
close(wait) // Allow the last request to proceed.
waitResponse(http.StatusOK)
}
+
+func BenchmarkThrottle(b *testing.B) {
+ throttleMiddleware := ThrottleBacklog(1000, 50, time.Second)
+
+ handler := throttleMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ req := httptest.NewRequest("GET", "/", nil)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ w := httptest.NewRecorder()
+ handler.ServeHTTP(w, req)
+ }
+}
diff --git middleware/url_format.go middleware/url_format.go
index d8a651e4e..2ec6657e1 100644
--- middleware/url_format.go
+++ middleware/url_format.go
@@ -63,7 +63,9 @@ func URLFormat(next http.Handler) http.Handler {
idx += base
format = path[idx+1:]
- rctx.RoutePath = path[:idx]
+ if rctx != nil {
+ rctx.RoutePath = path[:idx]
+ }
}
}
diff --git mux.go mux.go
index f1266971b..71652dd17 100644
--- mux.go
+++ mux.go
@@ -107,7 +107,8 @@ func (mx *Mux) Use(middlewares ...func(http.Handler) http.Handler) {
// Handle adds the route `pattern` that matches any http method to
// execute the `handler` http.Handler.
func (mx *Mux) Handle(pattern string, handler http.Handler) {
- if method, rest, found := strings.Cut(pattern, " "); found {
+ if i := strings.IndexAny(pattern, " \t"); i >= 0 {
+ method, rest := pattern[:i], strings.TrimLeft(pattern[i+1:], " \t")
mx.Method(method, rest, handler)
return
}
@@ -118,12 +119,7 @@ func (mx *Mux) Handle(pattern string, handler http.Handler) {
// HandleFunc adds the route `pattern` that matches any http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) HandleFunc(pattern string, handlerFn http.HandlerFunc) {
- if method, rest, found := strings.Cut(pattern, " "); found {
- mx.Method(method, rest, handlerFn)
- return
- }
-
- mx.handle(mALL, pattern, handlerFn)
+ mx.Handle(pattern, handlerFn)
}
// Method adds the route `pattern` that matches `method` http method to
@@ -471,8 +467,13 @@ func (mx *Mux) routeHTTP(w http.ResponseWriter, r *http.Request) {
// Find the route
if _, _, h := mx.tree.FindRoute(rctx, method, routePath); h != nil {
- if supportsPathValue {
- setPathValue(rctx, r)
+ // Set http.Request path values from our request context
+ for i, key := range rctx.URLParams.Keys {
+ value := rctx.URLParams.Values[i]
+ r.SetPathValue(key, value)
+ }
+ if supportsPattern {
+ setPattern(rctx, r)
}
h.ServeHTTP(w, r)
diff --git mux_test.go mux_test.go
index 0b7bf1ff0..d69a6f8a5 100644
--- mux_test.go
+++ mux_test.go
@@ -668,6 +668,15 @@ func TestMuxHandlePatternValidation(t *testing.T) {
expectedBody: "with-prefix POST",
expectedStatus: http.StatusOK,
},
+ {
+ name: "Valid pattern with multiple whitespace after method",
+ pattern: "PATCH \t /",
+ shouldPanic: false,
+ method: "PATCH",
+ path: "/",
+ expectedBody: "extended-whitespace PATCH",
+ expectedStatus: http.StatusOK,
+ },
// Invalid patterns
{
name: "Invalid pattern with no method",
@@ -1684,11 +1693,11 @@ func TestMuxContextIsThreadSafe(t *testing.T) {
wg := sync.WaitGroup{}
- for i := 0; i < 100; i++ {
+ for range 100 {
wg.Add(1)
go func() {
defer wg.Done()
- for j := 0; j < 10000; j++ {
+ for range 10000 {
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/ok", nil)
if err != nil {
@@ -1773,6 +1782,26 @@ func TestCustomHTTPMethod(t *testing.T) {
if _, body := testRequest(t, ts, "BOO", "/hi", nil); body != "custom method" {
t.Fatal(body)
}
+
+ var expectRoutes = map[string]string{
+ "GET": "/",
+ "BOO": "/hi",
+ }
+ Walk(r, func(method string, route string, handler http.Handler, _ ...func(http.Handler) http.Handler) error {
+ r, ok := expectRoutes[method]
+ if !ok {
+ t.Fatalf("unexpected method %s", method)
+ }
+ if r != route {
+ t.Fatalf("expected route %s, got %s", r, route)
+ }
+ delete(expectRoutes, method)
+
+ return nil
+ })
+ if len(expectRoutes) != 0 {
+ t.Fatalf("missing expected methods: %v", expectRoutes)
+ }
}
func TestMuxMatch(t *testing.T) {
diff --git path_value.go path_value.go
deleted file mode 100644
index 77c840f01..000000000
--- path_value.go
+++ /dev/null
@@ -1,21 +0,0 @@
-//go:build go1.22 && !tinygo
-// +build go1.22,!tinygo
-
-
-package chi
-
-import "net/http"
-
-// supportsPathValue is true if the Go version is 1.22 and above.
-//
-// If this is true, `net/http.Request` has methods `SetPathValue` and `PathValue`.
-const supportsPathValue = true
-
-// setPathValue sets the path values in the Request value
-// based on the provided request context.
-func setPathValue(rctx *Context, r *http.Request) {
- for i, key := range rctx.URLParams.Keys {
- value := rctx.URLParams.Values[i]
- r.SetPathValue(key, value)
- }
-}
diff --git path_value_fallback.go path_value_fallback.go
deleted file mode 100644
index 749a8520a..000000000
--- path_value_fallback.go
+++ /dev/null
@@ -1,19 +0,0 @@
-//go:build !go1.22 || tinygo
-// +build !go1.22 tinygo
-
-package chi
-
-import "net/http"
-
-// supportsPathValue is true if the Go version is 1.22 and above.
-//
-// If this is true, `net/http.Request` has methods `SetPathValue` and `PathValue`.
-const supportsPathValue = false
-
-// setPathValue sets the path values in the Request value
-// based on the provided request context.
-//
-// setPathValue is only supported in Go 1.22 and above so
-// this is just a blank function so that it compiles.
-func setPathValue(rctx *Context, r *http.Request) {
-}
diff --git path_value_test.go path_value_test.go
index 35fe421a0..a762030ce 100644
--- path_value_test.go
+++ path_value_test.go
@@ -1,6 +1,3 @@
-//go:build go1.22 && !tinygo
-// +build go1.22,!tinygo
-
package chi
import (
diff --git a/pattern.go b/pattern.go
new file mode 100644
index 000000000..890a2c217
--- /dev/null
+++ pattern.go
@@ -0,0 +1,16 @@
+//go:build go1.23 && !tinygo
+// +build go1.23,!tinygo
+
+package chi
+
+import "net/http"
+
+// supportsPattern is true if the Go version is 1.23 and above.
+//
+// If this is true, `net/http.Request` has field `Pattern`.
+const supportsPattern = true
+
+// setPattern sets the mux matched pattern in the http Request.
+func setPattern(rctx *Context, r *http.Request) {
+ r.Pattern = rctx.routePattern
+}
diff --git a/pattern_fallback.go b/pattern_fallback.go
new file mode 100644
index 000000000..48a94ef82
--- /dev/null
+++ pattern_fallback.go
@@ -0,0 +1,17 @@
+//go:build !go1.23 || tinygo
+// +build !go1.23 tinygo
+
+package chi
+
+import "net/http"
+
+// supportsPattern is true if the Go version is 1.23 and above.
+//
+// If this is true, `net/http.Request` has field `Pattern`.
+const supportsPattern = false
+
+// setPattern sets the mux matched pattern in the http Request.
+//
+// setPattern is only supported in Go 1.23 and above so
+// this is just a blank function so that it compiles.
+func setPattern(rctx *Context, r *http.Request) {}
diff --git a/pattern_test.go b/pattern_test.go
new file mode 100644
index 000000000..bf8a30454
--- /dev/null
+++ pattern_test.go
@@ -0,0 +1,56 @@
+//go:build go1.23
+// +build go1.23
+
+package chi
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestPattern(t *testing.T) {
+ testCases := []struct {
+ name string
+ pattern string
+ method string
+ requestPath string
+ }{
+ {
+ name: "Basic path value",
+ pattern: "/hubs/{hubID}",
+ method: "GET",
+ requestPath: "/hubs/392",
+ },
+ {
+ name: "Two path values",
+ pattern: "/users/{userID}/conversations/{conversationID}",
+ method: "POST",
+ requestPath: "/users/Gojo/conversations/2948",
+ },
+ {
+ name: "Wildcard path",
+ pattern: "/users/{userID}/friends/*",
+ method: "POST",
+ requestPath: "/users/Gojo/friends/all-of-them/and/more",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ r := NewRouter()
+
+ r.Handle(tc.method+" "+tc.pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(r.Pattern))
+ }))
+
+ ts := httptest.NewServer(r)
+ defer ts.Close()
+
+ _, body := testRequest(t, ts, tc.method, tc.requestPath, nil)
+ if body != tc.pattern {
+ t.Fatalf("expecting %q, got %q", tc.pattern, body)
+ }
+ })
+ }
+}
diff --git tree.go tree.go
index 85fcfdbb8..8b1ed1995 100644
--- tree.go
+++ tree.go
@@ -71,6 +71,7 @@ func RegisterMethod(method string) {
}
mt := methodTyp(2 << n)
methodMap[method] = mt
+ reverseMethodMap[mt] = method
mALL |= mt
}
@@ -328,7 +329,7 @@ func (n *node) replaceChild(label, tail byte, child *node) {
func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node {
nds := n.children[ntyp]
- for i := 0; i < len(nds); i++ {
+ for i := range nds {
if nds[i].label == label && nds[i].tail == tail {
if ntyp == ntRegexp && nds[i].prefix != prefix {
continue
@@ -429,9 +430,7 @@ func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node {
}
// serially loop through each node grouped by the tail delimiter
- for idx := 0; idx < len(nds); idx++ {
- xn = nds[idx]
-
+ for _, xn = range nds {
// label for param nodes is the delimiter byte
p := strings.IndexByte(xsearch, xn.tail)
@@ -650,11 +649,9 @@ func (n *node) routes() []Route {
if h.handler == nil {
continue
}
- m := methodTypString(mt)
- if m == "" {
- continue
+ if m, ok := reverseMethodMap[mt]; ok {
+ hs[m] = h.handler
}
- hs[m] = h.handler
}
rt := Route{subroutes, hs, p}
@@ -772,29 +769,14 @@ func patParamKeys(pattern string) []string {
}
}
-// longestPrefix finds the length of the shared prefix
-// of two strings
-func longestPrefix(k1, k2 string) int {
- max := len(k1)
- if l := len(k2); l < max {
- max = l
- }
- var i int
- for i = 0; i < max; i++ {
+// longestPrefix finds the length of the shared prefix of two strings
+func longestPrefix(k1, k2 string) (i int) {
+ for i = 0; i < min(len(k1), len(k2)); i++ {
if k1[i] != k2[i] {
break
}
}
- return i
-}
-
-func methodTypString(method methodTyp) string {
- for s, t := range methodMap {
- if method == t {
- return s
- }
- }
- return ""
+ return
}
type nodes []*node
DescriptionThis PR is a maintenance and modernization update to the Possible Issues
Security Hotspots
ChangesChanges
sequenceDiagram
participant Client
participant Mux
participant Tree
participant Handler
Client->>Mux: HTTP Request
Mux->>Mux: routeHTTP(w, r)
Mux->>Tree: FindRoute(rctx, method, path)
Tree-->>Mux: handler found
Mux->>Mux: r.SetPathValue(key, value) for each param
alt Go >= 1.23
Mux->>Mux: setPattern(rctx, r) → r.Pattern = routePattern
end
Mux->>Handler: h.ServeHTTP(w, r)
Handler->>Handler: r.PathValue("param") / r.Pattern
Handler-->>Client: HTTP Response
|
This PR contains the following updates:
v5.2.2→v5.2.5Release Notes
go-chi/chi (github.com/go-chi/chi/v5)
v5.2.5Compare Source
What's Changed
New Contributors
Full Changelog: go-chi/chi@v5.2.3...v5.2.5
v5.2.4Compare Source
v5.2.3Compare Source
What's Changed
New Contributors
Full Changelog: go-chi/chi@v5.2.2...v5.2.3
Configuration
📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).
🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.
♻ Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.
🔕 Ignore: Close this PR and you won't be reminded about this update again.
This PR was generated by Mend Renovate. View the repository job log.