diff --git a/CHANGELOG.md b/CHANGELOG.md index 73dba9b..865fbe5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - `zb store object delete` is no longer flaky - ([#135](https://github.com/256lights/zb/issues/135)) + ([#135](https://github.com/256lights/zb/issues/135)). +- Lua operator metamethods now receive their arguments in the correct order + when one of the operands is a constant + ([#152](https://github.com/256lights/zb/issues/152)). ## [0.1.0][] - 2025-06-15 diff --git a/internal/lua/vm.go b/internal/lua/vm.go index 15a87b8..a2d72cd 100644 --- a/internal/lua/vm.go +++ b/internal/lua/vm.go @@ -692,12 +692,11 @@ func (l *State) exec(ctx context.Context) (err error) { if err != nil { return err } - result, err := l.callArithmeticMetamethod( - ctx, - prevOperator.TagMethod(), - *ra, - integerValue(luacode.SignedArg(i.ArgB())), - ) + var arg1, arg2 value = *ra, integerValue(luacode.SignedArg(i.ArgB())) + if i.K() { + arg1, arg2 = arg2, arg1 + } + result, err := l.callArithmeticMetamethod(ctx, prevOperator.TagMethod(), arg1, arg2) if err != nil { return err } @@ -725,12 +724,11 @@ func (l *State) exec(ctx context.Context) (err error) { if err != nil { return err } - result, err := l.callArithmeticMetamethod( - ctx, - prevOperator.TagMethod(), - *ra, - importConstant(kb), - ) + arg1, arg2 := *ra, importConstant(kb) + if i.K() { + arg1, arg2 = arg2, arg1 + } + result, err := l.callArithmeticMetamethod(ctx, prevOperator.TagMethod(), arg1, arg2) if err != nil { return err } diff --git a/internal/lua/vm_test.go b/internal/lua/vm_test.go index c013dac..48fbeab 100644 --- a/internal/lua/vm_test.go +++ b/internal/lua/vm_test.go @@ -785,4 +785,157 @@ func TestVM(t *testing.T) { t.Errorf("state.ToInteger(-1) = %d, %t; want %d, true", got, ok, wantResult) } }) + + t.Run("FlippedAddImmediateMetamethod", func(t *testing.T) { + ctx := context.Background() + state := new(State) + defer func() { + if err := state.Close(); err != nil { + t.Error("Close:", err) + } + }() + + // Set a metatable for nil which treats the nil like 0 when adding. + state.PushNil() + var isNil [2]bool + NewPureLib(state, map[string]Function{ + luacode.TagMethodAdd.String(): func(ctx context.Context, l *State) (int, error) { + for i := range 2 { + isNil[i] = l.IsNil(1 + i) + if isNil[i] { + l.PushInteger(0) + l.Replace(1 + i) + } + } + if err := l.Arithmetic(ctx, luacode.Add); err != nil { + return 0, err + } + return 1, nil + }, + }) + if err := state.SetMetatable(-2); err != nil { + t.Fatal(err) + } + + const source = `return 1 + nil` + if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { + t.Fatal(err) + } + if err := state.Call(ctx, 0, 1); err != nil { + t.Fatal(err) + } + + if isNil[0] || !isNil[1] { + t.Errorf("when calling __add, (x == nil), (y == nil) = %t, %t; want false, true", isNil[0], isNil[1]) + } + + const wantResult = 1 + if got, want := state.Type(-1), TypeNumber; got != want { + t.Fatalf("state.Type(-1) = %v; want %v", got, want) + } else if got, ok := state.ToInteger(-1); got != wantResult || !ok { + t.Errorf("state.ToInteger(-1) = %d, %t; want %d, true", got, ok, wantResult) + } + }) + + t.Run("FlippedBOrImmediateMetamethod", func(t *testing.T) { + ctx := context.Background() + state := new(State) + defer func() { + if err := state.Close(); err != nil { + t.Error("Close:", err) + } + }() + + // Set a metatable for nil which treats the nil like 0 when adding. + state.PushNil() + var isNil [2]bool + NewPureLib(state, map[string]Function{ + luacode.TagMethodBOr.String(): func(ctx context.Context, l *State) (int, error) { + for i := range 2 { + isNil[i] = l.IsNil(1 + i) + if isNil[i] { + l.PushInteger(0) + l.Replace(1 + i) + } + } + if err := l.Arithmetic(ctx, luacode.BitwiseOr); err != nil { + return 0, err + } + return 1, nil + }, + }) + if err := state.SetMetatable(-2); err != nil { + t.Fatal(err) + } + + const source = `return 1 | nil` + if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { + t.Fatal(err) + } + if err := state.Call(ctx, 0, 1); err != nil { + t.Fatal(err) + } + + if isNil[0] || !isNil[1] { + t.Errorf("when calling __bor, (x == nil), (y == nil) = %t, %t; want false, true", isNil[0], isNil[1]) + } + + const wantResult = 1 + if got, want := state.Type(-1), TypeNumber; got != want { + t.Fatalf("state.Type(-1) = %v; want %v", got, want) + } else if got, ok := state.ToInteger(-1); got != wantResult || !ok { + t.Errorf("state.ToInteger(-1) = %d, %t; want %d, true", got, ok, wantResult) + } + }) + + t.Run("FlippedAddConstantMetamethod", func(t *testing.T) { + ctx := context.Background() + state := new(State) + defer func() { + if err := state.Close(); err != nil { + t.Error("Close:", err) + } + }() + + // Set a metatable for nil which treats the nil like 0 when adding. + state.PushNil() + var isNil [2]bool + NewPureLib(state, map[string]Function{ + luacode.TagMethodAdd.String(): func(ctx context.Context, l *State) (int, error) { + for i := range 2 { + isNil[i] = l.IsNil(1 + i) + if isNil[i] { + l.PushInteger(0) + l.Replace(1 + i) + } + } + if err := l.Arithmetic(ctx, luacode.Add); err != nil { + return 0, err + } + return 1, nil + }, + }) + if err := state.SetMetatable(-2); err != nil { + t.Fatal(err) + } + + const source = `return 3.14 + nil` + if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { + t.Fatal(err) + } + if err := state.Call(ctx, 0, 1); err != nil { + t.Fatal(err) + } + + if isNil[0] || !isNil[1] { + t.Errorf("when calling __add, (x == nil), (y == nil) = %t, %t; want false, true", isNil[0], isNil[1]) + } + + const wantResult = 3.14 + if got, want := state.Type(-1), TypeNumber; got != want { + t.Fatalf("state.Type(-1) = %v; want %v", got, want) + } else if got, ok := state.ToNumber(-1); got != wantResult || !ok { + t.Errorf("state.ToNumber(-1) = %g, %t; want %g, true", got, ok, wantResult) + } + }) } diff --git a/internal/luacode/code.go b/internal/luacode/code.go index fbc62b9..3d18615 100644 --- a/internal/luacode/code.go +++ b/internal/luacode/code.go @@ -561,7 +561,6 @@ func (p *parser) codeCommutative(fs *funcState, operator binaryOperator, e1, e2 flip := e1.isNumeral() if flip { e1, e2 = e2, e1 - flip = true } if i, isInt := e2.intConstant(); isInt && fitsSignedArg(i) && operator == binaryOperatorAdd { return p.codeBinaryExpImmediate(fs, OpAddI, e1, e2, flip, line, TagMethodAdd) @@ -570,22 +569,24 @@ func (p *parser) codeCommutative(fs *funcState, operator binaryOperator, e1, e2 } // codeBitwise appends instructions for bitwise operators +// to fs.Code. // // Equivalent to `codebitwise` in upstream Lua. func (p *parser) codeBitwise(fs *funcState, operator binaryOperator, e1, e2 expressionDescriptor, line int) (expressionDescriptor, error) { - // All operations are commutative, - // so if first operand is a numeric constant, - // change order of operands to try to use an immediate or K operator. - flip := e1.kind == expressionKindIntConstant - if flip { - e1, e2 = e2, e1 - } - if e2.kind == expressionKindIntConstant { + switch { + case e1.kind == expressionKindIntConstant: + // All operations are commutative, + // so if first operand is a numeric constant, + // change order of operands to try to use an immediate or K operator. + if e1, _, ok := p.toConstantTable(fs, e1); ok { + return p.codeBinaryExpConstant(fs, operator, e2, e1, true, line) + } + case e2.kind == expressionKindIntConstant: if e2, _, ok := p.toConstantTable(fs, e2); ok { - return p.codeBinaryExpConstant(fs, operator, e1, e2, flip, line) + return p.codeBinaryExpConstant(fs, operator, e1, e2, false, line) } } - return p.codeBinaryExpNoConstants(fs, operator, e1, e2, flip, line) + return p.codeBinaryExp(fs, operator, e1, e2, line) } // codeArithmetic appends instructions for an arithmetic binary operator @@ -598,17 +599,7 @@ func (p *parser) codeArithmetic(fs *funcState, operator binaryOperator, e1, e2 e return p.codeBinaryExpConstant(fs, operator, e1, e2, flip, line) } } - return p.codeBinaryExpNoConstants(fs, operator, e1, e2, flip, line) -} - -// codeBinaryExpNoConstants appends the instructions -// for a binary expression without constant operands -// to fs.Code. -// -// Equivalent to `codebinNoK` in upstream Lua. -func (p *parser) codeBinaryExpNoConstants(fs *funcState, operator binaryOperator, e1, e2 expressionDescriptor, flip bool, line int) (expressionDescriptor, error) { if flip { - // Back to original order. e1, e2 = e2, e1 } return p.codeBinaryExp(fs, operator, e1, e2, line)