Skip to content

Commit

Permalink
fix: improved table output, vram estimation (#134)
Browse files Browse the repository at this point in the history
* fix: improved table output, vram estimation

* fix: americanisations

* fix: americanisations

* fix: improved table output, vram estimation
  • Loading branch information
sammcj authored Nov 19, 2024
1 parent 850d8f8 commit 99b775c
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 113 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,12 @@ Inspect (`i`)
- `-v`: Print the version and exit
- `-h`, or `--host`: Specify the host for the Ollama API
- `-H`: Shortcut for `-h http://localhost:11434` (connect to local Ollama API) **new**
- `--vram`: Estimate vRAM usage for an existing (pulled) Ollama model name (e.g. `llama3.1:8b-instruct-q6_K`) huggingface model ID (e.g. `NousResearch/Hermes-2-Theta-Llama-3-8B`)
- `--vram`: Estimate vRAM usage for a model. Accepts:
- Ollama models (e.g. `llama3.1:8b-instruct-q6_K`, `qwen2:14b-q4_0`)
- HuggingFace models (e.g. `NousResearch/Hermes-2-Theta-Llama-3-8B`)
- `--fits`: Available memory in GB for context calculation (e.g. `6` for 6GB)
- `--vram-to-nth`: Top context length to search for (e.g. `40k` or `40000`)
- `--vram-to-nth` or `--context`: Maximum context length to analyze (e.g. `32k` or `128k`)
- `--quant`: Override quantisation level (e.g. `Q4_0`, `Q5_K_M`)

##### Simple model listing

Expand Down
2 changes: 1 addition & 1 deletion app_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ func (m *AppModel) inspectModelView(model Model) string {
{"Name", model.Name},
{"ID", model.ID},
{"Size (GB)", fmt.Sprintf("%.2f", model.Size)},
{"Quantization Level", model.QuantizationLevel},
{"quantisation Level", model.QuantizationLevel},
{"Modified", model.Modified.Format("2006-01-02")},
{"Family", model.Family},
}
Expand Down
22 changes: 0 additions & 22 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,16 @@ github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWp
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
github.com/charmbracelet/bubbletea v1.1.2 h1:naQXF2laRxyLyil/i7fxdpiz1/k06IKquhm4vBfHsIc=
github.com/charmbracelet/bubbletea v1.1.2/go.mod h1:9HIU/hBV24qKjlehyj8z1r/tR9TYTQEag+cWZnuXo8E=
github.com/charmbracelet/bubbletea v1.2.2 h1:EMz//Ky/aFS2uLcKqpCst5UOE6z5CFDGRsUpyXz0chs=
github.com/charmbracelet/bubbletea v1.2.2/go.mod h1:Qr6fVQw+wX7JkWWkVyXYk/ZUQ92a6XNekLXa3rR18MM=
github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ=
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
github.com/charmbracelet/lipgloss v0.13.1 h1:Oik/oqDTMVA01GetT4JdEC033dNzWoQHdWnHnQmXE2A=
github.com/charmbracelet/lipgloss v0.13.1/go.mod h1:zaYVJ2xKSKEnTEEbX6uAHabh2d975RJ+0yfkFpRBz5U=
github.com/charmbracelet/lipgloss v1.0.0 h1:O7VkGDvqEdGi93X+DeqsQ7PKHDgtQfF8j8/O2qFMQNg=
github.com/charmbracelet/lipgloss v1.0.0/go.mod h1:U5fy9Z+C38obMs+T+tJqst9VGzlOYGj4ri9reL3qUlo=
github.com/charmbracelet/x/ansi v0.4.0 h1:NqwHA4B23VwsDn4H3VcNX1W1tOmgnvY1NDx5tOXdnOU=
github.com/charmbracelet/x/ansi v0.4.0/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw=
github.com/charmbracelet/x/ansi v0.4.5 h1:LqK4vwBNaXw2AyGIICa5/29Sbdq58GbGdFngSexTdRM=
github.com/charmbracelet/x/ansi v0.4.5/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/term v0.2.0 h1:cNB9Ot9q8I711MyZ7myUR5HFWL/lc3OpU8jZ4hwm0x0=
github.com/charmbracelet/x/term v0.2.0/go.mod h1:GVxgxAbjUrmpvIINHIQnJJKpMlHiZ4cktEQCN6GWyF0=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
Expand All @@ -36,8 +28,6 @@ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
Expand Down Expand Up @@ -83,8 +73,6 @@ github.com/natefinch/lumberjack v2.0.0+incompatible h1:4QJd3OLAMgj7ph+yZTuX13Ld4
github.com/natefinch/lumberjack v2.0.0+incompatible/go.mod h1:Wi9p2TTF5DG5oU+6YfsmYQpsTIOm0B1VNzQg9Mw6nPk=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/ollama/ollama v0.3.14 h1:e94+Fb1PDqmD3O90g5cqUSkSxfNm9U3fHMIyaKQ8aSc=
github.com/ollama/ollama v0.3.14/go.mod h1:YrWoNkFnPOYsnDvsf/Ztb1wxU9/IXrNsQHqcxbY2r94=
github.com/ollama/ollama v0.4.2 h1:LEbpKDoCGnFoX9h5U+lkzA6xZ10CfV01jiaU8RL5VlQ=
github.com/ollama/ollama v0.4.2/go.mod h1:1GP0mGWnV3x930mGdgpXYEjmoe6xbMyp+XtLRsIH6XU=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
Expand Down Expand Up @@ -132,12 +120,8 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo=
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ=
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand All @@ -147,16 +131,10 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/term v0.26.0 h1:WEQa6V3Gja/BhNxg540hBip/kkaYtRg3cxg4oXSw4AU=
golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
87 changes: 67 additions & 20 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ func main() {
localHostFlag := flag.Bool("H", false, "Shortcut to connect to http://localhost:11434")
editFlag := flag.Bool("e", false, "Edit a model's modelfile")
// vRAM estimation flags
flag.Float64Var(&fitsVRAM, "fits", 0, "Highlight quant sizes and context sizes that fit in this amount of vRAM (in GB)")
vramFlag := flag.String("vram", "", "Estimate vRAM usage - Model ID or Ollama model name")
topContextFlag := flag.String("vram-to-nth", "65536", "Top context length to search for (e.g., 65536, 32k, 2m)")
// flag.Float64Var(&fitsVRAM, "fits", 0, "Highlight quant sizes and context sizes that fit in this amount of vRAM (in GB)")
vramFlag := flag.String("vram", "", "Model to estimate VRAM usage for (e.g., 'qwen2:q4_0' or 'meta-llama/Llama-2-7b')")
fitsVRAMFlag := flag.Float64("fits", 0, "Target VRAM constraint in GB (default: auto-detect)")
contextFlag := flag.String("context", "", "Maximum context length (e.g., '32k' or '128k')")
quantFlag := flag.String("quant", "", "Specific quantisation level (e.g., 'Q4_0', 'Q5_K_M')")
vramToNthFlag := flag.String("vram-to-nth", "65536", "Top context length to search for (e.g., 65536, 32k, 2m)")

flag.Parse()

Expand Down Expand Up @@ -154,34 +157,78 @@ func main() {
// Handle --vram flag
if *vramFlag != "" {
modelName := *vramFlag
logging.DebugLogger.Println("vRAM estimation flag detected")
if *vramFlag == "" {
fmt.Println("Error: Model ID or Ollama model name is required for vRAM estimation")
logging.DebugLogger.Printf("Processing vRAM estimation for model: %s", modelName)

// Parse the model identifier and quantisation level
baseModel, quantLevel, err := vramestimator.ParseModelIdentifier(modelName)
if err != nil {
fmt.Printf("Error parsing model identifier: %v\n", err)
os.Exit(1)
}

logging.DebugLogger.Println("Generating VRAM estimation table")
logging.DebugLogger.Printf("Parsed model identifier: base=%s, quant=%s", baseModel, quantLevel)

var ollamaModelInfo *vramestimator.OllamaModelInfo
var err error
// Override quantisation level if specified via flag
if *quantFlag != "" {
logging.DebugLogger.Printf("Overriding quantisation level from flag: %s", *quantFlag)
quantLevel = *quantFlag
}

// Check if the input is an Ollama model name (contains a colon)
if strings.Contains(modelName, ":") {
ollamaModelInfo, err = vramestimator.FetchOllamaModelInfo(cfg.OllamaAPIURL, modelName)
if err != nil {
fmt.Printf("Error fetching Ollama model info: %v\n", err)
os.Exit(1)
}
var isHuggingFaceModel = strings.Contains(baseModel, "/")
var isOllamaModel = !isHuggingFaceModel

// Parse the context size
var topContext int
var contextSource string
if *contextFlag != "" && *contextFlag != "65536" {
topContext, err = parseContextSize(*contextFlag)
contextSource = "context"
} else if *vramToNthFlag != "" {
topContext, err = parseContextSize(*vramToNthFlag)
contextSource = "vram-to-nth"
} else {
topContext = 65536
contextSource = "default"
}

// Parse the top context size
topContext, err := parseContextSize(*topContextFlag)
if err != nil {
fmt.Printf("Error parsing top context size: %v\n", err)
fmt.Printf("Error parsing context size from --%s flag: %v\n", contextSource, err)
os.Exit(1)
}

table, err := vramestimator.GenerateQuantTable(modelName, fitsVRAM, ollamaModelInfo, topContext)
logging.DebugLogger.Printf("Using context size %d from --%s", topContext, contextSource)

// If a specific quantisation level is provided, verify it exists
if quantLevel != "" {
if _, exists := vramestimator.GGUFMapping[strings.ToUpper(quantLevel)]; !exists {
fmt.Printf("Warning: Unknown quantisation level '%s'. Available levels:\n", quantLevel)
var levels []string
for level := range vramestimator.GGUFMapping {
levels = append(levels, level)
}
sort.Strings(levels)
for _, level := range levels {
fmt.Printf(" - %s\n", level)
}
os.Exit(1)
}
}

// Fetch model information from appropriate source
var ollamaModelInfo *vramestimator.OllamaModelInfo
if isOllamaModel {
logging.DebugLogger.Printf("Fetching model info from Ollama API for %s", baseModel)
ollamaModelInfo, err = vramestimator.FetchOllamaModelInfo(cfg.OllamaAPIURL, modelName)
if err != nil {
fmt.Printf("Error: Could not fetch Ollama model info: %v\n", err)
os.Exit(1)
}
} else {
logging.DebugLogger.Printf("Using HuggingFace model ID: %s", baseModel)
}

// Generate and display the table
table, err := vramestimator.GenerateQuantTable(baseModel, *fitsVRAMFlag, ollamaModelInfo, topContext)
if err != nil {
fmt.Printf("Error generating VRAM estimation table: %v\n", err)
os.Exit(1)
Expand Down
104 changes: 52 additions & 52 deletions operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -692,44 +692,44 @@ func unloadModel(client *api.Client, modelName string) (string, error) {

// editModelfile opens the modelfile in the user's editor and updates the model on the server with the new content
func editModelfile(client *api.Client, modelName string) (string, error) {
if client == nil {
return "", fmt.Errorf("error: Client is nil")
}
ctx := context.Background()

// Fetch the current modelfile from the server
showResp, err := client.Show(ctx, &api.ShowRequest{Name: modelName})
if err != nil {
return "", fmt.Errorf("error fetching modelfile for %s: %v", modelName, err)
}
modelfileContent := showResp.Modelfile

// Get editor from environment or config
editor := getEditor()
if editor == "" {
editor = "vim" // Default fallback
}

logging.DebugLogger.Printf("Using editor: %s for model: %s\n", editor, modelName)

// Write the fetched content to a temporary file
tempDir := os.TempDir()
newModelfilePath := filepath.Join(tempDir, fmt.Sprintf("%s_modelfile.txt", modelName))
err = os.WriteFile(newModelfilePath, []byte(modelfileContent), 0644)
if err != nil {
return "", fmt.Errorf("error writing modelfile to temp file: %v", err)
}
defer os.Remove(newModelfilePath)

// Open the local modelfile in the editor
cmd := exec.Command(editor, newModelfilePath)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err = cmd.Run()
if err != nil {
return "", fmt.Errorf("error running editor: %v", err)
}
if client == nil {
return "", fmt.Errorf("error: Client is nil")
}
ctx := context.Background()

// Fetch the current modelfile from the server
showResp, err := client.Show(ctx, &api.ShowRequest{Name: modelName})
if err != nil {
return "", fmt.Errorf("error fetching modelfile for %s: %v", modelName, err)
}
modelfileContent := showResp.Modelfile

// Get editor from environment or config
editor := getEditor()
if editor == "" {
editor = "vim" // Default fallback
}

logging.DebugLogger.Printf("Using editor: %s for model: %s\n", editor, modelName)

// Write the fetched content to a temporary file
tempDir := os.TempDir()
newModelfilePath := filepath.Join(tempDir, fmt.Sprintf("%s_modelfile.txt", modelName))
err = os.WriteFile(newModelfilePath, []byte(modelfileContent), 0644)
if err != nil {
return "", fmt.Errorf("error writing modelfile to temp file: %v", err)
}
defer os.Remove(newModelfilePath)

// Open the local modelfile in the editor
cmd := exec.Command(editor, newModelfilePath)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err = cmd.Run()
if err != nil {
return "", fmt.Errorf("error running editor: %v", err)
}

// Read the edited content from the local file
newModelfileContent, err := os.ReadFile(newModelfilePath)
Expand Down Expand Up @@ -767,7 +767,7 @@ func isLocalhost(url string) bool {
}

func parseContextSize(input string) (int, error) {
input = strings.ToLower(input)
input = strings.ToLower(strings.TrimSpace(input))
multiplier := 1

if strings.HasSuffix(input, "k") {
Expand All @@ -788,17 +788,17 @@ func parseContextSize(input string) (int, error) {

// getEditor returns the users editor
func getEditor() string {
// First check environment variable
if editor := os.Getenv("EDITOR"); editor != "" {
return editor
}

// Then check config
cfg, err := config.LoadConfig()
if err != nil {
logging.ErrorLogger.Printf("Error loading config for editor: %v\n", err)
return ""
}

return cfg.Editor
// First check environment variable
if editor := os.Getenv("EDITOR"); editor != "" {
return editor
}

// Then check config
cfg, err := config.LoadConfig()
if err != nil {
logging.ErrorLogger.Printf("Error loading config for editor: %v\n", err)
return ""
}

return cfg.Editor
}
Loading

0 comments on commit 99b775c

Please sign in to comment.