本文讲解的是
golang.org/x/sync
这个包中的errgroup
1、errgroup 的基础介绍
学习过 Go 的朋友都知道 Go 实现并发编程是比较容易的事情,只需要使用go
关键字就可以开启一个 goroutine。那对于并发场景中,如何实现goroutine
的协调控制呢?常见的一种方式是使用sync.WaitGroup
来进行协调控制。
使用过sync.WaitGroup
的朋友知道,sync.WaitGroup
虽然可以实现协调控制,但是不能传递错误,那该如何解决呢?聪明的你可能马上想到使用 chan 或者是 context
来传递错误,确实是可以的。那接下来,我们一起看看官方是怎么实现上面的需求的呢?
1.1 errgroup的安装
安装命令:
go get golang.org/x/sync
//下面的案例是基于v0.1.0 演示的
go get golang.org/x/sync@v0.1.0
1.2 errgroup的基础例子
这里我们需要请求3个url来获取数据,假设请求url2时报错,url3耗时比较久,需要等一秒。
package main
import (
"errors"
"fmt"
"golang.org/x/sync/errgroup"
"strings"
"time"
)
func main() {
queryUrls := map[string]string{
"url1": "http://localhost/url1",
"url2": "http://localhost/url2",
"url3": "http://localhost/url3",
}
var eg errgroup.Group
var results []string
for _, url := range queryUrls {
url := url
eg.Go(func() error {
result, err := query(url)
if err != nil {
return err
}
results = append(results, fmt.Sprintf("url:%s -- ret: %v", url, result))
return nil
})
}
// group 的wait方法,等待上面的 eg.Go 的协程执行完成,并且可以接受错误
err := eg.Wait()
if err != nil {
fmt.Println("eg.Wait error:", err)
return
}
for k, v := range results {
fmt.Printf("%v ---> %v\n", k, v)
}
}
func query(url string) (ret string, err error) {
// 假设这里是发送请求,获取数据
if strings.Contains(url, "url2") {
// 假设请求 url2 时出现错误
fmt.Printf("请求 %s 中....\n", url)
return "", errors.New("请求超时")
} else if strings.Contains(url, "url3") {
// 假设 请求 url3 需要1秒
time.Sleep(time.Second*1)
}
fmt.Printf("请求 %s 中....\n", url)
return "success", nil
}
执行结果:
请求 http://localhost/url2 中....
请求 http://localhost/url1 中....
请求 http://localhost/url3 中....
eg.Wait error: 请求超时
果然,当其中一个goroutine
出现错误时,会把goroutine
中的错误传递出来。
我们自己运行一下上面的代码就会发现这样一个问题,请求 url2 出错了,但是依旧在请求 url3 。因为我们需要聚合 url1、url2、url3 的结果,所以当其中一个出现问题时,我们是可以做一个优化的,就是当其中一个出现错误时,取消还在执行的任务,直接返回结果,不用等待任务执行结果。
那应该如何做呢?
这里假设 url1 执行1秒,url2 执行报错,url3执行3秒。所以当url2报错后,就不用等url3执行结束就可以返回了。
package main
import (
"context"
"errors"
"fmt"
"golang.org/x/sync/errgroup"
"strings"
"time"
)
func main() {
queryUrls := map[string]string{
"url1": "http://localhost/url1",
"url2": "http://localhost/url2",
"url3": "http://localhost/url3",
}
var results []string
ctx, cancel := context.WithCancel(context.Background())
eg, errCtx := errgroup.WithContext(ctx)
for _, url := range queryUrls {
url := url
eg.Go(func() error {
result, err := query(errCtx, url)
if err != nil {
//其实这里不用手动取消,看完源码就知道为啥了
cancel()
return err
}
results = append(results, fmt.Sprintf("url:%s -- ret: %v", url, result))
return nil
})
}
err := eg.Wait()
if err != nil {
fmt.Println("eg.Wait error:", err)
return
}
for k, v := range results {
fmt.Printf("%v ---> %v\n", k, v)
}
}
func query(errCtx context.Context, url string) (ret string, err error) {