go proto 插件开发
约 919 字大约 3 分钟
2024-07-20
proto 生成gin代码 项目地址
main.go
package main
import (
"flag"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/pluginpb"
"protobufProject/generate"
)
func main() {
flag.Parse()
var flags flag.FlagSet
protogen.Options{
ParamFunc: flags.Set,
}.Run(func(gen *protogen.Plugin) error {
gen.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
for _, f := range gen.Files {
if !f.Generate {
continue
}
generate.ProtocGenGoFile(gen, f)
}
return nil
})
}
generate.go
package generate
import (
"errors"
"google.golang.org/genproto/googleapis/api/annotations"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/proto"
)
func ProtocGenGoFile(plugin *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile {
if len(file.Services) == 0 {
return nil
}
fileName := file.GeneratedFilenamePrefix + "_gin.pb.go"
gf := plugin.NewGeneratedFile(fileName, file.GoImportPath)
gf.P("// Code generated by protoc-gen-go-gin. DO NOT EDIT.")
gf.P()
gf.P("package ", file.GoPackageName)
gf.QualifiedGoIdent(protogen.GoImportPath("github.com/gin-gonic/gin").Ident(""))
gf.QualifiedGoIdent(protogen.GoImportPath("net/http").Ident(""))
for _, service := range file.Services {
genService(gf, service)
}
return gf
}
func genService(gf *protogen.GeneratedFile, service *protogen.Service) {
s := Service{Name: service.GoName}
for _, method := range service.Methods {
if sm, err := genMethod(method); err == nil && sm != nil {
s.Methods = append(s.Methods, sm)
}
}
gf.P(s.execute())
}
func genMethod(m *protogen.Method) (*ServiceMethod, error) {
if rule, ok := proto.GetExtension(m.Desc.Options(), annotations.E_Http).(*annotations.HttpRule); rule != nil && ok {
var path, method string
switch pattern := rule.Pattern.(type) {
case *annotations.HttpRule_Get:
path = pattern.Get
method = "GET"
case *annotations.HttpRule_Put:
path = pattern.Put
method = "PUT"
case *annotations.HttpRule_Post:
path = pattern.Post
method = "POST"
case *annotations.HttpRule_Delete:
path = pattern.Delete
method = "DELETE"
case *annotations.HttpRule_Patch:
path = pattern.Patch
method = "PATCH"
case *annotations.HttpRule_Custom:
path = pattern.Custom.Path
method = pattern.Custom.Kind
}
sm := &ServiceMethod{
Name: m.GoName,
Request: m.Input.GoIdent.GoName,
Reply: m.Output.GoIdent.GoName,
Method: method,
Path: path,
}
sm.initPathParams()
return sm, nil
} else {
return nil, errors.New("proto.GetExtension 失败")
}
}
service.go
package generate
import (
"bytes"
_ "embed"
"html/template"
"strings"
)
//go:embed template.go.tpl
var tpl string
type Service struct {
Name string
Methods []*ServiceMethod
}
func (s *Service) ServiceName() string {
return s.Name + "Server"
}
func (s *Service) execute() string {
buf := new(bytes.Buffer)
if tmpl, err := template.New("http").Parse(strings.TrimSpace(tpl)); err != nil {
panic(err)
} else if err := tmpl.Execute(buf, s); err != nil {
panic(err)
}
return buf.String()
}
func (*Service) GoCamelCase(s string) string {
return GoCamelCase(s)
}
type ServiceMethod struct {
Name string
Request string
Reply string
Method string
Path string
PathParams []string
}
func (m *ServiceMethod) initPathParams() {
paths := strings.Split(m.Path, "/")
for i, p := range paths {
if p == "" {
continue
} else if p[0] == '{' && p[len(p)-1] == '}' {
paths[i] = ":" + p[1:len(p)-1]
m.PathParams = append(m.PathParams, paths[i][1:])
} else if p[0] == ':' {
m.PathParams = append(m.PathParams, paths[i][1:])
}
}
m.Path = strings.Join(paths, "/")
}
// 下面代码从{@link "google.golang.org/protobuf/internal/strs/strings.go"}"复制
// GoCamelCase camel-cases a protobuf name for use as a Go identifier.
//
// If there is an interior underscore followed by a lower case letter,
// drop the underscore and convert the letter to upper case.
func GoCamelCase(s string) string {
// Invariant: if the next letter is lower case, it must be converted
// to upper case.
// That is, we process a word at a time, where words are marked by _ or
// upper case letter. Digits are treated as words.
var b []byte
for i := 0; i < len(s); i++ {
c := s[i]
switch {
case c == '.' && i+1 < len(s) && isASCIILower(s[i+1]):
// Skip over '.' in ".{{lowercase}}".
case c == '.':
b = append(b, '_') // convert '.' to '_'
case c == '_' && (i == 0 || s[i-1] == '.'):
// Convert initial '_' to ensure we start with a capital letter.
// Do the same for '_' after '.' to match historic behavior.
b = append(b, 'X') // convert '_' to 'X'
case c == '_' && i+1 < len(s) && isASCIILower(s[i+1]):
// Skip over '_' in "_{{lowercase}}".
case isASCIIDigit(c):
b = append(b, c)
default:
// Assume we have a letter now - if not, it's a bogus identifier.
// The next word is a sequence of characters that must start upper case.
if isASCIILower(c) {
c -= 'a' - 'A' // convert lowercase to uppercase
}
b = append(b, c)
// Accept lower case sequence that follows.
for ; i+1 < len(s) && isASCIILower(s[i+1]); i++ {
b = append(b, s[i+1])
}
}
}
return string(b)
}
func isASCIILower(c byte) bool {
return 'a' <= c && c <= 'z'
}
func isASCIIDigit(c byte) bool {
return '0' <= c && c <= '9'
}
template.go.tpl
type {{ $.Name }}GinServer struct {
server {{ $.ServiceName }}
router gin.IRouter
}
func Register{{ $.Name }}GinServer(srv {{ $.ServiceName }}, r gin.IRouter) {
s := {{ $.Name }}GinServer{
server: srv,
router: r,
}
s.RegisterService()
}
{{ range $.Methods }}
func (gs *{{ $.Name }}GinServer) {{ .Name }} (ctx *gin.Context) {
var request {{ .Request }}
{{ if eq .Method "GET" "DELETE" }}
if err := ctx.ShouldBindQuery(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
{{ else if eq .Method "POST" "PUT" "PATCH" }}
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
{{ else }}
if err := ctx.ShouldBind(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
{{ end }}
{{ range $item := .PathParams }}
request.{{ $.GoCamelCase $item }} = ctx.Params.ByName("{{ $item }}")
{{ end }}
if out, err := gs.server.{{ .Name }}(ctx, &request); err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} else {
ctx.JSON(http.StatusOK, out)
}
}
{{ end }}
func (gs *{{ $.Name }}GinServer) RegisterService() {
{{ range .Methods }}
gs.router.Handle("{{ .Method }}", "{{ .Path }}", gs.{{ .Name }})
{{ end }}
}
插件使用
build
go build -o protoc-gen-go-gin main.go # 将生成文件加入环境变量
生成命令
protoc --proto_path=. --proto_path=../third_party --go_out=. --go-grpc_out=. --go-gin_out=. api.proto
api.proto
syntax = "proto3";
package template;
import "google/api/annotations.proto";
option go_package = ".;proto";
// protoc --proto_path=. --proto_path=../third_party --go_out=. --go-grpc_out=. --go-gin_out=. api.proto
service Greeter{
rpc SayHello(HelloRequest) returns(HelloReply) {
option (google.api.http) = {
get : "/greeter/sayHello/{name}",
};
}
}
message HelloRequest{
string name = 1;
}
message HelloReply {
string message = 1;
}