|
25 | 25 | import platform
|
26 | 26 | import weakref
|
27 | 27 | import gc
|
| 28 | +from pathlib import Path |
| 29 | +from functools import lru_cache |
| 30 | + |
28 | 31 |
|
29 | 32 | class VRAMState(Enum):
|
30 | 33 | DISABLED = 0 #No vram present: no need to move models to vram
|
@@ -177,7 +180,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
177 | 180 | dev = get_torch_device()
|
178 | 181 |
|
179 | 182 | if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
180 |
| - mem_total = psutil.virtual_memory().total |
| 183 | + mem_total = _cgroup_limit_bytes() or psutil.virtual_memory().total |
181 | 184 | mem_total_torch = mem_total
|
182 | 185 | else:
|
183 | 186 | if directml_enabled:
|
@@ -218,6 +221,35 @@ def mac_version():
|
218 | 221 | except:
|
219 | 222 | return None
|
220 | 223 |
|
| 224 | + |
| 225 | +_CG = Path("/sys/fs/cgroup") |
| 226 | +if (_CG / "memory.max").exists(): # cgroup v2 |
| 227 | + _LIMIT_F = _CG / "memory.max" |
| 228 | + _USED_F = _CG / "memory.current" |
| 229 | +else: # cgroup v1 |
| 230 | + _LIMIT_F = _CG / "memory/memory.limit_in_bytes" |
| 231 | + _USED_F = _CG / "memory/memory.usage_in_bytes" |
| 232 | + |
| 233 | + |
| 234 | +@lru_cache(maxsize=None) # the hard limit never changes |
| 235 | +def _cgroup_limit_bytes(): |
| 236 | + return _read_int(_LIMIT_F) |
| 237 | + |
| 238 | + |
| 239 | +def _cgroup_used_bytes(): |
| 240 | + return _read_int(_USED_F) |
| 241 | + |
| 242 | + |
| 243 | +def _read_int(p: Path): |
| 244 | + try: |
| 245 | + v = int(p.read_text().strip()) |
| 246 | + if v == 0 or v >= (1 << 60): |
| 247 | + return None # 'max' in v2 shows up as 2**63-1 or 0, treat both as unlimited |
| 248 | + return v |
| 249 | + except (FileNotFoundError, PermissionError, ValueError): |
| 250 | + return None |
| 251 | + |
| 252 | + |
221 | 253 | total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
222 | 254 | total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
223 | 255 | logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
@@ -1081,7 +1113,15 @@ def get_free_memory(dev=None, torch_free_too=False):
|
1081 | 1113 | if dev is None:
|
1082 | 1114 | dev = get_torch_device()
|
1083 | 1115 |
|
1084 |
| - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): |
| 1116 | + if hasattr(dev, 'type') and dev.type == 'cpu': |
| 1117 | + limit = _cgroup_limit_bytes() |
| 1118 | + used = _cgroup_used_bytes() if limit is not None else None |
| 1119 | + if limit is not None and used is not None: |
| 1120 | + mem_free_total = max(limit - used, 0) |
| 1121 | + else: |
| 1122 | + mem_free_total = psutil.virtual_memory().available |
| 1123 | + mem_free_torch = mem_free_total |
| 1124 | + elif hasattr(dev, 'type') and dev.type == 'mps': |
1085 | 1125 | mem_free_total = psutil.virtual_memory().available
|
1086 | 1126 | mem_free_torch = mem_free_total
|
1087 | 1127 | else:
|
|
0 commit comments