This commit is contained in:
82
pkg/gitea/gitea.go
Normal file
82
pkg/gitea/gitea.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
mcpContext "gitea.com/gitea/gitea-mcp/pkg/context"
|
||||
"gitea.com/gitea/gitea-mcp/pkg/flag"
|
||||
|
||||
"gitea.dev/sdk"
|
||||
)
|
||||
|
||||
var (
|
||||
clientCache sync.Map // token -> *gitea.Client
|
||||
sharedTransOnce sync.Once
|
||||
sharedTrans *http.Transport
|
||||
)
|
||||
|
||||
func sharedTransport() *http.Transport {
|
||||
sharedTransOnce.Do(func() {
|
||||
sharedTrans = http.DefaultTransport.(*http.Transport).Clone()
|
||||
if flag.Insecure {
|
||||
sharedTrans.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // user-requested insecure mode
|
||||
}
|
||||
})
|
||||
return sharedTrans
|
||||
}
|
||||
|
||||
// NewClient returns a cached *gitea.Client keyed by host+token. The SDK's per-client
|
||||
// version cache and the shared transport let us reuse keep-alive connections
|
||||
// and avoid the SDK's /api/v1/version preflight on every tool call.
|
||||
func NewClient(token string) (*gitea.Client, error) {
|
||||
key := flag.Host + "\x00" + token
|
||||
if v, ok := clientCache.Load(key); ok {
|
||||
return v.(*gitea.Client), nil
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: sharedTransport(),
|
||||
CheckRedirect: checkRedirect,
|
||||
}
|
||||
opts := []gitea.ClientOption{
|
||||
gitea.SetToken(token),
|
||||
gitea.SetHTTPClient(httpClient),
|
||||
}
|
||||
if flag.Debug {
|
||||
opts = append(opts, gitea.SetDebugMode())
|
||||
}
|
||||
client, err := gitea.NewClient(flag.Host, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create gitea client err: %w", err)
|
||||
}
|
||||
client.SetUserAgent("gitea-mcp-server/" + flag.Version)
|
||||
|
||||
actual, _ := clientCache.LoadOrStore(key, client)
|
||||
return actual.(*gitea.Client), nil
|
||||
}
|
||||
|
||||
// checkRedirect prevents Go from silently changing mutating requests (POST, PATCH, etc.)
|
||||
// to GET when following 301/302/303 redirects, which would drop the request body and
|
||||
// make writes appear to succeed when they didn't.
|
||||
func checkRedirect(_ *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
if via[0].Method != http.MethodGet && via[0].Method != http.MethodHead {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClientFromContext(ctx context.Context) (*gitea.Client, error) {
|
||||
token, ok := ctx.Value(mcpContext.TokenContextKey).(string)
|
||||
if !ok {
|
||||
token = flag.Token
|
||||
}
|
||||
return NewClient(token)
|
||||
}
|
||||
120
pkg/gitea/redirect_test.go
Normal file
120
pkg/gitea/redirect_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"gitea.com/gitea/gitea-mcp/pkg/flag"
|
||||
)
|
||||
|
||||
func TestCheckRedirect(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
method string
|
||||
wantErr error
|
||||
}{
|
||||
{"allows GET", http.MethodGet, nil},
|
||||
{"allows HEAD", http.MethodHead, nil},
|
||||
{"blocks PATCH", http.MethodPatch, http.ErrUseLastResponse},
|
||||
{"blocks POST", http.MethodPost, http.ErrUseLastResponse},
|
||||
{"blocks PUT", http.MethodPut, http.ErrUseLastResponse},
|
||||
{"blocks DELETE", http.MethodDelete, http.ErrUseLastResponse},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
via := []*http.Request{{Method: tc.method}}
|
||||
err := checkRedirect(nil, via)
|
||||
if err != tc.wantErr {
|
||||
t.Fatalf("expected %v, got %v", tc.wantErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("stops after 10 redirects", func(t *testing.T) {
|
||||
via := make([]*http.Request, 10)
|
||||
for i := range via {
|
||||
via[i] = &http.Request{Method: http.MethodGet}
|
||||
}
|
||||
err := checkRedirect(nil, via)
|
||||
if err == nil || err == http.ErrUseLastResponse {
|
||||
t.Fatalf("expected redirect limit error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDoJSON_RepoRenameRedirect is a regression test for the bug where a PATCH
|
||||
// request to a renamed repo got a 301 redirect, Go's http.Client silently
|
||||
// changed the method to GET, and the write appeared to succeed without error.
|
||||
func TestDoJSON_RepoRenameRedirect(t *testing.T) {
|
||||
// Simulate a Gitea API that returns 301 for the old repo name (like a renamed repo).
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("PATCH /api/v1/repos/owner/old-name/pulls/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/api/v1/repos/owner/new-name/pulls/1", http.StatusMovedPermanently)
|
||||
})
|
||||
mux.HandleFunc("PATCH /api/v1/repos/owner/new-name/pulls/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, `{"id":1,"title":"updated"}`)
|
||||
})
|
||||
mux.HandleFunc("GET /api/v1/repos/owner/new-name/pulls/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, `{"id":1,"title":"not-updated"}`)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
origHost := flag.Host
|
||||
defer func() { flag.Host = origHost }()
|
||||
flag.Host = srv.URL
|
||||
|
||||
var result map[string]any
|
||||
status, err := DoJSON(context.Background(), http.MethodPatch, "repos/owner/old-name/pulls/1", nil, map[string]string{"title": "updated"}, &result)
|
||||
if err != nil {
|
||||
// The redirect should be blocked, returning the 301 response directly.
|
||||
// DoJSON treats non-2xx as an error, which is the correct behavior.
|
||||
if status != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected status 301, got %d (err: %v)", status, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If we reach here without error, the redirect was followed. Verify the
|
||||
// method was preserved (title should be "updated", not "not-updated").
|
||||
title, _ := result["title"].(string)
|
||||
if title == "not-updated" {
|
||||
t.Fatal("PATCH was silently converted to GET on 301 redirect — write was lost")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDoJSON_GETRedirectFollowed verifies that GET requests still follow redirects normally.
|
||||
func TestDoJSON_GETRedirectFollowed(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /api/v1/repos/owner/old-name/pulls/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/api/v1/repos/owner/new-name/pulls/1", http.StatusMovedPermanently)
|
||||
})
|
||||
mux.HandleFunc("GET /api/v1/repos/owner/new-name/pulls/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "title": "found"})
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
origHost := flag.Host
|
||||
defer func() { flag.Host = origHost }()
|
||||
flag.Host = srv.URL
|
||||
|
||||
var result map[string]any
|
||||
status, err := DoJSON(context.Background(), http.MethodGet, "repos/owner/old-name/pulls/1", nil, nil, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("GET redirect should be followed, got error: %v (status %d)", err, status)
|
||||
}
|
||||
title, _ := result["title"].(string)
|
||||
if title != "found" {
|
||||
t.Fatalf("expected title 'found', got %q", title)
|
||||
}
|
||||
}
|
||||
184
pkg/gitea/rest.go
Normal file
184
pkg/gitea/rest.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
mcpContext "gitea.com/gitea/gitea-mcp/pkg/context"
|
||||
"gitea.com/gitea/gitea-mcp/pkg/flag"
|
||||
)
|
||||
|
||||
const (
|
||||
httpClientTimeout = 60 * time.Second
|
||||
errBodySnippetSize = 8192
|
||||
)
|
||||
|
||||
type HTTPError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *HTTPError) Error() string {
|
||||
if e.Body == "" {
|
||||
return fmt.Sprintf("request failed with status %d", e.StatusCode)
|
||||
}
|
||||
return fmt.Sprintf("request failed with status %d: %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
func tokenFromContext(ctx context.Context) string {
|
||||
if ctx != nil {
|
||||
if token, ok := ctx.Value(mcpContext.TokenContextKey).(string); ok && token != "" {
|
||||
return token
|
||||
}
|
||||
}
|
||||
return flag.Token
|
||||
}
|
||||
|
||||
var (
|
||||
restClientOnce sync.Once
|
||||
restClient *http.Client
|
||||
)
|
||||
|
||||
func restHTTPClient() *http.Client {
|
||||
restClientOnce.Do(func() {
|
||||
restClient = &http.Client{
|
||||
Transport: sharedTransport(),
|
||||
Timeout: httpClientTimeout,
|
||||
CheckRedirect: checkRedirect,
|
||||
}
|
||||
})
|
||||
return restClient
|
||||
}
|
||||
|
||||
func buildAPIURL(path string, query url.Values) (string, error) {
|
||||
host := strings.TrimRight(flag.Host, "/")
|
||||
if host == "" {
|
||||
return "", errors.New("gitea host is empty")
|
||||
}
|
||||
p := strings.TrimLeft(path, "/")
|
||||
u, err := url.Parse(fmt.Sprintf("%s/api/v1/%s", host, p))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if query != nil {
|
||||
u.RawQuery = query.Encode()
|
||||
}
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// DoJSON performs an API request and decodes a JSON response into respOut (if non-nil).
|
||||
// It returns the HTTP status code.
|
||||
func DoJSON(ctx context.Context, method, path string, query url.Values, body, respOut any) (int, error) {
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(b)
|
||||
}
|
||||
|
||||
u, err := buildAPIURL(path, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, u, bodyReader)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
token := tokenFromContext(ctx)
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "token "+token)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
client := restHTTPClient()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("do request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
bodySnippet, _ := io.ReadAll(io.LimitReader(resp.Body, errBodySnippetSize))
|
||||
return resp.StatusCode, &HTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(bodySnippet))}
|
||||
}
|
||||
|
||||
if respOut == nil {
|
||||
_, _ = io.Copy(io.Discard, resp.Body) // best-effort
|
||||
return resp.StatusCode, nil
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(respOut); err != nil {
|
||||
return resp.StatusCode, fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
return resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// DoBytes performs an API request and returns the raw response bytes.
|
||||
// It returns the HTTP status code.
|
||||
func DoBytes(ctx context.Context, method, path string, query url.Values, body any, accept string) ([]byte, int, error) {
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(b)
|
||||
}
|
||||
|
||||
u, err := buildAPIURL(path, query)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, u, bodyReader)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
token := tokenFromContext(ctx)
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "token "+token)
|
||||
}
|
||||
if accept != "" {
|
||||
req.Header.Set("Accept", accept)
|
||||
}
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
client := restHTTPClient()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("do request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
bodySnippet := respBytes
|
||||
if len(bodySnippet) > errBodySnippetSize {
|
||||
bodySnippet = bodySnippet[:errBodySnippetSize]
|
||||
}
|
||||
return nil, resp.StatusCode, &HTTPError{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(bodySnippet))}
|
||||
}
|
||||
|
||||
return respBytes, resp.StatusCode, nil
|
||||
}
|
||||
30
pkg/gitea/rest_test.go
Normal file
30
pkg/gitea/rest_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
mcpContext "gitea.com/gitea/gitea-mcp/pkg/context"
|
||||
"gitea.com/gitea/gitea-mcp/pkg/flag"
|
||||
)
|
||||
|
||||
func TestTokenFromContext(t *testing.T) {
|
||||
orig := flag.Token
|
||||
defer func() { flag.Token = orig }()
|
||||
|
||||
flag.Token = "flag-token"
|
||||
|
||||
t.Run("context token wins", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), mcpContext.TokenContextKey, "ctx-token")
|
||||
if got := tokenFromContext(ctx); got != "ctx-token" {
|
||||
t.Fatalf("tokenFromContext() = %q, want %q", got, "ctx-token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fallback to flag token", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
if got := tokenFromContext(ctx); got != "flag-token" {
|
||||
t.Fatalf("tokenFromContext() = %q, want %q", got, "flag-token")
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user