diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..887ad96 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/README.md b/README.md index 64f0356..35c1a23 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # stable-diffusion -pure go for stable-diffusion and support cross-platform. +pure go ( cgo free ) for stable-diffusion and support cross-platform. [![Go Reference](https://pkg.go.dev/badge/github.com/seasonjs/stable-diffusion.svg)](https://pkg.go.dev/github.com/seasonjs/stable-diffusion) -sd.go is a wrapper around [stable-diffusion-cpp](https://github.com/leejet/stable-diffusion.cpp), which is an adaption +sd.go is a wrapper around [stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp), which is an adaption of ggml.cpp.

@@ -43,45 +43,52 @@ If you are worried about the security of the dynamic library during the use proc This `stable-diffusion` golang library provide two api `Predict` and `ImagePredict`. -You don't need to download stable-diffusion dynamic library. +Usually you can use `NewAutoModel`, so you don't need to load the dynamic library. + +You can find a complete example in [examples](./exmaples) folder. + +Here is a simple example: ```go package main import ( + "github.com/seasonjs/hf-hub/api" sd "github.com/seasonjs/stable-diffusion" "io" "os" - "path/filepath" ) func main() { - options := sd.DefaultStableDiffusionOptions - options.Width = 512 - options.Height = 512 + options := sd.DefaultOptions - model, err := sd.NewStableDiffusionAutoModel(options) + model, err := sd.NewAutoModel(options) if err != nil { print(err.Error()) return } defer model.Close() - err = model.LoadFromFile("./models/mysd.safetensors") + hapi, err := api.NewApi() if err != nil { print(err.Error()) return } - var writers []io.Writer - girl, err := filepath.Abs("./assets/a_girl.png") + modelPath, err := hapi.Model("justinpinkney/miniSD").Get("miniSD.ckpt") if err != nil { print(err.Error()) return } + err = model.LoadFromFile(modelPath) + if err != nil { + print(err.Error()) + return + } + var writers []io.Writer filenames := []string{ - girl, + "../assets/love_cat0.png", } for _, filename := range filenames { file, err := os.Create(filename) @@ -93,49 +100,13 @@ func main() { writers = append(writers, file) } - err = model.Predict("a girl, high quality", writers) + err = model.Predict("british short hair cat, high quality", sd.DefaultFullParams, writers) if err != nil { print(err.Error()) } } - ``` -If `NewStableDiffusionAutoModel` can't automatic loading of dynamic library, please use `NewStableDiffusionModel` method load manually. - -```go -package main - -import ( - "fmt" - sd "github.com/seasonjs/stable-diffusion" - "runtime" -) - -func getLibrary() string { - switch runtime.GOOS { - case "darwin": - return "./deps/darwin/libstable-diffusion_arm64.dylib" - case "linux": - return "./deps/linux/libstable-diffusion.so" - case "windows": - return "./deps/windows/stable-diffusion_avx2_x64.dll" - default: - panic(fmt.Errorf("GOOS=%s is not supported", runtime.GOOS)) - } -} -func main() { - options := sd.DefaultStableDiffusionOptions - //It's important to set image size,different model support different size - options.Width = 256 - options.Height = 256 - model, err := sd.NewStableDiffusionModel(getLibrary(), options) - print(model, err) - //.... the usage is same as `NewStableDiffusionAutoModel`. they both return `StableDiffusionModel` struct - -} - -``` ## Packaging To ship a working program that includes this AI, you will need to include the following files: diff --git a/assets/love_cat0.png b/assets/love_cat0.png index 7a2c280..ef3b982 100644 Binary files a/assets/love_cat0.png and b/assets/love_cat0.png differ diff --git a/assets/love_cat1.png b/assets/love_cat1.png index 0b95c85..7c1bfd8 100644 Binary files a/assets/love_cat1.png and b/assets/love_cat1.png differ diff --git a/binding.go b/binding.go index 16a00ce..74da8db 100644 --- a/binding.go +++ b/binding.go @@ -5,309 +5,267 @@ package sd import ( "github.com/ebitengine/purego" + "runtime" "unsafe" ) -type SDLogLevel string +type LogLevel int -type RNGType string +type RNGType int -type SampleMethod string +type SampleMethod int -type Schedule string +type Schedule int -type GgmlType string +type WType int const ( - DEBUG SDLogLevel = "DEBUG" - INFO = "INFO" - WARN = "WARN" - ERROR = "ERROR" + DEBUG LogLevel = iota + INFO + WARN + ERROR ) const ( - STD_DEFAULT_RNG RNGType = "STD_DEFAULT_RNG" - CUDA_RNG = "CUDA_RNG" + STD_DEFAULT_RNG RNGType = iota + CUDA_RNG ) const ( - EULER_A SampleMethod = "EULER_A" - EULER = "EULER" - HEUN = "HEUN" - DPM2 = "DPM2" - DPMPP2S_A = "DPMPP2S_A" - DPMPP2M = "DPMPP2M" - DPMPP2Mv2 = "DPMPP2Mv2" - LCM = "LCM" - N_SAMPLE_METHODS = "N_SAMPLE_METHODS" + EULER_A SampleMethod = iota + EULER + HEUN + DPM2 + DPMPP2S_A + DPMPP2M + DPMPP2Mv2 + LCM + N_SAMPLE_METHODS ) const ( - DEFAULT Schedule = "DEFAULT" - DISCRETE = "DISCRETE" - KARRAS = "KARRAS" - N_SCHEDULES = "N_SCHEDULES" + DEFAULT Schedule = iota + DISCRETE + KARRAS + N_SCHEDULES ) const ( - T_DEFAULT GgmlType = "DEFAULT" - F32 = "F32" - F16 = "F16" - Q4_0 = "Q4_0" - Q4_1 = "Q4_1" - Q5_0 = "Q5_0" - Q5_1 = "Q5_1" - Q8_0 = "Q8_0" + F32 WType = 0 + F16 = 1 + Q4_0 = 2 + Q4_1 = 3 + Q5_0 = 6 + Q5_1 = 7 + Q8_0 = 8 + Q8_1 = 9 + Q2_K = 10 + Q3_K = 11 + Q4_K = 12 + Q5_K = 13 + Q6_K = 14 + Q8_K = 15 + I8 = 16 + I16 = 17 + I32 = 18 + COUNT = 19 // don't use this when specifying a type ) -const ( - cStableDiffusionFullDefaultParamsRef = "stable_diffusion_full_default_params_ref" - cStableDiffusionFullParamsSetNegativePrompt = "stable_diffusion_full_params_set_negative_prompt" - cStableDiffusionFullParamsSetCfgScale = "stable_diffusion_full_params_set_cfg_scale" - cStableDiffusionFullParamsSetWidth = "stable_diffusion_full_params_set_width" - cStableDiffusionFullParamsSetHeight = "stable_diffusion_full_params_set_height" - cStableDiffusionFullParamsSetSampleMethod = "stable_diffusion_full_params_set_sample_method" - cStableDiffusionFullParamsSetSampleSteps = "stable_diffusion_full_params_set_sample_steps" - cStableDiffusionFullParamsSetSeed = "stable_diffusion_full_params_set_seed" - cStableDiffusionFullParamsSetBatchCount = "stable_diffusion_full_params_set_batch_count" - cStableDiffusionFullParamsSetStrength = "stable_diffusion_full_params_set_strength" - - cStableDiffusionInit = "stable_diffusion_init" - cStableDiffusionLoadFromFile = "stable_diffusion_load_from_file" - cStableDiffusionPredictImage = "stable_diffusion_predict_image" - cStableDiffusionImagePredictImage = "stable_diffusion_image_predict_image" - cStableDiffusionSetLogLevel = "stable_diffusion_set_log_level" - cStableDiffusionGetSystemInfo = "stable_diffusion_get_system_info" - cStableDiffusionFree = "stable_diffusion_free" - cStableDiffusionFreeBuffer = "stable_diffusion_free_buffer" - cStableDiffusionFreeFullParams = "stable_diffusion_free_full_params" -) - -type StableDiffusionFullParams struct { - params uintptr - negativePrompt string - cfgScale float32 - width int - height int - sampleMethod SampleMethod - sampleSteps int - seed int64 - batchCount int - strength float32 +type CStableDiffusionCtx struct { + ctx uintptr } -type StableDiffusionCtx struct { +type CUpScalerCtx struct { ctx uintptr } +type CLogCallback func(level LogLevel, text string) + type CStableDiffusion interface { - StableDiffusionFullDefaultParamsRef() *StableDiffusionFullParams - StableDiffusionFullParamsSetNegativePrompt(params *StableDiffusionFullParams, negativePrompt string) - StableDiffusionFullParamsSetCfgScale(params *StableDiffusionFullParams, cfgScale float32) - StableDiffusionFullParamsSetWidth(params *StableDiffusionFullParams, width int) - StableDiffusionFullParamsSetHeight(params *StableDiffusionFullParams, height int) - StableDiffusionFullParamsSetSampleMethod(params *StableDiffusionFullParams, sampleMethod SampleMethod) - StableDiffusionFullParamsSetSampleSteps(params *StableDiffusionFullParams, sampleSteps int) - StableDiffusionFullParamsSetSeed(params *StableDiffusionFullParams, seed int64) - StableDiffusionFullParamsSetBatchCount(params *StableDiffusionFullParams, batchCount int) - StableDiffusionFullParamsSetStrength(params *StableDiffusionFullParams, strength float32) - - StableDiffusionInit(nThreads int, vaeDecodeOnly bool, taesdPath string, freeParamsImmediately bool, loraModelDir string, rngType RNGType) *StableDiffusionCtx - StableDiffusionLoadFromFile(ctx *StableDiffusionCtx, filePath string, vaePath string, wtype GgmlType, schedule Schedule) - StableDiffusionPredictImage(ctx *StableDiffusionCtx, params *StableDiffusionFullParams, prompt string) []byte - StableDiffusionImagePredictImage(ctx *StableDiffusionCtx, params *StableDiffusionFullParams, initImage []byte, prompt string) []byte - StableDiffusionSetLogLevel(level SDLogLevel) - StableDiffusionGetSystemInfo() string - StableDiffusionFree(ctx *StableDiffusionCtx) - StableDiffusionFreeBuffer(buffer uintptr) - StableDiffusionFreeFullParams(params *StableDiffusionFullParams) + NewCtx(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wType WType, rngType RNGType, schedule Schedule) *CStableDiffusionCtx + PredictImage(ctx *CStableDiffusionCtx, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, seed int64, batchCount int) []Image + ImagePredictImage(ctx *CStableDiffusionCtx, img Image, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, strength float32, seed int64, batchCount int) []Image + SetLogCallBack(cb CLogCallback) + GetSystemInfo() string + FreeCtx(ctx *CStableDiffusionCtx) + + NewUpscalerCtx(esrganPath string, nThreads int, wType WType) *CUpScalerCtx + FreeUpscalerCtx(ctx *CUpScalerCtx) + UpscaleImage(ctx *CUpScalerCtx, img Image, upscaleFactor uint32) Image + + Close() error +} + +type cImage struct { + width uint32 + height uint32 + channel uint32 + data uintptr +} + +type Image struct { + Width uint32 + Height uint32 + Channel uint32 + Data []byte } type CStableDiffusionImpl struct { libSd uintptr - cStableDiffusionFullDefaultParamsRef func() uintptr - cStableDiffusionFullParamsSetNegativePrompt func(params uintptr, negative_prompt string) - cStableDiffusionFullParamsSetCfgScale func(params uintptr, cfg_scale float32) - cStableDiffusionFullParamsSetWidth func(params uintptr, width int) - cStableDiffusionFullParamsSetHeight func(params uintptr, height int) - cStableDiffusionFullParamsSetSampleMethod func(params uintptr, sampleMethod string) - cStableDiffusionFullParamsSetSampleSteps func(params uintptr, sampleSteps int) - cStableDiffusionFullParamsSetSeed func(params uintptr, seed int64) - cStableDiffusionFullParamsSetBatchCount func(params uintptr, batchCount int) - cStableDiffusionFullParamsSetStrength func(params uintptr, strength float32) - - cStableDiffusionInit func(nThreads int, vaeDecodeOnly bool, taesdPath string, freeParamsImmediately bool, loraModelDir string, rngType string) uintptr - cStableDiffusionLoadFromFile func(ctx uintptr, filePath string, vaePath string, wtype string, schedule string) bool - cStableDiffusionPredictImage func(ctx uintptr, params uintptr, prompt string) *byte - cStableDiffusionImagePredictImage func(ctx uintptr, params uintptr, initImage *byte, prompt string) *byte - cStableDiffusionSetLogLevel func(level string) - cStableDiffusionGetSystemInfo func() string - cStableDiffusionFree func(options uintptr) - cStableDiffusionFreeBuffer func(options uintptr) - cStableDiffusionFreeFullParams func(options uintptr) + sdGetSystemInfo func() string + + newSdCtx func(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wtype int, rngType int, schedule int) uintptr + + sdSetLogCallback func(callback func(level int, text uintptr, data uintptr) uintptr, data uintptr) + + txt2img func(ctx uintptr, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, seed int64, batchCount int) uintptr + + img2img func(ctx uintptr, img uintptr, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, strength float32, seed int64, batchCount int) uintptr + + freeSdCtx func(ctx uintptr) + + newUpscalerCtx func(esrganPath string, nThreads int, wtype int) uintptr + + freeUpscalerCtx func(ctx uintptr) + + upscale func(ctx uintptr, img uintptr, upscaleFactor uint32) uintptr } -func NewCStableDiffusion(libraryPath string) (CStableDiffusion, error) { +func NewCStableDiffusion(libraryPath string) (*CStableDiffusionImpl, error) { libSd, err := openLibrary(libraryPath) if err != nil { return nil, err } - var ( - stableDiffusionFullDefaultParamsRef func() uintptr - stableDiffusionFullParamsSetNegativePrompt func(params uintptr, negativePrompt string) - stableDiffusionFullParamsSetCfgScale func(params uintptr, cfgScale float32) - stableDiffusionFullParamsSetWidth func(params uintptr, width int) - stableDiffusionFullParamsSetHeight func(params uintptr, height int) - stableDiffusionFullParamsSetSampleMethod func(params uintptr, sampleMethod string) - stableDiffusionFullParamsSetSampleSteps func(params uintptr, sampleSteps int) - stableDiffusionFullParamsSetSeed func(params uintptr, seed int64) - stableDiffusionFullParamsSetBatchCount func(params uintptr, batchCount int) - stableDiffusionFullParamsSetStrength func(params uintptr, strength float32) - - stableDiffusionInit func(nThreads int, vaeDecodeOnly bool, taesdPath string, freeParamsImmediately bool, loraModelDir string, rngType string) uintptr - stableDiffusionLoadFromFile func(ctx uintptr, filePath string, vaePath string, wtype string, schedule string) bool - stableDiffusionPredictImage func(ctx uintptr, params uintptr, prompt string) *byte - stableDiffusionImagePredictImage func(ctx uintptr, params uintptr, initImage *byte, prompt string) *byte - stableDiffusionSetLogLevel func(level string) - stableDiffusionGetSystemInfo func() string - stableDiffusionFree func(options uintptr) - stableDiffusionFreeBuffer func(options uintptr) - stableDiffusionFreeFullParams func(options uintptr) - ) - purego.RegisterLibFunc(&stableDiffusionFullDefaultParamsRef, libSd, cStableDiffusionFullDefaultParamsRef) - purego.RegisterLibFunc(&stableDiffusionFullParamsSetNegativePrompt, libSd, cStableDiffusionFullParamsSetNegativePrompt) - purego.RegisterLibFunc(&stableDiffusionFullParamsSetCfgScale, libSd, cStableDiffusionFullParamsSetCfgScale) - purego.RegisterLibFunc(&stableDiffusionFullParamsSetWidth, libSd, cStableDiffusionFullParamsSetWidth) - purego.RegisterLibFunc(&stableDiffusionFullParamsSetHeight, libSd, cStableDiffusionFullParamsSetHeight) - purego.RegisterLibFunc(&stableDiffusionFullParamsSetSampleMethod, libSd, cStableDiffusionFullParamsSetSampleMethod) - purego.RegisterLibFunc(&stableDiffusionFullParamsSetSampleSteps, libSd, cStableDiffusionFullParamsSetSampleSteps) - purego.RegisterLibFunc(&stableDiffusionFullParamsSetSeed, libSd, cStableDiffusionFullParamsSetSeed) - purego.RegisterLibFunc(&stableDiffusionFullParamsSetBatchCount, libSd, cStableDiffusionFullParamsSetBatchCount) - purego.RegisterLibFunc(&stableDiffusionFullParamsSetStrength, libSd, cStableDiffusionFullParamsSetStrength) - - purego.RegisterLibFunc(&stableDiffusionInit, libSd, cStableDiffusionInit) - purego.RegisterLibFunc(&stableDiffusionLoadFromFile, libSd, cStableDiffusionLoadFromFile) - purego.RegisterLibFunc(&stableDiffusionPredictImage, libSd, cStableDiffusionPredictImage) - purego.RegisterLibFunc(&stableDiffusionImagePredictImage, libSd, cStableDiffusionImagePredictImage) - purego.RegisterLibFunc(&stableDiffusionSetLogLevel, libSd, cStableDiffusionSetLogLevel) - purego.RegisterLibFunc(&stableDiffusionGetSystemInfo, libSd, cStableDiffusionGetSystemInfo) - purego.RegisterLibFunc(&stableDiffusionFree, libSd, cStableDiffusionFree) - purego.RegisterLibFunc(&stableDiffusionFreeBuffer, libSd, cStableDiffusionFreeBuffer) - purego.RegisterLibFunc(&stableDiffusionFreeFullParams, libSd, cStableDiffusionFreeFullParams) - - return &CStableDiffusionImpl{ - libSd, - stableDiffusionFullDefaultParamsRef, - stableDiffusionFullParamsSetNegativePrompt, - stableDiffusionFullParamsSetCfgScale, - stableDiffusionFullParamsSetWidth, - stableDiffusionFullParamsSetHeight, - stableDiffusionFullParamsSetSampleMethod, - stableDiffusionFullParamsSetSampleSteps, - stableDiffusionFullParamsSetSeed, - stableDiffusionFullParamsSetBatchCount, - stableDiffusionFullParamsSetStrength, - - stableDiffusionInit, - stableDiffusionLoadFromFile, - stableDiffusionPredictImage, - stableDiffusionImagePredictImage, - stableDiffusionSetLogLevel, - stableDiffusionGetSystemInfo, - stableDiffusionFree, - stableDiffusionFreeBuffer, - stableDiffusionFreeFullParams, - }, nil -} -func (c *CStableDiffusionImpl) StableDiffusionFullDefaultParamsRef() *StableDiffusionFullParams { - return &StableDiffusionFullParams{ - params: c.cStableDiffusionFullDefaultParamsRef(), - } -} + impl := CStableDiffusionImpl{} -func (c *CStableDiffusionImpl) StableDiffusionFullParamsSetNegativePrompt(params *StableDiffusionFullParams, negativePrompt string) { - c.cStableDiffusionFullParamsSetNegativePrompt(params.params, negativePrompt) - params.negativePrompt = negativePrompt -} + purego.RegisterLibFunc(&impl.sdSetLogCallback, libSd, "sd_get_system_info") -func (c *CStableDiffusionImpl) StableDiffusionFullParamsSetCfgScale(params *StableDiffusionFullParams, cfgScale float32) { - c.cStableDiffusionFullParamsSetCfgScale(params.params, cfgScale) - params.cfgScale = cfgScale -} + purego.RegisterLibFunc(&impl.newSdCtx, libSd, "new_sd_ctx") + purego.RegisterLibFunc(&impl.sdSetLogCallback, libSd, "sd_set_log_callback") + purego.RegisterLibFunc(&impl.txt2img, libSd, "txt2img") + purego.RegisterLibFunc(&impl.img2img, libSd, "img2img") + purego.RegisterLibFunc(&impl.freeSdCtx, libSd, "free_sd_ctx") -func (c *CStableDiffusionImpl) StableDiffusionFullParamsSetWidth(params *StableDiffusionFullParams, width int) { - c.cStableDiffusionFullParamsSetWidth(params.params, width) - params.width = width -} + purego.RegisterLibFunc(&impl.newUpscalerCtx, libSd, "new_upscaler_ctx") + purego.RegisterLibFunc(&impl.freeUpscalerCtx, libSd, "free_upscaler_ctx") + purego.RegisterLibFunc(&impl.upscale, libSd, "upscale") -func (c *CStableDiffusionImpl) StableDiffusionFullParamsSetHeight(params *StableDiffusionFullParams, height int) { - c.cStableDiffusionFullParamsSetHeight(params.params, height) - params.height = height + return &impl, nil } -func (c *CStableDiffusionImpl) StableDiffusionFullParamsSetSampleMethod(params *StableDiffusionFullParams, sampleMethod SampleMethod) { - c.cStableDiffusionFullParamsSetSampleMethod(params.params, string(sampleMethod)) - params.sampleMethod = sampleMethod +func (c *CStableDiffusionImpl) NewCtx(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wType WType, rngType RNGType, schedule Schedule) *CStableDiffusionCtx { + ctx := c.newSdCtx(modelPath, vaePath, taesdPath, loraModelDir, vaeDecodeOnly, vaeTiling, freeParamsImmediately, nThreads, int(wType), int(rngType), int(schedule)) + return &CStableDiffusionCtx{ + ctx: ctx, + } } -func (c *CStableDiffusionImpl) StableDiffusionFullParamsSetSampleSteps(params *StableDiffusionFullParams, sampleSteps int) { - c.cStableDiffusionFullParamsSetSampleSteps(params.params, sampleSteps) - params.sampleSteps = sampleSteps +func (c *CStableDiffusionImpl) PredictImage(ctx *CStableDiffusionCtx, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, seed int64, batchCount int) []Image { + images := c.txt2img(ctx.ctx, prompt, negativePrompt, clipSkip, cfgScale, width, height, int(sampleMethod), sampleSteps, seed, batchCount) + return goImageSlice(images, batchCount) } -func (c *CStableDiffusionImpl) StableDiffusionFullParamsSetSeed(params *StableDiffusionFullParams, seed int64) { - c.cStableDiffusionFullParamsSetSeed(params.params, seed) - params.seed = seed +func (c *CStableDiffusionImpl) ImagePredictImage(ctx *CStableDiffusionCtx, img Image, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, strength float32, seed int64, batchCount int) []Image { + images := c.img2img(ctx.ctx, uintptr(unsafe.Pointer(&img)), prompt, negativePrompt, clipSkip, cfgScale, width, height, int(sampleMethod), sampleSteps, strength, seed, batchCount) + return goImageSlice(images, batchCount) } -func (c *CStableDiffusionImpl) StableDiffusionFullParamsSetBatchCount(params *StableDiffusionFullParams, batchCount int) { - c.cStableDiffusionFullParamsSetBatchCount(params.params, batchCount) - params.batchCount = batchCount +func (c *CStableDiffusionImpl) SetLogCallBack(cb CLogCallback) { + c.sdSetLogCallback(func(level int, text uintptr, data uintptr) uintptr { + cb(LogLevel(level), goString(text)) + return 0 + }, 0) } -func (c *CStableDiffusionImpl) StableDiffusionFullParamsSetStrength(params *StableDiffusionFullParams, strength float32) { - c.cStableDiffusionFullParamsSetStrength(params.params, strength) - params.strength = strength +func (c *CStableDiffusionImpl) GetSystemInfo() string { + return c.sdGetSystemInfo() } -func (c *CStableDiffusionImpl) StableDiffusionInit(nThreads int, vaeDecodeOnly bool, taesdPath string, freeParamsImmediately bool, loraModelDir string, rngType RNGType) *StableDiffusionCtx { - return &StableDiffusionCtx{ - ctx: c.cStableDiffusionInit(nThreads, vaeDecodeOnly, taesdPath, freeParamsImmediately, loraModelDir, string(rngType)), +func (c *CStableDiffusionImpl) FreeCtx(ctx *CStableDiffusionCtx) { + ptr := *(*unsafe.Pointer)(unsafe.Pointer(&ctx.ctx)) + if ptr != nil { + c.freeSdCtx(ctx.ctx) } + ctx = nil + runtime.GC() } -func (c *CStableDiffusionImpl) StableDiffusionLoadFromFile(ctx *StableDiffusionCtx, filePath string, vaePath string, wtype GgmlType, schedule Schedule) { - c.cStableDiffusionLoadFromFile(ctx.ctx, filePath, vaePath, string(wtype), string(schedule)) -} - -func (c *CStableDiffusionImpl) StableDiffusionPredictImage(ctx *StableDiffusionCtx, params *StableDiffusionFullParams, prompt string) []byte { - data := c.cStableDiffusionPredictImage(ctx.ctx, params.params, prompt) - return unsafe.Slice(data, params.width*params.height*3*params.batchCount) -} +func (c *CStableDiffusionImpl) NewUpscalerCtx(esrganPath string, nThreads int, wType WType) *CUpScalerCtx { + ctx := c.newUpscalerCtx(esrganPath, nThreads, int(wType)) -func (c *CStableDiffusionImpl) StableDiffusionImagePredictImage(ctx *StableDiffusionCtx, params *StableDiffusionFullParams, initImage []byte, prompt string) []byte { - data := c.cStableDiffusionImagePredictImage(ctx.ctx, params.params, &initImage[0], prompt) - return unsafe.Slice(data, params.width*params.height*3) + return &CUpScalerCtx{ctx: ctx} } -func (c *CStableDiffusionImpl) StableDiffusionSetLogLevel(level SDLogLevel) { - c.cStableDiffusionSetLogLevel(string(level)) +func (c *CStableDiffusionImpl) FreeUpscalerCtx(ctx *CUpScalerCtx) { + ptr := *(*unsafe.Pointer)(unsafe.Pointer(&ctx.ctx)) + if ptr != nil { + c.freeUpscalerCtx(ctx.ctx) + } + ctx = nil + runtime.GC() } -func (c *CStableDiffusionImpl) StableDiffusionGetSystemInfo() string { - return c.cStableDiffusionGetSystemInfo() +func (c *CStableDiffusionImpl) Close() error { + if c.libSd != 0 { + err := closeLibrary(c.libSd) + return err + } + return nil } -func (c *CStableDiffusionImpl) StableDiffusionFree(ctx *StableDiffusionCtx) { - c.cStableDiffusionFree(ctx.ctx) +func (c *CStableDiffusionImpl) UpscaleImage(ctx *CUpScalerCtx, img Image, upscaleFactor uint32) Image { + uptr := c.upscale(ctx.ctx, uintptr(unsafe.Pointer(&img)), upscaleFactor) + ptr := *(*unsafe.Pointer)(unsafe.Pointer(&uptr)) + if ptr == nil { + return Image{} + } + cimg := (*cImage)(ptr) + dataPtr := *(*unsafe.Pointer)(unsafe.Pointer(&cimg.data)) + return Image{ + Width: cimg.width, + Height: cimg.height, + Channel: cimg.channel, + Data: unsafe.Slice((*byte)(dataPtr), cimg.channel*cimg.width*cimg.height), + } } -func (c *CStableDiffusionImpl) StableDiffusionFreeBuffer(buffer uintptr) { - c.cStableDiffusionFreeBuffer(buffer) +func goString(c uintptr) string { + // We take the address and then dereference it to trick go vet from creating a possible misuse of unsafe.Pointer + ptr := *(*unsafe.Pointer)(unsafe.Pointer(&c)) + if ptr == nil { + return "" + } + var length int + for { + if *(*byte)(unsafe.Add(ptr, uintptr(length))) == '\x00' { + break + } + length++ + } + return unsafe.String((*byte)(ptr), length) } -func (c *CStableDiffusionImpl) StableDiffusionFreeFullParams(params *StableDiffusionFullParams) { - c.cStableDiffusionFreeFullParams(params.params) +func goImageSlice(c uintptr, size int) []Image { + // We take the address and then dereference it to trick go vet from creating a possible misuse of unsafe.Pointer + ptr := *(*unsafe.Pointer)(unsafe.Pointer(&c)) + if ptr == nil { + return nil + } + img := (*cImage)(ptr) + goImages := make([]Image, 0, size) + imgSlice := unsafe.Slice(img, size) + for _, image := range imgSlice { + var gImg Image + gImg.Channel = image.channel + gImg.Width = image.width + gImg.Height = image.height + dataPtr := *(*unsafe.Pointer)(unsafe.Pointer(&image.data)) + if ptr != nil { + gImg.Data = unsafe.Slice((*byte)(dataPtr), image.channel*image.width*image.height) + } + goImages = append(goImages, gImg) + } + return goImages } diff --git a/binding_test.go b/binding_test.go index 4a0be48..1e11f52 100644 --- a/binding_test.go +++ b/binding_test.go @@ -1,10 +1,11 @@ // Copyright (c) seasonjs. All rights reserved. // Licensed under the MIT License. See License.txt in the project root for license information. -package sd +package sd_test import ( "fmt" + sd "github.com/seasonjs/stable-diffusion" "image" "image/color" "image/png" @@ -16,11 +17,11 @@ import ( func getLibrary() string { switch runtime.GOOS { case "darwin": - return "./deps/darwin/libstable-diffusion_arm64.dylib" + return "./deps/darwin/libsd-abi.dylib" case "linux": return "./deps/linux/libstable-diffusion.so" case "windows": - return "./deps/windows/stable-diffusion_avx2.dll" + return "./deps/windows/sd-abi_avx2.dll" default: panic(fmt.Errorf("GOOS=%s is not supported", runtime.GOOS)) } @@ -73,7 +74,7 @@ func writeToFile(t *testing.T, byteData []byte, height int, width int, outputPat t.Log("Image saved at", outputPath) } -func readFromFile(t *testing.T, path string) []byte { +func readFromFile(t *testing.T, path string) *sd.Image { file, err := os.Open(path) if err != nil { t.Error(err) @@ -98,36 +99,44 @@ func readFromFile(t *testing.T, path string) []byte { img[idx+2] = byte(b) } } - return img + return &sd.Image{ + Width: uint32(width), + Height: uint32(height), + Data: img, + } +} + +func TestNewCStableDiffusionText2Img(t *testing.T) { + diffusion, err := sd.NewCStableDiffusion(getLibrary()) + if err != nil { + t.Error(err) + return + } + diffusion.SetLogCallBack(func(level sd.LogLevel, text string) { + fmt.Printf("%s", text) + }) + ctx := diffusion.NewCtx("./models/miniSD.ckpt", "", "", "", false, false, true, 4, sd.F16, sd.CUDA_RNG, sd.DEFAULT) + defer diffusion.FreeCtx(ctx) + + images := diffusion.PredictImage(ctx, "british short hair cat, high quality", "", 0, 7.0, 256, 256, sd.EULER_A, 10, 43, 1) + + writeToFile(t, images[0].Data, 256, 256, "./assets/love_cat1.png") } -//func TestStableDiffusionTextToImage(t *testing.T) { -// sd, err := NewCStableDiffusion(getLibrary()) -// if err != nil { -// t.Log(err) -// } -// ctx := sd.NewStableDiffusionCtx(8, true, true, "", CUDA_RNG) -// defer ctx.Close() -// ctx.StableDiffusionLoadFromFile("./models/miniSD-ggml-model-q5_0.bin", DEFAULT) -// data, _ := ctx.StableDiffusionTextToImage("A lovely cat, high quality", "", 7.0, 256, 256, EULER_A, 20, 42, 1) -// writeToFile(t, data[1], 256, 256, "./data/love_cat2.png") -//} -// -//func TestStableDiffusionImgToImage(t *testing.T) { -// sd, err := NewCStableDiffusion(getLibrary()) -// if err != nil { -// t.Log(err) -// } -// ctx := sd.NewStableDiffusionCtx(8, false, true, "", CUDA_RNG) -// defer ctx.Close() -// ctx.StableDiffusionLoadFromFile("./models/miniSD-ggml-model-q5_0.bin", DEFAULT) -// img := readFromFile(t, "./data/love_cat2.png") -// data, _ := ctx.StableDiffusionImageToImage(img, "A lovely cat that theme pink", "", 7.0, 256, 256, EULER_A, 20, 0.4, 42) -// writeToFile(t, data, 256, 256, "./data/output1.png") -//} -// -//func TestBase64(t *testing.T) { -// img := readFromFile(t, "./assets/love_cat2.png") -// imgBase64 := base64.StdEncoding.EncodeToString(img) -// t.Log(imgBase64) -//} +func TestNewCStableDiffusionImg2Img(t *testing.T) { + diffusion, err := sd.NewCStableDiffusion(getLibrary()) + if err != nil { + t.Error(err) + return + } + diffusion.SetLogCallBack(func(level sd.LogLevel, text string) { + fmt.Printf("%s", text) + }) + ctx := diffusion.NewCtx("./models/miniSD.ckpt", "", "", "", false, false, true, -1, sd.F16, sd.CUDA_RNG, sd.DEFAULT) + defer diffusion.FreeCtx(ctx) + + img := readFromFile(t, "./assets/test.png") + images := diffusion.ImagePredictImage(ctx, *img, "cat wears shoes, high quality", "", 0, 7.0, 256, 256, sd.EULER_A, 20, 0.4, 42, 1) + + writeToFile(t, images[0].Data, 256, 256, "./assets/test1.png") +} diff --git a/deps/darwin/libsd-abi.dylib b/deps/darwin/libsd-abi.dylib index 163a244..9c160e3 100755 Binary files a/deps/darwin/libsd-abi.dylib and b/deps/darwin/libsd-abi.dylib differ diff --git a/deps/linux/libsd-abi.so b/deps/linux/libsd-abi.so index 5d42d75..48c15f8 100755 Binary files a/deps/linux/libsd-abi.so and b/deps/linux/libsd-abi.so differ diff --git a/deps/stable-diffusion.version b/deps/stable-diffusion.version index 29d65ba..9324828 100644 --- a/deps/stable-diffusion.version +++ b/deps/stable-diffusion.version @@ -1 +1 @@ -main-eefcc62 \ No newline at end of file +main-2c60396 \ No newline at end of file diff --git a/deps/windows/sd-abi_avx.dll b/deps/windows/sd-abi_avx.dll index 015fb7a..00d0bda 100644 Binary files a/deps/windows/sd-abi_avx.dll and b/deps/windows/sd-abi_avx.dll differ diff --git a/deps/windows/sd-abi_avx2.dll b/deps/windows/sd-abi_avx2.dll index 8c6d615..3463394 100644 Binary files a/deps/windows/sd-abi_avx2.dll and b/deps/windows/sd-abi_avx2.dll differ diff --git a/deps/windows/sd-abi_avx512.dll b/deps/windows/sd-abi_avx512.dll index a57c457..d0e41d8 100644 Binary files a/deps/windows/sd-abi_avx512.dll and b/deps/windows/sd-abi_avx512.dll differ diff --git a/deps/windows/sd-abi_cuda12.dll b/deps/windows/sd-abi_cuda12.dll index 14d4485..8b9eddd 100644 Binary files a/deps/windows/sd-abi_cuda12.dll and b/deps/windows/sd-abi_cuda12.dll differ diff --git a/embed_windows.go b/embed_windows.go index 9a97906..f9b98d8 100644 --- a/embed_windows.go +++ b/embed_windows.go @@ -34,24 +34,31 @@ func getDl(gpu bool) []byte { if err != nil { log.Println(err) } - log.Print("get gpu info: ", info) - // Name - // DriverVersion - // AdapterCompatibility - if info["AdapterCompatibility"] == "NVIDIA" { + log.Print("get gpu info: ", info["Name"]) + + if strings.Contains(info["Name"], "NVIDIA") { + log.Println("Use GPU CUDA12 instead.") return libStableDiffusionCuda12 } + log.Println("GPU not support, use CPU instead.") } if cpu.X86.HasAVX512 { + log.Println("Use CPU AVX512 instead.") return libStableDiffusionAvx512 } if cpu.X86.HasAVX2 { + log.Println("Use CPU AVX2 instead.") return libStableDiffusionAvx2 } - return libStableDiffusionAvx + if cpu.X86.HasAVX { + log.Println("Use CPU AVX instead.") + return libStableDiffusionAvx + } + + panic("Automatic loading of dynamic library failed, please use `NewRwkvModel` method load manually. ") } func runPowerShellCommand(command string) (string, error) { diff --git a/exmaples/custom_model/main.go b/exmaples/custom_model/main.go index aa21f14..6bb01ff 100644 --- a/exmaples/custom_model/main.go +++ b/exmaples/custom_model/main.go @@ -3,23 +3,26 @@ package main import ( sd "github.com/seasonjs/stable-diffusion" "io" + "log" "os" "path/filepath" ) func main() { - options := sd.DefaultStableDiffusionOptions - options.Width = 512 - options.Height = 512 + options := sd.DefaultOptions - model, err := sd.NewStableDiffusionAutoModel(options) + model, err := sd.NewAutoModel(options) if err != nil { print(err.Error()) return } defer model.Close() - err = model.LoadFromFile("./models/mysd.safetensors") + model.SetLogCallback(func(level sd.LogLevel, msg string) { + log.Println(msg) + }) + + err = model.LoadFromFile("./models/miniSD.ckpt") if err != nil { print(err.Error()) return @@ -45,8 +48,14 @@ func main() { writers = append(writers, file) } - err = model.Predict("a girl, high quality", writers) + params := sd.DefaultFullParams + params.BatchCount = 1 + params.Width = 256 + params.Height = 256 + + err = model.Predict("a girl, high quality", params, writers) if err != nil { print(err.Error()) + return } } diff --git a/exmaples/main.go b/exmaples/main.go index 76cf696..c7089b6 100644 --- a/exmaples/main.go +++ b/exmaples/main.go @@ -8,12 +8,9 @@ import ( ) func main() { - options := sd.DefaultStableDiffusionOptions - options.Width = 256 - options.Height = 256 - options.SampleSteps = 1 + options := sd.DefaultOptions - model, err := sd.NewStableDiffusionAutoModel(options) + model, err := sd.NewAutoModel(options) if err != nil { print(err.Error()) return @@ -51,7 +48,7 @@ func main() { writers = append(writers, file) } - err = model.Predict("british short hair cat, high quality", writers) + err = model.Predict("british short hair cat, high quality", sd.DefaultFullParams, writers) if err != nil { print(err.Error()) } diff --git a/go.mod b/go.mod index 52ea271..1e865cb 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module github.com/seasonjs/stable-diffusion go 1.21 require ( - github.com/ebitengine/purego v0.5.1 + github.com/ebitengine/purego v0.6.0-alpha.2.0.20231129131118-33b97fd6a58b github.com/seasonjs/hf-hub v0.0.3 - golang.org/x/sys v0.15.0 + golang.org/x/sys v0.16.0 ) require ( diff --git a/go.sum b/go.sum index 383bfff..c711a65 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,9 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/ebitengine/purego v0.5.1 h1:hNunhThpOf1vzKl49v6YxIsXLhl92vbBEv1/2Ez3ZrY= github.com/ebitengine/purego v0.5.1/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ= +github.com/ebitengine/purego v0.6.0-alpha.2/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ= +github.com/ebitengine/purego v0.6.0-alpha.2.0.20231129131118-33b97fd6a58b h1:Bv4UCplZd8EjKx6/55D+o9pNi2gJGCaoHCLW4X3VJLA= +github.com/ebitengine/purego v0.6.0-alpha.2.0.20231129131118-33b97fd6a58b/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ= github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= @@ -15,6 +18,7 @@ github.com/schollz/progressbar/v3 v3.14.1 h1:VD+MJPCr4s3wdhTc7OEJ/Z3dAeBzJ7yKH/P github.com/schollz/progressbar/v3 v3.14.1/go.mod h1:Zc9xXneTzWXF81TGoqL71u0sBPjULtEHYtj/WVgVy8E= github.com/seasonjs/hf-hub v0.0.2 h1:wPWBAwrcbVQd9UDLFHEp1oEua2n95ZKgRfk2k7sBKow= github.com/seasonjs/hf-hub v0.0.2/go.mod h1:bIgK32RZh/Zn6FYeTvfIttO19hEmfNz0QdaPqFG1UIA= +github.com/seasonjs/hf-hub v0.0.3 h1:CUtkyuOjbIGFpvQtRxqFP24rF8GVIvxnJjM478jsG6s= github.com/seasonjs/hf-hub v0.0.3/go.mod h1:bIgK32RZh/Zn6FYeTvfIttO19hEmfNz0QdaPqFG1UIA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= @@ -23,6 +27,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= diff --git a/sd.go b/sd.go index 47888a8..65010ed 100644 --- a/sd.go +++ b/sd.go @@ -12,6 +12,7 @@ import ( "io" "log" "os" + "path/filepath" ) type OutputsImageType string @@ -21,18 +22,23 @@ const ( JPEG = "JPEG" ) -type StableDiffusionOptions struct { - Threads int - VaeDecodeOnly bool +type Options struct { + VaePath string TaesdPath string - FreeParamsImmediately bool LoraModelDir string + VaeDecodeOnly bool + VaeTiling bool + FreeParamsImmediately bool + Threads int + Wtype WType RngType RNGType - VaePath string - WType GgmlType Schedule Schedule + GpuEnable bool +} +type FullParams struct { NegativePrompt string + ClipSkip int CfgScale float32 Width int Height int @@ -41,32 +47,23 @@ type StableDiffusionOptions struct { Strength float32 Seed int64 BatchCount int - GpuEnable bool OutputsImageType OutputsImageType } -type StableDiffusionModel struct { - ctx *StableDiffusionCtx - options *StableDiffusionOptions - params *StableDiffusionFullParams - csd CStableDiffusion - isAutoLoad bool - dylibPath string -} - -var DefaultStableDiffusionOptions = StableDiffusionOptions{ +var DefaultOptions = Options{ Threads: -1, // auto VaeDecodeOnly: true, FreeParamsImmediately: true, - LoraModelDir: "", RngType: CUDA_RNG, - WType: T_DEFAULT, + Wtype: F32, Schedule: DEFAULT, +} +var DefaultFullParams = FullParams{ NegativePrompt: "out of frame, lowers, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature", CfgScale: 7.0, - Width: 500, - Height: 500, + Width: 512, + Height: 512, SampleMethod: EULER_A, SampleSteps: 20, Strength: 0.4, @@ -75,39 +72,18 @@ var DefaultStableDiffusionOptions = StableDiffusionOptions{ OutputsImageType: PNG, } -func (s *StableDiffusionOptions) toStableDiffusionFullParamsRef(c CStableDiffusion) *StableDiffusionFullParams { - params := c.StableDiffusionFullDefaultParamsRef() - if len(s.NegativePrompt) != 0 { - c.StableDiffusionFullParamsSetNegativePrompt(params, s.NegativePrompt) - } - if s.CfgScale != 0 { - c.StableDiffusionFullParamsSetCfgScale(params, s.CfgScale) - } - if s.Width != 0 { - c.StableDiffusionFullParamsSetWidth(params, s.Width) - } - if s.Height != 0 { - c.StableDiffusionFullParamsSetHeight(params, s.Height) - } - c.StableDiffusionFullParamsSetSampleMethod(params, s.SampleMethod) - if s.SampleSteps != 0 { - c.StableDiffusionFullParamsSetSampleSteps(params, s.SampleSteps) - } - if s.Strength != 0 { - c.StableDiffusionFullParamsSetStrength(params, s.Strength) - } - if s.Seed != 0 { - c.StableDiffusionFullParamsSetSeed(params, s.Seed) - } - //default batch count is 1 in c++ - if s.BatchCount != 0 { - c.StableDiffusionFullParamsSetBatchCount(params, s.BatchCount) - } - - return params +type Model struct { + ctx *CStableDiffusionCtx + options *Options + csd CStableDiffusion + isAutoLoad bool + dylibPath string + diffusionModelPath string + esrganPath string + upscalerCtx *CUpScalerCtx } -func NewStableDiffusionAutoModel(options StableDiffusionOptions) (*StableDiffusionModel, error) { +func NewAutoModel(options Options) (*Model, error) { file, err := dumpSDLibrary(options.GpuEnable) if err != nil { return nil, err @@ -115,11 +91,10 @@ func NewStableDiffusionAutoModel(options StableDiffusionOptions) (*StableDiffusi if options.GpuEnable { log.Printf("If you want to try offload your model to the GPU. " + - "Please confirm the size of your GPU memory to prevent memory overflow." + - "If the model is larger than GPU memory, please specify the layers to offload.") + "Please confirm the size of your GPU memory to prevent memory overflow.") } dylibPath := file.Name() - model, err := NewStableDiffusionModel(dylibPath, options) + model, err := NewModel(dylibPath, options) if err != nil { return nil, err } @@ -127,56 +102,111 @@ func NewStableDiffusionAutoModel(options StableDiffusionOptions) (*StableDiffusi return model, nil } -func NewStableDiffusionModel(dylibPath string, options StableDiffusionOptions) (*StableDiffusionModel, error) { - sd, err := NewCStableDiffusion(dylibPath) +func NewModel(dylibPath string, options Options) (*Model, error) { + csd, err := NewCStableDiffusion(dylibPath) if err != nil { return nil, err } - if options.BatchCount < 1 { - options.BatchCount = 1 - } - - ctx := sd.StableDiffusionInit(options.Threads, options.VaeDecodeOnly, options.TaesdPath, options.FreeParamsImmediately, options.LoraModelDir, options.RngType) - params := options.toStableDiffusionFullParamsRef(sd) - return &StableDiffusionModel{ + return &Model{ dylibPath: dylibPath, - ctx: ctx, options: &options, - params: params, - csd: sd, + csd: csd, }, nil } -func (sd *StableDiffusionModel) LoadFromFile(path string) error { +func (sd *Model) LoadFromFile(path string) error { + if sd.ctx != nil { + sd.csd.FreeCtx(sd.ctx) + sd.ctx = nil + log.Printf("model already loaded, free old model") + } + _, err := os.Stat(path) if err != nil { return errors.New("the system cannot find the model file specified") } - sd.csd.StableDiffusionLoadFromFile(sd.ctx, path, sd.options.TaesdPath, sd.options.WType, sd.options.Schedule) + + if !filepath.IsAbs(path) { + sd.diffusionModelPath, err = filepath.Abs(path) + if err != nil { + return err + } + } else { + sd.diffusionModelPath = path + } + + ctx := sd.csd.NewCtx(path, + sd.options.VaePath, + sd.options.TaesdPath, + sd.options.LoraModelDir, + sd.options.VaeDecodeOnly, + sd.options.VaeTiling, + sd.options.FreeParamsImmediately, + sd.options.Threads, + sd.options.Wtype, + sd.options.RngType, + sd.options.Schedule) + sd.ctx = ctx return nil } -func (sd *StableDiffusionModel) SetOptions(options StableDiffusionOptions) { +func (sd *Model) SetOptions(options Options) { + if sd.ctx != nil { + sd.csd.FreeCtx(sd.ctx) + sd.ctx = nil + log.Printf("model already loaded, free old model and set new options") + } sd.options = &options - sd.params = options.toStableDiffusionFullParamsRef(sd.csd) + ctx := sd.csd.NewCtx( + sd.diffusionModelPath, + sd.options.VaePath, + sd.options.TaesdPath, + sd.options.LoraModelDir, + sd.options.VaeDecodeOnly, + sd.options.VaeTiling, + sd.options.FreeParamsImmediately, + sd.options.Threads, + sd.options.Wtype, + sd.options.RngType, + sd.options.Schedule) + sd.ctx = ctx } -func (sd *StableDiffusionModel) Predict(prompt string, writer []io.Writer) error { - if len(writer) != sd.options.BatchCount { +func (sd *Model) Predict(prompt string, params FullParams, writer []io.Writer) error { + if len(writer) != params.BatchCount { return errors.New("writer count not match batch count") } - data := sd.csd.StableDiffusionPredictImage( + if sd.ctx == nil { + return errors.New("model not loaded") + } + + if params.Width%8 != 0 || params.Height%8 != 0 { + return errors.New("width and height must be multiples of 8") + } + + images := sd.csd.PredictImage( sd.ctx, - sd.params, prompt, + params.NegativePrompt, + params.ClipSkip, + params.CfgScale, + params.Width, + params.Height, + params.SampleMethod, + params.SampleSteps, + params.Seed, + params.BatchCount, ) - result := chunkBytes(data, sd.options.BatchCount) + if images == nil || len(images) != params.BatchCount { + return errors.New("predict failed") + } - for i := 0; i < sd.options.BatchCount; i++ { - outputsImage := bytesToImage(result[i], sd.options.Width, sd.options.Height) - err := imageToWriter(outputsImage, sd.options.OutputsImageType, writer[i]) + for i, img := range images { + outputsImage := bytesToImage(img.Data, int(img.Width), int(img.Height)) + + err := imageToWriter(outputsImage, params.OutputsImageType, writer[i]) if err != nil { return err } @@ -185,32 +215,81 @@ func (sd *StableDiffusionModel) Predict(prompt string, writer []io.Writer) error return nil } -func (sd *StableDiffusionModel) ImagePredict(reader io.Reader, prompt string, writer io.Writer) error { +func (sd *Model) ImagePredict(reader io.Reader, prompt string, params FullParams, writer []io.Writer) error { + + if len(writer) != params.BatchCount { + return errors.New("writer count not match batch count") + } + + if sd.ctx == nil { + return errors.New("model not loaded") + } + decode, _, err := image.Decode(reader) if err != nil { return err } - bytesImg := imageToBytes(decode) - outputsBytes := sd.csd.StableDiffusionImagePredictImage( + initImage := imageToBytes(decode) + images := sd.csd.ImagePredictImage( sd.ctx, - sd.params, - bytesImg, + initImage, prompt, + params.NegativePrompt, + params.ClipSkip, + params.CfgScale, + params.Width, + params.Height, + params.SampleMethod, + params.SampleSteps, + params.Strength, + params.Seed, + params.BatchCount, ) - outputsImage := bytesToImage(outputsBytes, sd.options.Width, sd.options.Height) - return imageToWriter(outputsImage, sd.options.OutputsImageType, writer) + for i, img := range images { + outputsImage := bytesToImage(img.Data, int(img.Width), int(img.Height)) + err = imageToWriter(outputsImage, params.OutputsImageType, writer[i]) + if err != nil { + return err + } + } + return nil +} + +func (sd *Model) UpscaleImage(reader io.Reader, esrganPath string, upscaleFactor uint32, writer io.Writer) error { + if sd.upscalerCtx == nil { + sd.upscalerCtx = sd.csd.NewUpscalerCtx(esrganPath, sd.options.Threads, sd.options.Wtype) + } + decode, _, err := image.Decode(reader) + if err != nil { + return err + } + initImage := imageToBytes(decode) + img := sd.csd.UpscaleImage(sd.upscalerCtx, initImage, upscaleFactor) + outputsImage := bytesToImage(img.Data, int(img.Width), int(img.Height)) + err = imageToWriter(outputsImage, PNG, writer) + return err +} + +func (sd *Model) SetLogCallback(cb CLogCallback) { + sd.csd.SetLogCallBack(cb) } -func (sd *StableDiffusionModel) Close() error { +func (sd *Model) Close() error { if sd.ctx != nil { - sd.csd.StableDiffusionFree(sd.ctx) + sd.csd.FreeCtx(sd.ctx) sd.ctx = nil } - if sd.params != nil { - sd.csd.StableDiffusionFreeFullParams(sd.params) - sd.params = nil + if sd.upscalerCtx != nil { + sd.csd.FreeUpscalerCtx(sd.upscalerCtx) + sd.upscalerCtx = nil + } + if sd.csd != nil { + err := sd.csd.Close() + if err != nil { + return err + } } if sd.isAutoLoad { @@ -222,7 +301,7 @@ func (sd *StableDiffusionModel) Close() error { return nil } -func imageToBytes(decode image.Image) []byte { +func imageToBytes(decode image.Image) Image { bounds := decode.Bounds() width := bounds.Max.X - bounds.Min.X height := bounds.Max.Y - bounds.Min.Y @@ -237,7 +316,11 @@ func imageToBytes(decode image.Image) []byte { bytesImg[idx+2] = byte(b >> 8) } } - return bytesImg + return Image{ + Width: uint32(width), + Height: uint32(height), + Data: bytesImg, + } } func bytesToImage(byteData []byte, width, height int) image.Image { @@ -275,19 +358,19 @@ func imageToWriter(image image.Image, imageType OutputsImageType, writer io.Writ return nil } -func chunkBytes(data []byte, chunks int) [][]byte { - length := len(data) - chunkSize := (length + chunks - 1) / chunks - result := make([][]byte, chunks) - - for i := 0; i < chunks; i++ { - start := i * chunkSize - end := (i + 1) * chunkSize - if end > length { - end = length - } - result[i] = data[start:end:end] - } - - return result -} +//func chunkBytes(data []byte, chunks int) [][]byte { +// length := len(data) +// chunkSize := (length + chunks - 1) / chunks +// result := make([][]byte, chunks) +// +// for i := 0; i < chunks; i++ { +// start := i * chunkSize +// end := (i + 1) * chunkSize +// if end > length { +// end = length +// } +// result[i] = data[start:end:end] +// } +// +// return result +//} diff --git a/sd_test.go b/sd_test.go index 67dcd5b..cddff27 100644 --- a/sd_test.go +++ b/sd_test.go @@ -1,25 +1,24 @@ -package sd +package sd_test import ( - "image" + sd "github.com/seasonjs/stable-diffusion" "io" "os" "testing" ) func TestNewStableDiffusionAutoModelPredict(t *testing.T) { - options := DefaultStableDiffusionOptions - options.Width = 256 - options.Height = 256 - options.BatchCount = 2 - //options.SampleSteps = 2 + options := sd.DefaultOptions t.Log(options) - model, err := NewStableDiffusionAutoModel(options) + model, err := sd.NewAutoModel(options) if err != nil { t.Error(err) return } defer model.Close() + model.SetLogCallback(func(level sd.LogLevel, msg string) { + t.Log(msg) + }) err = model.LoadFromFile("./models/miniSD.ckpt") if err != nil { t.Error(err) @@ -27,10 +26,7 @@ func TestNewStableDiffusionAutoModelPredict(t *testing.T) { } var writers []io.Writer filenames := []string{ - "./assets/love_cat0.png", - "./assets/love_cat1.png", - //"./assets/love_cat5.png", - //"./assets/love_cat6.png" + "./assets/love_cat2.png", } for _, filename := range filenames { file, err := os.Create(filename) @@ -42,7 +38,12 @@ func TestNewStableDiffusionAutoModelPredict(t *testing.T) { writers = append(writers, file) } - err = model.Predict("british short hair cat,high quality", writers) + params := sd.DefaultFullParams + params.BatchCount = 1 + params.Width = 256 + params.Height = 256 + params.NegativePrompt = "" + err = model.Predict("british short hair cat, high quality", params, writers) if err != nil { t.Error(err) return @@ -50,12 +51,10 @@ func TestNewStableDiffusionAutoModelPredict(t *testing.T) { } func TestNewStableDiffusionAutoModelImagePredict(t *testing.T) { - options := DefaultStableDiffusionOptions - options.Width = 256 - options.Height = 256 + options := sd.DefaultOptions options.VaeDecodeOnly = false t.Log(options) - model, err := NewStableDiffusionAutoModel(options) + model, err := sd.NewAutoModel(options) if err != nil { t.Error(err) return @@ -73,44 +72,24 @@ func TestNewStableDiffusionAutoModelImagePredict(t *testing.T) { } defer inFile.Close() - outfile, err := os.Create("./assets/shoes_cat.png") - if err != nil { - t.Error(err) - return - } - defer outfile.Close() - - err = model.ImagePredict(inFile, "the cat that wears shoe", outfile) - if err != nil { - t.Error(err) - return - } -} - -func TestImageToByte(t *testing.T) { - inFile, err := os.Open("./assets/love_cat0.png") - if err != nil { - t.Error(err) - return - } - defer inFile.Close() - - outfile, err := os.Create("./assets/test_cat0.png") - if err != nil { - t.Error(err) - return + var writers []io.Writer + filenames := []string{ + "./assets/love_cat0_m.png", + "./assets/love_cat1_m.png", + //"./assets/love_cat5.png", + //"./assets/love_cat6.png" } - defer outfile.Close() - - decode, _, err := image.Decode(inFile) - if err != nil { - t.Error(err) - return + for _, filename := range filenames { + file, err := os.Create(filename) + if err != nil { + t.Error(err) + return + } + defer file.Close() + writers = append(writers, file) } - outputsBytes := imageToBytes(decode) - outputsImage := bytesToImage(outputsBytes, 256, 256) - err = imageToWriter(outputsImage, "PNG", outfile) + err = model.ImagePredict(inFile, "the cat that wears shoe", sd.DefaultFullParams, writers) if err != nil { t.Error(err) return