feat: add timeout capabilities to worker pool

This commit is contained in:
Alexander Navarro 2024-11-29 17:38:57 -03:00
parent eeefceb2fc
commit f8c068a7f3
2 changed files with 36 additions and 7 deletions

View file

@ -78,6 +78,7 @@ type Worker[T, S any] struct {
wg *sync.WaitGroup // wg is the wait group to synchronize the completion of tasks. wg *sync.WaitGroup // wg is the wait group to synchronize the completion of tasks.
work Work[T, S] // work is the function that processes tasks. work Work[T, S] // work is the function that processes tasks.
rate_limit <-chan time.Time rate_limit <-chan time.Time
timeout time.Duration
} }
type WorkConfig struct { type WorkConfig struct {
@ -86,6 +87,7 @@ type WorkConfig struct {
max_retries uint8 max_retries uint8
base_retry_time time.Duration base_retry_time time.Duration
rate_limit <-chan time.Time rate_limit <-chan time.Time
timeout time.Duration
} }
type Channels[T, S any] struct { type Channels[T, S any] struct {
@ -96,15 +98,40 @@ type Channels[T, S any] struct {
units_receiver chan WorkUnit[T, S] units_receiver chan WorkUnit[T, S]
} }
func spawn_worker[T, S any](worker *Worker[T, S]) { func spawn_worker[T, S any](ctx context.Context, worker *Worker[T, S]) {
// TODO: handle tiemouts // TODO: handle tiemouts
for workUnit := range worker.receptor { for workUnit := range worker.receptor {
// Wait for rate-limit // Wait for rate-limit
<-worker.rate_limit <-worker.rate_limit
value, err := worker.work(workUnit.argument) var timeout context.Context
var cancel context.CancelFunc
if worker.timeout != 0 {
timeout, cancel = context.WithTimeout(ctx, worker.timeout)
} else {
timeout, cancel = context.WithCancel(ctx)
}
done := make(chan struct{})
var value S
var err error
go func() {
value, err = worker.work(workUnit.argument)
close(done)
}()
select {
case <-done:
workUnit.result = value workUnit.result = value
workUnit.err = err workUnit.err = err
case <-timeout.Done():
workUnit.err = timeout.Err()
}
cancel()
worker.transmiter <- workUnit worker.transmiter <- workUnit
} }
@ -248,10 +275,11 @@ func asyncTaskRunner[T, S any](
receptor: channels.units_dispatcher, receptor: channels.units_dispatcher,
transmiter: channels.units_receiver, transmiter: channels.units_receiver,
rate_limit: config.rate_limit, rate_limit: config.rate_limit,
timeout: config.timeout,
work: work, work: work,
} }
go spawn_worker(worker) go spawn_worker(ctx, worker)
} }
go listenForWorkResults(done, channels, config) go listenForWorkResults(done, channels, config)

View file

@ -19,7 +19,7 @@ func (platform *Platform) FetchCollections(fetcher Fetcher, start_pagination Pag
// fmt.Printf("Requesting offset: %v\n", offset) // fmt.Printf("Requesting offset: %v\n", offset)
if offset == 10 { if offset == 10 {
return nil, fmt.Errorf("Simulated error jeje") time.Sleep(time.Second * 5)
} }
pagination := start_pagination pagination := start_pagination
@ -34,12 +34,14 @@ func (platform *Platform) FetchCollections(fetcher Fetcher, start_pagination Pag
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel()
config := &WorkConfig{ config := &WorkConfig{
max_workers: 5, max_workers: 5,
max_retries: 2, max_retries: 2,
base_retry_time: time.Second, base_retry_time: time.Second,
rate_limit: NewRateLimit(5, time.Minute), rate_limit: NewRateLimit(5, time.Minute),
timeout: time.Second * 2,
} }
tasks := make(chan int) tasks := make(chan int)
@ -69,7 +71,6 @@ loop:
} }
fmt.Printf("There was an error: %v\n", error) fmt.Printf("There was an error: %v\n", error)
cancel()
case <-ctx.Done(): case <-ctx.Done():
break loop break loop
case <-done: case <-done: