diff --git a/internal/errors/errors.go b/internal/errors/errors.go index d64470404..ce87e86f4 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -1,270 +1,153 @@ -// Package errors provides simple error handling primitives. +// Package errors provides error handling primitives with stack traces. // -// The traditional error handling idiom in Go is roughly akin to -// -// if err != nil { -// return err -// } -// -// which applied recursively up the call stack results in error reports -// without context or debugging information. The errors package allows -// programmers to add context to the failure path in their code in a way -// that does not destroy the original value of the error. +// It is a thin layer over the standard library's errors package, adding a +// stack trace at the point an error is created or wrapped. The wrapping +// chain is fully compatible with errors.Is, errors.As, and errors.Unwrap. // // # Adding context to an error // -// The errors.Wrap function returns a new error that adds context to the -// original error by recording a stack trace at the point Wrap is called, -// and the supplied message. For example -// -// _, err := ioutil.ReadAll(r) +// _, err := io.ReadAll(r) // if err != nil { // return errors.Wrap(err, "read failed") // } // -// If additional control is required the errors.WithStack and errors.WithMessage -// functions destructure errors.Wrap into its component operations of annotating -// an error with a stack trace and an a message, respectively. -// -// # Retrieving the cause of an error -// -// Using errors.Wrap constructs a stack of errors, adding context to the -// preceding error. Depending on the nature of the error it may be necessary -// to reverse the operation of errors.Wrap to retrieve the original error -// for inspection. Any error value which implements this interface -// -// type causer interface { -// Cause() error -// } -// -// can be inspected by errors.Cause. errors.Cause will recursively retrieve -// the topmost error which does not implement causer, which is assumed to be -// the original cause. For example: -// -// switch err := errors.Cause(err).(type) { -// case *MyError: -// // handle specifically -// default: -// // unknown error -// } -// -// causer interface is not exported by this package, but is considered a part -// of stable public API. -// // # Formatted printing of errors // -// All error values returned from this package implement fmt.Formatter and can -// be formatted by the fmt package. The following verbs are supported +// %s the error message (full wrap chain) +// %v same as %s +// %+v the error message followed by the captured stack trace +// %q the error message, quoted // -// %s print the error. If the error has a Cause it will be -// printed recursively -// %v see %s -// %+v extended format. Each Frame of the error's StackTrace will -// be printed in detail. +// # Retrieving the stack trace // -// # Retrieving the stack trace of an error or wrapper -// -// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are -// invoked. This information can be retrieved with the following interface. +// Errors returned by this package satisfy the following interface: // // type stackTracer interface { -// StackTrace() errors.StackTrace +// StackTrace() []uintptr // } -// -// Where errors.StackTrace is defined as -// -// type StackTrace []Frame -// -// The Frame type represents a call site in the stack trace. Frame supports -// the fmt.Formatter interface that can be used for printing information about -// the stack trace of this error. For example: -// -// if err, ok := err.(stackTracer); ok { -// for _, f := range err.StackTrace() { -// fmt.Printf("%+s:%d", f) -// } -// } -// -// stackTracer interface is not exported by this package, but is considered a part -// of stable public API. -// -// See the documentation for Frame.Format for more details. -// Fork from https://github.com/pkg/errors package errors import ( + "errors" "fmt" - "io" + "runtime" ) -// New returns an error with the supplied message. -// New also records the stack trace at the point it was called. -func New(message string) error { - return &fundamental{ - msg: message, - stack: callers(), - } +// Re-exported stdlib primitives so callers can use a single import. +var ( + Is = errors.Is + As = errors.As + Unwrap = errors.Unwrap + Join = errors.Join +) + +// withStack wraps an error with a captured stack trace. +type withStack struct { + err error + pcs []uintptr } -// Errorf formats according to a format specifier and returns the string -// as a value that satisfies error. -// Errorf also records the stack trace at the point it was called. -func Errorf(format string, args ...interface{}) error { - return &fundamental{ - msg: fmt.Sprintf(format, args...), - stack: callers(), - } +func (e *withStack) Error() string { + return e.err.Error() } -// fundamental is an error that has a message and a stack, but no caller. -type fundamental struct { - msg string - *stack +func (e *withStack) Unwrap() error { + return e.err } -func (f *fundamental) Error() string { return f.msg } +func (e *withStack) StackTrace() []uintptr { + return e.pcs +} -func (f *fundamental) Format(s fmt.State, verb rune) { +func (e *withStack) Format(s fmt.State, verb rune) { switch verb { case 'v': if s.Flag('+') { - io.WriteString(s, f.msg) - f.stack.Format(s, verb) + fmt.Fprint(s, e.err.Error()) + frames := runtime.CallersFrames(e.pcs) + for { + f, more := frames.Next() + fmt.Fprintf(s, "\n%s\n\t%s:%d", f.Function, f.File, f.Line) + if !more { + break + } + } return } fallthrough case 's': - io.WriteString(s, f.msg) + fmt.Fprint(s, e.err.Error()) case 'q': - fmt.Fprintf(s, "%q", f.msg) + fmt.Fprintf(s, "%q", e.err.Error()) } } +func callers() []uintptr { + var pcs [32]uintptr + n := runtime.Callers(3, pcs[:]) + return pcs[:n] +} + +func attach(err error) error { + return &withStack{err: err, pcs: callers()} +} + +// New returns an error with the supplied message and a captured stack trace. +func New(message string) error { + return attach(errors.New(message)) +} + +// Errorf formats according to a format specifier and returns a new error with +// a captured stack trace. It supports %w for wrapping an existing error. +func Errorf(format string, args ...any) error { + return attach(fmt.Errorf(format, args...)) +} + // WithStack annotates err with a stack trace at the point WithStack was called. // If err is nil, WithStack returns nil. func WithStack(err error) error { if err == nil { return nil } - return &withStack{ - err, - callers(), - } + return attach(err) } -type withStack struct { - error - *stack -} - -func (w *withStack) Cause() error { return w.error } - -func (w *withStack) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - if s.Flag('+') { - fmt.Fprintf(s, "%+v", w.Cause()) - w.stack.Format(s, verb) - return - } - fallthrough - case 's': - io.WriteString(s, w.Error()) - case 'q': - fmt.Fprintf(s, "%q", w.Error()) - } -} - -// Wrap returns an error annotating err with a stack trace -// at the point Wrap is called, and the supplied message. -// If err is nil, Wrap returns nil. -func Wrap(err error, message string) error { - if err == nil { - return nil - } - err = &withMessage{ - cause: err, - msg: message, - } - return &withStack{ - err, - callers(), - } -} - -// Wrapf returns an error annotating err with a stack trace -// at the point Wrapf is call, and the format specifier. -// If err is nil, Wrapf returns nil. -func Wrapf(err error, format string, args ...interface{}) error { - if err == nil { - return nil - } - err = &withMessage{ - cause: err, - msg: fmt.Sprintf(format, args...), - } - return &withStack{ - err, - callers(), - } -} - -// WithMessage annotates err with a new message. +// WithMessage annotates err with a new message, without capturing a stack. // If err is nil, WithMessage returns nil. func WithMessage(err error, message string) error { if err == nil { return nil } - return &withMessage{ - cause: err, - msg: message, + return fmt.Errorf("%s: %w", message, err) +} + +// Wrap returns an error annotating err with a message and a captured stack. +// If err is nil, Wrap returns nil. +func Wrap(err error, message string) error { + if err == nil { + return nil } + return attach(fmt.Errorf("%s: %w", message, err)) } -type withMessage struct { - cause error - msg string -} - -func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } -func (w *withMessage) Cause() error { return w.cause } - -func (w *withMessage) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - if s.Flag('+') { - fmt.Fprintf(s, "%+v\n", w.Cause()) - io.WriteString(s, w.msg) - return - } - fallthrough - case 's', 'q': - io.WriteString(s, w.Error()) +// Wrapf is the formatting variant of Wrap. +// If err is nil, Wrapf returns nil. +func Wrapf(err error, format string, args ...any) error { + if err == nil { + return nil } + return attach(fmt.Errorf(fmt.Sprintf(format, args...)+": %w", err)) } -// Cause returns the underlying cause of the error, if possible. -// An error value has a cause if it implements the following -// interface: -// -// type causer interface { -// Cause() error -// } -// -// If the error does not implement Cause, the original error will -// be returned. If the error is nil, nil will be returned without further -// investigation. +// Cause walks the error's Unwrap chain and returns the root error. +// New code should prefer errors.Is or errors.As. func Cause(err error) error { - type causer interface { - Cause() error - } - for err != nil { - cause, ok := err.(causer) - if !ok { - break + u := errors.Unwrap(err) + if u == nil { + return err } - err = cause.Cause() + err = u } - return err + return nil } diff --git a/internal/errors/errors_test.go b/internal/errors/errors_test.go new file mode 100644 index 000000000..a348c2705 --- /dev/null +++ b/internal/errors/errors_test.go @@ -0,0 +1,233 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package errors + +import ( + stderrors "errors" + "fmt" + "strings" + "testing" +) + +func TestNew_MessageAndStack(t *testing.T) { + err := New("boom") + if err == nil { + t.Fatal("New returned nil") + } + if err.Error() != "boom" { + t.Fatalf("Error() = %q, want %q", err.Error(), "boom") + } + ws, ok := err.(*withStack) + if !ok { + t.Fatalf("New did not return *withStack, got %T", err) + } + if len(ws.StackTrace()) == 0 { + t.Fatal("StackTrace is empty") + } +} + +func TestErrorf_FormatsMessage(t *testing.T) { + err := Errorf("code=%d reason=%s", 42, "oops") + if err.Error() != "code=42 reason=oops" { + t.Fatalf("Error() = %q", err.Error()) + } +} + +func TestErrorf_SupportsWrapVerb(t *testing.T) { + root := stderrors.New("root") + err := Errorf("ctx: %w", root) + if !stderrors.Is(err, root) { + t.Fatal("errors.Is did not find root through Errorf(%w)") + } +} + +func TestWithStack_NilReturnsNil(t *testing.T) { + if got := WithStack(nil); got != nil { + t.Fatalf("WithStack(nil) = %v, want nil", got) + } +} + +func TestWithStack_PreservesMessage(t *testing.T) { + inner := stderrors.New("plain") + err := WithStack(inner) + if err.Error() != "plain" { + t.Fatalf("Error() = %q, want %q", err.Error(), "plain") + } + if !stderrors.Is(err, inner) { + t.Fatal("errors.Is did not find inner through WithStack") + } +} + +func TestWithMessage_NilReturnsNil(t *testing.T) { + if got := WithMessage(nil, "ignored"); got != nil { + t.Fatalf("WithMessage(nil) = %v, want nil", got) + } +} + +func TestWithMessage_PrependsAndWraps(t *testing.T) { + inner := stderrors.New("root") + err := WithMessage(inner, "ctx") + if err.Error() != "ctx: root" { + t.Fatalf("Error() = %q", err.Error()) + } + if !stderrors.Is(err, inner) { + t.Fatal("errors.Is did not traverse WithMessage") + } + // WithMessage must not capture a stack — verify the result is not a *withStack. + if _, ok := err.(*withStack); ok { + t.Fatal("WithMessage should not attach a stack") + } +} + +func TestWrap_NilReturnsNil(t *testing.T) { + if got := Wrap(nil, "ignored"); got != nil { + t.Fatalf("Wrap(nil) = %v, want nil", got) + } +} + +func TestWrap_MessageAndStackAndChain(t *testing.T) { + inner := stderrors.New("root") + err := Wrap(inner, "ctx") + if err.Error() != "ctx: root" { + t.Fatalf("Error() = %q", err.Error()) + } + ws, ok := err.(*withStack) + if !ok { + t.Fatalf("Wrap did not return *withStack, got %T", err) + } + if len(ws.StackTrace()) == 0 { + t.Fatal("StackTrace is empty") + } + if !stderrors.Is(err, inner) { + t.Fatal("errors.Is did not traverse Wrap") + } +} + +func TestWrapf_NilReturnsNil(t *testing.T) { + if got := Wrapf(nil, "ignored %d", 1); got != nil { + t.Fatalf("Wrapf(nil) = %v, want nil", got) + } +} + +func TestWrapf_FormatsAndChains(t *testing.T) { + inner := stderrors.New("root") + err := Wrapf(inner, "ctx=%d op=%s", 7, "read") + if err.Error() != "ctx=7 op=read: root" { + t.Fatalf("Error() = %q", err.Error()) + } + if !stderrors.Is(err, inner) { + t.Fatal("errors.Is did not traverse Wrapf") + } +} + +func TestCause_NilReturnsNil(t *testing.T) { + if got := Cause(nil); got != nil { + t.Fatalf("Cause(nil) = %v, want nil", got) + } +} + +func TestCause_NoUnwrapReturnsSelf(t *testing.T) { + root := stderrors.New("root") + if got := Cause(root); got != root { + t.Fatalf("Cause(root) = %v, want root", got) + } +} + +func TestCause_WalksToRoot(t *testing.T) { + root := stderrors.New("root") + err := Wrap(Wrap(WithMessage(root, "a"), "b"), "c") + if got := Cause(err); got != root { + t.Fatalf("Cause = %v, want root", got) + } +} + +func TestUnwrap_ReturnsInner(t *testing.T) { + inner := stderrors.New("inner") + err := WithStack(inner) + if got := stderrors.Unwrap(err); got != inner { + t.Fatalf("Unwrap = %v, want inner", got) + } +} + +func TestFormat_S(t *testing.T) { + err := New("msg") + got := fmt.Sprintf("%s", err) + if got != "msg" { + t.Fatalf("%%s = %q, want %q", got, "msg") + } +} + +func TestFormat_VFallsThroughToS(t *testing.T) { + err := New("msg") + got := fmt.Sprintf("%v", err) + if got != "msg" { + t.Fatalf("%%v = %q, want %q", got, "msg") + } +} + +func TestFormat_VPlusIncludesStack(t *testing.T) { + err := New("msg") + got := fmt.Sprintf("%+v", err) + if !strings.HasPrefix(got, "msg") { + t.Fatalf("%%+v output does not start with message: %q", got) + } + // Must include this test function in the captured stack. + if !strings.Contains(got, "TestFormat_VPlusIncludesStack") { + t.Fatalf("%%+v output missing caller frame:\n%s", got) + } + // Must include a file:line reference. + if !strings.Contains(got, "errors_test.go:") { + t.Fatalf("%%+v output missing file:line:\n%s", got) + } +} + +func TestFormat_Q(t *testing.T) { + err := New("msg") + got := fmt.Sprintf("%q", err) + if got != `"msg"` { + t.Fatalf("%%q = %q, want %q", got, `"msg"`) + } +} + +func TestIs_ThroughWrapChain(t *testing.T) { + sentinel := stderrors.New("sentinel") + err := Wrap(WithMessage(WithStack(sentinel), "mid"), "outer") + if !stderrors.Is(err, sentinel) { + t.Fatal("errors.Is failed to traverse Wrap/WithMessage/WithStack chain") + } +} + +type typedErr struct{ code int } + +func (t *typedErr) Error() string { return fmt.Sprintf("typed(%d)", t.code) } + +func TestAs_ThroughWrapChain(t *testing.T) { + target := &typedErr{code: 7} + err := Wrap(WithStack(target), "ctx") + var got *typedErr + if !stderrors.As(err, &got) { + t.Fatal("errors.As failed to find *typedErr in chain") + } + if got.code != 7 { + t.Fatalf("As returned code=%d, want 7", got.code) + } +} + +func TestReExports_AreStdlib(t *testing.T) { + // Sanity: the re-exports must actually be the stdlib functions. + a := stderrors.New("a") + b := stderrors.New("b") + joined := Join(a, b) + if !Is(joined, a) || !Is(joined, b) { + t.Fatal("Join/Is re-exports do not match stdlib behavior") + } + if Unwrap(WithStack(a)) != a { + t.Fatal("Unwrap re-export does not match stdlib behavior") + } + var target *typedErr + te := &typedErr{code: 1} + if !As(WithStack(te), &target) { + t.Fatal("As re-export does not match stdlib behavior") + } +} diff --git a/internal/errors/stack.go b/internal/errors/stack.go deleted file mode 100644 index 7e5aacc48..000000000 --- a/internal/errors/stack.go +++ /dev/null @@ -1,187 +0,0 @@ -// Fork from https://github.com/pkg/errors -package errors - -import ( - "fmt" - "io" - "path" - "runtime" - "strings" -) - -// Frame represents a program counter inside a stack frame. -type Frame uintptr - -// pc returns the program counter for this frame; -// multiple frames may have the same PC value. -func (f Frame) pc() uintptr { return uintptr(f) - 1 } - -// file returns the full path to the file that contains the -// function for this Frame's pc. -func (f Frame) file() string { - fn := runtime.FuncForPC(f.pc()) - if fn == nil { - return "unknown" - } - file, _ := fn.FileLine(f.pc()) - return file -} - -// line returns the line number of source code of the -// function for this Frame's pc. -func (f Frame) line() int { - fn := runtime.FuncForPC(f.pc()) - if fn == nil { - return 0 - } - _, line := fn.FileLine(f.pc()) - return line -} - -// Format formats the frame according to the fmt.Formatter interface. -// -// %s source file -// %d source line -// %n function name -// %v equivalent to %s:%d -// -// Format accepts flags that alter the printing of some verbs, as follows: -// -// %+s path of source file relative to the compile time GOPATH -// %+v equivalent to %+s:%d -func (f Frame) Format(s fmt.State, verb rune) { - switch verb { - case 's': - switch { - case s.Flag('+'): - pc := f.pc() - fn := runtime.FuncForPC(pc) - if fn == nil { - io.WriteString(s, "unknown") - } else { - file, _ := fn.FileLine(pc) - fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) - } - default: - io.WriteString(s, path.Base(f.file())) - } - case 'd': - fmt.Fprintf(s, "%d", f.line()) - case 'n': - name := runtime.FuncForPC(f.pc()).Name() - io.WriteString(s, funcname(name)) - case 'v': - f.Format(s, 's') - io.WriteString(s, ":") - f.Format(s, 'd') - } -} - -// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). -type StackTrace []Frame - -// Format formats the stack of Frames according to the fmt.Formatter interface. -// -// %s lists source files for each Frame in the stack -// %v lists the source file and line number for each Frame in the stack -// -// Format accepts flags that alter the printing of some verbs, as follows: -// -// %+v Prints filename, function, and line number for each Frame in the stack. -func (st StackTrace) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - switch { - case s.Flag('+'): - for _, f := range st { - fmt.Fprintf(s, "\n%+v", f) - } - case s.Flag('#'): - fmt.Fprintf(s, "%#v", []Frame(st)) - default: - fmt.Fprintf(s, "%v", []Frame(st)) - } - case 's': - fmt.Fprintf(s, "%s", []Frame(st)) - } -} - -// stack represents a stack of program counters. -type stack []uintptr - -func (s *stack) Format(st fmt.State, verb rune) { - switch verb { - case 'v': - switch { - case st.Flag('+'): - for _, pc := range *s { - f := Frame(pc) - fmt.Fprintf(st, "\n%+v", f) - } - } - } -} - -func (s *stack) StackTrace() StackTrace { - f := make([]Frame, len(*s)) - for i := 0; i < len(f); i++ { - f[i] = Frame((*s)[i]) - } - return f -} - -func callers() *stack { - const depth = 32 - var pcs [depth]uintptr - n := runtime.Callers(3, pcs[:]) - var st stack = pcs[0:n] - return &st -} - -// funcname removes the path prefix component of a function's name reported by func.Name(). -func funcname(name string) string { - i := strings.LastIndex(name, "/") - name = name[i+1:] - i = strings.Index(name, ".") - return name[i+1:] -} - -func trimGOPATH(name, file string) string { - // Here we want to get the source file path relative to the compile time - // GOPATH. As of Go 1.6.x there is no direct way to know the compiled - // GOPATH at runtime, but we can infer the number of path segments in the - // GOPATH. We note that fn.Name() returns the function name qualified by - // the import path, which does not include the GOPATH. Thus we can trim - // segments from the beginning of the file path until the number of path - // separators remaining is one more than the number of path separators in - // the function name. For example, given: - // - // GOPATH /home/user - // file /home/user/src/pkg/sub/file.go - // fn.Name() pkg/sub.Type.Method - // - // We want to produce: - // - // pkg/sub/file.go - // - // From this we can easily see that fn.Name() has one less path separator - // than our desired output. We count separators from the end of the file - // path until it finds two more than in the function name and then move - // one character forward to preserve the initial path segment without a - // leading separator. - const sep = "/" - goal := strings.Count(name, sep) + 2 - i := len(file) - for n := 0; n < goal; n++ { - i = strings.LastIndex(file[:i], sep) - if i == -1 { - // not enough separators found, set i so that the slice expression - // below leaves file unmodified - i = -len(sep) - break - } - } - // get back to 0 or trim the leading separator - file = file[i+len(sep):] - return file -}