diff --git a/internal/cmd/shim.go b/internal/cmd/shim.go index 654500429a..fe6d99985f 100644 --- a/internal/cmd/shim.go +++ b/internal/cmd/shim.go @@ -161,6 +161,8 @@ func pluginQueries(r *compiler.Result) []*plugin.Query { Params: params, Filename: q.Metadata.Filename, InsertIntoTable: iit, + IsReplace: q.IsReplace, + IgnoreErr: q.IgnoreErr, }) } return out diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 3b4fb2fa1a..a6bdce0d19 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -266,7 +266,9 @@ type Query struct { Ret QueryValue Arg QueryValue // Used for :copyfrom - Table *plugin.Identifier + Table *plugin.Identifier + IsReplace bool + IgnoreErr bool } func (q Query) hasRetType() bool { diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 0820488f9d..0f2de179a0 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -223,6 +223,8 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] SQL: query.Text, Comments: comments, Table: query.InsertIntoTable, + IsReplace: query.IsReplace, + IgnoreErr: query.IgnoreErr, } sqlpkg := parseDriver(options.SqlPackage) diff --git a/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl b/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl index e21475b148..ba8dea98f1 100644 --- a/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl +++ b/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl @@ -40,7 +40,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context{{if $.EmitMethodsWithDBArg go convertRowsFor{{.MethodName}}(pw, {{.Arg.Name}}) // The string interpolation is necessary because LOAD DATA INFILE requires // the file name to be given as a literal string. - result, err := {{if (not $.EmitMethodsWithDBArgument)}}q.{{end}}db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping)) + result, err := {{if (not $.EmitMethodsWithDBArgument)}}q.{{end}}db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' {{if .IsReplace}}REPLACE {{else if .IgnoreErr}}IGNORE {{end}}INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping)) if err != nil { return 0, err } diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 0d7d507575..cf98323cd7 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -14,6 +14,8 @@ import ( type analysis struct { Table *ast.TableName + IsReplace bool + IgnoreErr bool Columns []*Column Parameters []Parameter Named *named.ParamSet @@ -142,6 +144,8 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar) var table *ast.TableName + var isReplace bool + var ignoreErr bool switch n := raw.Stmt.(type) { case *ast.InsertStmt: if err := check(validate.InsertStmt(n)); err != nil { @@ -152,6 +156,8 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(err); err != nil { return nil, err } + isReplace = n.IsReplace + ignoreErr = n.IgnoreErr } if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil { @@ -207,6 +213,8 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return &analysis{ Table: table, + IsReplace: isReplace, + IgnoreErr: ignoreErr, Columns: cols, Parameters: params, Query: expanded, diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 751cb3271a..5da00ca468 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -178,6 +178,8 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, Columns: anlys.Columns, SQL: trimmed, InsertIntoTable: anlys.Table, + IsReplace: anlys.IsReplace, + IgnoreErr: anlys.IgnoreErr, }, nil } diff --git a/internal/compiler/query.go b/internal/compiler/query.go index b3cf9d6154..2ec6338be2 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -50,6 +50,8 @@ type Query struct { // Needed for CopyFrom InsertIntoTable *ast.TableName + IsReplace bool + IgnoreErr bool // Needed for vet RawStmt *ast.RawStmt diff --git a/internal/endtoend/testdata/mysql_copyfrom_replace/db/copyfrom.go b/internal/endtoend/testdata/mysql_copyfrom_replace/db/copyfrom.go new file mode 100644 index 0000000000..95c7593816 --- /dev/null +++ b/internal/endtoend/testdata/mysql_copyfrom_replace/db/copyfrom.go @@ -0,0 +1,92 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: copyfrom.go + +package db + +import ( + "context" + "fmt" + "io" + "sync/atomic" + + "github.com/go-sql-driver/mysql" + "github.com/hexon/mysqltsv" +) + +var readerHandlerSequenceForIgnoreLocations uint32 = 1 + +func convertRowsForIgnoreLocations(w *io.PipeWriter, arg []IgnoreLocationsParams) { + e := mysqltsv.NewEncoder(w, 5, nil) + for _, row := range arg { + e.AppendString(row.ID) + e.AppendString(row.Name) + e.AppendString(row.Address) + e.AppendValue(row.Latitude) + e.AppendValue(row.Longitude) + } + w.CloseWithError(e.Close()) +} + +// IgnoreLocations uses MySQL's LOAD DATA LOCAL INFILE and is not atomic. +// +// Errors and duplicate keys are treated as warnings and insertion will +// continue, even without an error for some cases. Use this in a transaction +// and use SHOW WARNINGS to check for any problems and roll back if you want to. +// +// Check the documentation for more information: +// https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling +func (q *Queries) IgnoreLocations(ctx context.Context, arg []IgnoreLocationsParams) (int64, error) { + pr, pw := io.Pipe() + defer pr.Close() + rh := fmt.Sprintf("IgnoreLocations_%d", atomic.AddUint32(&readerHandlerSequenceForIgnoreLocations, 1)) + mysql.RegisterReaderHandler(rh, func() io.Reader { return pr }) + defer mysql.DeregisterReaderHandler(rh) + go convertRowsForIgnoreLocations(pw, arg) + // The string interpolation is necessary because LOAD DATA INFILE requires + // the file name to be given as a literal string. + result, err := q.db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' IGNORE INTO TABLE `locations` %s (id, name, address, latitude, longitude)", "Reader::"+rh, mysqltsv.Escaping)) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +var readerHandlerSequenceForUpsertLocations uint32 = 1 + +func convertRowsForUpsertLocations(w *io.PipeWriter, arg []UpsertLocationsParams) { + e := mysqltsv.NewEncoder(w, 5, nil) + for _, row := range arg { + e.AppendString(row.ID) + e.AppendString(row.Name) + e.AppendString(row.Address) + e.AppendValue(row.Latitude) + e.AppendValue(row.Longitude) + } + w.CloseWithError(e.Close()) +} + +// UpsertLocations uses MySQL's LOAD DATA LOCAL INFILE and is not atomic. +// +// Errors and duplicate keys are treated as warnings and insertion will +// continue, even without an error for some cases. Use this in a transaction +// and use SHOW WARNINGS to check for any problems and roll back if you want to. +// +// Check the documentation for more information: +// https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling +func (q *Queries) UpsertLocations(ctx context.Context, arg []UpsertLocationsParams) (int64, error) { + pr, pw := io.Pipe() + defer pr.Close() + rh := fmt.Sprintf("UpsertLocations_%d", atomic.AddUint32(&readerHandlerSequenceForUpsertLocations, 1)) + mysql.RegisterReaderHandler(rh, func() io.Reader { return pr }) + defer mysql.DeregisterReaderHandler(rh) + go convertRowsForUpsertLocations(pw, arg) + // The string interpolation is necessary because LOAD DATA INFILE requires + // the file name to be given as a literal string. + result, err := q.db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' REPLACE INTO TABLE `locations` %s (id, name, address, latitude, longitude)", "Reader::"+rh, mysqltsv.Escaping)) + if err != nil { + return 0, err + } + return result.RowsAffected() +} diff --git a/internal/endtoend/testdata/mysql_copyfrom_replace/db/db.go b/internal/endtoend/testdata/mysql_copyfrom_replace/db/db.go new file mode 100644 index 0000000000..cd5bbb8e08 --- /dev/null +++ b/internal/endtoend/testdata/mysql_copyfrom_replace/db/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mysql_copyfrom_replace/db/models.go b/internal/endtoend/testdata/mysql_copyfrom_replace/db/models.go new file mode 100644 index 0000000000..50270dcb56 --- /dev/null +++ b/internal/endtoend/testdata/mysql_copyfrom_replace/db/models.go @@ -0,0 +1,17 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package db + +import ( + "database/sql" +) + +type Location struct { + ID string + Name string + Address string + Latitude sql.NullFloat64 + Longitude sql.NullFloat64 +} diff --git a/internal/endtoend/testdata/mysql_copyfrom_replace/db/query.sql.go b/internal/endtoend/testdata/mysql_copyfrom_replace/db/query.sql.go new file mode 100644 index 0000000000..e8030b2652 --- /dev/null +++ b/internal/endtoend/testdata/mysql_copyfrom_replace/db/query.sql.go @@ -0,0 +1,36 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package db + +import ( + "database/sql" +) + +const ignoreLocations = `-- name: IgnoreLocations :copyfrom +INSERT IGNORE INTO locations (id, name, address, latitude, longitude) +VALUES (?, ?, ?, ?, ?) +` + +type IgnoreLocationsParams struct { + ID string + Name string + Address string + Latitude sql.NullFloat64 + Longitude sql.NullFloat64 +} + +const upsertLocations = `-- name: UpsertLocations :copyfrom +REPLACE INTO locations (id, name, address, latitude, longitude) +VALUES (?, ?, ?, ?, ?) +` + +type UpsertLocationsParams struct { + ID string + Name string + Address string + Latitude sql.NullFloat64 + Longitude sql.NullFloat64 +} diff --git a/internal/endtoend/testdata/mysql_copyfrom_replace/query.sql b/internal/endtoend/testdata/mysql_copyfrom_replace/query.sql new file mode 100644 index 0000000000..f8a283eebf --- /dev/null +++ b/internal/endtoend/testdata/mysql_copyfrom_replace/query.sql @@ -0,0 +1,7 @@ +-- name: UpsertLocations :copyfrom +REPLACE INTO locations (id, name, address, latitude, longitude) +VALUES (?, ?, ?, ?, ?); + +-- name: IgnoreLocations :copyfrom +INSERT IGNORE INTO locations (id, name, address, latitude, longitude) +VALUES (?, ?, ?, ?, ?); diff --git a/internal/endtoend/testdata/mysql_copyfrom_replace/schema.sql b/internal/endtoend/testdata/mysql_copyfrom_replace/schema.sql new file mode 100644 index 0000000000..ceeff152e3 --- /dev/null +++ b/internal/endtoend/testdata/mysql_copyfrom_replace/schema.sql @@ -0,0 +1,7 @@ +CREATE TABLE locations ( + id VARCHAR(512) PRIMARY KEY, + name TEXT NOT NULL, + address TEXT NOT NULL, + latitude FLOAT, + longitude FLOAT +); diff --git a/internal/endtoend/testdata/mysql_copyfrom_replace/sqlc.yaml b/internal/endtoend/testdata/mysql_copyfrom_replace/sqlc.yaml new file mode 100644 index 0000000000..6e9ce8374d --- /dev/null +++ b/internal/endtoend/testdata/mysql_copyfrom_replace/sqlc.yaml @@ -0,0 +1,11 @@ +version: "2" +sql: + - engine: "mysql" + queries: "query.sql" + schema: "schema.sql" + gen: + go: + package: "db" + out: "db" + sql_package: "database/sql" + sql_driver: "github.com/go-sql-driver/mysql" diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 1f68358ce4..2a588e0823 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -493,6 +493,8 @@ func (c *cc) convertInsertStmt(n *pcast.InsertStmt) *ast.InsertStmt { Relation: rangeVar, Cols: c.convertColumnNames(n.Columns), ReturningList: &ast.List{}, + IsReplace: n.IsReplace, + IgnoreErr: n.IgnoreErr, } if ss, ok := c.convert(n.Select).(*ast.SelectStmt); ok { ss.ValuesLists = c.convertLists(n.Lists) diff --git a/internal/plugin/codegen.pb.go b/internal/plugin/codegen.pb.go index 525ffc72ef..325f2d173b 100644 --- a/internal/plugin/codegen.pb.go +++ b/internal/plugin/codegen.pb.go @@ -816,6 +816,8 @@ type Query struct { Comments []string `protobuf:"bytes,6,rep,name=comments,proto3" json:"comments,omitempty"` Filename string `protobuf:"bytes,7,opt,name=filename,proto3" json:"filename,omitempty"` InsertIntoTable *Identifier `protobuf:"bytes,8,opt,name=insert_into_table,proto3" json:"insert_into_table,omitempty"` + IsReplace bool `protobuf:"varint,9,opt,name=is_replace,json=isReplace,proto3" json:"is_replace,omitempty"` + IgnoreErr bool `protobuf:"varint,10,opt,name=ignore_err,json=ignoreErr,proto3" json:"ignore_err,omitempty"` } func (x *Query) Reset() { @@ -906,6 +908,20 @@ func (x *Query) GetInsertIntoTable() *Identifier { return nil } +func (x *Query) GetIsReplace() bool { + if x != nil { + return x.IsReplace + } + return false +} + +func (x *Query) GetIgnoreErr() bool { + if x != nil { + return x.IgnoreErr + } + return false +} + type Parameter struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index 4d5c8d1df2..b1b8118e5e 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -12,6 +12,8 @@ type InsertStmt struct { WithClause *WithClause Override OverridingKind DefaultValues bool // SQLite-specific: INSERT INTO ... DEFAULT VALUES + IsReplace bool // MySQL-specific + IgnoreErr bool // MySQL-specific } func (n *InsertStmt) Pos() int {