Browse Source

Code dedoublication in models/models.go

Just some code dedoublication in models/models.go
pull/444/head
Tristan Storch 10 years ago
parent
commit
bdfdf3cacb
  1. 44
      models/models.go

44
models/models.go

@ -55,11 +55,12 @@ func LoadModelsConfig() {
DbCfg.Path = setting.Cfg.MustValue("database", "PATH", "data/gogs.db") DbCfg.Path = setting.Cfg.MustValue("database", "PATH", "data/gogs.db")
} }
func NewTestEngine(x *xorm.Engine) (err error) { func getEngine() (*xorm.Engine, error) {
cnnstr := ""
switch DbCfg.Type { switch DbCfg.Type {
case "mysql": case "mysql":
x, err = xorm.NewEngine("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8", cnnstr = fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8",
DbCfg.User, DbCfg.Pwd, DbCfg.Host, DbCfg.Name)) DbCfg.User, DbCfg.Pwd, DbCfg.Host, DbCfg.Name)
case "postgres": case "postgres":
var host, port = "127.0.0.1", "5432" var host, port = "127.0.0.1", "5432"
fields := strings.Split(DbCfg.Host, ":") fields := strings.Split(DbCfg.Host, ":")
@ -69,46 +70,31 @@ func NewTestEngine(x *xorm.Engine) (err error) {
if len(fields) > 1 && len(strings.TrimSpace(fields[1])) > 0 { if len(fields) > 1 && len(strings.TrimSpace(fields[1])) > 0 {
port = fields[1] port = fields[1]
} }
cnnstr := fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s", cnnstr = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s",
DbCfg.User, DbCfg.Pwd, host, port, DbCfg.Name, DbCfg.SslMode) DbCfg.User, DbCfg.Pwd, host, port, DbCfg.Name, DbCfg.SslMode)
x, err = xorm.NewEngine("postgres", cnnstr)
case "sqlite3": case "sqlite3":
if !EnableSQLite3 { if !EnableSQLite3 {
return fmt.Errorf("Unknown database type: %s", DbCfg.Type) return nil, fmt.Errorf("Unknown database type: %s", DbCfg.Type)
} }
os.MkdirAll(path.Dir(DbCfg.Path), os.ModePerm) os.MkdirAll(path.Dir(DbCfg.Path), os.ModePerm)
x, err = xorm.NewEngine("sqlite3", DbCfg.Path) cnnstr = DbCfg.Path
default: default:
return fmt.Errorf("Unknown database type: %s", DbCfg.Type) return nil, fmt.Errorf("Unknown database type: %s", DbCfg.Type)
} }
return xorm.NewEngine(DbCfg.Type, cnnstr)
}
func NewTestEngine(x *xorm.Engine) (err error) {
x, err = getEngine()
if err != nil { if err != nil {
return fmt.Errorf("models.init(fail to conntect database): %v", err) return fmt.Errorf("models.init(fail to conntect database): %v", err)
} }
return x.Sync(tables...) return x.Sync(tables...)
} }
func SetEngine() (err error) { func SetEngine() (err error) {
switch DbCfg.Type { x, err = getEngine()
case "mysql":
x, err = xorm.NewEngine("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8",
DbCfg.User, DbCfg.Pwd, DbCfg.Host, DbCfg.Name))
case "postgres":
var host, port = "127.0.0.1", "5432"
fields := strings.Split(DbCfg.Host, ":")
if len(fields) > 0 && len(strings.TrimSpace(fields[0])) > 0 {
host = fields[0]
}
if len(fields) > 1 && len(strings.TrimSpace(fields[1])) > 0 {
port = fields[1]
}
x, err = xorm.NewEngine("postgres", fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s",
DbCfg.User, DbCfg.Pwd, host, port, DbCfg.Name, DbCfg.SslMode))
case "sqlite3":
os.MkdirAll(path.Dir(DbCfg.Path), os.ModePerm)
x, err = xorm.NewEngine("sqlite3", DbCfg.Path)
default:
return fmt.Errorf("Unknown database type: %s", DbCfg.Type)
}
if err != nil { if err != nil {
return fmt.Errorf("models.init(fail to conntect database): %v", err) return fmt.Errorf("models.init(fail to conntect database): %v", err)
} }

Loading…
Cancel
Save