347 lines
8.2 KiB
Go
347 lines
8.2 KiB
Go
package engine
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"git.kingecg.top/kingecg/gomog/pkg/types"
|
|
)
|
|
|
|
// TestAggregationPipelineIntegration 测试聚合管道集成
|
|
func TestAggregationPipelineIntegration(t *testing.T) {
|
|
store := NewMemoryStore(nil)
|
|
collection := "test.agg_integration"
|
|
|
|
// Setup test data
|
|
store.collections[collection] = &Collection{
|
|
name: collection,
|
|
documents: map[string]types.Document{
|
|
"doc1": {
|
|
ID: "doc1",
|
|
Data: map[string]interface{}{"category": "A", "score": 85, "quantity": 10},
|
|
},
|
|
"doc2": {
|
|
ID: "doc2",
|
|
Data: map[string]interface{}{"category": "A", "score": 92, "quantity": 5},
|
|
},
|
|
"doc3": {
|
|
ID: "doc3",
|
|
Data: map[string]interface{}{"category": "B", "score": 78, "quantity": 15},
|
|
},
|
|
"doc4": {
|
|
ID: "doc4",
|
|
Data: map[string]interface{}{"category": "B", "score": 95, "quantity": 8},
|
|
},
|
|
},
|
|
}
|
|
|
|
engine := &AggregationEngine{store: store}
|
|
|
|
tests := []struct {
|
|
name string
|
|
pipeline []types.AggregateStage
|
|
expectedLen int
|
|
checkField string
|
|
expectedVal interface{}
|
|
}{
|
|
{
|
|
name: "match and group with sum",
|
|
pipeline: []types.AggregateStage{
|
|
{Stage: "$match", Spec: types.Filter{"score": types.Filter{"$gte": float64(80)}}},
|
|
{
|
|
Stage: "$group",
|
|
Spec: map[string]interface{}{
|
|
"_id": "$category",
|
|
"total": map[string]interface{}{"$sum": "$quantity"},
|
|
},
|
|
},
|
|
},
|
|
expectedLen: 2,
|
|
},
|
|
{
|
|
name: "project with switch expression",
|
|
pipeline: []types.AggregateStage{
|
|
{
|
|
Stage: "$project",
|
|
Spec: map[string]interface{}{
|
|
"category": 1,
|
|
"grade": map[string]interface{}{
|
|
"$switch": map[string]interface{}{
|
|
"branches": []interface{}{
|
|
map[string]interface{}{
|
|
"case": map[string]interface{}{
|
|
"$gte": []interface{}{"$score", float64(90)},
|
|
},
|
|
"then": "A",
|
|
},
|
|
map[string]interface{}{
|
|
"case": map[string]interface{}{
|
|
"$gte": []interface{}{"$score", float64(80)},
|
|
},
|
|
"then": "B",
|
|
},
|
|
},
|
|
"default": "C",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
expectedLen: 4,
|
|
},
|
|
{
|
|
name: "addFields with arithmetic",
|
|
pipeline: []types.AggregateStage{
|
|
{
|
|
Stage: "$addFields",
|
|
Spec: map[string]interface{}{
|
|
"totalValue": map[string]interface{}{
|
|
"$multiply": []interface{}{"$score", "$quantity"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
expectedLen: 4,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
results, err := engine.Execute(collection, tt.pipeline)
|
|
if err != nil {
|
|
t.Fatalf("Execute() error = %v", err)
|
|
}
|
|
|
|
if len(results) != tt.expectedLen {
|
|
t.Errorf("Expected %d results, got %d", tt.expectedLen, len(results))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestQueryWithExprAndJsonSchema 测试 $expr 和 $jsonSchema 组合查询
|
|
func TestQueryWithExprAndJsonSchema(t *testing.T) {
|
|
store := NewMemoryStore(nil)
|
|
collection := "test.expr_schema_integration"
|
|
|
|
store.collections[collection] = &Collection{
|
|
name: collection,
|
|
documents: map[string]types.Document{
|
|
"doc1": {
|
|
ID: "doc1",
|
|
Data: map[string]interface{}{
|
|
"name": "Alice",
|
|
"age": 25,
|
|
"salary": float64(5000),
|
|
"bonus": float64(1000),
|
|
},
|
|
},
|
|
"doc2": {
|
|
ID: "doc2",
|
|
Data: map[string]interface{}{
|
|
"name": "Bob",
|
|
"age": 30,
|
|
"salary": float64(6000),
|
|
"bonus": float64(500),
|
|
},
|
|
},
|
|
"doc3": {
|
|
ID: "doc3",
|
|
Data: map[string]interface{}{
|
|
"name": "Charlie",
|
|
"age": 35,
|
|
"salary": float64(7000),
|
|
"bonus": float64(2000),
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
filter types.Filter
|
|
expectedLen int
|
|
}{
|
|
{
|
|
name: "$expr with field comparison",
|
|
filter: types.Filter{
|
|
"$expr": types.Filter{
|
|
"$gt": []interface{}{"$bonus", map[string]interface{}{
|
|
"$multiply": []interface{}{"$salary", float64(0.1)},
|
|
}},
|
|
},
|
|
},
|
|
expectedLen: 2, // Alice and Charlie have bonus > 10% of salary
|
|
},
|
|
{
|
|
name: "$jsonSchema validation",
|
|
filter: types.Filter{
|
|
"$jsonSchema": map[string]interface{}{
|
|
"bsonType": "object",
|
|
"required": []interface{}{"name", "age"},
|
|
"properties": map[string]interface{}{
|
|
"name": map[string]interface{}{"bsonType": "string"},
|
|
"age": map[string]interface{}{"bsonType": "int", "minimum": float64(0)},
|
|
},
|
|
},
|
|
},
|
|
expectedLen: 3, // All documents match
|
|
},
|
|
{
|
|
name: "combined $expr and regular filter",
|
|
filter: types.Filter{
|
|
"age": types.Filter{"$gte": float64(30)},
|
|
"$expr": types.Filter{
|
|
"$gt": []interface{}{"$salary", float64(5500)},
|
|
},
|
|
},
|
|
expectedLen: 2, // Bob and Charlie
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
results, err := store.Find(collection, tt.filter)
|
|
if err != nil {
|
|
t.Fatalf("Find() error = %v", err)
|
|
}
|
|
|
|
if len(results) != tt.expectedLen {
|
|
t.Errorf("Expected %d results, got %d", tt.expectedLen, len(results))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestUpdateWithProjectionRoundTrip 测试更新后查询投影的完整流程
|
|
func TestUpdateWithProjectionRoundTrip(t *testing.T) {
|
|
store := NewMemoryStore(nil)
|
|
collection := "test.roundtrip"
|
|
|
|
store.collections[collection] = &Collection{
|
|
name: collection,
|
|
documents: map[string]types.Document{
|
|
"doc1": {
|
|
ID: "doc1",
|
|
Data: map[string]interface{}{
|
|
"name": "Product A",
|
|
"prices": []interface{}{float64(100), float64(150), float64(200)},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
// Update with array position operator
|
|
update := types.Update{
|
|
Set: map[string]interface{}{
|
|
"prices.$[]": float64(99),
|
|
},
|
|
}
|
|
|
|
matched, modified, _, err := store.Update(collection, types.Filter{"name": "Product A"}, update, false, nil)
|
|
if err != nil {
|
|
t.Fatalf("Update() error = %v", err)
|
|
}
|
|
|
|
if matched != 1 {
|
|
t.Errorf("Expected 1 match, got %d", matched)
|
|
}
|
|
if modified != 1 {
|
|
t.Errorf("Expected 1 modified, got %d", modified)
|
|
}
|
|
|
|
// Find with projection
|
|
filter := types.Filter{"name": "Product A"}
|
|
results, err := store.Find(collection, filter)
|
|
if err != nil {
|
|
t.Fatalf("Find() error = %v", err)
|
|
}
|
|
|
|
if len(results) != 1 {
|
|
t.Errorf("Expected 1 result, got %d", len(results))
|
|
}
|
|
|
|
// Verify all prices are updated to 99
|
|
prices, ok := results[0].Data["prices"].([]interface{})
|
|
if !ok {
|
|
t.Fatal("prices is not an array")
|
|
}
|
|
|
|
for i, price := range prices {
|
|
if price != float64(99) {
|
|
t.Errorf("Price at index %d = %v, want 99", i, price)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestComplexAggregationPipeline 测试复杂聚合管道
|
|
func TestComplexAggregationPipeline(t *testing.T) {
|
|
store := NewMemoryStore(nil)
|
|
collection := "test.complex_agg"
|
|
|
|
store.collections[collection] = &Collection{
|
|
name: collection,
|
|
documents: map[string]types.Document{
|
|
"doc1": {ID: "doc1", Data: map[string]interface{}{"status": "A", "qty": 10, "price": 5.5}},
|
|
"doc2": {ID: "doc2", Data: map[string]interface{}{"status": "A", "qty": 20, "price": 3.0}},
|
|
"doc3": {ID: "doc3", Data: map[string]interface{}{"status": "B", "qty": 15, "price": 4.0}},
|
|
"doc4": {ID: "doc4", Data: map[string]interface{}{"status": "B", "qty": 5, "price": 6.0}},
|
|
},
|
|
}
|
|
|
|
engine := &AggregationEngine{store: store}
|
|
|
|
pipeline := []types.AggregateStage{
|
|
{Stage: "$match", Spec: types.Filter{"status": "A"}},
|
|
{
|
|
Stage: "$addFields",
|
|
Spec: map[string]interface{}{
|
|
"total": map[string]interface{}{
|
|
"$multiply": []interface{}{"$qty", "$price"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Stage: "$group",
|
|
Spec: map[string]interface{}{
|
|
"_id": "$status",
|
|
"avgTotal": map[string]interface{}{
|
|
"$avg": "$total",
|
|
},
|
|
"maxTotal": map[string]interface{}{
|
|
"$max": "$total",
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Stage: "$project",
|
|
Spec: map[string]interface{}{
|
|
"_id": 0,
|
|
"status": "$_id",
|
|
"avgTotal": 1,
|
|
"maxTotal": 1,
|
|
},
|
|
},
|
|
}
|
|
|
|
results, err := engine.Execute(collection, pipeline)
|
|
if err != nil {
|
|
t.Fatalf("Execute() error = %v", err)
|
|
}
|
|
|
|
if len(results) != 1 {
|
|
t.Errorf("Expected 1 result, got %d", len(results))
|
|
}
|
|
|
|
// Verify the result has the expected fields
|
|
result := results[0].Data
|
|
if _, exists := result["status"]; !exists {
|
|
t.Error("Expected 'status' field")
|
|
}
|
|
if _, exists := result["avgTotal"]; !exists {
|
|
t.Error("Expected 'avgTotal' field")
|
|
}
|
|
if _, exists := result["maxTotal"]; !exists {
|
|
t.Error("Expected 'maxTotal' field")
|
|
}
|
|
}
|