2025-07-24 16:15:24 +08:00
|
|
|
package task_group
|
|
|
|
|
|
|
|
import (
|
|
|
|
"sync"
|
|
|
|
|
|
|
|
"github.com/sirupsen/logrus"
|
|
|
|
)
|
|
|
|
|
2025-07-26 00:27:46 +08:00
|
|
|
type OnCompletionFunc func(groupID string, payloads ...any)
|
2025-07-24 16:15:24 +08:00
|
|
|
type TaskGroupCoordinator struct {
|
|
|
|
name string
|
|
|
|
mu sync.Mutex
|
|
|
|
|
|
|
|
groupPayloads map[string][]any
|
|
|
|
groupStates map[string]groupState
|
|
|
|
onCompletion OnCompletionFunc
|
|
|
|
}
|
|
|
|
|
|
|
|
type groupState struct {
|
|
|
|
pending int
|
|
|
|
hasSuccess bool
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewTaskGroupCoordinator(name string, f OnCompletionFunc) *TaskGroupCoordinator {
|
|
|
|
return &TaskGroupCoordinator{
|
|
|
|
name: name,
|
|
|
|
groupPayloads: map[string][]any{},
|
|
|
|
groupStates: map[string]groupState{},
|
|
|
|
onCompletion: f,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// payload可为nil
|
|
|
|
func (tgc *TaskGroupCoordinator) AddTask(groupID string, payload any) {
|
|
|
|
tgc.mu.Lock()
|
|
|
|
defer tgc.mu.Unlock()
|
|
|
|
state := tgc.groupStates[groupID]
|
|
|
|
state.pending++
|
|
|
|
tgc.groupStates[groupID] = state
|
|
|
|
logrus.Debugf("AddTask:%s ,count=%+v", groupID, state)
|
|
|
|
if payload == nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
tgc.groupPayloads[groupID] = append(tgc.groupPayloads[groupID], payload)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tgc *TaskGroupCoordinator) AppendPayload(groupID string, payload any) {
|
|
|
|
if payload == nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
tgc.mu.Lock()
|
|
|
|
defer tgc.mu.Unlock()
|
|
|
|
tgc.groupPayloads[groupID] = append(tgc.groupPayloads[groupID], payload)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tgc *TaskGroupCoordinator) Done(groupID string, success bool) {
|
|
|
|
tgc.mu.Lock()
|
|
|
|
defer tgc.mu.Unlock()
|
|
|
|
state, ok := tgc.groupStates[groupID]
|
|
|
|
if !ok || state.pending == 0 {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if success {
|
|
|
|
state.hasSuccess = true
|
|
|
|
}
|
|
|
|
logrus.Debugf("Done:%s ,state=%+v", groupID, state)
|
|
|
|
if state.pending == 1 {
|
|
|
|
payloads := tgc.groupPayloads[groupID]
|
|
|
|
delete(tgc.groupStates, groupID)
|
|
|
|
delete(tgc.groupPayloads, groupID)
|
|
|
|
if tgc.onCompletion != nil && state.hasSuccess {
|
|
|
|
logrus.Debugf("OnCompletion:%s", groupID)
|
|
|
|
tgc.mu.Unlock()
|
2025-07-26 00:27:46 +08:00
|
|
|
tgc.onCompletion(groupID, payloads...)
|
2025-07-24 16:15:24 +08:00
|
|
|
tgc.mu.Lock()
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
state.pending--
|
|
|
|
tgc.groupStates[groupID] = state
|
|
|
|
}
|