Skip to content

Commit

Permalink
Introduce ModelConfig and ModelVisionConfig to hold relevant parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Animesh Bohara committed May 2, 2024
1 parent 7c01023 commit 83a5243
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 48 deletions.
137 changes: 99 additions & 38 deletions cpp/json_ffi/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,99 @@ namespace json_ffi {

using namespace mlc::llm;

/****************** Model vision config ******************/

ModelVisionConfig ModelVisionConfig::FromJSON(const picojson::object& json_obj, std::string* err) {
ModelVisionConfig config;

int64_t hidden_size;
if (json::ParseJSONField(json_obj, "hidden_size", hidden_size, err, false)) {
config.hidden_size = hidden_size;
}
int64_t image_size;
if (json::ParseJSONField(json_obj, "image_size", image_size, err, false)) {
config.image_size = image_size;
}
int64_t intermediate_size;
if (json::ParseJSONField(json_obj, "intermediate_size", intermediate_size, err, false)) {
config.intermediate_size = intermediate_size;
}
int64_t num_attention_heads;
if (json::ParseJSONField(json_obj, "num_attention_heads", num_attention_heads, err, false)) {
config.num_attention_heads = num_attention_heads;
}
int64_t num_hidden_layers;
if (json::ParseJSONField(json_obj, "num_hidden_layers", num_hidden_layers, err, false)) {
config.num_hidden_layers = num_hidden_layers;
}
int64_t patch_size;
if (json::ParseJSONField(json_obj, "patch_size", patch_size, err, false)) {
config.patch_size = patch_size;
}
int64_t projection_dim;
if (json::ParseJSONField(json_obj, "projection_dim", projection_dim, err, false)) {
config.projection_dim = projection_dim;
}
int64_t vocab_size;
if (json::ParseJSONField(json_obj, "vocab_size", vocab_size, err, false)) {
config.vocab_size = vocab_size;
}
std::string dtype;
if (json::ParseJSONField(json_obj, "dtype", dtype, err, false)) {
config.dtype = dtype;
}
int64_t num_channels;
if (json::ParseJSONField(json_obj, "num_channels", num_channels, err, false)) {
config.num_channels = num_channels;
}
double layer_norm_eps;
if (json::ParseJSONField(json_obj, "layer_norm_eps", layer_norm_eps, err, false)) {
config.layer_norm_eps = layer_norm_eps;
}

return config;
}

/****************** Model config ******************/

ModelConfig ModelConfig::FromJSON(const picojson::object& json_obj, std::string* err) {
ModelConfig config;

int64_t vocab_size;
if (json::ParseJSONField(json_obj, "vocab_size", vocab_size, err, false)) {
config.vocab_size = vocab_size;
}
int64_t context_window_size;
if (json::ParseJSONField(json_obj, "context_window_size", context_window_size, err, false)) {
config.context_window_size = context_window_size;
}
int64_t sliding_window_size;
if (json::ParseJSONField(json_obj, "sliding_window_size", sliding_window_size, err, false)) {
config.sliding_window_size = sliding_window_size;
}
int64_t prefill_chunk_size;
if (json::ParseJSONField(json_obj, "prefill_chunk_size", prefill_chunk_size, err, false)) {
config.prefill_chunk_size = prefill_chunk_size;
}
int64_t tensor_parallel_shards;
if (json::ParseJSONField(json_obj, "tensor_parallel_shards", tensor_parallel_shards, err,
false)) {
config.tensor_parallel_shards = tensor_parallel_shards;
}
int64_t max_batch_size;
if (json::ParseJSONField(json_obj, "max_batch_size", max_batch_size, err, false)) {
config.max_batch_size = max_batch_size;
}

if (json_obj.count("vision_config")) {
const picojson::object& vision_config_obj =
json_obj.at("vision_config").get<picojson::object>();
config.vision_config = ModelVisionConfig::FromJSON(vision_config_obj, err);
}

return config;
}

/****************** Model-defined generation config ******************/

TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode);
Expand Down Expand Up @@ -63,7 +156,7 @@ std::vector<std::string> Conversation::CheckMessageSeps(std::vector<std::string>
return seps;
}

std::optional<std::vector<Data>> Conversation::AsPrompt(picojson::object config, DLDevice device,
std::optional<std::vector<Data>> Conversation::AsPrompt(ModelConfig config, DLDevice device,
std::string* err) {
// Get the system message
std::string system_msg = system_template;
Expand Down Expand Up @@ -155,47 +248,15 @@ std::optional<std::vector<Data>> Conversation::AsPrompt(picojson::object config,
// we are just assuming this as the URL for now
std::string base64_image = image_url.substr(image_url.find(",") + 1);
std::optional<NDArray> image_data = LoadImageFromBase64(base64_image, err);
if (!image_data) {
return std::nullopt;
}

if (config.find("model_config") == config.end()) {
*err += "model_config is required in config";
return std::nullopt;
}
if (!config["model_config"].is<picojson::object>()) {
*err += "model_config should be an object";
if (!image_data.has_value()) {
return std::nullopt;
}
picojson::object model_config = config["model_config"].get<picojson::object>();
if (model_config.find("vision_config") == model_config.end()) {
*err += "vision_config is required in model_config";
if (!config.vision_config.has_value()) {
*err += "Vision config is required for image input";
return std::nullopt;
}
if (!model_config["vision_config"].is<picojson::object>()) {
*err += "vision_config should be an object";
return std::nullopt;
}
picojson::object vision_config = model_config["vision_config"].get<picojson::object>();
if (vision_config.find("image_size") == vision_config.end()) {
*err += "image_size is required in vision_config";
return std::nullopt;
}
if (!vision_config["image_size"].is<int64_t>()) {
*err += "image_size should be an integer";
return std::nullopt;
}
if (vision_config.find("patch_size") == vision_config.end()) {
*err += "patch_size is required in vision_config";
return std::nullopt;
}
if (!vision_config["patch_size"].is<int64_t>()) {
*err += "patch_size should be an integer";
return std::nullopt;
}

int image_size = vision_config["image_size"].get<int64_t>();
int patch_size = vision_config["patch_size"].get<int64_t>();
int image_size = config.vision_config.value().image_size;
int patch_size = config.vision_config.value().patch_size;

int embed_size = (image_size * image_size) / (patch_size * patch_size);

Expand Down
40 changes: 38 additions & 2 deletions cpp/json_ffi/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,43 @@ namespace mlc {
namespace llm {
namespace json_ffi {

/****************** Model vision config ******************/

/*! \brief Defines the Vision config of the model (if present) */
class ModelVisionConfig {
public:
int hidden_size;
int image_size;
int intermediate_size;
int num_attention_heads;
int num_hidden_layers;
int patch_size;
int projection_dim;
int vocab_size;
std::string dtype;
int num_channels;
double layer_norm_eps;

static ModelVisionConfig FromJSON(const picojson::object& json_obj, std::string* err);
};

/****************** Model config ******************/

/*! \brief Defines the config of the model.
Populated from "model_config" field in mlc-chat-config.json */
class ModelConfig {
public:
int vocab_size;
int context_window_size;
int sliding_window_size;
int prefill_chunk_size;
int tensor_parallel_shards;
int max_batch_size;
std::optional<ModelVisionConfig> vision_config = std::nullopt;

static ModelConfig FromJSON(const picojson::object& json_obj, std::string* err);
};

/****************** Model-defined generation config ******************/

class ModelDefinedGenerationConfigNode : public Object {
Expand Down Expand Up @@ -129,8 +166,7 @@ struct Conversation {
* \brief Create the list of prompts from the messages based on the conversation template.
* When creation fails, errors are dumped to the input error string, and nullopt is returned.
*/
std::optional<std::vector<Data>> AsPrompt(picojson::object config, DLDevice device,
std::string* err);
std::optional<std::vector<Data>> AsPrompt(ModelConfig config, DLDevice device, std::string* err);

/*!
* \brief Create a Conversation instance from the given JSON object.
Expand Down
2 changes: 1 addition & 1 deletion cpp/json_ffi/image_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ size_t Base64DecodedSize(const std::string& base64_str) {
return 3 * len / 4 - padding;
}

std::optional<NDArray> LoadImageFromBase64(std::string base64_str, std::string* err) {
std::optional<NDArray> LoadImageFromBase64(const std::string& base64_str, std::string* err) {
MemoryBufferStream stream(base64_str.c_str(), base64_str.size());
tvm::support::Base64InStream base64_stream(&stream);
size_t decoded_size = Base64DecodedSize(base64_str);
Expand Down
10 changes: 6 additions & 4 deletions cpp/json_ffi/image_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ namespace mlc {
namespace llm {
namespace json_ffi {

using namespace tvm::runtime;
/*! \brief Load a base64 encoded image string into a CPU NDArray of shape {height, width, 3} */
std::optional<tvm::runtime::NDArray> LoadImageFromBase64(const std::string& base64_str,
std::string* err);

std::optional<NDArray> LoadImageFromBase64(std::string base64_str, std::string* err);

NDArray ClipPreprocessor(NDArray image_data, int target_size, DLDevice device, std::string* err);
/*! \brief Preprocess the CPU image for CLIP encoder and return an NDArray on the given device */
tvm::runtime::NDArray ClipPreprocessor(tvm::runtime::NDArray image_data, int target_size,
DLDevice device, std::string* err);

} // namespace json_ffi
} // namespace llm
Expand Down
8 changes: 6 additions & 2 deletions cpp/json_ffi/json_ffi_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
}
std::string model_config_str = std::string((std::istreambuf_iterator<char>(model_config_file)),
std::istreambuf_iterator<char>());
this->model_config_ = json::LoadJSONFromString(model_config_str, &err_).value();
picojson::object model_config_obj = json::LoadJSONFromString(model_config_str, &err_).value();
this->model_config_ =
ModelConfig::FromJSON(model_config_obj.at("model_config").get<picojson::object>(), &err_);

this->device_ = std::move(device);

Expand Down Expand Up @@ -183,7 +185,9 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
}
std::string model_config_str = std::string((std::istreambuf_iterator<char>(model_config_file)),
std::istreambuf_iterator<char>());
this->model_config_ = json::LoadJSONFromString(model_config_str, &err_).value();
picojson::object model_config_obj = json::LoadJSONFromString(model_config_str, &err_).value();
this->model_config_ =
ModelConfig::FromJSON(model_config_obj.at("model_config").get<picojson::object>(), &err_);
}

void Unload() { this->engine_->Unload(); }
Expand Down
2 changes: 1 addition & 1 deletion cpp/json_ffi/json_ffi_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class JSONFFIEngine {
TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request
Conversation conv_template_;
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs;
picojson::object model_config_;
ModelConfig model_config_;
DLDevice device_;
};

Expand Down

0 comments on commit 83a5243

Please sign in to comment.