Files
be.ems/lib/eval/evaluate.go
2024-11-22 10:06:51 +08:00

120 lines
2.4 KiB
Go

package evaluate
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"math"
"regexp"
"strconv"
"strings"
)
// Parse and caculate expression
func CalcExpr(expr string, paramValues map[string]any) (float64, error) {
// match parameter with ''
re := regexp.MustCompile(`'([^']+)'`)
matches := re.FindAllStringSubmatch(expr, -1)
// replace to value
for _, match := range matches {
paramName := match[1]
value, exists := paramValues[paramName]
if !exists {
return 0, fmt.Errorf("parameter '%s' not found", paramName)
}
expr = strings.Replace(expr, match[0], fmt.Sprintf("%v", value), 1)
}
// expression to evaluate
result, err := evalExpr(expr)
if math.IsNaN(result) {
return 0.0, err
}
return result, err
}
// eval 解析和计算表达式
func evalExpr(expr string) (float64, error) {
//fset := token.NewFileSet()
node, err := parser.ParseExpr(expr)
if err != nil {
return 0, err
}
return evalNode(node)
}
// EvaluateExpr 解析并计算给定的表达式
func EvalExpr(expr string, values map[string]any) (float64, error) {
// 解析表达式
node, err := parser.ParseExpr(expr)
if err != nil {
return 0, err
}
// 遍历 AST 并替换变量
ast.Inspect(node, func(n ast.Node) bool {
if ident, ok := n.(*ast.Ident); ok {
if val, ok := values[ident.Name]; ok {
// 替换标识符为对应值
ident.Name = fmt.Sprintf("%v", val)
}
}
return true
})
// 计算表达式
return evalNode(node)
}
// eval 递归计算 AST 节点
func evalNode(node ast.Node) (float64, error) {
var result float64
switch n := node.(type) {
case *ast.BinaryExpr:
left, err := evalNode(n.X)
if err != nil {
return 0, err
}
right, err := evalNode(n.Y)
if err != nil {
return 0, err
}
switch n.Op {
case token.ADD:
result = left + right
case token.SUB:
result = left - right
case token.MUL:
result = left * right
case token.QUO:
if right == 0 {
return math.NaN(), fmt.Errorf("divisor cannot be zero")
}
result = left / right
}
case *ast.BasicLit:
var err error
result, err = strconv.ParseFloat(n.Value, 64)
if err != nil {
return 0, err
}
case *ast.Ident:
val, err := strconv.ParseFloat(n.Name, 64)
if err != nil {
return 0, fmt.Errorf("unsupported expression: %s", n.Name)
}
result = val
case *ast.ParenExpr:
return evalNode(n.X) // 递归评估括号中的表达式
default:
return 0, fmt.Errorf("unsupported expression: %T", n)
}
return result, nil
}