[SHELLBTRFS] Addendum to 1725ddf
[reactos.git] / dll / shellext / shellbtrfs / mountmgr_local.cpp
1 #include "shellext.h"
2 #ifndef __REACTOS__
3 #include "mountmgr.h"
4 #else
5 #include "mountmgr_local.h"
6 #endif
7 #include <mountmgr.h>
8
9 using namespace std;
10
11 mountmgr::mountmgr() {
12 UNICODE_STRING us;
13 OBJECT_ATTRIBUTES attr;
14 IO_STATUS_BLOCK iosb;
15 NTSTATUS Status;
16
17 RtlInitUnicodeString(&us, MOUNTMGR_DEVICE_NAME);
18 InitializeObjectAttributes(&attr, &us, 0, nullptr, nullptr);
19
20 Status = NtOpenFile(&h, FILE_GENERIC_READ | FILE_GENERIC_WRITE, &attr, &iosb,
21 FILE_SHARE_READ, FILE_SYNCHRONOUS_IO_ALERT);
22
23 if (!NT_SUCCESS(Status))
24 throw ntstatus_error(Status);
25 }
26
27 mountmgr::~mountmgr() {
28 NtClose(h);
29 }
30
31 void mountmgr::create_point(const wstring_view& symlink, const wstring_view& device) const {
32 NTSTATUS Status;
33 IO_STATUS_BLOCK iosb;
34
35 vector<uint8_t> buf(sizeof(MOUNTMGR_CREATE_POINT_INPUT) + ((symlink.length() + device.length()) * sizeof(WCHAR)));
36 #ifndef __REACTOS__
37 auto mcpi = reinterpret_cast<MOUNTMGR_CREATE_POINT_INPUT*>(buf.data());
38 #else
39 auto mcpi = reinterpret_cast<MOUNTMGR_CREATE_POINT_INPUT*>(&buf[0]);
40 #endif
41
42 mcpi->SymbolicLinkNameOffset = sizeof(MOUNTMGR_CREATE_POINT_INPUT);
43 mcpi->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
44 mcpi->DeviceNameOffset = (USHORT)(mcpi->SymbolicLinkNameOffset + mcpi->SymbolicLinkNameLength);
45 mcpi->DeviceNameLength = (USHORT)(device.length() * sizeof(WCHAR));
46
47 memcpy((uint8_t*)mcpi + mcpi->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
48 memcpy((uint8_t*)mcpi + mcpi->DeviceNameOffset, device.data(), device.length() * sizeof(WCHAR));
49
50 Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_CREATE_POINT,
51 #ifndef __REACTOS__
52 buf.data(), (ULONG)buf.size(), nullptr, 0);
53 #else
54 &buf[0], (ULONG)buf.size(), nullptr, 0);
55 #endif
56
57 if (!NT_SUCCESS(Status))
58 throw ntstatus_error(Status);
59 }
60
61 void mountmgr::delete_points(const wstring_view& symlink, const wstring_view& unique_id, const wstring_view& device_name) const {
62 NTSTATUS Status;
63 IO_STATUS_BLOCK iosb;
64
65 vector<uint8_t> buf(sizeof(MOUNTMGR_MOUNT_POINT) + ((symlink.length() + unique_id.length() + device_name.length()) * sizeof(WCHAR)));
66 #ifndef __REACTOS__
67 auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(buf.data());
68 #else
69 auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(&buf[0]);
70 #endif
71
72 memset(mmp, 0, sizeof(MOUNTMGR_MOUNT_POINT));
73
74 if (symlink.length() > 0) {
75 mmp->SymbolicLinkNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
76 mmp->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
77 memcpy((uint8_t*)mmp + mmp->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
78 }
79
80 if (unique_id.length() > 0) {
81 if (mmp->SymbolicLinkNameLength == 0)
82 mmp->UniqueIdOffset = sizeof(MOUNTMGR_MOUNT_POINT);
83 else
84 mmp->UniqueIdOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
85
86 mmp->UniqueIdLength = (USHORT)(unique_id.length() * sizeof(WCHAR));
87 memcpy((uint8_t*)mmp + mmp->UniqueIdOffset, unique_id.data(), unique_id.length() * sizeof(WCHAR));
88 }
89
90 if (device_name.length() > 0) {
91 if (mmp->SymbolicLinkNameLength == 0 && mmp->UniqueIdOffset == 0)
92 mmp->DeviceNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
93 else if (mmp->SymbolicLinkNameLength != 0)
94 mmp->DeviceNameOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
95 else
96 mmp->DeviceNameOffset = mmp->UniqueIdOffset + mmp->UniqueIdLength;
97
98 mmp->DeviceNameLength = (USHORT)(device_name.length() * sizeof(WCHAR));
99 memcpy((uint8_t*)mmp + mmp->DeviceNameOffset, device_name.data(), device_name.length() * sizeof(WCHAR));
100 }
101
102 vector<uint8_t> buf2(sizeof(MOUNTMGR_MOUNT_POINTS));
103 #ifndef __REACTOS__
104 auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
105 #else
106 auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(&buf2[0]);
107 #endif
108
109 Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_DELETE_POINTS,
110 #ifndef __REACTOS__
111 buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
112 #else
113 &buf[0], (ULONG)buf.size(), &buf2[0], (ULONG)buf2.size());
114 #endif
115
116 if (Status == STATUS_BUFFER_OVERFLOW) {
117 buf2.resize(mmps->Size);
118
119 Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_DELETE_POINTS,
120 #ifndef __REACTOS__
121 buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
122 #else
123 &buf[0], (ULONG)buf.size(), &buf2[0], (ULONG)buf2.size());
124 #endif
125 }
126
127 if (!NT_SUCCESS(Status))
128 throw ntstatus_error(Status);
129 }
130
131 vector<mountmgr_point> mountmgr::query_points(const wstring_view& symlink, const wstring_view& unique_id, const wstring_view& device_name) const {
132 NTSTATUS Status;
133 IO_STATUS_BLOCK iosb;
134 vector<mountmgr_point> v;
135
136 vector<uint8_t> buf(sizeof(MOUNTMGR_MOUNT_POINT) + ((symlink.length() + unique_id.length() + device_name.length()) * sizeof(WCHAR)));
137 #ifndef __REACTOS__
138 auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(buf.data());
139 #else
140 auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(&buf[0]);
141 #endif
142
143 memset(mmp, 0, sizeof(MOUNTMGR_MOUNT_POINT));
144
145 if (symlink.length() > 0) {
146 mmp->SymbolicLinkNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
147 mmp->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
148 memcpy((uint8_t*)mmp + mmp->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
149 }
150
151 if (unique_id.length() > 0) {
152 if (mmp->SymbolicLinkNameLength == 0)
153 mmp->UniqueIdOffset = sizeof(MOUNTMGR_MOUNT_POINT);
154 else
155 mmp->UniqueIdOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
156
157 mmp->UniqueIdLength = (USHORT)(unique_id.length() * sizeof(WCHAR));
158 memcpy((uint8_t*)mmp + mmp->UniqueIdOffset, unique_id.data(), unique_id.length() * sizeof(WCHAR));
159 }
160
161 if (device_name.length() > 0) {
162 if (mmp->SymbolicLinkNameLength == 0 && mmp->UniqueIdOffset == 0)
163 mmp->DeviceNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
164 else if (mmp->SymbolicLinkNameLength != 0)
165 mmp->DeviceNameOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
166 else
167 mmp->DeviceNameOffset = mmp->UniqueIdOffset + mmp->UniqueIdLength;
168
169 mmp->DeviceNameLength = (USHORT)(device_name.length() * sizeof(WCHAR));
170 memcpy((uint8_t*)mmp + mmp->DeviceNameOffset, device_name.data(), device_name.length() * sizeof(WCHAR));
171 }
172
173 vector<uint8_t> buf2(sizeof(MOUNTMGR_MOUNT_POINTS));
174 #ifndef __REACTOS__
175 auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
176 #else
177 auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(&buf2[0]);
178 #endif
179
180 Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_QUERY_POINTS,
181 #ifndef __REACTOS__
182 buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
183 #else
184 &buf[0], (ULONG)buf.size(), &buf2[0], (ULONG)buf2.size());
185 #endif
186
187 if (!NT_SUCCESS(Status) && Status != STATUS_BUFFER_OVERFLOW)
188 throw ntstatus_error(Status);
189
190 buf2.resize(mmps->Size);
191 #ifndef __REACTOS__
192 mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
193 #else
194 mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(&buf2[0]);
195 #endif
196
197 Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_QUERY_POINTS,
198 #ifndef __REACTOS__
199 buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
200 #else
201 &buf[0], (ULONG)buf.size(), &buf2[0], (ULONG)buf2.size());
202 #endif
203
204 if (!NT_SUCCESS(Status))
205 throw ntstatus_error(Status);
206
207 for (ULONG i = 0; i < mmps->NumberOfMountPoints; i++) {
208 wstring_view mpsl, mpdn;
209 string_view mpuid;
210
211 if (mmps->MountPoints[i].SymbolicLinkNameLength)
212 mpsl = wstring_view((WCHAR*)((uint8_t*)mmps + mmps->MountPoints[i].SymbolicLinkNameOffset), mmps->MountPoints[i].SymbolicLinkNameLength / sizeof(WCHAR));
213
214 if (mmps->MountPoints[i].UniqueIdLength)
215 mpuid = string_view((char*)((uint8_t*)mmps + mmps->MountPoints[i].UniqueIdOffset), mmps->MountPoints[i].UniqueIdLength);
216
217 if (mmps->MountPoints[i].DeviceNameLength)
218 mpdn = wstring_view((WCHAR*)((uint8_t*)mmps + mmps->MountPoints[i].DeviceNameOffset), mmps->MountPoints[i].DeviceNameLength / sizeof(WCHAR));
219
220 v.emplace_back(mpsl, mpuid, mpdn);
221 }
222
223 return v;
224 }