chore: fix postgres driver

prod
Steven 9 months ago
parent a7d48e8059
commit 4c66edc170

@ -2,7 +2,6 @@ package postgres
import (
"context"
"fmt"
"strings"
"github.com/yourselfhosted/slash/store"
@ -38,10 +37,10 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
where, args := []string{"1 = 1"}, []any{}
if find.Type != "" {
where, args = append(where, "type = $"+fmt.Sprint(len(args)+1)), append(args, find.Type.String())
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type.String())
}
if find.Level != "" {
where, args = append(where, "level = $"+fmt.Sprint(len(args)+1)), append(args, find.Level.String())
where, args = append(where, "level = "+placeholder(len(args)+1)), append(args, find.Level.String())
}
if find.Where != nil {
where = append(where, find.Where...)

@ -6,23 +6,20 @@ import (
"fmt"
"strings"
"github.com/lib/pq"
"github.com/pkg/errors"
"github.com/yourselfhosted/slash/internal/util"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) {
set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"}
args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()}
placeholder := []string{"$1", "$2", "$3", "$4", "$5", "$6"}
args := []any{create.CreatorId, create.Name, create.Title, create.Description, pq.Array(create.ShortcutIds), create.Visibility.String()}
stmt := `
INSERT INTO collection (
` + strings.Join(set, ", ") + `
)
VALUES (` + strings.Join(placeholder, ",") + `)
INSERT INTO collection (` + strings.Join(set, ", ") + `)
VALUES (` + placeholders(len(args)) + `)
RETURNING id, created_ts, updated_ts
`
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
@ -39,35 +36,34 @@ func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (
func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollection) (*storepb.Collection, error) {
set, args := []string{}, []any{}
if update.Name != nil {
set, args = append(set, "name = $1"), append(args, *update.Name)
set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *update.Name)
}
if update.Title != nil {
set, args = append(set, "title = $2"), append(args, *update.Title)
set, args = append(set, "title = "+placeholder(len(args)+1)), append(args, *update.Title)
}
if update.Description != nil {
set, args = append(set, "description = $3"), append(args, *update.Description)
set, args = append(set, "description = "+placeholder(len(args)+1)), append(args, *update.Description)
}
if update.ShortcutIDs != nil {
set, args = append(set, "shortcut_ids = $4"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]"))
set, args = append(set, "shortcut_ids = "+placeholder(len(args)+1)), append(args, pq.Array(update.ShortcutIDs))
}
if update.Visibility != nil {
set, args = append(set, "visibility = $5"), append(args, update.Visibility.String())
set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, update.Visibility.String())
}
if len(set) == 0 {
return nil, errors.New("no update specified")
}
args = append(args, update.ID)
stmt := `
UPDATE collection
SET
` + strings.Join(set, ", ") + `
WHERE
id = $6
SET ` + strings.Join(set, ", ") + `
WHERE id = ` + placeholder(len(args)+1) + `
RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility
`
args = append(args, update.ID)
collection := &storepb.Collection{}
var shortcutIDs, visibility string
var shortcutIDs []sql.NullInt32
var visibility string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&collection.Id,
&collection.CreatorId,
@ -76,20 +72,16 @@ func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollectio
&collection.Name,
&collection.Title,
&collection.Description,
&shortcutIDs,
pq.Array(&shortcutIDs),
&visibility,
); err != nil {
return nil, err
}
collection.ShortcutIds = []int32{}
if shortcutIDs != "" {
for _, idStr := range strings.Split(shortcutIDs, ",") {
shortcutID, err := util.ConvertStringToInt32(idStr)
if err != nil {
return nil, errors.Wrap(err, "failed to convert shortcut id")
}
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
for _, id := range shortcutIDs {
if id.Valid {
collection.ShortcutIds = append(collection.ShortcutIds, id.Int32)
}
}
collection.Visibility = convertVisibilityStringToStorepb(visibility)
@ -99,19 +91,18 @@ func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollectio
func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([]*storepb.Collection, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = $1"), append(args, *v)
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = $2"), append(args, *v)
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, "name = $3"), append(args, *v)
where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for i, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+i+1))
args = append(args, visibility)
for _, visibility := range v {
list, args = append(list, placeholder(len(args)+1)), append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ",")))
}
@ -140,7 +131,8 @@ func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([
list := make([]*storepb.Collection, 0)
for rows.Next() {
collection := &storepb.Collection{}
var shortcutIDs, visibility string
var shortcutIDs []sql.NullInt32
var visibility string
if err := rows.Scan(
&collection.Id,
&collection.CreatorId,
@ -149,20 +141,16 @@ func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([
&collection.Name,
&collection.Title,
&collection.Description,
&shortcutIDs,
pq.Array(&shortcutIDs),
&visibility,
); err != nil {
return nil, err
}
collection.ShortcutIds = []int32{}
if shortcutIDs != "" {
for _, idStr := range strings.Split(shortcutIDs, ",") {
shortcutID, err := util.ConvertStringToInt32(idStr)
if err != nil {
return nil, errors.Wrap(err, "failed to convert shortcut id")
}
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
for _, id := range shortcutIDs {
if id.Valid {
collection.ShortcutIds = append(collection.ShortcutIds, id.Int32)
}
}
collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
@ -182,13 +170,3 @@ func (d *DB) DeleteCollection(ctx context.Context, delete *store.DeleteCollectio
return nil
}
func vacuumCollection(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM collection WHERE creator_id NOT IN (SELECT id FROM user)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}

@ -2,7 +2,6 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
@ -17,9 +16,7 @@ func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Mem
args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")}
stmt := `
INSERT INTO memo (
` + strings.Join(set, ", ") + `
)
INSERT INTO memo (` + strings.Join(set, ", ") + `)
VALUES (` + placeholders(len(args)) + `)
RETURNING id, created_ts, updated_ts, row_status
`
@ -41,43 +38,34 @@ func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Mem
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb.Memo, error) {
set, args := []string{}, []any{}
if update.RowStatus != nil {
set = append(set, fmt.Sprintf("row_status = $%d", len(set)+1))
args = append(args, update.RowStatus.String())
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, update.RowStatus.String())
}
if update.Name != nil {
set = append(set, fmt.Sprintf("name = $%d", len(set)+1))
args = append(args, *update.Name)
set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *update.Name)
}
if update.Title != nil {
set = append(set, fmt.Sprintf("title = $%d", len(set)+1))
args = append(args, *update.Title)
set, args = append(set, "title = "+placeholder(len(args)+1)), append(args, *update.Title)
}
if update.Content != nil {
set = append(set, fmt.Sprintf("content = $%d", len(set)+1))
args = append(args, *update.Content)
set, args = append(set, "content = "+placeholder(len(args)+1)), append(args, *update.Content)
}
if update.Visibility != nil {
set = append(set, fmt.Sprintf("visibility = $%d", len(set)+1))
args = append(args, update.Visibility.String())
set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, update.Visibility.String())
}
if update.Tag != nil {
set = append(set, fmt.Sprintf("tag = $%d", len(set)+1))
args = append(args, *update.Tag)
set, args = append(set, "tag = "+placeholder(len(args)+1)), append(args, *update.Tag)
}
if len(set) == 0 {
return nil, errors.New("no update specified")
}
args = append(args, update.ID)
stmt := `
UPDATE memo
SET
` + strings.Join(set, ", ") + `
WHERE
id = $` + fmt.Sprint(len(set)+1) + `
SET ` + strings.Join(set, ", ") + `
WHERE id = ` + placeholder(len(args)+1) + `
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag
`
args = append(args, update.ID)
memo := &storepb.Memo{}
var rowStatus, visibility, tags string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
@ -103,27 +91,26 @@ func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*storepb.Memo, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = $1"), append(args, *v)
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = $2"), append(args, *v)
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = $3"), append(args, *v)
where, args = append(where, "row_status = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, "name = $4"), append(args, *v)
where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for i, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+i+1))
args = append(args, visibility)
for _, visibility := range v {
list, args = append(list, placeholder(len(args)+1)), append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ",")))
}
if v := find.Tag; v != nil {
where, args = append(where, "tag LIKE $"+fmt.Sprint(len(args)+1)), append(args, "%"+*v+"%")
where, args = append(where, "tag LIKE "+placeholder(len(args)+1)), append(args, "%"+*v+"%")
}
rows, err := d.db.QueryContext(ctx, `
@ -185,24 +172,10 @@ func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
return nil
}
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM memo WHERE creator_id NOT IN (SELECT id FROM user)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}
func placeholders(n int) string {
placeholder := ""
list := []string{}
for i := 0; i < n; i++ {
if i == 0 {
placeholder = fmt.Sprintf("$%d", i+1)
} else {
placeholder = fmt.Sprintf("%s, $%d", placeholder, i+1)
}
list = append(list, fmt.Sprintf("$%d", i+1))
}
return placeholder
return strings.Join(list, ", ")
}

@ -1,3 +1,13 @@
-- drop all tables first (PostgreSQL style)
DROP TABLE IF EXISTS migration_history CASCADE;
DROP TABLE IF EXISTS workspace_setting CASCADE;
DROP TABLE IF EXISTS "user" CASCADE;
DROP TABLE IF EXISTS user_setting CASCADE;
DROP TABLE IF EXISTS shortcut CASCADE;
DROP TABLE IF EXISTS activity CASCADE;
DROP TABLE IF EXISTS collection CASCADE;
DROP TABLE IF EXISTS memo CASCADE;
-- migration_history
CREATE TABLE migration_history (
version TEXT NOT NULL PRIMARY KEY,
@ -11,7 +21,7 @@ CREATE TABLE workspace_setting (
);
-- user
CREATE TABLE user (
CREATE TABLE "user" (
id SERIAL PRIMARY KEY,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
@ -22,11 +32,11 @@ CREATE TABLE user (
role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER'
);
CREATE INDEX idx_user_email ON user(email);
CREATE INDEX idx_user_email ON "user"(email);
-- user_setting
CREATE TABLE user_setting (
user_id INTEGER REFERENCES user(id) NOT NULL,
user_id INTEGER REFERENCES "user"(id) NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
PRIMARY KEY (user_id, key)
@ -35,7 +45,7 @@ CREATE TABLE user_setting (
-- shortcut
CREATE TABLE shortcut (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
@ -53,7 +63,7 @@ CREATE INDEX idx_shortcut_name ON shortcut(name);
-- activity
CREATE TABLE activity (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
type TEXT NOT NULL DEFAULT '',
level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO',
@ -63,7 +73,7 @@ CREATE TABLE activity (
-- collection
CREATE TABLE collection (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
name TEXT NOT NULL UNIQUE,
@ -78,7 +88,7 @@ CREATE INDEX idx_collection_name ON collection(name);
-- memo
CREATE TABLE memo (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',

@ -11,7 +11,7 @@ CREATE TABLE workspace_setting (
);
-- user
CREATE TABLE user (
CREATE TABLE "user" (
id SERIAL PRIMARY KEY,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
@ -22,11 +22,11 @@ CREATE TABLE user (
role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER'
);
CREATE INDEX idx_user_email ON user(email);
CREATE INDEX idx_user_email ON "user"(email);
-- user_setting
CREATE TABLE user_setting (
user_id INTEGER REFERENCES user(id) NOT NULL,
user_id INTEGER REFERENCES "user"(id) NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
PRIMARY KEY (user_id, key)
@ -35,7 +35,7 @@ CREATE TABLE user_setting (
-- shortcut
CREATE TABLE shortcut (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
@ -53,7 +53,7 @@ CREATE INDEX idx_shortcut_name ON shortcut(name);
-- activity
CREATE TABLE activity (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
type TEXT NOT NULL DEFAULT '',
level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO',
@ -63,7 +63,7 @@ CREATE TABLE activity (
-- collection
CREATE TABLE collection (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
name TEXT NOT NULL UNIQUE,
@ -78,7 +78,7 @@ CREATE INDEX idx_collection_name ON collection(name);
-- memo
CREATE TABLE memo (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',

@ -108,7 +108,7 @@ func (d *DB) Migrate(ctx context.Context) error {
}
// In demo mode, we should seed the database.
if d.profile.Mode == "demo" {
if err := d.seed(ctx); err != nil {
if err := d.Seed(ctx); err != nil {
return errors.Wrap(err, "failed to seed")
}
}
@ -119,7 +119,7 @@ func (d *DB) Migrate(ctx context.Context) error {
}
const (
latestSchemaFileName = "LATEST__SCHEMA.sql"
latestSchemaFileName = "LATEST.sql"
)
func (d *DB) applyLatestSchema(ctx context.Context) error {
@ -172,7 +172,7 @@ func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion str
return nil
}
func (d *DB) seed(ctx context.Context) error {
func (d *DB) Seed(ctx context.Context) error {
filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed"))
if err != nil {
return errors.Wrap(err, "failed to read seed files")

@ -1,9 +0,0 @@
DELETE FROM activity;
DELETE FROM shortcut;
DELETE FROM user_setting;
DELETE FROM user;
DELETE FROM workspace_setting;

@ -2,7 +2,6 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
@ -207,12 +206,6 @@ func (d *DB) DeleteShortcut(ctx context.Context, delete *store.DeleteShortcut) e
return err
}
func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM shortcut WHERE creator_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
return err
}
func filterTags(tags []string) []string {
result := []string{}
for _, tag := range tags {

@ -41,21 +41,20 @@ func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, e
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
set, args := []string{}, []any{}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = $"+placeholder(len(args)+1)), append(args, *v)
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Email; v != nil {
set, args = append(set, "email = $"+placeholder(len(args)+1)), append(args, *v)
set, args = append(set, "email = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Nickname; v != nil {
set, args = append(set, "nickname = $"+placeholder(len(args)+1)), append(args, *v)
set, args = append(set, "nickname = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.PasswordHash; v != nil {
set, args = append(set, "password_hash = $"+placeholder(len(args)+1)), append(args, *v)
set, args = append(set, "password_hash = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Role; v != nil {
set, args = append(set, "role = $"+placeholder(len(args)+1)), append(args, *v)
set, args = append(set, "role = "+placeholder(len(args)+1)), append(args, *v)
}
if len(set) == 0 {
return nil, errors.New("no fields to update")
}
@ -63,7 +62,7 @@ func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.U
stmt := `
UPDATE "user"
SET ` + strings.Join(set, ", ") + `
WHERE id = $` + placeholder(len(args)+1) + `
WHERE id = ` + placeholder(len(args)+1) + `
RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role
`
args = append(args, update.ID)
@ -88,19 +87,19 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = $"+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = $"+placeholder(len(args)+1)), append(args, v.String())
where, args = append(where, "row_status = "+placeholder(len(args)+1)), append(args, v.String())
}
if v := find.Email; v != nil {
where, args = append(where, "email = $"+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "email = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Nickname; v != nil {
where, args = append(where, "nickname = $"+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "nickname = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = $"+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "role = "+placeholder(len(args)+1)), append(args, *v)
}
query := `
@ -149,32 +148,10 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
}
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, `
DELETE FROM "user" WHERE id = $1
`, delete.ID); err != nil {
return err
}
if err := vacuumUserSetting(ctx, tx); err != nil {
if _, err := d.db.ExecContext(ctx, `DELETE FROM "user" WHERE id = $1`, delete.ID); err != nil {
return err
}
if err := vacuumShortcut(ctx, tx); err != nil {
return err
}
if err := vacuumMemo(ctx, tx); err != nil {
return err
}
if err := vacuumCollection(ctx, tx); err != nil {
return err
}
return tx.Commit()
return nil
}
func placeholder(n int) string {

@ -2,9 +2,7 @@ package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"google.golang.org/protobuf/encoding/protojson"
@ -51,10 +49,10 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
where, args := []string{"1 = 1"}, []any{}
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
where, args = append(where, fmt.Sprintf("key = $%d", len(args)+1)), append(args, v.String())
where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, v.String())
}
if v := find.UserID; v != nil {
where, args = append(where, fmt.Sprintf("user_id = $%d", len(args)+1)), append(args, *find.UserID)
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID)
}
query := `
@ -110,13 +108,3 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
return userSettingList, nil
}
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM user_setting WHERE user_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}

@ -55,7 +55,7 @@ func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspac
where, args := []string{"1 = 1"}, []interface{}{}
if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED {
where, args = append(where, "key = $"+placeholder(len(args)+1)), append(args, find.Key.String())
where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, find.Key.String())
}
query := `

@ -108,7 +108,7 @@ func (d *DB) Migrate(ctx context.Context) error {
}
// In demo mode, we should seed the database.
if d.profile.Mode == "demo" {
if err := d.seed(ctx); err != nil {
if err := d.Seed(ctx); err != nil {
return errors.Wrap(err, "failed to seed")
}
}
@ -119,7 +119,7 @@ func (d *DB) Migrate(ctx context.Context) error {
}
const (
latestSchemaFileName = "LATEST__SCHEMA.sql"
latestSchemaFileName = "LATEST.sql"
)
func (d *DB) applyLatestSchema(ctx context.Context) error {
@ -172,7 +172,7 @@ func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion str
return nil
}
func (d *DB) seed(ctx context.Context) error {
func (d *DB) Seed(ctx context.Context) error {
filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed"))
if err != nil {
return errors.Wrap(err, "failed to read seed files")

@ -14,6 +14,7 @@ type Driver interface {
Close() error
Migrate(ctx context.Context) error
Seed(ctx context.Context) error
// MigrationHistory model related methods.
UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error)

@ -12,11 +12,13 @@ import (
func TestActivityStore(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingAdminUser(ctx, ts)
require.NoError(t, err)
list, err := ts.ListActivities(ctx, &store.FindActivity{})
require.NoError(t, err)
require.Equal(t, 0, len(list))
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: -1,
CreatorID: user.ID,
Type: store.ActivityShortcutCreate,
Level: store.ActivityInfo,
Payload: "",

@ -22,6 +22,9 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
if err := dbDriver.Migrate(ctx); err != nil {
fmt.Printf("failed to migrate db, error: %+v\n", err)
}
if err := dbDriver.Seed(ctx); err != nil {
fmt.Printf("failed to seed db, error: %+v\n", err)
}
store := store.New(dbDriver, profile)
return store

@ -7,7 +7,6 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
@ -27,13 +26,6 @@ func TestUserStore(t *testing.T) {
Nickname: &userPatchNickname,
})
require.NoError(t, err)
_, err = ts.CreateShortcut(ctx, &storepb.Shortcut{
CreatorId: user.ID,
Name: "test_shortcut",
Link: "https://www.google.com",
Visibility: storepb.Visibility_PUBLIC,
})
require.NoError(t, err)
require.Equal(t, userPatchNickname, user.Nickname)
err = ts.DeleteUser(ctx, &store.DeleteUser{
ID: user.ID,
@ -42,9 +34,6 @@ func TestUserStore(t *testing.T) {
users, err = ts.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err)
require.Equal(t, 0, len(users))
shortcuts, err := ts.ListShortcuts(ctx, &store.FindShortcut{})
require.NoError(t, err)
require.Equal(t, 0, len(shortcuts))
}
// createTestingAdminUser creates a testing admin user.

Loading…
Cancel
Save