diff --git a/pkg/fetcher.go b/pkg/fetcher.go index 7b73061..b7b381a 100644 --- a/pkg/fetcher.go +++ b/pkg/fetcher.go @@ -78,6 +78,7 @@ type Worker[T, S any] struct { wg *sync.WaitGroup // wg is the wait group to synchronize the completion of tasks. work Work[T, S] // work is the function that processes tasks. rate_limit <-chan time.Time + timeout time.Duration } type WorkConfig struct { @@ -86,6 +87,7 @@ type WorkConfig struct { max_retries uint8 base_retry_time time.Duration rate_limit <-chan time.Time + timeout time.Duration } type Channels[T, S any] struct { @@ -96,15 +98,40 @@ type Channels[T, S any] struct { units_receiver chan WorkUnit[T, S] } -func spawn_worker[T, S any](worker *Worker[T, S]) { +func spawn_worker[T, S any](ctx context.Context, worker *Worker[T, S]) { // TODO: handle tiemouts for workUnit := range worker.receptor { // Wait for rate-limit <-worker.rate_limit - value, err := worker.work(workUnit.argument) - workUnit.result = value - workUnit.err = err + var timeout context.Context + var cancel context.CancelFunc + + if worker.timeout != 0 { + timeout, cancel = context.WithTimeout(ctx, worker.timeout) + } else { + timeout, cancel = context.WithCancel(ctx) + } + + done := make(chan struct{}) + + var value S + var err error + + go func() { + value, err = worker.work(workUnit.argument) + close(done) + }() + + select { + case <-done: + workUnit.result = value + workUnit.err = err + case <-timeout.Done(): + workUnit.err = timeout.Err() + } + + cancel() worker.transmiter <- workUnit } @@ -248,10 +275,11 @@ func asyncTaskRunner[T, S any]( receptor: channels.units_dispatcher, transmiter: channels.units_receiver, rate_limit: config.rate_limit, + timeout: config.timeout, work: work, } - go spawn_worker(worker) + go spawn_worker(ctx, worker) } go listenForWorkResults(done, channels, config) diff --git a/pkg/platform.go b/pkg/platform.go index 293fd35..b8822bc 100644 --- a/pkg/platform.go +++ b/pkg/platform.go @@ -19,7 +19,7 @@ func (platform *Platform) FetchCollections(fetcher Fetcher, start_pagination Pag // fmt.Printf("Requesting offset: %v\n", offset) if offset == 10 { - return nil, fmt.Errorf("Simulated error jeje") + time.Sleep(time.Second * 5) } pagination := start_pagination @@ -34,12 +34,14 @@ func (platform *Platform) FetchCollections(fetcher Fetcher, start_pagination Pag } ctx, cancel := context.WithCancel(context.Background()) + defer cancel() config := &WorkConfig{ max_workers: 5, max_retries: 2, base_retry_time: time.Second, rate_limit: NewRateLimit(5, time.Minute), + timeout: time.Second * 2, } tasks := make(chan int) @@ -69,7 +71,6 @@ loop: } fmt.Printf("There was an error: %v\n", error) - cancel() case <-ctx.Done(): break loop case <-done: