Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vulkan k-quant mmq and ggml-backend offload functionality #6155

Merged
merged 12 commits into from Mar 29, 2024

Conversation

0cc4m
Copy link
Collaborator

@0cc4m 0cc4m commented Mar 19, 2024

I added k-quant mmq shaders and cleaned up the cpu-assist functions, as they are now replaced with the ggml-backend code.

I also fixed the working buffer allocation code, it should now use noticeably less VRAM.

Should hopefully fix #5848

@slaren
Copy link
Collaborator

slaren commented Mar 19, 2024

Note: this is on master, I was trying to test it there first as a baseline, but the same happens on the PR.

I have a 3080 and a 3090 Ti, and I am having some trouble getting ggml-vulkan to use the 3090 Ti. It looks like it only finds the 3080.

ggml_vulkan: Found 1 Vulkan devices:
Vulkan0: NVIDIA GeForce RTX 3080 | uma: 0 | fp16: 1 | warp size: 32

vulkaninfo lists both GPUs:


Layers: count = 11
==================
VK_LAYER_KHRONOS_profiles (Khronos Profiles layer) Vulkan version 1.3.275, layer version 1:
	Layer Extensions: count = 0
	Devices: count = 2
		GPU id = 0 (NVIDIA GeForce RTX 3080)
		Layer-Device Extensions: count = 1
			VK_EXT_tooling_info : extension revision 1

		GPU id = 1 (NVIDIA GeForce RTX 3090 Ti)
		Layer-Device Extensions: count = 1
			VK_EXT_tooling_info : extension revision 1

@0cc4m
Copy link
Collaborator Author

0cc4m commented Mar 19, 2024

@slaren Because Vulkan supports a lot more devices I haven't settled on a multigpu-default yet. Currently it just uses the first one by default, if you want more you need to set the GGML_VK_VISIBLE_DEVICES environment variable, similar to the cuda one. 0,1 for you.

We could use this chance to discuss and set a better default as well. I guess all dedicated GPUs would be sane?

@slaren
Copy link
Collaborator

slaren commented Mar 19, 2024

Sorry, I had forgotten about GGML_VK_VISIBLE_DEVICES. This is the performance I get with the 3090 Ti. There was a big improvement in Q4_K performance with -ngl 0, and moderate improvements in most cases, although not as big as with the CUDA backend.

model size params backend ngl n_batch n_ubatch test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 0 1024 1024 pp 1024 221.47 ± 23.77
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 10 1024 1024 pp 1024 256.31 ± 5.29
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 20 1024 1024 pp 1024 389.19 ± 11.24
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 30 1024 1024 pp 1024 712.05 ± 72.24
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 1024 1024 pp 1024 972.84 ± 2.19
llama 7B Q4_K - Small 3.59 GiB 6.74 B Vulkan 0 1024 1024 pp 1024 33.71 ± 4.48
llama 7B Q4_K - Small 3.59 GiB 6.74 B Vulkan 10 1024 1024 pp 1024 274.82 ± 26.43
llama 7B Q4_K - Small 3.59 GiB 6.74 B Vulkan 20 1024 1024 pp 1024 345.44 ± 4.25
llama 7B Q4_K - Small 3.59 GiB 6.74 B Vulkan 30 1024 1024 pp 1024 574.53 ± 31.10
llama 7B Q4_K - Small 3.59 GiB 6.74 B Vulkan 99 1024 1024 pp 1024 790.17 ± 0.47

build: d0d5de4 (2464)

model size params backend ngl n_batch n_ubatch test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 0 1024 1024 pp 1024 328.63 ± 8.48
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 10 1024 1024 pp 1024 354.88 ± 7.73
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 20 1024 1024 pp 1024 494.57 ± 33.10
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 30 1024 1024 pp 1024 797.17 ± 52.51
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 1024 1024 pp 1024 983.79 ± 1.87
llama 7B Q4_K - Small 3.59 GiB 6.74 B Vulkan 0 1024 1024 pp 1024 221.21 ± 7.87
llama 7B Q4_K - Small 3.59 GiB 6.74 B Vulkan 10 1024 1024 pp 1024 282.71 ± 12.36
llama 7B Q4_K - Small 3.59 GiB 6.74 B Vulkan 20 1024 1024 pp 1024 394.85 ± 20.59
llama 7B Q4_K - Small 3.59 GiB 6.74 B Vulkan 30 1024 1024 pp 1024 682.70 ± 13.70
llama 7B Q4_K - Small 3.59 GiB 6.74 B Vulkan 99 1024 1024 pp 1024 811.51 ± 0.51

build: 86386e2 (2460)

@slaren
Copy link
Collaborator

slaren commented Mar 19, 2024

We could use this chance to discuss and set a better default as well. I guess all dedicated GPUs would be sane?

I would think so. It will still take a while, but in the future I would like to improve the device selection, allow the users to select the devices that they want to use by name, and improve the defaults. It is already possible to select what devices to use with -ts and -mg, but admittedly it is not very intuitive.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Mar 19, 2024

I found two issues that I have to fix before this can get merged:

  • The backend-offload code does not take APUs into account yet, so they suffer a big performance penalty with the current version
  • There's a validation issue that manifests as incoherence on Intel GPUs

I'll look into these problems in the next days.

@MaggotHATE
Copy link
Contributor

MaggotHATE commented Mar 19, 2024

Thanks for this update! Good news: restarting is fixed!
Bad news: RAM consumption is now higher than before, q6_k uses almost as much as q8 used to (tested on 7b).
VRAM use is unaffected, speeds are a bit better.
Win 8.1, 16GB DDR3, 1060 3GB, 7 layers only.

@daniandtheweb
Copy link
Contributor

daniandtheweb commented Mar 20, 2024

With this PR k quants performance on my card have almost reached and ,in q4_k_m, even surpassed ROCm's speed. Ram usage is slightly higher than on master and still higher than ROCm.
Great work!

Model: llama 2

AMD Radeon RX 5700 XT (RADV NAVI10)

model size params backend ngl test t/s speedup
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 pp 512 325.12 ± 0.29
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 tg 128 60.25 ± 0.03
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (master) 99 pp 512 320.24 ± 0.49
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (master) 99 tg 128 37.73 ± 0.08
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (PR) 99 pp 512 322.11 ± 0.43
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (PR) 99 tg 128 38.02 ± 0.09
llama 7B Q4_K - Medium 3.80 GiB 6.74 B ROCm 99 pp 512 282.11 ± 0.28
llama 7B Q4_K - Medium 3.80 GiB 6.74 B ROCm 99 tg 128 45.72 ± 0.05
llama 7B Q4_K - Medium 3.80 GiB 6.74 B Vulkan (master) 99 pp 512 183.45 ± 0.36
llama 7B Q4_K - Medium 3.80 GiB 6.74 B Vulkan (master) 99 tg 128 50.65 ± 0.22
llama 7B Q4_K - Medium 3.80 GiB 6.74 B Vulkan (PR) 99 pp 512 282.72 ± 0.58 1.54
llama 7B Q4_K - Medium 3.80 GiB 6.74 B Vulkan (PR) 99 tg 128 51.34 ± 0.58
llama 7B Q5_K - Medium 4.45 GiB 6.74 B ROCm 99 pp 512 281.48 ± 0.30
llama 7B Q5_K - Medium 4.45 GiB 6.74 B ROCm 99 tg 128 43.92 ± 0.04
llama 7B Q5_K - Medium 4.45 GiB 6.74 B Vulkan (master) 99 pp 512 183.63 ± 1.09
llama 7B Q5_K - Medium 4.45 GiB 6.74 B Vulkan (master) 99 tg 128 37.10 ± 0.23
llama 7B Q5_K - Medium 4.45 GiB 6.74 B Vulkan (PR) 99 pp 512 269.24 ± 0.37 1.47
llama 7B Q5_K - Medium 4.45 GiB 6.74 B Vulkan (PR) 99 tg 128 37.51 ± 0.09

@Nindaleth
Copy link
Contributor

Nindaleth commented Mar 20, 2024

Here are my numbers:
Radeon RX 6700 XT, Ryzen 5700X ECO, model mistral-7b-instruct-v0.2 fully offloaded to the GPU (unless specified otherwise). ROCm 6.0.

Only listing the quants where the pp result changed.
Vulkan0: AMD Radeon RX 6700 XT (RADV NAVI22) | uma: 0 | fp16: 1 | warp size: 64

model size params backend ngl test tk/s
llama 7B Q4_1 4.24 GiB 7.24 B Vulkan (master) 99 pp 512 350.44 ± 0.60
llama 7B Q4_1 4.24 GiB 7.24 B Vulkan (PR) 99 pp 512 415.73 ± 0.39
llama 7B Q5_0 4.65 GiB 7.24 B Vulkan (master) 99 pp 512 348.05 ± 2.50
llama 7B Q5_0 4.65 GiB 7.24 B Vulkan (PR) 99 pp 512 394.29 ± 0.45
llama 7B Q5_1 5.07 GiB 7.24 B Vulkan (master) 99 pp 512 348.34 ± 0.73
llama 7B Q5_1 5.07 GiB 7.24 B Vulkan (PR) 99 pp 512 400.94 ± 0.11
llama 7B Q2_K - Small 2.35 GiB 7.24 B Vulkan (master) 99 pp 512 338.56 ± 0.71
llama 7B Q2_K - Small 2.35 GiB 7.24 B Vulkan (PR) 99 pp 512 385.64 ± 0.33
llama 7B Q2_K - Medium 2.53 GiB 7.24 B Vulkan (master) 99 pp 512 333.81 ± 1.93
llama 7B Q2_K - Medium 2.53 GiB 7.24 B Vulkan (PR) 99 pp 512 368.42 ± 0.60
llama 7B IQ3_XS - 3.3 bpw 2.79 GiB 7.24 B Vulkan (master) 99 pp 512 336.30 ± 2.64
llama 7B IQ3_XS - 3.3 bpw 2.79 GiB 7.24 B Vulkan (PR) 99 pp 512 353.17 ± 0.80
llama 7B Q4_K - Small 3.86 GiB 7.24 B Vulkan (master) 99 pp 512 345.65 ± 0.51
llama 7B Q4_K - Small 3.86 GiB 7.24 B Vulkan (PR) 99 pp 512 356.07 ± 1.00
llama 7B Q4_K - Medium 4.07 GiB 7.24 B Vulkan (master) 99 pp 512 345.98 ± 0.40
llama 7B Q4_K - Medium 4.07 GiB 7.24 B Vulkan (PR) 99 pp 512 358.87 ± 0.45
llama 7B Q5_K - Medium 4.78 GiB 7.24 B Vulkan (master) 99 pp 512 343.78 ± 2.35
llama 7B Q5_K - Medium 4.78 GiB 7.24 B Vulkan (PR) 99 pp 512 350.55 ± 0.95
llama 7B Q6_K 5.53 GiB 7.24 B Vulkan (master) 99 pp 512 327.14 ± 1.96
llama 7B Q6_K 5.53 GiB 7.24 B Vulkan (PR) 99 pp 512 360.30 ± 1.30
llama 7B Q8_0 7.17 GiB 7.24 B Vulkan (master) 99 pp 512 345.40 ± 0.61
llama 7B Q8_0 7.17 GiB 7.24 B Vulkan (PR) 99 pp 512 405.50 ± 0.38

And a couple of tests with -ngl 0:

model size params backend ngl test tk/s
llama 7B Q4_K - Medium 4.07 GiB 7.24 B Vulkan (master) 0 pp 512 108.21 ± 0.75
llama 7B Q4_K - Medium 4.07 GiB 7.24 B Vulkan (PR) 0 pp 512 252.58 ± 2.44
llama 7B Q6_K 5.53 GiB 7.24 B Vulkan (master) 0 pp 512 100.25 ± 0.56
llama 7B Q6_K 5.53 GiB 7.24 B Vulkan (PR) 0 pp 512 231.54 ± 1.65
llama 7B Q8_0 7.17 GiB 7.24 B Vulkan (master) 0 pp 512 94.35 ± 1.39
llama 7B Q8_0 7.17 GiB 7.24 B Vulkan (PR) 0 pp 512 210.63 ± 1.53

Fix validation issue
@0cc4m
Copy link
Collaborator Author

0cc4m commented Mar 23, 2024

@slaren Do you have an idea why test-backend-ops doesn't build with Vulkan support with cmake?

» build_vk/bin/test-backend-ops
Testing 1 backends

Backend 1/1 (CPU)
  Skipping CPU backend
1/1 backends passed
OK

It works fine with cublas.

@slaren
Copy link
Collaborator

slaren commented Mar 23, 2024

Maybe ggml-backend.c is being built without GGML_USE_VULKAN defined, so it never gets added to the registry. Edit: or more likely, because ggml_backend_vk_reg_devices is being called before device_indices is populated.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Mar 23, 2024

Maybe ggml-backend.c is being built without GGML_USE_VULKAN defined, so it never gets added to the registry. Edit: or more likely, because ggml_backend_vk_reg_devices is being called before device_indices is populated.

Yeah, it was not populated yet. Thanks.

@netrunnereve
Copy link
Contributor

netrunnereve commented Mar 24, 2024

Interestingly enough I'm actually seeing 25% slower prompt processing speed with this PR compared to master. This only happens on the K-quants and inference speed remains the same.

Maybe this is an architecture thing but it could also be due to the fact that I'm running in fp32 mode whereas all the other commenters have fp16 cards.

Vulkan0: AMD Radeon FirePro W8100 (RADV HAWAII) | uma: 0 | fp16: 0 | warp size: 64

model size params backend ngl test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (master) 99 pp 512 93.89 ± 0.65
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (PR) 99 pp 512 93.17 ± 0.37
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (master) 99 tg 128 11.95 ± 0.02
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (PR) 99 tg 128 11.92 ± 0.09
llama 7B Q4_K - Medium 3.80 GiB 6.74 B Vulkan (master) 99 pp 512 93.63 ± 0.17
llama 7B Q4_K - Medium 3.80 GiB 6.74 B Vulkan (PR) 99 pp 512 73.58 ± 0.27
llama 7B Q4_K - Medium 3.80 GiB 6.74 B Vulkan (master) 99 tg 128 8.79 ± 0.03
llama 7B Q4_K - Medium 3.80 GiB 6.74 B Vulkan (PR) 99 tg 128 8.79 ± 0.08
llama 7B Q5_K - Medium 4.45 GiB 6.74 B Vulkan (master) 99 pp 512 93.08 ± 0.22
llama 7B Q5_K - Medium 4.45 GiB 6.74 B Vulkan (PR) 99 pp 512 69.47 ± 0.26
llama 7B Q5_K - Medium 4.45 GiB 6.74 B Vulkan (master) 99 tg 128 6.07 ± 0.05
llama 7B Q5_K - Medium 4.45 GiB 6.74 B Vulkan (PR) 99 tg 128 5.97 ± 0.02
llama 7B Q6_K 5.15 GiB 6.74 B Vulkan (master) 99 pp 512 98.62 ± 0.38
llama 7B Q6_K 5.15 GiB 6.74 B Vulkan (PR) 99 pp 512 74.50 ± 0.27
llama 7B Q6_K 5.15 GiB 6.74 B Vulkan (master) 99 tg 128 6.64 ± 0.02
llama 7B Q6_K 5.15 GiB 6.74 B Vulkan (PR) 99 tg 128 6.67 ± 0.07
llama 7B Q8_0 6.67 GiB 6.74 B Vulkan (master) 99 pp 512 100.23 ± 0.32
llama 7B Q8_0 6.67 GiB 6.74 B Vulkan (PR) 99 pp 512 92.47 ± 0.31
llama 7B Q8_0 6.67 GiB 6.74 B Vulkan (master) 99 tg 128 7.69 ± 0.03
llama 7B Q8_0 6.67 GiB 6.74 B Vulkan (PR) 99 tg 128 7.67 ± 0.03

@0cc4m
Copy link
Collaborator Author

0cc4m commented Mar 27, 2024

I fixed GET_ROWS and pulled upstream changes, seems to work. There might be an issue with f16, but that's not that important. I just have to fix UMA, then this PR can be merged.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Mar 27, 2024

Interestingly enough I'm actually seeing 25% slower prompt processing speed with this PR compared to master. This only happens on the K-quants and inference speed remains the same.

Maybe this is an architecture thing but it could also be due to the fact that I'm running in fp32 mode whereas all the other commenters have fp16 cards.

Thanks for testing this. I can confirm it's related to fp16 on AMD (GCN?) GPUs. I can only guess it's related to register pressure, since float32 uses twice the space for its variables. The k-quants need more registers to dequantize, and now that happens in the same shader as the matrix multiplication itself. I'll have to take a look in the future on whether I can mitigate that.

ggml.c Outdated Show resolved Hide resolved
ggml.c Outdated Show resolved Hide resolved
@MaggotHATE
Copy link
Contributor

Since I reported previously on memory consumption, adding console log with the latest changes. At the moment both RAM and VRAM usage is significantly higher than with Clblast (which is missing backend and some operations) - around 400MB more VRAM and 4GB more RAM.

At the same time the main benefit of Vulkan now is higher prompt processing speed and more stable generation speed over time.

VK_log.txt

@0cc4m
Copy link
Collaborator Author

0cc4m commented Mar 28, 2024

Since I reported previously on memory consumption, adding console log with the latest changes. At the moment both RAM and VRAM usage is significantly higher than with Clblast (which is missing backend and some operations) - around 400MB more VRAM and 4GB more RAM.

At the same time the main benefit of Vulkan now is higher prompt processing speed and more stable generation speed over time.

VK_log.txt

Thanks for the report, I don't pay enough attention to RAM use since my development server has 128GB. If Vulkan uses 4GB more RAM for the same number of offloaded layers it's probably some issue with the staging buffers. I'll take a look at some point.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Mar 29, 2024

I checked again and it seems that Vulkan UMA has not regressed. It works fine if you put no tensors or all tensors on GPU, but slows down if you do anything else. I should fix that in the future, but it is not necessary to hold back this PR for that reason, since master is in the same state.

I'll wait for the CI checks and merge this afterwards.

@0cc4m 0cc4m merged commit ba0c7c7 into master Mar 29, 2024
58 of 59 checks passed
@0cc4m 0cc4m deleted the 0cc4m/vulkan-improvements branch March 29, 2024 16:29
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
…6155)

* Fix Vulkan no kv offload incoherence

* Add k-quant mul mat mat shaders

* Rework working buffer allocation, reduces vram use noticeably

Clean up cpu assist code, replaced with ggml-backend offload function

* Default to all dedicated GPUs

* Add fallback for integrated GPUs if no dedicated GPUs are found

* Add debug info which device is allocating memory

* Fix Intel dequant issue

Fix validation issue

* Fix Vulkan GGML_OP_GET_ROWS implementation

* Clean up merge artifacts

* Remove Vulkan warning
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 3, 2024
…6155)

* Fix Vulkan no kv offload incoherence

* Add k-quant mul mat mat shaders

* Rework working buffer allocation, reduces vram use noticeably

Clean up cpu assist code, replaced with ggml-backend offload function

* Default to all dedicated GPUs

* Add fallback for integrated GPUs if no dedicated GPUs are found

* Add debug info which device is allocating memory

* Fix Intel dequant issue

Fix validation issue

* Fix Vulkan GGML_OP_GET_ROWS implementation

* Clean up merge artifacts

* Remove Vulkan warning
tybalex pushed a commit to tybalex/function.cpp that referenced this pull request Apr 17, 2024
…6155)

* Fix Vulkan no kv offload incoherence

* Add k-quant mul mat mat shaders

* Rework working buffer allocation, reduces vram use noticeably

Clean up cpu assist code, replaced with ggml-backend offload function

* Default to all dedicated GPUs

* Add fallback for integrated GPUs if no dedicated GPUs are found

* Add debug info which device is allocating memory

* Fix Intel dequant issue

Fix validation issue

* Fix Vulkan GGML_OP_GET_ROWS implementation

* Clean up merge artifacts

* Remove Vulkan warning
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi GPU with Vulkan out of memory issue.
6 participants