Skip to content

Commit 5277622

Browse files
committed
feat: add RunCompleted method to Paralleler.
1 parent 4b0db6b commit 5277622

File tree

2 files changed

+82
-8
lines changed

2 files changed

+82
-8
lines changed

paralleler.go

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package async
33
import (
44
"context"
55
"sync"
6+
"sync/atomic"
67
)
78

89
// builtinPool is the Parallelers pool for built-in functions.
@@ -76,7 +77,7 @@ func (p *Paralleler) Run() ([][]any, error) {
7677

7778
ch := make(chan executeResult, len(tasks))
7879

79-
go p.runTasks(ctx, ch, tasks)
80+
go p.runTasks(ctx, ch, tasks, true)
8081

8182
finished := 0
8283
for finished < len(tasks) {
@@ -98,6 +99,41 @@ func (p *Paralleler) Run() ([][]any, error) {
9899
return out, nil
99100
}
100101

102+
// RunCompleted runs the tasks in the paralleler's pending list until all functions are finished,
103+
// it'll clear the pending list and return the results of the tasks.
104+
func (p *Paralleler) RunCompleted() ([][]any, error) {
105+
tasks := p.getTasks()
106+
out := make([][]any, len(tasks))
107+
if len(tasks) == 0 {
108+
return out, nil
109+
}
110+
111+
errs := make([]error, len(tasks))
112+
errNum := atomic.Int32{}
113+
parent := getContext(p.ctx)
114+
ctx, canFunc := context.WithCancel(parent)
115+
defer canFunc()
116+
117+
ch := make(chan executeResult, len(tasks))
118+
119+
go p.runTasks(ctx, ch, tasks, false)
120+
121+
for finished := 0; finished < len(tasks); finished++ {
122+
ret := <-ch
123+
out[ret.Index] = ret.Out
124+
if ret.Error != nil {
125+
errs[ret.Index] = ret.Error
126+
errNum.Add(1)
127+
}
128+
}
129+
130+
if errNum.Load() == 0 {
131+
return out, nil
132+
}
133+
134+
return out, convertErrorListToExecutionErrors(errs, int(errNum.Load()))
135+
}
136+
101137
// getConcurrencyChan creates and returns a concurrency controlling channel by the specific number
102138
// of the concurrency limitation.
103139
func (p *Paralleler) getConcurrencyChan() chan empty {
@@ -124,15 +160,20 @@ func (p *Paralleler) getTasks() []AsyncFn {
124160
}
125161

126162
// runTasks runs the tasks with the concurrency limitation.
127-
func (p *Paralleler) runTasks(ctx context.Context, resCh chan executeResult, tasks []AsyncFn) {
163+
func (p *Paralleler) runTasks(
164+
ctx context.Context,
165+
resCh chan executeResult,
166+
tasks []AsyncFn,
167+
exitWhenDone bool,
168+
) {
128169
conch := p.getConcurrencyChan()
129170

130171
for i := 0; i < len(tasks); i++ {
131172
if conch != nil {
132173
conch <- empty{}
133174
}
134175

135-
go p.runTask(ctx, i, tasks[i], conch, resCh)
176+
go p.runTask(ctx, i, tasks[i], conch, resCh, exitWhenDone)
136177
}
137178
}
138179

@@ -143,6 +184,7 @@ func (p *Paralleler) runTask(
143184
fn AsyncFn,
144185
conch chan empty,
145186
ch chan executeResult,
187+
exitWhenDone bool,
146188
) {
147189
childCtx, childCanFunc := context.WithCancel(ctx)
148190
defer childCanFunc()
@@ -153,14 +195,22 @@ func (p *Paralleler) runTask(
153195
<-conch
154196
}
155197

156-
select {
157-
case <-ctx.Done():
158-
return
159-
default:
198+
if !exitWhenDone {
160199
ch <- executeResult{
161200
Index: n,
162201
Error: err,
163202
Out: ret,
164203
}
204+
} else {
205+
select {
206+
case <-ctx.Done():
207+
return
208+
default:
209+
ch <- executeResult{
210+
Index: n,
211+
Error: err,
212+
Out: ret,
213+
}
214+
}
165215
}
166216
}

paralleler_test.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package async_test
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"sync/atomic"
78
"testing"
@@ -74,7 +75,7 @@ func TestParallelerClear(t *testing.T) {
7475
}
7576

7677
_, err := p.Run()
77-
a.Nil(err)
78+
a.NilNow(err)
7879
a.EqualNow(cnt.Load(), 3)
7980
}
8081

@@ -125,6 +126,29 @@ func TestParallelerWithContext(t *testing.T) {
125126
a.EqualNow(cnt.Load(), 2)
126127
}
127128

129+
func TestParallelerRunCompleted(t *testing.T) {
130+
a := assert.New(t)
131+
cnt := atomic.Int32{}
132+
expectedErr := errors.New("n = 2")
133+
134+
p := new(async.Paralleler).WithConcurrency(2)
135+
for i := 0; i < 5; i++ {
136+
n := i
137+
p.Add(func() error {
138+
cnt.Add(1)
139+
if n == 2 {
140+
return expectedErr
141+
}
142+
return nil
143+
})
144+
}
145+
146+
out, err := p.RunCompleted()
147+
a.NotNilNow(err)
148+
a.EqualNow(out, [][]any{{nil}, {nil}, {expectedErr}, {nil}, {nil}})
149+
a.EqualNow(cnt.Load(), 5)
150+
}
151+
128152
func ExampleParalleler() {
129153
p := new(async.Paralleler)
130154

0 commit comments

Comments
 (0)